Implement async shell result handling

Change-Id: I946839c0414a7facbdad0ad948d948adc849206f
This commit is contained in:
Federico Ressi 2019-05-27 09:36:06 +02:00
parent 8b3df80590
commit 48ca748bb4
11 changed files with 902 additions and 281 deletions

View File

@ -179,26 +179,29 @@ def iter_statistics(parameters=None, ssh_client=None, until=None, check=True,
ssh_client=ssh_client, ssh_client=ssh_client,
timeout=end_of_time - now, timeout=end_of_time - now,
check=check) check=check)
except sh.ShellTimeoutExpired: except sh.ShellError as ex:
pass LOG.exception("Error executing ping command")
output = str(ex.stdout)
else: else:
if result.exit_status is not None and result.stdout: output = str(result.stdout)
statistics = _statistics.parse_ping_statistics(
output=result.stdout, begin_interval=now,
end_interval=time.time())
yield statistics if output:
statistics = _statistics.parse_ping_statistics(
output=output, begin_interval=now,
end_interval=time.time())
transmitted += statistics.transmitted yield statistics
received += statistics.received
undelivered += statistics.undelivered transmitted += statistics.transmitted
count = {None: 0, received += statistics.received
TRANSMITTED: transmitted, undelivered += statistics.undelivered
DELIVERED: transmitted - undelivered, count = {None: 0,
UNDELIVERED: undelivered, TRANSMITTED: transmitted,
RECEIVED: received, DELIVERED: transmitted - undelivered,
UNRECEIVED: transmitted - received}[until] UNDELIVERED: undelivered,
RECEIVED: received,
UNRECEIVED: transmitted - received}[until]
now = time.time() now = time.time()
deadline = min(int(end_of_time - now), parameters.deadline) deadline = min(int(end_of_time - now), parameters.deadline)
@ -219,10 +222,10 @@ def execute_ping(parameters, ssh_client=None, check=True, **params):
command = get_ping_command(parameters) command = get_ping_command(parameters)
result = sh.execute(command=command, ssh_client=ssh_client, result = sh.execute(command=command, ssh_client=ssh_client,
timeout=parameters.timeout, check=False) timeout=parameters.timeout, check=False, wait=True)
if check and result.exit_status and result.stderr: if check and result.exit_status and result.stderr:
handle_ping_command_error(error=result.stderr) handle_ping_command_error(error=str(result.stderr))
return result return result

View File

@ -15,15 +15,20 @@
# under the License. # under the License.
from __future__ import absolute_import from __future__ import absolute_import
from tobiko.shell.sh import _command
from tobiko.shell.sh import _exception from tobiko.shell.sh import _exception
from tobiko.shell.sh import _execute from tobiko.shell.sh import _execute
ShellCommandFailed = _exception.ShellCommandFailed
ShellError = _exception.ShellError ShellError = _exception.ShellError
ShellCommandFailed = _exception.ShellCommandFailed
ShellTimeoutExpired = _exception.ShellTimeoutExpired ShellTimeoutExpired = _exception.ShellTimeoutExpired
ShellProcessTeriminated = _exception.ShellProcessTeriminated
ShellProcessNotTeriminated = _exception.ShellProcessNotTeriminated
ShellStdinClosed = _exception.ShellStdinClosed
execute = _execute.execute execute = _execute.execute
ShellExecuteResult = _execute.ShellExecuteResult local_execute = _execute.local_execute
split_command = _execute.split_command ssh_execute = _execute.ssh_execute
join_command = _execute.join_command
shell_command = _command.shell_command

View File

@ -0,0 +1,44 @@
# Copyright (c) 2019 Red Hat, Inc.
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import subprocess
import six
def shell_command(command):
if isinstance(command, ShellCommand):
return command
elif isinstance(command, six.string_types):
return ShellCommand(command.split())
elif command:
return ShellCommand(str(a) for a in command)
else:
return ShellCommand()
class ShellCommand(tuple):
def __repr__(self):
return "ShellCommand([{!s}])".format(', '.join(self))
def __str__(self):
return subprocess.list2cmdline(self)
def __add__(self, other):
other = shell_command(other)
return shell_command(tuple(self) + other)

View File

@ -25,14 +25,39 @@ class ShellError(tobiko.TobikoException):
class ShellCommandFailed(ShellError): class ShellCommandFailed(ShellError):
"""Raised when shell command exited with non-zero status """Raised when shell command exited with non-zero status
""" """
message = ("command {command!r} failed (exit status is {exit_status}); " message = ("command '{command}' failed (exit status is {exit_status});\n"
"stderr:\n{stderr!s}\n" "stdin:\n{stdin}\n"
"stdout:\n{stdout!s}") "stdout:\n{stdout}\n"
"stderr:\n{stderr}")
class ShellTimeoutExpired(ShellError): class ShellTimeoutExpired(ShellError):
"""Raised when shell command timeouts and has been killed before exiting """Raised when shell command timeouts and has been killed before exiting
""" """
message = ("command {command!r} timed out after {timeout!s} seconds; " message = ("command {command} timed out after {timeout} seconds;\n"
"stderr:\n{stderr!s}\n" "stdin:\n{stdin}\n"
"stdout:\n{stdout!s}") "stdout:\n{stdout}\n"
"stderr:\n{stderr}")
class ShellProcessTeriminated(ShellError):
message = ("command '{command}' terminated (exit status is {exit_status})"
";\n"
"stdin:\n{stdin}\n"
"stdout:\n{stdout}\n"
"stderr:\n{stderr}")
class ShellProcessNotTeriminated(ShellError):
message = ("command '{command}' not terminated (time left is {time_left})"
";\n"
"stdin:\n{stdin}\n"
"stdout:\n{stdout}\n"
"stderr:\n{stderr}")
class ShellStdinClosed(ShellError):
message = ("command {command}: STDIN stream closed;\n"
"stdin:\n{stdin}\n"
"stdout:\n{stdout}\n"
"stderr:\n{stderr}")

View File

@ -15,25 +15,29 @@
# under the License. # under the License.
from __future__ import absolute_import from __future__ import absolute_import
import collections import fcntl
import select
import subprocess import subprocess
import sys import os
import time
from oslo_log import log from oslo_log import log
import paramiko
import six import six
import tobiko import tobiko
from tobiko.shell import ssh from tobiko.shell import ssh
from tobiko.shell.sh import _exception from tobiko.shell.sh import _command
from tobiko.shell.sh import _process
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
def execute(command, stdin=None, environment=None, timeout=None, shell=None, DATA_TYPES = six.string_types + (six.binary_type, six.text_type)
check=True, ssh_client=None):
def execute(command, environment=None, timeout=None, shell=None, check=True,
wait=None, stdin=True, stdout=True, stderr=True, ssh_client=None,
**kwargs):
"""Execute command inside a remote or local shell """Execute command inside a remote or local shell
:param command: command argument list :param command: command argument list
@ -45,9 +49,6 @@ def execute(command, stdin=None, environment=None, timeout=None, shell=None,
:param ssh_client: SSH client instance used for remote shell execution :param ssh_client: SSH client instance used for remote shell execution
:returns: STDOUT text when command execution terminates with zero exit
status.
:raises ShellTimeoutExpired: when timeout expires before command execution :raises ShellTimeoutExpired: when timeout expires before command execution
terminates. In such case it kills the process, then it eventually would terminates. In such case it kills the process, then it eventually would
try to read STDOUT and STDERR buffers (not fully implemented) before try to read STDOUT and STDERR buffers (not fully implemented) before
@ -57,226 +58,286 @@ def execute(command, stdin=None, environment=None, timeout=None, shell=None,
exit status. exit status.
""" """
if timeout: fixture = ShellExecuteFixture(
timeout = float(timeout) command, environment=environment, shell=shell, stdin=stdin,
stdout=stdout, stderr=stderr, timeout=timeout, check=check, wait=wait,
ssh_client = ssh_client or ssh.ssh_proxy_client() ssh_client=ssh_client, **kwargs)
if not ssh_client and not shell: return tobiko.setup_fixture(fixture).process
from tobiko import config
CONF = config.CONF
shell = CONF.tobiko.shell.command
if shell:
command = split_command(shell) + [join_command(command)]
else:
command = split_command(command)
if ssh_client:
result = execute_remote_command(command=command, stdin=stdin,
environment=environment,
timeout=timeout,
ssh_client=ssh_client)
else:
result = execute_local_command(command=command, stdin=stdin,
environment=environment,
timeout=timeout)
if result.exit_status == 0:
LOG.debug("Command %r succeeded:\n"
"stderr:\n%s\n"
"stdout:\n%s\n",
command, result.stderr, result.stdout)
elif result.exit_status is None:
LOG.debug("Command %r timeout expired (timeout=%s):\n"
"stderr:\n%s\n"
"stdout:\n%s\n",
command, timeout, result.stderr, result.stdout)
else:
LOG.debug("Command %r failed (exit_status=%s):\n"
"stderr:\n%s\n"
"stdout:\n%s\n",
command, result.exit_status, result.stderr, result.stdout)
if check:
result.check()
return result
def execute_remote_command(command, ssh_client, stdin=None, timeout=None, def local_execute(command, environment=None, shell=None, stdin=True,
environment=None): stdout=True, stderr=True, timeout=None, check=True,
"""Execute command on a remote host using SSH client""" wait=None, **kwargs):
if isinstance(ssh_client, ssh.SSHClientFixture):
# Connect to fixture
ssh_client = tobiko.setup_fixture(ssh_client).client
transport = ssh_client.get_transport()
with transport.open_session() as channel:
if environment:
channel.update_environment(environment)
channel.exec_command(join_command(command))
stdout, stderr = comunicate_ssh_channel(channel, stdin=stdin,
timeout=timeout)
if channel.exit_status_ready():
exit_status = channel.recv_exit_status()
else:
exit_status = None
return ShellExecuteResult(command=command, timeout=timeout,
stdout=stdout, stderr=stderr,
exit_status=exit_status)
def comunicate_ssh_channel(ssh_channel, stdin=None, chunk_size=None,
timeout=None, sleep_time=None, read_stdout=True,
read_stderr=True):
if read_stdout:
rlist = [ssh_channel]
else:
rlist = []
if not stdin:
ssh_channel.shutdown_write()
stdin = None
wlist = []
else:
wlist = [ssh_channel]
if not isinstance(stdin, six.binary_type):
stdin = stdin.encode()
chunk_size = chunk_size or 1024
sleep_time = sleep_time or 1.
timeout = timeout or float("inf")
start = time.time()
stdout = None
stderr = None
while True:
chunk_timeout = min(sleep_time, timeout - (time.time() - start))
if chunk_timeout < 0.:
LOG.debug('Timed out reading from SSH channel: %r', ssh_channel)
break
ssh_channel.settimeout(chunk_timeout)
if read_stdout and ssh_channel.recv_ready():
chunk = ssh_channel.recv(chunk_size)
if stdout:
stdout += chunk
else:
stdout = chunk
if not chunk:
LOG.debug("STDOUT channel closed by peer on SSH channel %r",
ssh_channel)
read_stdout = False
elif read_stderr and ssh_channel.recv_stderr_ready():
chunk = ssh_channel.recv_stderr(chunk_size)
if stderr:
stderr += chunk
else:
stderr = chunk
if not chunk:
LOG.debug("STDERR channel closed by peer on SSH channel %r",
ssh_channel)
read_stderr = False
elif ssh_channel.exit_status_ready():
break
elif stdin and ssh_channel.send_ready():
sent_bytes = ssh_channel.send(stdin[:chunk_size])
stdin = stdin[sent_bytes:] or None
if not stdin:
LOG.debug('shutdown_write() on SSH channel: %r', ssh_channel)
ssh_channel.shutdown_write()
else:
select.select(rlist, wlist, rlist or wlist, chunk_timeout)
if stdout:
if not isinstance(stdout, six.string_types):
stdout = stdout.decode()
else:
stdout = ''
if stderr:
if not isinstance(stderr, six.string_types):
stderr = stderr.decode()
else:
stderr = ''
return stdout, stderr
def execute_local_command(command, stdin=None, environment=None, timeout=None):
"""Execute command on local host using local shell""" """Execute command on local host using local shell"""
LOG.debug("Executing command %r on local host (timeout=%r)...", return execute(
command, timeout) command=command, environment=environment, shell=shell, stdin=stdin,
stdout=stdout, stderr=stderr, timeout=timeout, check=check, wait=wait,
stdin = stdin or None ssh_client=False, **kwargs)
process = subprocess.Popen(command,
universal_newlines=True,
env=environment,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if timeout and sys.version_info < (3, 3):
LOG.warning("Popen.communicate method doens't support for timeout "
"on Python %r", sys.version)
timeout = None
# Wait for process execution while reading STDERR and STDOUT streams
if timeout:
try:
# pylint: disable=unexpected-keyword-arg,no-member
stdout, stderr = process.communicate(input=stdin,
timeout=timeout)
except subprocess.TimeoutExpired:
# At this state I expect the process to be still running
# therefore it has to be kill later after calling poll()
LOG.exception("Command %r timeout expired.", command)
stdout = stderr = ''
else:
stdout, stderr = process.communicate(input=stdin)
# Check process termination status
exit_status = process.poll()
if exit_status is None:
# The process is still running after calling communicate():
# let kill it
process.kill()
return ShellExecuteResult(command=command, timeout=timeout,
stdout=stdout, stderr=stderr,
exit_status=exit_status)
class ShellExecuteResult(collections.namedtuple( def ssh_execute(ssh_client, command, environment=None, shell=None, stdin=True,
'ShellExecuteResult', ['command', 'timeout', 'exit_status', 'stdout', stdout=True, stderr=True, timeout=None, check=True, wait=None,
'stderr'])): **kwargs):
"""Execute command on local host using local shell"""
def check(self): return execute(
command = join_command(self.command) command=command, environment=environment, shell=shell, stdin=stdin,
if self.exit_status is None: stdout=stdout, stderr=stderr, timeout=timeout, check=check, wait=wait,
raise _exception.ShellTimeoutExpired(command=command, ssh_client=ssh_client, **kwargs)
timeout=self.timeout,
stderr=self.stderr,
stdout=self.stdout)
elif self.exit_status != 0:
raise _exception.ShellCommandFailed(command=command,
exit_status=self.exit_status,
stderr=self.stderr,
stdout=self.stdout)
def split_command(command): class ShellExecuteFixture(tobiko.SharedFixture):
if isinstance(command, six.string_types):
return command.split() command = None
elif command: shell = None
return [str(a) for a in command] environment = {}
else: stdin = None
return [] stderr = None
stdout = None
timeout = 120.
check = None
wait = None
process = None
process_parameters = None
def __init__(self, command=None, shell=None, environment=None, stdin=None,
stdout=None, stderr=None, timeout=None, check=None, wait=None,
ssh_client=None, **kwargs):
super(ShellExecuteFixture, self).__init__()
if ssh_client is not None:
self.ssh_client = ssh_client
else:
self.ssh_client = ssh_client = self.default_ssh_client
if shell is not None:
self.shell = shell = bool(shell) and _command.shell_command(shell)
elif not ssh_client:
self.shell = shell = self.default_shell_command
if command is None:
command = self.command
command = _command.shell_command(command)
if shell:
command = shell + [str(command)]
self.command = command
environment = environment or self.environment
if environment:
self.environment = dict(environment).update(environment)
if stdin is not None:
self.stdin = stdin
if stdout is not None:
self.stdout = stdout
if stderr is not None:
self.stderr = stderr
if timeout is not None:
self.timeout = timeout
if check is not None:
self.check = check
if wait is not None:
self.wait = wait
self.process_parameters = (self.process_parameters and
dict(self.process_parameters) or
{})
if kwargs:
self.process_parameters.update(kwargs)
@property
def default_shell_command(self):
from tobiko import config
CONF = config.CONF
return _command.shell_command(CONF.tobiko.shell.command)
@property
def default_ssh_client(self):
return ssh.ssh_proxy_client()
def setup_fixture(self):
self.setup_process()
def setup_process(self):
self.process = self.execute()
def execute(self, timeout=None, stdin=None, stdout=None, stderr=None,
check=None, ssh_client=None, wait=None, **kwargs):
command = self.command
environment = self.environment
if timeout is None:
timeout = self.timeout
LOG.debug("Execute command '%s' on local host (timeout=%r, "
"environment=%r)...",
command, timeout, environment)
if stdin is None:
stdin = self.stdin
if stdout is None:
stdout = self.stdout
if stderr is None:
stderr = self.stderr
if check is None:
check = self.check
if wait is None:
wait = self.wait
if ssh_client is None:
ssh_client = self.ssh_client
process_parameters = self.process_parameters
if kwargs:
process_parameters = dict(process_parameters, **kwargs)
process = self.create_process(command=command,
environment=environment,
timeout=timeout, stdin=stdin,
stdout=stdout, stderr=stderr,
ssh_client=ssh_client,
**process_parameters)
self.addCleanup(process.close)
if stdin and isinstance(stdin, DATA_TYPES):
process.send(data=stdin)
if wait or check:
if process.stdin:
process.stdin.close()
process.wait()
if check:
process.check_exit_status()
return process
def create_process(self, ssh_client, **kwargs):
if ssh_client:
return self.create_ssh_process(ssh_client=ssh_client, **kwargs)
else:
return self.create_local_process(**kwargs)
def create_local_process(self, command, environment, timeout, stdin,
stdout, stderr, **kwargs):
popen_params = {}
if stdin:
popen_params.update(stdin=subprocess.PIPE)
if stdout:
popen_params.update(stdout=subprocess.PIPE)
if stderr:
popen_params.update(stderr=subprocess.PIPE)
process = subprocess.Popen(command,
universal_newlines=True,
env=environment,
**popen_params)
if stdin:
set_non_blocking(process.stdin.fileno())
kwargs.update(stdin=process.stdin)
if stdout:
set_non_blocking(process.stdout.fileno())
kwargs.update(stdout=process.stdout)
if stderr:
set_non_blocking(process.stderr.fileno())
kwargs.update(stderr=process.stderr)
return LocalShellProcess(process=process, command=command,
timeout=timeout, **kwargs)
def create_ssh_process(self, command, environment, timeout, stdin, stdout,
stderr, ssh_client, **kwargs):
"""Execute command on a remote host using SSH client"""
if isinstance(ssh_client, ssh.SSHClientFixture):
# Connect to SSH server
ssh_client = ssh_client.connect()
if not isinstance(ssh_client, paramiko.SSHClient):
message = "Object {!r} is not an SSHClient".format(ssh_client)
raise TypeError(message)
LOG.debug("Execute command %r on remote host (timeout=%r)...",
str(command), timeout)
channel = ssh_client.get_transport().open_session()
if environment:
channel.update_environment(environment)
channel.exec_command(str(command))
if stdin:
kwargs.update(stdin=StdinSSHChannelFile(channel, 'wb'))
if stdout:
kwargs.update(stdout=StdoutSSHChannelFile(channel, 'rb'))
if stderr:
kwargs.update(stderr=StderrSSHChannelFile(channel, 'rb'))
return SSHShellProcess(channel=channel, command=command,
timeout=timeout, **kwargs)
def join_command(command): def set_non_blocking(fd):
if isinstance(command, six.string_types): flag = fcntl.fcntl(fd, fcntl.F_GETFL)
return command fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK)
elif command:
return subprocess.list2cmdline([str(a) for a in command])
else: class LocalShellProcess(_process.ShellProcess):
return ""
def __init__(self, process=None, **kwargs):
super(LocalShellProcess, self).__init__(**kwargs)
self.process = process
def poll_exit_status(self):
return self.process.poll()
def kill(self):
self.process.kill()
class SSHChannelFile(paramiko.ChannelFile):
def fileno(self):
return self.channel.fileno()
class StdinSSHChannelFile(SSHChannelFile):
def close(self):
super(StdinSSHChannelFile, self).close()
self.channel.shutdown_write()
@property
def write_ready(self):
return self.channel.send_ready()
def write(self, data):
super(StdinSSHChannelFile, self).write(data)
return len(data)
class StdoutSSHChannelFile(SSHChannelFile):
def fileno(self):
return self.channel.fileno()
def close(self):
super(StdoutSSHChannelFile, self).close()
self.channel.shutdown_read()
@property
def read_ready(self):
return self.channel.recv_ready()
class StderrSSHChannelFile(SSHChannelFile, paramiko.channel.ChannelStderrFile):
def fileno(self):
return self.channel.fileno()
@property
def read_ready(self):
return self.channel.recv_stderr_ready()
class SSHShellProcess(_process.ShellProcess):
def __init__(self, channel=None, **kwargs):
super(SSHShellProcess, self).__init__(**kwargs)
self.channel = channel
def poll_exit_status(self):
if self.channel.exit_status_ready():
return self.channel.recv_exit_status()
def close(self):
super(SSHShellProcess, self).close()
self.channel.close()

193
tobiko/shell/sh/_io.py Normal file
View File

@ -0,0 +1,193 @@
# Copyright (c) 2019 Red Hat, Inc.
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import io
import select
from oslo_log import log
import six
LOG = log.getLogger(__name__)
class ShellIOBase(io.IOBase):
buffer_size = io.DEFAULT_BUFFER_SIZE
def __init__(self, delegate, fd=None, buffer_size=None):
super(ShellIOBase, self).__init__()
self.delegate = delegate
if buffer_size:
self.buffer_size = int(buffer_size)
if fd is None:
fd = delegate.fileno()
self.fd = fd
self._data_chunks = []
@property
def data(self):
chunks = self._data_chunks
if not chunks:
return None
chunks_number = len(chunks)
if chunks_number == 1:
return chunks[0]
# Use a zero-length object of chunk type to join chunks
data = chunks[0][:0].join(chunks)
self._data_chunks = chunks = [data]
return data
def __str__(self):
data = self.data
if not data:
return ''
if isinstance(data, six.string_types):
return data
return data.decode()
def fileno(self):
return self.fd
def readable(self):
return False
def writable(self):
return False
def close(self):
self.delegate.close()
@property
def closed(self):
return self.delegate.closed
def __bool__(self):
for chunk in self._data_chunks:
if chunk:
return True
return False
class ShellReadable(ShellIOBase):
def readable(self):
return True
def read(self, size=None):
size = size or self.buffer_size
chunk = self.delegate.read(size)
self._data_chunks.append(chunk)
return chunk
@property
def read_ready(self):
return (not self.closed and
getattr(self.delegate, 'read_ready', False))
class ShellWritable(ShellIOBase):
def writable(self):
return True
def write(self, chunk):
witten_bytes = self.delegate.write(chunk)
self._data_chunks.append(chunk)
return witten_bytes
@property
def write_ready(self):
return (not self.closed and
getattr(self.delegate, 'write_ready', False))
class ShellStdin(ShellWritable):
pass
class ShellStdout(ShellReadable):
pass
class ShellStderr(ShellReadable):
pass
def select_files(files, timeout, mode='rw'):
# NOTE: in case there is no files that can be selected for given mode,
# this function is going to behave like time.sleep()
if timeout is None:
message = "Invalid value for timeout: {!r}".format(timeout)
raise ValueError(message)
timeout = float(timeout)
opened = select_opened_files(files)
readable = writable = set()
if 'r' in mode:
readable = select_readable_files(opened)
if 'w' in mode:
writable = select_writable_files(opened)
read_ready = select_read_ready_files(readable)
write_ready = select_write_ready_files(writable)
if not write_ready and not read_ready:
if timeout > 0.:
LOG.debug("Calling select with %d files and timeout %f",
len(opened), timeout)
rlist, wlist, xlist = select.select(readable, writable, opened,
timeout)
read_ready = readable & set(rlist + xlist)
write_ready = writable & set(wlist + xlist)
return read_ready, write_ready
def select_opened_files(files):
return {f for f in files if is_opened_file(f)}
def is_opened_file(f):
return not getattr(f, 'closed', True)
def select_readable_files(files):
return {f for f in files if is_readable_file(f)}
def is_readable_file(f):
return f.readable()
def select_read_ready_files(files):
return {f for f in files if f.read_ready}
def select_writable_files(files):
return {f for f in files if is_writable_file(f)}
def is_writable_file(f):
return f.writable()
def select_write_ready_files(files):
return {f for f in files if f.write_ready}

241
tobiko/shell/sh/_process.py Normal file
View File

@ -0,0 +1,241 @@
# Copyright (c) 2019 Red Hat, Inc.
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import io
import time
from oslo_log import log
from tobiko.shell.sh import _exception
from tobiko.shell.sh import _io
LOG = log.getLogger(__name__)
class Timeout(object):
timeout = float('inf')
def __init__(self, timeout=None, start_time=None):
if timeout is None:
timeout = float('inf')
else:
timeout = float(timeout)
self.timeout = timeout
start_time = start_time and float(start_time) or time.time()
self.start_time = start_time
self.end_time = start_time + timeout
def __float__(self):
return self.timeout
def time_left(self, now=None):
now = now or time.time()
return self.end_time - now
def is_expired(self, now=None):
raise self.time_left(now=now) <= 0.
class ShellProcess(object):
buffer_size = io.DEFAULT_BUFFER_SIZE
stdin = None
stdout = None
stderr = None
poll_time = 0.1
def __init__(self, command, timeout=None, stdin=None, stdout=None,
stderr=None, buffer_size=None, poll_time=None):
self.command = command
self.timeout = Timeout(timeout)
if buffer_size is not None:
self.buffer_size = max(64, int(buffer_size))
if stdin:
self.stdin = _io.ShellStdin(stdin, buffer_size=self.buffer_size)
if stdout:
self.stdout = _io.ShellStdout(stdout, buffer_size=self.buffer_size)
if stderr:
self.stderr = _io.ShellStderr(stderr, buffer_size=self.buffer_size)
if poll_time is not None:
self.poll_time = max(0., float(poll_time))
def __enter__(self):
return self
def __exit__(self, _exception_type, _exception_value, _traceback):
self.close()
def close(self):
if self.is_running:
self.kill()
for f in _io.select_opened_files([self.stdin,
self.stdout,
self.stderr]):
f.close()
def kill(self):
pass
def poll_exit_status(self):
raise NotImplementedError
@property
def exit_status(self):
return self.poll_exit_status()
@property
def is_running(self):
return self.poll_exit_status() is None
def check_is_running(self):
exit_status = self.poll_exit_status()
if exit_status is not None:
raise _exception.ShellProcessTeriminated(
command=self.command,
exit_status=int(exit_status),
stdin=self.stdin,
stdout=self.stdout,
stderr=self.stderr)
def check_stdin_is_opened(self):
if self.stdin.closed:
raise _exception.ShellStdinClosed(
command=self.command,
stdin=self.stdin,
stdout=self.stdout,
stderr=self.stderr)
def send(self, data, timeout=None):
self.comunicate(stdin=data, timeout=timeout, wait=False)
def wait(self, timeout=None):
self.comunicate(stdin=None, timeout=timeout, wait=True)
def comunicate(self, stdin=None, stdout=True, stderr=True, timeout=None,
wait=True):
timeout = Timeout(timeout=timeout)
# Avoid waiting for data in the first loop
poll_time = 0.
poll_files = _io.select_opened_files([stdin and self.stdin,
stdout and self.stdout,
stderr and self.stderr])
while wait or stdin or poll_files:
self.check_timeout(timeout=timeout)
if stdin:
self.check_is_running()
self.check_stdin_is_opened()
else:
wait = wait and self.is_running
read_ready, write_ready = _io.select_files(files=poll_files,
timeout=poll_time)
if read_ready or write_ready:
# Avoid waiting for data the next time
poll_time = 0.
else:
# Wait for data in the following loops
poll_time = min(self.poll_time,
self.check_timeout(timeout=timeout))
if self.stdin in write_ready:
# Write data to remote STDIN
sent_bytes = self.stdin.write(stdin)
if sent_bytes:
stdin = stdin[sent_bytes:]
if not stdin:
self.stdin.flush()
else:
LOG.debug("STDIN channel closed by peer on %r", self)
self.stdin.close()
if self.stdout in read_ready:
# Read data from remote STDOUT
chunk = self.stdout.read(self.buffer_size)
if not chunk:
LOG.debug("STDOUT channel closed by peer on %r", self)
self.stdout.close()
if self.stderr in read_ready:
# Read data from remote STDERR
chunk = self.stderr.read(self.buffer_size)
if not chunk:
LOG.debug("STDERR channel closed by peer on %r", self)
self.stderr.close()
poll_files = _io.select_opened_files(poll_files)
def time_left(self, now=None, timeout=None):
now = now or time.time()
time_left = self.timeout.time_left(now=now)
if timeout:
time_left = min(time_left, timeout.time_left(now=now))
return time_left
def check_timeout(self, timeout=None, now=None):
now = now or time.time()
time_left = float('inf')
for timeout in [self.timeout, timeout]:
if timeout is not None:
time_left = min(time_left, timeout.time_left(now=now))
if time_left <= 0.:
ex = _exception.ShellTimeoutExpired(
command=self.command,
timeout=timeout.timeout,
stdin=self.stdin,
stdout=self.stdout,
stderr=self.stderr)
LOG.debug("%s", ex)
raise ex
return time_left
def check_exit_status(self, expected_status=0):
exit_status = self.poll_exit_status()
if exit_status is None:
time_left = self.check_timeout()
ex = _exception.ShellProcessNotTeriminated(
command=self.command,
time_left=time_left,
stdin=self.stdin,
stdout=self.stdout,
stderr=self.stderr)
LOG.debug("%s", ex)
raise ex
exit_status = int(exit_status)
if expected_status != exit_status:
ex = _exception.ShellCommandFailed(
command=self.command,
exit_status=exit_status,
stdin=self.stdin,
stdout=self.stdout,
stderr=self.stderr)
LOG.debug("%s", ex)
raise ex
LOG.debug("Command '%s' succeeded (exit_status=%d):\n"
"stdin:\n%s\n"
"stderr:\n%s\n"
"stdout:\n%s",
self.command, exit_status,
self.stdin, self.stdout, self.stderr)
def clamp(left, value, right):
return max(left, min(value, right))

View File

@ -42,8 +42,8 @@ def ssh_command(host, username=None, port=None, command=None,
username = username or host_config.username username = username or host_config.username
command += [ssh_login(hostname=hostname, username=username)] command += [ssh_login(hostname=hostname, username=username)]
if host_config.default.debug: # if host_config.default.debug:
command += ['-vvvvvv'] # command += ['-vvvvvv']
port = port or host_config.port port = port or host_config.port
if port: if port:

View File

@ -31,15 +31,27 @@ class ExecuteTest(testtools.TestCase):
ssh_client = None ssh_client = None
shell = '/bin/sh -c' shell = '/bin/sh -c'
def test_succeed(self, command='true', stdout='', stderr='', **kwargs): def test_succeed(self, command='true', stdin=None, stdout=None,
result = self.execute(command, **kwargs) stderr=None, **kwargs):
expected_result = sh.ShellExecuteResult( process = self.execute(command,
command=self.expected_command(command), stdin=stdin,
timeout=kwargs.get('timeout'), stdout=bool(stdout),
exit_status=0, stderr=bool(stderr),
stdout=stdout, **kwargs)
stderr=stderr) self.assertEqual(self.expected_command(command), process.command)
self.assertEqual(expected_result, result) if stdin:
self.assertEqual(stdin, str(process.stdin))
else:
self.assertIsNone(process.stdin)
if stdout:
self.assertEqual(stdout, str(process.stdout))
else:
self.assertIsNone(process.stdout)
if stderr:
self.assertEqual(stderr, str(process.stderr))
else:
self.assertIsNone(process.stderr)
self.assertEqual(0, process.exit_status)
def test_succeed_with_command_list(self): def test_succeed_with_command_list(self):
self.test_succeed(['echo', 'something'], self.test_succeed(['echo', 'something'],
@ -61,13 +73,26 @@ class ExecuteTest(testtools.TestCase):
def test_succeed_with_timeout(self): def test_succeed_with_timeout(self):
self.test_succeed(timeout=30.) self.test_succeed(timeout=30.)
def test_fails(self, command='false', exit_status=None, stdout='', def test_fails(self, command='false', exit_status=None, stdin=None,
stderr='', **kwargs): stdout=None, stderr=None, **kwargs):
ex = self.assertRaises(sh.ShellCommandFailed, self.execute, command, ex = self.assertRaises(sh.ShellCommandFailed, self.execute, command,
stdin=stdin,
stdout=bool(stdout),
stderr=bool(stderr),
**kwargs) **kwargs)
self.assertEqual(self.expected_ex_command(command), ex.command) self.assertEqual(self.expected_command(command), ex.command)
self.assertEqual(stdout, ex.stdout) if stdin:
self.assertEqual(stderr, ex.stderr) self.assertEqual(stdin, str(ex.stdin))
else:
self.assertIsNone(ex.stdin)
if stdout:
self.assertEqual(stdout, str(ex.stdout))
else:
self.assertIsNone(ex.stdout)
if stderr:
self.assertEqual(stderr, str(ex.stderr))
else:
self.assertIsNone(ex.stderr)
if exit_status: if exit_status:
self.assertEqual(exit_status, ex.exit_status) self.assertEqual(exit_status, ex.exit_status)
else: else:
@ -89,13 +114,27 @@ class ExecuteTest(testtools.TestCase):
stdin='some input\n', stdin='some input\n',
stdout='some input\n') stdout='some input\n')
def test_timeout_expires(self, command='sleep 5', timeout=0.1, stdout='', def test_timeout_expires(self, command='sleep 5', timeout=0.1, stdin=None,
stderr='', **kwargs): stdout=None, stderr=None, **kwargs):
ex = self.assertRaises(sh.ShellTimeoutExpired, self.execute, command, ex = self.assertRaises(sh.ShellTimeoutExpired, self.execute, command,
timeout=timeout, **kwargs) timeout=timeout,
self.assertEqual(self.expected_ex_command(command), ex.command) stdin=stdin,
self.assertTrue(stdout.startswith(ex.stdout)) stdout=bool(stdout),
self.assertTrue(stderr.startswith(ex.stderr)) stderr=bool(stderr),
**kwargs)
self.assertEqual(self.expected_command(command), ex.command)
if stdin:
self.assertTrue(stdin.startswith(ex.stdin))
else:
self.assertIsNone(ex.stdin)
if stdout:
self.assertTrue(stdout.startswith(ex.stdout))
else:
self.assertIsNone(ex.stdout)
if stderr:
self.assertTrue(stderr.startswith(ex.stderr))
else:
self.assertIsNone(ex.stderr)
self.assertEqual(timeout, ex.timeout) self.assertEqual(timeout, ex.timeout)
def execute(self, command, **kwargs): def execute(self, command, **kwargs):
@ -104,13 +143,19 @@ class ExecuteTest(testtools.TestCase):
return sh.execute(command, **kwargs) return sh.execute(command, **kwargs)
def expected_command(self, command): def expected_command(self, command):
return sh.split_command(self.shell) + [sh.join_command(command)] command = sh.shell_command(command)
shell = sh.shell_command(self.shell)
def expected_ex_command(self, command): return shell + [str(command)]
return sh.join_command(self.expected_command(command))
class ExecuteWithSSHClientTest(ExecuteTest): class LocalExecuteTest(ExecuteTest):
def execute(self, command, **kwargs):
kwargs.setdefault('shell', self.shell)
return sh.local_execute(command, **kwargs)
class SSHExecuteTest(ExecuteTest):
server_stack = tobiko.required_setup_fixture( server_stack = tobiko.required_setup_fixture(
stacks.NeutronServerStackFixture) stacks.NeutronServerStackFixture)
@ -119,6 +164,10 @@ class ExecuteWithSSHClientTest(ExecuteTest):
def ssh_client(self): def ssh_client(self):
return self.server_stack.ssh_client return self.server_stack.ssh_client
def execute(self, command, **kwargs):
kwargs.setdefault('shell', self.shell)
return sh.ssh_execute(self.ssh_client, command, **kwargs)
class ExecuteWithSSHCommandTest(ExecuteTest): class ExecuteWithSSHCommandTest(ExecuteTest):

View File

@ -136,13 +136,13 @@ class FloatingIPTest(base.TobikoTest):
"""Test SSH connectivity to floating IP address""" """Test SSH connectivity to floating IP address"""
result = sh.execute("hostname", ssh_client=self.ssh_client) result = sh.execute("hostname", ssh_client=self.ssh_client)
self.assertEqual([self.server_name.lower()], self.assertEqual([self.server_name.lower()],
result.stdout.splitlines()) str(result.stdout).splitlines())
def test_ssh_from_cli(self): def test_ssh_from_cli(self):
"""Test SSH connectivity to floating IP address from CLI""" """Test SSH connectivity to floating IP address from CLI"""
result = sh.execute(self.floating_ip_stack.ssh_command + ['hostname']) result = sh.execute(self.floating_ip_stack.ssh_command + ['hostname'])
self.assertEqual([self.server_name.lower()], self.assertEqual([self.server_name.lower()],
result.stdout.splitlines()) str(result.stdout).splitlines())
def test_ping(self): def test_ping(self):
"""Test ICMP connectivity to floating IP address""" """Test ICMP connectivity to floating IP address"""

View File

@ -48,7 +48,7 @@ commands =
coverage combine coverage combine
coverage html -d cover coverage html -d cover
coverage xml -o cover/coverage.xml coverage xml -o cover/coverage.xml
coverage report --fail-under=55 --skip-covered coverage report --fail-under=50 --skip-covered
find . -type f -name ".coverage.*" -delete find . -type f -name ".coverage.*" -delete
whitelist_externals = whitelist_externals =
find find