# Copyright (c) 2019 Red Hat, Inc. # # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. from __future__ import absolute_import import io import time from oslo_log import log import tobiko from tobiko.shell.sh import _command from tobiko.shell.sh import _exception from tobiko.shell.sh import _io LOG = log.getLogger(__name__) MAX_TIMEOUT = 3600. # 1 hour def process(command=None, environment=None, timeout=None, shell=None, stdin=None, stdout=None, stderr=None, ssh_client=None, sudo=None, **kwargs): kwargs.update(command=command, environment=environment, timeout=timeout, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, sudo=sudo) timeout = kwargs['timeout'] if timeout is not None: if timeout < 0.: raise ValueError("Invalid timeout for executing process: " "{!r}".format(timeout)) try: from tobiko.shell.sh import _ssh from tobiko.shell import ssh except ImportError: if ssh_client: raise else: if ssh_client is None: ssh_client = ssh.ssh_proxy_client() if ssh_client: return _ssh.ssh_process(ssh_client=ssh_client, **kwargs) from tobiko.shell.sh import _local return _local.local_process(**kwargs) class Parameters(object): def __init__(self, **kwargs): cls = type(self) for name, value in kwargs.items(): if value is not None: if not hasattr(cls, name): raise ValueError('Invalid parameter: {!s}'.format(name)) setattr(self, name, value) class ShellProcessParameters(Parameters): command = None environment = None current_dir = None timeout = None shell = None stdin = False stdout = True stderr = True buffer_size = io.DEFAULT_BUFFER_SIZE poll_interval = 1. sudo = None network_namespace = None class ShellProcessFixture(tobiko.SharedFixture): parameters = None command = None timeout = None process = None stdin = None stdout = None stderr = None _exit_status = None def __init__(self, **kwargs): super(ShellProcessFixture, self).__init__() self.parameters = self.init_parameters(**kwargs) def init_parameters(self, **kwargs): return ShellProcessParameters(**kwargs) def execute(self): return tobiko.setup_fixture(self) def setup_fixture(self): parameters = self.parameters self.setup_command() if parameters.timeout: self.setup_timeout() self.setup_process() if parameters.stdin: self.setup_stdin() if parameters.stdout: self.setup_stdout() if parameters.stderr: self.setup_stderr() def setup_command(self): command = _command.shell_command(self.parameters.command) network_namespace = self.parameters.network_namespace sudo = self.parameters.sudo shell = self.parameters.shell if shell: if shell is True: shell = default_shell_command() else: shell = _command.shell_command(shell) command = shell + [str(command)] else: command = _command.shell_command(command) if network_namespace: if sudo is None: sudo = True command = network_namespace_command(network_namespace, command) if sudo: if sudo is True: sudo = default_sudo_command() else: sudo = _command.shell_command(sudo) command = sudo + command self.command = command def setup_timeout(self): self.timeout = shell_process_timeout(self.parameters.timeout) def setup_process(self): if self._exit_status: del self._exit_status self.process = self.create_process() self.addCleanup(self.close) def setup_stdin(self): raise NotImplementedError def setup_stdout(self): raise NotImplementedError def setup_stderr(self): raise NotImplementedError def create_process(self): raise NotImplementedError def close_stdin(self): stdin = self.stdin if stdin is not None: try: stdin.closed or stdin.close() except Exception: LOG.exception("Error closing STDIN stream: %r", self.stdin) def close_stdout(self): stdout = self.stdout if stdout is not None: try: stdout.closed or stdout.close() except Exception: LOG.exception("Error closing STDOUT stream: %r", self.stdout) def close_stderr(self): stderr = self.stderr if stderr is not None: try: stderr.closed or stderr.close() except Exception: LOG.exception("Error closing STDERR stream: %r", self.stderr) def close(self, timeout=None): self.close_stdin() try: # Drain all incoming data from STDOUT and STDERR self.wait(timeout=timeout) finally: self._terminate() def _terminate(self): self.close_stdout() self.close_stderr() exit_status = None try: exit_status = self.get_exit_status() finally: if exit_status is None: try: self.kill() except Exception: LOG.exception('Error killing process: %r', self.command) def __getattr__(self, name): try: # Get attributes from parameters class return getattr(self.parameters, name) except AttributeError: message = "object {!r} has not attribute {!r}".format(self, name) raise AttributeError(message) def kill(self): raise NotImplementedError def poll_exit_status(self): raise NotImplementedError def get_exit_status(self, timeout=None): time_left, timeout = get_time_left([self.timeout, timeout]) if time_left > 0.: exit_status = self._get_exit_status(time_left=time_left) if exit_status is not None: return exit_status ex = _exception.ShellTimeoutExpired( command=str(self.command), timeout=timeout and timeout.timeout or None, stdin=str_from_stream(self.stdin), stdout=str_from_stream(self.stdout), stderr=str_from_stream(self.stderr)) LOG.debug("Timed out while waiting for command termination:\n%s", self.command) raise ex def _get_exit_status(self, time_left): raise NotImplementedError @property def exit_status(self): exit_status = self._exit_status if exit_status is None: exit_status = self.poll_exit_status() if exit_status is not None: self._exit_status = exit_status return exit_status @property def is_running(self): return self.exit_status is None def check_is_running(self): exit_status = self.exit_status if exit_status is not None: raise _exception.ShellProcessTeriminated( command=str(self.command), exit_status=int(exit_status), stdin=str_from_stream(self.stdin), stdout=str_from_stream(self.stdout), stderr=str_from_stream(self.stderr)) def check_stdin_is_opened(self): if self.stdin.closed: raise _exception.ShellStdinClosed( command=str(self.command), stdin=str_from_stream(self.stdin), stdout=str_from_stream(self.stdout), stderr=str_from_stream(self.stderr)) def send_all(self, data, **kwargs): self.communicate(stdin=data, **kwargs) self.stdin.flush() def receive_all(self, **kwargs): self.communicate(receive_all=True, **kwargs) def wait(self, timeout=None, receive_all=True, **kwargs): self.communicate(timeout=timeout, receive_all=receive_all, **kwargs) def communicate(self, stdin=None, stdout=True, stderr=True, timeout=None, receive_all=False, buffer_size=None): timeout = shell_process_timeout(timeout=timeout) # Avoid waiting for data in the first loop poll_interval = 0. streams = _io.select_opened_files([stdin and self.stdin, stdout and self.stdout, stderr and self.stderr]) while self._is_communicating(streams=streams, send=stdin, receive=receive_all): # Remove closed streams streams = _io.select_opened_files(streams) # Select ready streams read_ready, write_ready = _io.select_files( files=streams, timeout=poll_interval) if read_ready or write_ready: # Avoid waiting for data the next time poll_interval = 0. if self.stdin in write_ready: # Write data to remote STDIN stdin = self._write_to_stdin(stdin) if not stdin: streams.remove(self.stdin) if self.stdout in read_ready: # Read data from remote STDOUT stdout = self._read_from_stdout(buffer_size=buffer_size) if not stdout: streams.remove(self.stdout) if self.stderr in read_ready: # Read data from remote STDERR stderr = self._read_from_stderr(buffer_size=buffer_size) if not stderr: streams.remove(self.stderr) else: # Wait for data in the following loops poll_interval = min(self.poll_interval, self.check_timeout(timeout=timeout)) LOG.debug('Waiting for process (%s): %r', self.command, streams) def _is_communicating(self, streams, send, receive): if send and self.stdin in streams: return True elif receive and {self.stdout, self.stderr} & streams: return True else: return False def _write_to_stdin(self, data, check=True): """Write data to STDIN""" if check: self.check_stdin_is_opened() sent_bytes = self.stdin.write(data) if sent_bytes: return data[sent_bytes:] or None else: LOG.debug("%r closed by peer on %r", self.stdin, self) self.stdin.close() return data def _read_from_stdout(self, buffer_size=None): """Read data from remote stream""" # Read data from remote stream chunk = self.stdout.read(buffer_size) if chunk: return chunk else: LOG.debug("%r closed by peer on %r", self.stdout, self) self.stdout.close() return None def _read_from_stderr(self, buffer_size=None): """Read data from remote stream""" # Read data from remote stream chunk = self.stderr.read(buffer_size) if chunk: return chunk else: LOG.debug("%r closed by peer on %r", self.stderr, self) self.stderr.close() return None def check_timeout(self, timeout=None, now=None): time_left, timeout = get_time_left([self.timeout, timeout], now=now) if time_left <= 0.: ex = _exception.ShellTimeoutExpired( command=str(self.command), timeout=timeout and timeout.timeout or None, stdin=str_from_stream(self.stdin), stdout=str_from_stream(self.stdout), stderr=str_from_stream(self.stderr)) raise ex return time_left def check_exit_status(self, expected_status=0): exit_status = self.poll_exit_status() if exit_status is None: time_left = self.check_timeout() ex = _exception.ShellProcessNotTeriminated( command=str(self.command), time_left=time_left, stdin=self.stdin, stdout=self.stdout, stderr=self.stderr) raise ex exit_status = int(exit_status) if expected_status != exit_status: ex = _exception.ShellCommandFailed( command=str(self.command), exit_status=exit_status, stdin=str_from_stream(self.stdin), stdout=str_from_stream(self.stdout), stderr=str_from_stream(self.stderr)) raise ex def merge_dictionaries(*dictionaries): merged = {} for d in dictionaries: if d: merged.update(d) return merged def shell_process_timeout(timeout): if isinstance(timeout, ShellProcessTimeout): return timeout else: return ShellProcessTimeout(timeout=timeout) def get_time_left(timeouts, now=None): now = now or time.time() min_time_left = float(MAX_TIMEOUT) min_timeout = None for timeout in timeouts: if timeout is not None: timeout = shell_process_timeout(timeout=timeout) time_left = timeout.time_left(now=now) if time_left < min_time_left: min_time_left = time_left min_timeout = timeout return min_time_left, min_timeout class ShellProcessTimeout(object): timeout = MAX_TIMEOUT def __init__(self, timeout=None, start_time=None): if timeout is None: timeout = self.timeout else: self.timeout = float(timeout) start_time = start_time and float(start_time) or time.time() self.start_time = start_time self.end_time = start_time + timeout def __float__(self): return self.timeout def time_left(self, now=None): now = now or time.time() return self.end_time - now def is_expired(self, now=None): raise self.time_left(now=now) <= 0. def str_from_stream(stream): if stream is not None: return str(stream) else: return None def default_shell_command(): from tobiko import config CONF = config.CONF return _command.shell_command(CONF.tobiko.shell.sudo) def default_sudo_command(): from tobiko import config CONF = config.CONF return _command.shell_command(CONF.tobiko.shell.sudo) def network_namespace_command(network_namespace, command): return _command.shell_command(['/sbin/ip', 'netns', 'exec', network_namespace]) + command