Sshutils refactoring

Paramiko sftp client is not eventlet friendly. Also it doesn't work
with some servers, so upload/download methods has been removed.
Downloading or uploading can be done in other way (see examples).

Added ability to pass process stdin, which makes possible to
do on remote host anything at all.

Much more clean and usable API.

Change-Id: If357a2878c2c60646a975c386a0fe2f1616aec95
This commit is contained in:
Sergey Skripnick 2013-12-24 15:29:15 +02:00
parent 3d420c286b
commit 936ba2efd8
10 changed files with 557 additions and 382 deletions

View File

@ -20,7 +20,6 @@ import random
from rally.benchmark.scenarios.cinder import utils as cinder_utils
from rally.benchmark.scenarios.nova import utils
from rally.benchmark.scenarios import utils as scenario_utils
from rally.benchmark import utils as benchmark_utils
from rally.benchmark import validation
from rally import exceptions as rally_exceptions
from rally.openstack.common.gettextutils import _ # noqa
@ -102,43 +101,23 @@ class NovaServers(utils.NovaScenario,
)
server_ip = [ip for ip in server.addresses[network] if
ip['version'] == ip_version][0]['addr']
ssh = sshutils.SSH(ip=server_ip, port=port, user=username,
key=self.clients('ssh_key_pair')['private'],
key_type='string')
for retry in range(retries):
try:
LOG.debug(_('Execute script on server attempt '
'%(retry)i/%(retries)i') % dict(retry=retry,
retries=retries))
streams = list(ssh.execute_script(script=script,
interpreter=interpreter,
get_stdout=True,
get_stderr=True))
#NOTE(hughsaunders): Decode JSON script output
streams[sshutils.SSH.STDOUT_INDEX]\
= json.loads(streams[sshutils.SSH.STDOUT_INDEX])
break
except (rally_exceptions.SSHError,
rally_exceptions.TimeoutException, IOError) as e:
LOG.debug(_('Error running script on instance via SSH. '
'%(id)s/%(ip)s Attempt:%(retry)i, '
'Error: %(error)s') % dict(
id=server.id, ip=server_ip, retry=retry,
error=benchmark_utils.format_exc(e)))
self.sleep_between(5, 5)
except ValueError:
LOG.error(_('Script %(script)s did not output valid JSON. ')
% dict(script=script))
ssh = sshutils.SSH(username, server_ip, port=port,
pkey=self.clients('ssh_key_pair')['private'])
ssh.wait()
code, out, err = ssh.execute(interpreter, stdin=open(script, 'rb'))
if code:
LOG.error(_('Error running script on instance via SSH. '
'Error: %s') % err)
try:
out = json.loads(out)
except ValueError:
LOG.warning(_('Script %s did not output valid JSON. ') % script)
self._delete_server(server)
LOG.debug(_('Output streams from in-instance script execution: '
'stdout: %(stdout)s, stderr: $(stderr)s') % dict(
stdout=str(streams[sshutils.SSH.STDOUT_INDEX]),
stderr=str(streams[sshutils.SSH.STDERR_INDEX])))
return dict(data=streams[sshutils.SSH.STDOUT_INDEX],
errors=streams[sshutils.SSH.STDERR_INDEX])
stdout=out, stderr=err))
return {'data': out, 'errors': err}
@validation.add_validator(validation.flavor_exists("flavor_id"))
@validation.add_validator(validation.image_exists("image_id"))

View File

@ -14,7 +14,7 @@
# under the License.
import os
import tempfile
import StringIO
from rally.deploy import engine
from rally import objects
@ -78,7 +78,7 @@ class DevstackEngine(engine.EngineFactory):
def prepare_server(self, server):
script_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'devstack', 'install.sh'))
server.ssh.execute_script(script_path)
server.ssh.run('/bin/sh -e', stdin=open(script_path, 'rb'))
@utils.log_deploy_wrapper(LOG.info, _("Deploy devstack"))
def deploy(self):
@ -103,18 +103,15 @@ class DevstackEngine(engine.EngineFactory):
@utils.log_deploy_wrapper(LOG.info, _("Configure devstack"))
def configure_devstack(self, server):
devstack_repo = self.config.get('devstack_repo', DEVSTACK_REPO)
server.ssh.execute('git', 'clone', devstack_repo)
fd, config_path = tempfile.mkstemp()
config_file = open(config_path, "w")
server.ssh.run('git clone %s' % devstack_repo)
localrc = StringIO.StringIO()
for k, v in self.localrc.iteritems():
config_file.write('%s=%s\n' % (k, v))
config_file.close()
os.close(fd)
server.ssh.upload(config_path, "~/devstack/localrc")
os.unlink(config_path)
localrc.write('%s=%s\n' % (k, v))
localrc.seek(0)
server.ssh.run("cat > ~/devstack/localrc", stdin=localrc)
return True
@utils.log_deploy_wrapper(LOG.info, _("Run devstack"))
def start_devstack(self, server):
server.ssh.execute('~/devstack/stack.sh')
server.ssh.run('~/devstack/stack.sh')
return True

View File

@ -16,6 +16,7 @@
import abc
import jsonschema
from rally import exceptions
from rally import sshutils
from rally import utils
@ -32,7 +33,8 @@ class Server(utils.ImmutableMixin):
self.user = user
self.key = key
self.password = password
self.ssh = sshutils.SSH(host, user, port, key)
self.ssh = sshutils.SSH(user, host, key_filename=key, port=port,
password=password)
super(Server, self).__init__()
def get_credentials(self):

View File

@ -16,7 +16,7 @@
import netaddr
import os
import re
import tempfile
import StringIO
import time
from rally import exceptions
@ -30,17 +30,15 @@ LOG = logging.getLogger(__name__)
INET_ADDR_RE = re.compile(r' *inet ((\d+\.){3}\d+)\/\d+ .*')
def _get_script_path(filename):
return os.path.abspath(os.path.join(os.path.dirname(__file__),
def _get_script(filename):
path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'lxc', filename))
return open(path, 'rb')
def _write_script_from_template(template_filename, **kwargs):
template = open(_get_script_path(template_filename)).read()
new_file = tempfile.NamedTemporaryFile(delete=False)
new_file.write(template.format(**kwargs))
new_file.close()
return new_file.name
def _get_script_from_template(template_filename, **kwargs):
template = _get_script(template_filename).read()
return StringIO.StringIO(template.format(**kwargs))
class LxcHost(object):
@ -85,39 +83,35 @@ class LxcHost(object):
'LXC_DHCP_RANGE': dhcp_range,
'LXC_DHCP_MAX': self.network.size - 3,
}
config = tempfile.NamedTemporaryFile(delete=False)
config = StringIO.StringIO()
for name, value in values.iteritems():
config.write('%(name)s="%(value)s"\n' % {'name': name,
'value': value})
config.close()
self.server.ssh.upload(config.name, '/tmp/.lxc_default')
os.unlink(config.name)
config.seek(0)
self.server.ssh.run('cat > /tmp/.lxc_default', stdin=config)
script = _get_script_path('lxc-install.sh')
self.server.ssh.execute_script(script)
self.server.ssh.run('/bin/sh', stdin=_get_script('lxc-install.sh'))
self.create_local_tunnels()
self.create_remote_tunnels()
def create_local_tunnels(self):
"""Create tunel on lxc host side."""
for tunnel_to in self.config['tunnel_to']:
script = _write_script_from_template('tunnel-local.sh',
net=self.network,
local=self.server.host,
remote=tunnel_to)
self.server.ssh.execute_script(script)
os.unlink(script)
script = _get_script_from_template('tunnel-local.sh',
net=self.network,
local=self.server.host,
remote=tunnel_to)
self.server.ssh.run('/bin/sh -e', stdin=script)
def create_remote_tunnels(self):
"""Create tunel on remote side."""
for tunnel_to in self.config['tunnel_to']:
script = _write_script_from_template('tunnel-remote.sh',
net=self.network,
local=tunnel_to,
remote=self.server.host)
script = _get_script_from_template('tunnel-remote.sh',
net=self.network,
local=tunnel_to,
remote=self.server.host)
server = self._get_server_with_ip(tunnel_to)
server.ssh.execute_script(script)
os.unlink(script)
server.ssh.run('/bin/sh -e', stdin=script)
def delete_tunnels(self):
for tunnel_to in self.config['tunnel_to']:
@ -130,7 +124,9 @@ class LxcHost(object):
cmd = 'lxc-attach -n %s ip addr list dev eth0' % name
for attempt in range(1, 16):
stdout = self.server.ssh.execute(cmd, get_stdout=True)[0]
code, stdout = self.server.ssh.execute(cmd)[:2]
if code:
continue
for line in stdout.splitlines():
m = INET_ADDR_RE.match(line)
if m:
@ -140,9 +136,10 @@ class LxcHost(object):
raise exceptions.TimeoutException(msg)
def create_container(self, name, distribution):
self.server.ssh.execute('lxc-create', '-B', self.backingstore,
'-n', name,
'-t', distribution)
args = {'backingstore': self.backingstore,
'name': name, 'distribution': distribution}
self.server.ssh.run('lxc-create -B %(backingstore)s -n %(name)s'
' -t %(distribution)s' % args)
self.configure_container(name)
self.containers.append(name)
@ -152,28 +149,27 @@ class LxcHost(object):
if self.backingstore == 'btrfs':
cmd.append('--snapshot')
cmd.extend(['-o', source, '-n', name])
self.server.ssh.execute(*cmd)
self.server.ssh.execute(' '.join(cmd))
self.configure_container(name)
self.containers.append(name)
def configure_container(self, name):
path = os.path.join(self.path, name, 'rootfs')
configure_script = _get_script_path('configure_container.sh')
self.server.ssh.upload(configure_script, '/tmp/.rally_cont_conf.sh')
self.server.ssh.execute('/bin/sh', '/tmp/.rally_cont_conf.sh', path)
conf_script = _get_script('configure_container.sh')
self.server.ssh.run('/bin/sh -e -s %s' % path, stdin=conf_script)
def start_containers(self):
for name in self.containers:
self.server.ssh.execute('lxc-start -d -n %s' % name)
self.server.ssh.run('lxc-start -d -n %s' % name)
def stop_containers(self):
for name in self.containers:
self.server.ssh.execute('lxc-stop -n %s' % name)
self.server.ssh.run('lxc-stop -n %s' % name)
def destroy_containers(self):
for name in self.containers:
self.server.ssh.execute('lxc-stop -n %s' % name)
self.server.ssh.execute('lxc-destroy -n %s' % name)
self.server.ssh.run('lxc-stop -n %s' % name)
self.server.ssh.run('lxc-destroy -n %s' % name)
def get_server_object(self, name, wait=True):
"""Create Server object for container."""
@ -257,6 +253,7 @@ class LxcProvider(provider.ProviderFactory):
host.prepare()
ip = str(network.ip).replace('.', '-') if network else '0'
first_name = '%s-000-%s' % (name_prefix, ip)
host.create_container(first_name, distribution)
for i in range(1, self.config.get('containers_per_host', 1)):
name = '%s-%03d-%s' % (name_prefix, i, ip)

View File

@ -1,3 +1,3 @@
ip tun add t{net.ip} mode ipip local {local} remote {remote} || true
ip tun add t{net.ip} mode ipip local {local} remote {remote}
ip link set t{net.ip} up
ip route add {net} dev t{net.ip} src {local} || true
ip route add {net} dev t{net.ip} src {local}

View File

@ -13,160 +13,235 @@
# License for the specific language governing permissions and limitations
# under the License.
import os
"""High level ssh library.
Usage examples:
Execute command and get output:
ssh = sshclient.SSH('root', 'example.com', port=33)
status, stdout, stderr = ssh.execute('ps ax')
if status:
raise Exception('Command failed with non-zero status.')
print stdout.splitlines()
Execute command with huge output:
class PseudoFile(object):
def write(chunk):
if 'error' in chunk:
email_admin(chunk)
ssh = sshclient.SSH('root', 'example.com')
ssh.run('tail -f /var/log/syslog', stdout=PseudoFile(), timeout=False)
Execute local script on remote side:
ssh = sshclient.SSH('user', 'example.com')
status, out, err = ssh.execute('/bin/sh -s arg1 arg2',
stdin=open('~/myscript.sh', 'r'))
Upload file:
ssh = sshclient.SSH('user', 'example.com')
ssh.run('cat > ~/upload/file.gz', stdin=open('/store/file.gz', 'rb'))
Eventlet:
eventlet.monkey_patch(select=True, time=True)
or
eventlet.monkey_patch()
or
sshclient = eventlet.import_patched("opentstack.common.sshclient")
"""
import paramiko
import random
import select
import socket
import string
import StringIO
import time
from rally import exceptions
from rally.openstack.common.gettextutils import _
from rally.openstack.common import log as logging
LOG = logging.getLogger(__name__)
class SSHError(Exception):
pass
class SSHTimeout(SSHError):
pass
class SSH(object):
"""SSH common functions."""
STDOUT_INDEX = 0
STDERR_INDEX = 1
"""Represent ssh connection."""
def __init__(self, ip, user, port=22, key=None, key_type="file",
timeout=1800):
"""Initialize SSH client with ip, username and the default values.
def __init__(self, user, host, port=22, pkey=None,
key_filename=None, password=None):
"""Initialize SSH client.
:param user: ssh username
:param host: hostname or ip address of remote ssh server
:param port: remote ssh port
:param pkey: RSA or DSS private key string or file object
:param key_filename: private key filename
:param password: password
timeout - the timeout for execution of the command
key - path to private key file, or string containing actual key
key_type - "file" for key path, "string" for actual key
"""
self.ip = ip
self.port = port
self.user = user
self.timeout = timeout
self.client = None
self.key = key
self.key_type = key_type
if not self.key:
#Guess location of user's private key if no key is specified.
self.key = os.path.expanduser('~/.ssh/id_rsa')
self.host = host
self.port = port
self.pkey = self._get_pkey(pkey) if pkey else None
self.password = password
self.key_filename = key_filename
self._client = False
def _get_ssh_connection(self):
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
connect_params = {
'hostname': self.ip,
'port': self.port,
'username': self.user
}
def _get_pkey(self, key):
if isinstance(key, basestring):
key = StringIO.StringIO(key)
for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
try:
return key_class.from_private_key(key)
except paramiko.SSHException:
pass
raise SSHError('Invalid pkey')
# NOTE(hughsaunders): Set correct paramiko parameter names for each
# method of supplying a key.
if self.key_type == 'file':
connect_params['key_filename'] = self.key
else:
connect_params['pkey'] = paramiko.RSAKey(
file_obj=StringIO.StringIO(self.key))
def _get_client(self):
if self._client:
return self._client
try:
self._client = paramiko.SSHClient()
self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self._client.connect(self.host, username=self.user,
port=self.port, pkey=self.pkey,
key_filename=self.key_filename,
password=self.password)
return self._client
except paramiko.SSHException as e:
message = _("Paramiko exception %(exception_type)s was raised "
"during connect. Exception value is: %(exception)r")
raise SSHError(message % {'exception': e,
'exception_type': type(e)})
self.client.connect(**connect_params)
def close(self):
self._client.close()
self._client = False
def _is_timed_out(self, start_time, timeout=None):
timeout = timeout if timeout else self.timeout
return (time.time() - timeout) > start_time
def execute(self, *cmd, **kwargs):
"""Execute the specified command on the server.
Return tuple (stdout, stderr).
:param *cmd: Command and arguments to be executed.
:param get_stdout: Collect stdout data. Boolean.
:param get_stderr: Collect stderr data. Boolean.
def run(self, cmd, stdin=None, stdout=None, stderr=None,
raise_on_error=True, timeout=3600):
"""Execute specified command on the server.
:param cmd: Command to be executed.
:param stdin: Open file or string to pass to stdin.
:param stdout: Open file to connect to stdout.
:param stderr: Open file to connect to stderr.
:param raise_on_error: If False then exit code will be return. If True
then exception will be raized if non-zero code.
:param timeout: Timeout in seconds for command execution.
Default 1 hour. No timeout if set to 0.
"""
get_stdout = kwargs.get("get_stdout", False)
get_stderr = kwargs.get("get_stderr", False)
stdout = ''
stderr = ''
for chunk in self.execute_generator(*cmd, get_stdout=get_stdout,
get_stderr=get_stderr):
if chunk[0] == 1:
stdout += chunk[1]
elif chunk[0] == 2:
stderr += chunk[1]
return (stdout, stderr)
def execute_generator(self, *cmd, **kwargs):
"""Execute the specified command on the server.
client = self._get_client()
Return generator. Stdout and stderr data can be collected by chunks.
if isinstance(stdin, basestring):
stdin = StringIO.StringIO(stdin)
:param *cmd: Command and arguments to be executed.
:param get_stdout: Collect stdout data. Boolean.
:param get_stderr: Collect stderr data. Boolean.
return self._run(client, cmd, stdin=stdin, stdout=stdout,
stderr=stderr, raise_on_error=raise_on_error,
timeout=timeout)
"""
get_stdout = kwargs.get("get_stdout", True)
get_stderr = kwargs.get("get_stderr", True)
self._get_ssh_connection()
cmd = ' '.join(cmd)
transport = self.client.get_transport()
def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
raise_on_error=True, timeout=3600):
transport = client.get_transport()
session = transport.open_session()
session.exec_command(cmd)
start_time = time.time()
data_to_send = ''
stderr_data = None
# If we have data to be sent to stdin then `select' should also
# check for stdin availability.
if stdin and not stdin.closed:
writes = [session]
else:
writes = []
while True:
errors = select.select([session], [], [], 4)[2]
# Block until data can be read/write.
r, w, e = select.select([session], writes, [session], 1)
if session.recv_ready():
data = session.recv(4096)
LOG.debug(data)
if get_stdout:
yield (1, data)
LOG.debug(_('stdout: %r') % data)
if stdout is not None:
stdout.write(data)
continue
if session.recv_stderr_ready():
data = session.recv_stderr(4096)
LOG.debug(data)
if get_stderr:
yield (2, data)
stderr_data = session.recv_stderr(4096)
LOG.debug(_('stderr: %r') % stderr_data)
if stderr is not None:
stderr.write(stderr_data)
continue
if errors or session.exit_status_ready():
if session.send_ready():
if stdin is not None and not stdin.closed:
if not data_to_send:
data_to_send = stdin.read(4096)
if not data_to_send:
stdin.close()
session.shutdown_write()
writes = []
continue
sent_bytes = session.send(data_to_send)
data_to_send = data_to_send[sent_bytes:]
if session.exit_status_ready():
break
if self._is_timed_out(start_time):
raise exceptions.TimeoutException('SSH Timeout')
if timeout and (time.time() - timeout) > start_time:
args = {'cmd': cmd, 'host': self.host}
raise SSHTimeout(_('Timeout executing command '
'"%(cmd)s" on host %(host)s') % args)
if e:
raise SSHError('Socket error.')
exit_status = session.recv_exit_status()
if 0 != exit_status:
raise exceptions.SSHError(
'SSHExecCommandFailed with exit_status %s'
% exit_status)
self.client.close()
if 0 != exit_status and raise_on_error:
fmt = _('Command "%(cmd)s" failed with exit_status %(status)d.')
details = fmt % {'cmd': cmd, 'status': exit_status}
if stderr_data:
details += _(' Last stderr data: "%s".') % stderr_data
raise SSHError(details)
return exit_status
def upload(self, source, destination):
"""Upload the specified file to the server."""
if destination.startswith('~'):
destination = '/home/' + self.user + destination[1:]
self._get_ssh_connection()
ftp = self.client.open_sftp()
ftp.put(os.path.expanduser(source), destination)
ftp.close()
def execute(self, cmd, stdin=None, timeout=3600):
"""Execute the specified command on the server.
def execute_script(self, script, interpreter='/bin/sh',
get_stdout=False, get_stderr=False):
"""Execute the specified local script on the remote server."""
destination = '/tmp/' + ''.join(
random.choice(string.lowercase) for i in range(16))
:param cmd: Command to be executed.
:param stdin: Open file to be sent on process stdin.
:param timeout: Timeout for execution of the command.
self.upload(script, destination)
streams = self.execute('%s %s' % (interpreter, destination),
get_stdout=get_stdout, get_stderr=get_stderr)
self.execute('rm %s' % destination)
return streams
Return tuple (exit_status, stdout, stderr)
"""
stdout = StringIO.StringIO()
stderr = StringIO.StringIO()
exit_status = self.run(cmd, stderr=stderr,
stdout=stdout, stdin=stdin,
timeout=timeout, raise_on_error=False)
stdout.seek(0)
stderr.seek(0)
return (exit_status, stdout.read(), stderr.read())
def wait(self, timeout=120, interval=1):
"""Wait for the host will be available via ssh."""
@ -174,10 +249,8 @@ class SSH(object):
while True:
try:
return self.execute('uname')
except (socket.error, exceptions.SSHError) as e:
LOG.debug(
_('Ssh is still unavailable. (Exception was: %s)') % e)
except (socket.error, SSHError) as e:
LOG.debug(_('Ssh is still unavailable: %r') % e)
time.sleep(interval)
if self._is_timed_out(start_time, timeout):
raise exceptions.TimeoutException(
_('SSH Timeout waiting for "%s"') % self.ip)
if time.time() > (start_time + timeout):
raise SSHTimeout(_('Timeout waiting for "%s"') % self.host)

View File

@ -29,10 +29,15 @@ class NovaServersTestCase(test.TestCase):
@mock.patch("json.loads")
@mock.patch("rally.benchmark.base.Scenario.clients")
@mock.patch("rally.sshutils.SSH.execute_script")
def _verify_boot_runcommand_delete_server(self, mock_ssh_execute_script,
@mock.patch("rally.sshutils.SSH.execute")
@mock.patch("rally.sshutils.SSH.wait")
@mock.patch("rally.sshutils.SSH._get_pkey")
@mock.patch("rally.benchmark.scenarios.nova.servers.open", create=True)
def _verify_boot_runcommand_delete_server(self, mock_open, mock__get_pkey,
mock_wait, mock_execute,
mock_base_clients,
mock_json_loads):
mock_open.return_value = "fake_script"
fake_server = fakes.FakeServer()
fake_server.addresses = dict(
private=[dict(
@ -40,12 +45,13 @@ class NovaServersTestCase(test.TestCase):
addr="1.2.3.4"
)]
)
scenario = servers.NovaServers()
scenario._boot_server = mock.MagicMock(return_value=fake_server)
scenario._generate_random_name = mock.MagicMock(return_value="name")
scenario._delete_server = mock.MagicMock()
mock_ssh_execute_script.return_value = ('stdout', 'stderr')
mock_execute.return_value = (0, 'stdout', 'stderr')
mock_base_clients.return_value = dict(private='private-key-string')
scenario.boot_runcommand_delete_server("img", 0, "script_path",
@ -54,10 +60,8 @@ class NovaServersTestCase(test.TestCase):
scenario._boot_server.assert_called_once_with("name", "img", 0,
fakearg="f",
key_name='rally_ssh_key')
mock_ssh_execute_script.assert_called_once_with(
script="script_path", interpreter="/bin/bash", get_stdout=True,
get_stderr=True)
mock_execute.assert_called_once_with("/bin/bash", stdin="fake_script")
mock_open.assert_called_once_with("script_path", "rb")
mock_json_loads.assert_called_once_with('stdout')
scenario._delete_server.assert_called_once_with(fake_server)

View File

@ -54,12 +54,16 @@ class DevstackEngineTestCase(test.BaseTestCase):
def test_construct(self):
self.assertEqual(self.engine.localrc['ADMIN_PASSWORD'], 'secret')
def test_prepare_server(self):
@mock.patch('rally.deploy.engines.devstack.open', create=True)
def test_prepare_server(self, m_open):
m_open.return_value = 'fake_file'
server = mock.Mock()
self.engine.prepare_server(server)
filename = server.ssh.execute_script.mock_calls[0][1][0]
server.ssh.run.assert_called_once_with('/bin/sh -e', stdin='fake_file')
filename = m_open.mock_calls[0][1][0]
self.assertTrue(filename.endswith('rally/deploy/engines/'
'devstack/install.sh'))
self.assertEqual([mock.call(filename, 'rb')], m_open.mock_calls)
@mock.patch('rally.deploy.engines.devstack.open', create=True)
@mock.patch('rally.serverprovider.provider.Server')
@ -87,34 +91,26 @@ class DevstackEngineTestCase(test.BaseTestCase):
'tenant_name': 'admin',
})
@mock.patch('rally.deploy.engines.devstack.os')
@mock.patch('rally.deploy.engines.devstack.tempfile')
@mock.patch('rally.deploy.engines.devstack.open', create=True)
def test_configure_devstack(self, m_open, m_tmpf, m_os):
m_tmpf.mkstemp.return_value = (42, 'tmpnam')
fake_file = mock.Mock()
m_open.return_value = fake_file
@mock.patch('rally.deploy.engines.devstack.StringIO.StringIO')
def test_configure_devstack(self, m_sio):
m_sio.return_value = fake_localrc = mock.Mock()
server = mock.Mock()
self.engine.localrc = {'k1': 'v1', 'k2': 'v2'}
self.engine.configure_devstack(server)
calls = [
mock.call.ssh.execute('git', 'clone', DEVSTACK_REPO),
mock.call.ssh.upload('tmpnam', '~/devstack/localrc'),
mock.call.ssh.run('git clone https://github.com/'
'openstack-dev/devstack.git'),
mock.call.ssh.run('cat > ~/devstack/localrc', stdin=fake_localrc)
]
self.assertEqual(calls, server.mock_calls)
fake_file.asser_has_calls([
fake_localrc.asser_has_calls([
mock.call.write('k1=v1\n'),
mock.call.write('k2=v2\n'),
])
os_calls = [
mock.call.close(42),
mock.call.unlink('tmpnam'),
]
self.assertEqual(os_calls, m_os.mock_calls)
def test_start_devstack(self):
server = mock.Mock()
self.assertTrue(self.engine.start_devstack(server))
server.ssh.execute.assert_called_once_with('~/devstack/stack.sh')
server.ssh.run.assert_called_once_with('~/devstack/stack.sh')

View File

@ -27,29 +27,25 @@ MOD_NAME = 'rally.serverprovider.providers.lxc.'
class HelperFunctionsTestCase(test.BaseTestCase):
def test__get_script_path(self):
full_path = lxc._get_script_path('script.sh')
self.assertTrue(full_path.endswith('rally/serverprovider/'
'providers/lxc/script.sh'))
@mock.patch(MOD_NAME + 'open', create=True, return_value='fake_script')
def test__get_script(self, m_open):
script = lxc._get_script('script.sh')
self.assertEqual('fake_script', script)
path = m_open.mock_calls[0][1][0]
mode = m_open.mock_calls[0][1][1]
self.assertTrue(path.endswith('rally/serverprovider/providers'
'/lxc/script.sh'))
self.assertEqual('rb', mode)
@mock.patch(MOD_NAME + '_get_script_path', return_value='fake_path')
@mock.patch(MOD_NAME + 'tempfile')
@mock.patch(MOD_NAME + 'open', create=True)
def test__write_script_from_template(self, m_open, m_tempfile, m_gsp):
fake_tempfile = mock.Mock()
m_tempfile.NamedTemporaryFile.return_value = fake_tempfile
fake_file = mock.Mock()
fake_data = mock.Mock()
fake_data.format.return_value = 'fake_formatted_data'
fake_file.read.return_value = fake_data
m_open.return_value = fake_file
retval = lxc._write_script_from_template('script', key='value')
m_gsp.assert_called_once_with('script')
m_open.assert_called_once_with('fake_path')
m_tempfile.NamedTemporaryFile.assert_called_once_with(delete=False)
fake_data.format.assert_called_once_with(key='value')
fake_tempfile.write.assert_called_once_with('fake_formatted_data')
self.assertEqual(fake_tempfile.name, retval)
@mock.patch(MOD_NAME + '_get_script', return_value='fake_script')
@mock.patch(MOD_NAME + 'StringIO.StringIO')
def test__get_script_from_template(self, m_sio, m_gs):
m_gs.return_value = fake_script = mock.Mock()
fake_script.read.return_value = 'fake_data {k1} {k2}'
m_sio.return_value = 'fake_formatted_script'
script = lxc._get_script_from_template('fake_tpl', k1='v1', k2='v2')
self.assertEqual('fake_formatted_script', script)
m_sio.assert_called_once_with('fake_data v1 v2')
class LxcHostTestCase(test.BaseTestCase):
@ -81,14 +77,12 @@ class LxcHostTestCase(test.BaseTestCase):
self.server.ssh.execute.side_effect = exceptions.SSHError()
self.assertEqual('dir', self.host.backingstore)
@mock.patch(MOD_NAME + '_get_script_path', return_value='fake_sp')
@mock.patch(MOD_NAME + 'os.unlink')
@mock.patch(MOD_NAME + 'tempfile')
def test_prepare(self, m_tempfile, m_unlink, m_gsp):
@mock.patch(MOD_NAME + 'StringIO.StringIO')
@mock.patch(MOD_NAME + '_get_script', return_value='fake_script')
def test_prepare(self, m_gs, m_sio):
m_sio.return_value = fake_conf = mock.Mock()
self.host.create_local_tunnels = mock.Mock()
self.host.create_remote_tunnels = mock.Mock()
fake_tempfile = mock.Mock()
m_tempfile.NamedTemporaryFile.return_value = fake_tempfile
self.host.prepare()
@ -102,49 +96,38 @@ class LxcHostTestCase(test.BaseTestCase):
mock.call('USE_LXC_BRIDGE="true"\n')
]
for call in write_calls:
fake_tempfile.write.assert_has_calls(call)
self.server.ssh.upload.assert_called_once_with(fake_tempfile.name,
'/tmp/.lxc_default')
self.server.ssh.execute_script.assert_called_once_with('fake_sp')
m_unlink.assert_called_once_with(fake_tempfile.name)
fake_conf.write.assert_has_calls(call)
ssh_calls = [mock.call.run('cat > /tmp/.lxc_default', stdin=fake_conf),
mock.call.run('/bin/sh', stdin='fake_script')]
self.assertEqual(ssh_calls, self.server.ssh.mock_calls)
self.host.create_local_tunnels.assert_called_once()
self.host.create_remote_tunnels.assert_called_once()
@mock.patch(MOD_NAME + 'os.unlink')
@mock.patch(MOD_NAME + '_write_script_from_template')
def test_create_local_tunnels(self, m_ws, m_unlink):
m_ws.side_effect = ['1', '2']
@mock.patch(MOD_NAME + '_get_script_from_template')
def test_create_local_tunnels(self, m_gs, m_unlink):
m_gs.side_effect = ['s1', 's2']
self.host.create_local_tunnels()
ws_calls = [
gs_calls = [
mock.call('tunnel-local.sh', local='fake_server_ip',
net=netaddr.IPNetwork('10.1.1.0/24'), remote='1.1.1.1'),
mock.call('tunnel-local.sh', local='fake_server_ip',
net=netaddr.IPNetwork('10.1.1.0/24'), remote='2.2.2.2'),
]
self.assertEqual(ws_calls, m_ws.mock_calls)
self.assertEqual([mock.call('1'), mock.call('2')],
self.server.ssh.execute_script.mock_calls)
self.assertEqual(gs_calls, m_gs.mock_calls)
self.assertEqual([mock.call('/bin/sh -e', stdin='s1'),
mock.call('/bin/sh -e', stdin='s2')],
self.server.ssh.run.mock_calls)
@mock.patch(MOD_NAME + 'os.unlink')
@mock.patch(MOD_NAME + '_write_script_from_template')
def test_create_remote_tunnels(self, m_ws, m_unlink):
m_ws.side_effect = ['1', '2']
@mock.patch(MOD_NAME + '_get_script_from_template')
def test_create_remote_tunnels(self, m_get_script):
m_get_script.side_effect = ['s1', 's2']
fake_server = mock.Mock()
self.host._get_server_with_ip = mock.Mock(return_value=fake_server)
self.host.create_remote_tunnels()
ws_calls = [
mock.call('tunnel-remote.sh', local='1.1.1.1',
net=netaddr.IPNetwork('10.1.1.0/24'),
remote='fake_server_ip'),
mock.call('tunnel-remote.sh', local='2.2.2.2',
net=netaddr.IPNetwork('10.1.1.0/24'),
remote='fake_server_ip'),
]
self.assertEqual(ws_calls, m_ws.mock_calls)
self.assertEqual([mock.call('1'), mock.call('2')],
fake_server.ssh.execute_script.mock_calls)
self.assertEqual([mock.call('/bin/sh -e', stdin='s1'),
mock.call('/bin/sh -e', stdin='s2')],
fake_server.ssh.run.mock_calls)
def test_delete_tunnels(self):
s1 = mock.Mock()
@ -162,64 +145,58 @@ class LxcHostTestCase(test.BaseTestCase):
def test_get_ip(self, m_sleep):
s1 = 'link/ether fe:54:00:d3:f5:98 brd ff:ff:ff:ff:ff:ff'
s2 = s1 + '\n inet 10.20.0.1/24 scope global br1'
self.host.server.ssh.execute.side_effect = [(s1, ''), (s2, '')]
self.host.server.ssh.execute.side_effect = [(0, s1, ''), (0, s2, '')]
ip = self.host.get_ip('name')
self.assertEqual('10.20.0.1', ip)
self.assertEqual([mock.call('lxc-attach -n name ip addr list dev eth0',
get_stdout=True)] * 2,
self.assertEqual([mock.call('lxc-attach -n name ip'
' addr list dev eth0')] * 2,
self.host.server.ssh.execute.mock_calls)
def test_create_container(self):
self.host.configure_container = mock.Mock()
self.host._backingstore = 'btrfs'
self.host.create_container('name', 'dist')
self.server.ssh.execute.assert_called_once_with(
'lxc-create', '-B', 'btrfs', '-n', 'name', '-t', 'dist')
self.server.ssh.run.assert_called_once_with(
'lxc-create -B btrfs -n name -t dist')
self.assertEqual(['name'], self.host.containers)
self.host.configure_container.assert_called_once_with('name')
#check with no btrfs
self.host._backingstore = 'dir'
self.host.create_container('name', 'dist')
self.assertEqual(mock.call('lxc-create', '-B', 'dir', '-n',
'name', '-t', 'dist'),
self.server.ssh.execute.mock_calls[1])
self.assertEqual(mock.call('lxc-create -B dir -n name -t dist'),
self.server.ssh.run.mock_calls[1])
def test_create_clone(self):
self.host._backingstore = 'btrfs'
self.host.configure_container = mock.Mock()
self.host.create_clone('name', 'src')
self.server.ssh.execute.assert_called_once_with('lxc-clone',
'--snapshot',
'-o', 'src',
'-n', 'name')
self.server.ssh.execute.assert_called_once_with('lxc-clone --snapshot'
' -o src -n name')
self.assertEqual(['name'], self.host.containers)
#check with no btrfs
self.host._backingstore = 'dir'
self.host.create_clone('name', 'src')
self.assertEqual(mock.call('lxc-clone', '-o', 'src', '-n', 'name'),
self.assertEqual(mock.call('lxc-clone -o src -n name'),
self.server.ssh.execute.mock_calls[1])
@mock.patch(MOD_NAME + 'os.path.join')
@mock.patch(MOD_NAME + '_get_script_path')
def test_configure_container(self, m_gsp, m_join):
m_gsp.return_value = 'fake_script'
@mock.patch(MOD_NAME + '_get_script')
def test_configure_container(self, m_gs, m_join):
m_gs.return_value = 'fake_script'
m_join.return_value = 'fake_path'
self.server.ssh.execute.return_value = 0, '', ''
self.host.configure_container('name')
calls = [
mock.call.upload('fake_script', '/tmp/.rally_cont_conf.sh'),
mock.call.execute('/bin/sh', '/tmp/.rally_cont_conf.sh',
'fake_path'),
]
self.assertEqual(calls, self.server.ssh.mock_calls)
self.server.ssh.run.assert_called_once_with(
'/bin/sh -e -s fake_path', stdin='fake_script')
def test_start_containers(self):
self.host.containers = ['c1', 'c2']
self.host.start_containers()
calls = [mock.call('lxc-start -d -n c1'),
mock.call('lxc-start -d -n c2')]
self.assertEqual(calls, self.server.ssh.execute.mock_calls)
self.assertEqual(calls, self.server.ssh.run.mock_calls)
def test_stop_containers(self):
self.host.containers = ['c1', 'c2']
@ -228,7 +205,7 @@ class LxcHostTestCase(test.BaseTestCase):
mock.call('lxc-stop -n c1'),
mock.call('lxc-stop -n c2'),
]
self.assertEqual(calls, self.server.ssh.execute.mock_calls)
self.assertEqual(calls, self.server.ssh.run.mock_calls)
def test_destroy_containers(self):
self.host.containers = ['c1', 'c2']
@ -237,7 +214,7 @@ class LxcHostTestCase(test.BaseTestCase):
mock.call('lxc-stop -n c1'), mock.call('lxc-destroy -n c1'),
mock.call('lxc-stop -n c2'), mock.call('lxc-destroy -n c2'),
]
self.assertEqual(calls, self.server.ssh.execute.mock_calls)
self.assertEqual(calls, self.server.ssh.run.mock_calls)
@mock.patch(MOD_NAME + 'provider.Server.from_credentials')
def test_get_server_object(self, m_fc):

View File

@ -14,101 +14,251 @@
# under the License.
import mock
import os
from rally import exceptions
from rally import sshutils
from tests import test
class FakeParamikoException(Exception):
pass
class SSHTestCase(test.TestCase):
"""Test all small SSH methods."""
def setUp(self):
super(SSHTestCase, self).setUp()
self.ssh = sshutils.SSH('example.net', 'root')
self.channel = mock.Mock()
self.channel.recv.return_value = 'ok'
self.channel.recv_stderr.return_value = 'error'
self.channel.recv_exit_status.return_value = 0
self.transport = mock.Mock()
self.transport.open_session = mock.MagicMock(return_value=self.channel)
self.policy = mock.Mock()
self.client = mock.Mock()
self.client.get_transport = mock.MagicMock(return_value=self.transport)
self.ssh = sshutils.SSH('root', 'example.net')
self.channel.exit_status_ready.return_value = True
self.channel.recv_ready.side_effect = [True, False, False]
self.channel.recv_stderr_ready.side_effect = [True, False, False]
@mock.patch('rally.sshutils.SSH._get_pkey')
def test_construct(self, m_pkey):
m_pkey.return_value = 'pkey'
ssh = sshutils.SSH('root', 'example.net', port=33, pkey='key',
key_filename='kf', password='secret')
m_pkey.assert_called_once_with('key')
self.assertEqual('root', ssh.user)
self.assertEqual('example.net', ssh.host)
self.assertEqual(33, ssh.port)
self.assertEqual('pkey', ssh.pkey)
self.assertEqual('kf', ssh.key_filename)
self.assertEqual('secret', ssh.password)
def test_construct_default(self):
self.assertEqual('root', self.ssh.user)
self.assertEqual('example.net', self.ssh.host)
self.assertEqual(22, self.ssh.port)
self.assertIsNone(self.ssh.pkey)
self.assertIsNone(self.ssh.key_filename)
self.assertIsNone(self.ssh.password)
@mock.patch('rally.sshutils.paramiko')
@mock.patch('rally.sshutils.select')
def test_generator(self, st, pk):
pk.SSHClient.return_value = self.client
st.select.return_value = ([], [], [])
chunks = list(self.ssh.execute_generator('ps ax'))
self.assertEqual([(1, 'ok'), (2, 'error')], chunks)
def test__get_pkey_invalid(self, m_paramiko):
m_paramiko.SSHException = FakeParamikoException
rsa = m_paramiko.rsakey.RSAKey
dss = m_paramiko.dsskey.DSSKey
rsa.from_private_key.side_effect = m_paramiko.SSHException
dss.from_private_key.side_effect = m_paramiko.SSHException
self.assertRaises(sshutils.SSHError, self.ssh._get_pkey, 'key')
@mock.patch('rally.sshutils.StringIO')
@mock.patch('rally.sshutils.paramiko')
@mock.patch('rally.sshutils.select')
def test_execute(self, st, pk):
pk.SSHClient.return_value = self.client
st.select.return_value = ([], [], [])
stdout, stderr = self.ssh.execute('uname')
self.assertEqual('', stdout)
self.assertEqual('', stderr)
expected = [mock.call.exec_command('uname'),
mock.call.recv_ready(),
mock.call.recv(4096),
mock.call.recv_ready(),
mock.call.recv_stderr_ready(),
mock.call.recv_stderr(4096),
mock.call.recv_ready(),
mock.call.recv_stderr_ready(),
mock.call.exit_status_ready(),
mock.call.recv_exit_status()]
self.assertEqual(expected, self.channel.mock_calls)
def test__get_pkey_dss(self, m_paramiko, m_stringio):
m_paramiko.SSHException = FakeParamikoException
m_stringio.StringIO.return_value = 'string_key'
m_paramiko.dsskey.DSSKey.from_private_key.return_value = 'dss_key'
rsa = m_paramiko.rsakey.RSAKey
rsa.from_private_key.side_effect = m_paramiko.SSHException
key = self.ssh._get_pkey('key')
dss_calls = m_paramiko.dsskey.DSSKey.from_private_key.mock_calls
self.assertEqual([mock.call('string_key')], dss_calls)
self.assertEqual(key, 'dss_key')
m_stringio.StringIO.assert_called_once_with('key')
@mock.patch('rally.sshutils.StringIO')
@mock.patch('rally.sshutils.paramiko')
def test_upload_file(self, pk):
pk.AutoAddPolicy.return_value = self.policy
self.ssh.upload('/tmp/s', '/tmp/d')
def test__get_pkey_rsa(self, m_paramiko, m_stringio):
m_paramiko.SSHException = FakeParamikoException
m_stringio.StringIO.return_value = 'string_key'
m_paramiko.rsakey.RSAKey.from_private_key.return_value = 'rsa_key'
dss = m_paramiko.dsskey.DSSKey
dss.from_private_key.side_effect = m_paramiko.SSHException
key = self.ssh._get_pkey('key')
rsa_calls = m_paramiko.rsakey.RSAKey.from_private_key.mock_calls
self.assertEqual([mock.call('string_key')], rsa_calls)
self.assertEqual(key, 'rsa_key')
m_stringio.StringIO.assert_called_once_with('key')
expected = [mock.call.set_missing_host_key_policy(self.policy),
mock.call.connect(hostname='example.net', username='root',
key_filename=os.path.expanduser(
'~/.ssh/id_rsa'), port=22),
mock.call.open_sftp(),
mock.call.open_sftp().put('/tmp/s', '/tmp/d'),
mock.call.open_sftp().close()]
@mock.patch('rally.sshutils.SSH._get_pkey')
@mock.patch('rally.sshutils.paramiko')
def test__get_client(self, m_paramiko, m_pkey):
m_pkey.return_value = 'key'
fake_client = mock.Mock()
m_paramiko.SSHClient.return_value = fake_client
m_paramiko.AutoAddPolicy.return_value = 'autoadd'
self.assertEqual(pk.SSHClient().mock_calls, expected)
ssh = sshutils.SSH('admin', 'example.net', pkey='key')
client = ssh._get_client()
@mock.patch('rally.sshutils.SSH.execute')
@mock.patch('rally.sshutils.SSH.upload')
@mock.patch('rally.sshutils.random.choice')
def test_execute_script_new(self, rc, up, ex):
rc.return_value = 'a'
self.ssh.execute_script('/bin/script')
self.assertEqual(fake_client, client)
client_calls = [
mock.call.set_missing_host_key_policy('autoadd'),
mock.call.connect('example.net', username='admin',
port=22, pkey='key', key_filename=None,
password=None),
]
self.assertEqual(client_calls, client.mock_calls)
up.assert_called_once_with('/bin/script', '/tmp/aaaaaaaaaaaaaaaa')
ex.assert_has_calls([
mock.call('/bin/sh /tmp/aaaaaaaaaaaaaaaa',
get_stderr=False, get_stdout=False),
mock.call('rm /tmp/aaaaaaaaaaaaaaaa')
])
def test_close(self):
with mock.patch.object(self.ssh, '_client') as m_client:
self.ssh.close()
m_client.close.assert_called_once()
self.assertFalse(self.ssh._client)
@mock.patch('rally.sshutils.SSH.execute')
def test_wait(self, ex):
self.ssh.wait()
@mock.patch('rally.sshutils.StringIO')
def test_execute(self, m_stringio):
m_stringio.StringIO.side_effect = stdio = [mock.Mock(), mock.Mock()]
stdio[0].read.return_value = 'stdout fake data'
stdio[1].read.return_value = 'stderr fake data'
with mock.patch.object(self.ssh, 'run', return_value=0) as m_run:
status, stdout, stderr = self.ssh.execute('cmd',
stdin='fake_stdin',
timeout=43)
m_run.assert_called_once_with('cmd', stdin='fake_stdin',
stdout=stdio[0],
stderr=stdio[1], timeout=43,
raise_on_error=False)
self.assertEqual(0, status)
self.assertEqual('stdout fake data', stdout)
self.assertEqual('stderr fake data', stderr)
@mock.patch('rally.sshutils.time')
@mock.patch('rally.sshutils.SSH.execute')
def test_wait_timeout(self, ex, mock_time):
mock_time.time.side_effect = [1, 10]
ex.side_effect = exceptions.SSHError
self.assertRaises(exceptions.TimeoutException,
self.ssh.wait, 1, 1)
mock_time.sleep.assert_called_once_with(1)
def test_wait_timeout(self, m_time):
m_time.time.side_effect = [1, 50, 150]
self.ssh.execute = mock.Mock(side_effect=[sshutils.SSHError,
sshutils.SSHError,
0])
self.assertRaises(sshutils.SSHTimeout, self.ssh.wait)
self.assertEqual([mock.call('uname')] * 2, self.ssh.execute.mock_calls)
@mock.patch('rally.sshutils.time')
def test_wait(self, m_time):
m_time.time.side_effect = [1, 50, 100]
self.ssh.execute = mock.Mock(side_effect=[sshutils.SSHError,
sshutils.SSHError,
0])
self.ssh.wait()
self.assertEqual([mock.call('uname')] * 3, self.ssh.execute.mock_calls)
class SSHRunTestCase(test.TestCase):
"""Test SSH.run method in different aspects.
Also tested method 'execute'.
"""
def setUp(self):
super(SSHRunTestCase, self).setUp()
self.fake_client = mock.Mock()
self.fake_session = mock.Mock()
self.fake_transport = mock.Mock()
self.fake_transport.open_session.return_value = self.fake_session
self.fake_client.get_transport.return_value = self.fake_transport
self.fake_session.recv_ready.return_value = False
self.fake_session.recv_stderr_ready.return_value = False
self.fake_session.send_ready.return_value = False
self.fake_session.exit_status_ready.return_value = True
self.fake_session.recv_exit_status.return_value = 0
self.ssh = sshutils.SSH('admin', 'example.net')
self.ssh._get_client = mock.Mock(return_value=self.fake_client)
@mock.patch('rally.sshutils.select')
def test_execute(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_ready.side_effect = [1, 0, 0]
self.fake_session.recv_stderr_ready.side_effect = [1, 0]
self.fake_session.recv.return_value = 'ok'
self.fake_session.recv_stderr.return_value = 'error'
self.fake_session.exit_status_ready.return_value = 1
self.fake_session.recv_exit_status.return_value = 127
self.assertEqual((127, 'ok', 'error'), self.ssh.execute('cmd'))
self.fake_session.exec_command.assert_called_once_with('cmd')
@mock.patch('rally.sshutils.select')
def test_run(self, m_select):
m_select.select.return_value = ([], [], [])
self.assertEqual(0, self.ssh.run('cmd'))
@mock.patch('rally.sshutils.select')
def test_run_nonzero_status(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_exit_status.return_value = 1
self.assertRaises(sshutils.SSHError, self.ssh.run, 'cmd')
self.assertEqual(1, self.ssh.run('cmd', raise_on_error=False))
@mock.patch('rally.sshutils.select')
def test_run_stdout(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_ready.side_effect = [True, True, False]
self.fake_session.recv.side_effect = ['ok1', 'ok2']
stdout = mock.Mock()
self.ssh.run('cmd', stdout=stdout)
self.assertEqual([mock.call('ok1'), mock.call('ok2')],
stdout.write.mock_calls)
@mock.patch('rally.sshutils.select')
def test_run_stderr(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_stderr_ready.side_effect = [True, False]
self.fake_session.recv_stderr.return_value = 'error'
stderr = mock.Mock()
self.ssh.run('cmd', stderr=stderr)
stderr.write.assert_called_once_with('error')
@mock.patch('rally.sshutils.select')
def test_run_stdin(self, m_select):
"""Test run method with stdin.
Third send call was called with 'e2' because only 3 bytes was sent
by second call. So remainig 2 bytes of 'line2' was sent by third call.
"""
m_select.select.return_value = ([], [], [])
self.fake_session.exit_status_ready.side_effect = [0, 0, 0, True]
self.fake_session.send_ready.return_value = True
self.fake_session.send.side_effect = [5, 3, 2]
fake_stdin = mock.Mock()
fake_stdin.read.side_effect = ['line1', 'line2', '']
fake_stdin.closed = False
def close():
fake_stdin.closed = True
fake_stdin.close = mock.Mock(side_effect=close)
self.ssh.run('cmd', stdin=fake_stdin)
call = mock.call
send_calls = [call('line1'), call('line2'), call('e2')]
self.assertEqual(send_calls, self.fake_session.send.mock_calls)
@mock.patch('rally.sshutils.select')
def test_run_select_error(self, m_select):
self.fake_session.exit_status_ready.return_value = False
m_select.select.return_value = ([], [], [True])
self.assertRaises(sshutils.SSHError, self.ssh.run, 'cmd')
@mock.patch('rally.sshutils.time')
@mock.patch('rally.sshutils.select')
def test_run_timemout(self, m_select, m_time):
m_time.time.side_effect = [1, 3700]
m_select.select.return_value = ([], [], [])
self.fake_session.exit_status_ready.return_value = False
self.assertRaises(sshutils.SSHTimeout, self.ssh.run, 'cmd')
@mock.patch('rally.sshutils.select')
def test__run_client_closed_on_error(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_ready.return_value = True
self.fake_session.recv.side_effect = IOError
self.assertRaises(IOError, self.ssh._run, self.fake_client, 'cmd')
self.fake_client.close.assert_called_once()