Sshutils refactoring

Paramiko sftp client is not eventlet friendly. Also it doesn't work
with some servers, so upload/download methods has been removed.
Downloading or uploading can be done in other way (see examples).

Added ability to pass process stdin, which makes possible to
do on remote host anything at all.

Much more clean and usable API.

Change-Id: If357a2878c2c60646a975c386a0fe2f1616aec95
This commit is contained in:
Sergey Skripnick 2013-12-24 15:29:15 +02:00
parent 3d420c286b
commit 936ba2efd8
10 changed files with 557 additions and 382 deletions

View File

@ -20,7 +20,6 @@ import random
from rally.benchmark.scenarios.cinder import utils as cinder_utils from rally.benchmark.scenarios.cinder import utils as cinder_utils
from rally.benchmark.scenarios.nova import utils from rally.benchmark.scenarios.nova import utils
from rally.benchmark.scenarios import utils as scenario_utils from rally.benchmark.scenarios import utils as scenario_utils
from rally.benchmark import utils as benchmark_utils
from rally.benchmark import validation from rally.benchmark import validation
from rally import exceptions as rally_exceptions from rally import exceptions as rally_exceptions
from rally.openstack.common.gettextutils import _ # noqa from rally.openstack.common.gettextutils import _ # noqa
@ -102,43 +101,23 @@ class NovaServers(utils.NovaScenario,
) )
server_ip = [ip for ip in server.addresses[network] if server_ip = [ip for ip in server.addresses[network] if
ip['version'] == ip_version][0]['addr'] ip['version'] == ip_version][0]['addr']
ssh = sshutils.SSH(ip=server_ip, port=port, user=username, ssh = sshutils.SSH(username, server_ip, port=port,
key=self.clients('ssh_key_pair')['private'], pkey=self.clients('ssh_key_pair')['private'])
key_type='string') ssh.wait()
code, out, err = ssh.execute(interpreter, stdin=open(script, 'rb'))
for retry in range(retries): if code:
try: LOG.error(_('Error running script on instance via SSH. '
LOG.debug(_('Execute script on server attempt ' 'Error: %s') % err)
'%(retry)i/%(retries)i') % dict(retry=retry, try:
retries=retries)) out = json.loads(out)
streams = list(ssh.execute_script(script=script, except ValueError:
interpreter=interpreter, LOG.warning(_('Script %s did not output valid JSON. ') % script)
get_stdout=True,
get_stderr=True))
#NOTE(hughsaunders): Decode JSON script output
streams[sshutils.SSH.STDOUT_INDEX]\
= json.loads(streams[sshutils.SSH.STDOUT_INDEX])
break
except (rally_exceptions.SSHError,
rally_exceptions.TimeoutException, IOError) as e:
LOG.debug(_('Error running script on instance via SSH. '
'%(id)s/%(ip)s Attempt:%(retry)i, '
'Error: %(error)s') % dict(
id=server.id, ip=server_ip, retry=retry,
error=benchmark_utils.format_exc(e)))
self.sleep_between(5, 5)
except ValueError:
LOG.error(_('Script %(script)s did not output valid JSON. ')
% dict(script=script))
self._delete_server(server) self._delete_server(server)
LOG.debug(_('Output streams from in-instance script execution: ' LOG.debug(_('Output streams from in-instance script execution: '
'stdout: %(stdout)s, stderr: $(stderr)s') % dict( 'stdout: %(stdout)s, stderr: $(stderr)s') % dict(
stdout=str(streams[sshutils.SSH.STDOUT_INDEX]), stdout=out, stderr=err))
stderr=str(streams[sshutils.SSH.STDERR_INDEX]))) return {'data': out, 'errors': err}
return dict(data=streams[sshutils.SSH.STDOUT_INDEX],
errors=streams[sshutils.SSH.STDERR_INDEX])
@validation.add_validator(validation.flavor_exists("flavor_id")) @validation.add_validator(validation.flavor_exists("flavor_id"))
@validation.add_validator(validation.image_exists("image_id")) @validation.add_validator(validation.image_exists("image_id"))

View File

@ -14,7 +14,7 @@
# under the License. # under the License.
import os import os
import tempfile import StringIO
from rally.deploy import engine from rally.deploy import engine
from rally import objects from rally import objects
@ -78,7 +78,7 @@ class DevstackEngine(engine.EngineFactory):
def prepare_server(self, server): def prepare_server(self, server):
script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), script_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'devstack', 'install.sh')) 'devstack', 'install.sh'))
server.ssh.execute_script(script_path) server.ssh.run('/bin/sh -e', stdin=open(script_path, 'rb'))
@utils.log_deploy_wrapper(LOG.info, _("Deploy devstack")) @utils.log_deploy_wrapper(LOG.info, _("Deploy devstack"))
def deploy(self): def deploy(self):
@ -103,18 +103,15 @@ class DevstackEngine(engine.EngineFactory):
@utils.log_deploy_wrapper(LOG.info, _("Configure devstack")) @utils.log_deploy_wrapper(LOG.info, _("Configure devstack"))
def configure_devstack(self, server): def configure_devstack(self, server):
devstack_repo = self.config.get('devstack_repo', DEVSTACK_REPO) devstack_repo = self.config.get('devstack_repo', DEVSTACK_REPO)
server.ssh.execute('git', 'clone', devstack_repo) server.ssh.run('git clone %s' % devstack_repo)
fd, config_path = tempfile.mkstemp() localrc = StringIO.StringIO()
config_file = open(config_path, "w")
for k, v in self.localrc.iteritems(): for k, v in self.localrc.iteritems():
config_file.write('%s=%s\n' % (k, v)) localrc.write('%s=%s\n' % (k, v))
config_file.close() localrc.seek(0)
os.close(fd) server.ssh.run("cat > ~/devstack/localrc", stdin=localrc)
server.ssh.upload(config_path, "~/devstack/localrc")
os.unlink(config_path)
return True return True
@utils.log_deploy_wrapper(LOG.info, _("Run devstack")) @utils.log_deploy_wrapper(LOG.info, _("Run devstack"))
def start_devstack(self, server): def start_devstack(self, server):
server.ssh.execute('~/devstack/stack.sh') server.ssh.run('~/devstack/stack.sh')
return True return True

View File

@ -16,6 +16,7 @@
import abc import abc
import jsonschema import jsonschema
from rally import exceptions from rally import exceptions
from rally import sshutils from rally import sshutils
from rally import utils from rally import utils
@ -32,7 +33,8 @@ class Server(utils.ImmutableMixin):
self.user = user self.user = user
self.key = key self.key = key
self.password = password self.password = password
self.ssh = sshutils.SSH(host, user, port, key) self.ssh = sshutils.SSH(user, host, key_filename=key, port=port,
password=password)
super(Server, self).__init__() super(Server, self).__init__()
def get_credentials(self): def get_credentials(self):

View File

@ -16,7 +16,7 @@
import netaddr import netaddr
import os import os
import re import re
import tempfile import StringIO
import time import time
from rally import exceptions from rally import exceptions
@ -30,17 +30,15 @@ LOG = logging.getLogger(__name__)
INET_ADDR_RE = re.compile(r' *inet ((\d+\.){3}\d+)\/\d+ .*') INET_ADDR_RE = re.compile(r' *inet ((\d+\.){3}\d+)\/\d+ .*')
def _get_script_path(filename): def _get_script(filename):
return os.path.abspath(os.path.join(os.path.dirname(__file__), path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'lxc', filename)) 'lxc', filename))
return open(path, 'rb')
def _write_script_from_template(template_filename, **kwargs): def _get_script_from_template(template_filename, **kwargs):
template = open(_get_script_path(template_filename)).read() template = _get_script(template_filename).read()
new_file = tempfile.NamedTemporaryFile(delete=False) return StringIO.StringIO(template.format(**kwargs))
new_file.write(template.format(**kwargs))
new_file.close()
return new_file.name
class LxcHost(object): class LxcHost(object):
@ -85,39 +83,35 @@ class LxcHost(object):
'LXC_DHCP_RANGE': dhcp_range, 'LXC_DHCP_RANGE': dhcp_range,
'LXC_DHCP_MAX': self.network.size - 3, 'LXC_DHCP_MAX': self.network.size - 3,
} }
config = tempfile.NamedTemporaryFile(delete=False) config = StringIO.StringIO()
for name, value in values.iteritems(): for name, value in values.iteritems():
config.write('%(name)s="%(value)s"\n' % {'name': name, config.write('%(name)s="%(value)s"\n' % {'name': name,
'value': value}) 'value': value})
config.close() config.seek(0)
self.server.ssh.upload(config.name, '/tmp/.lxc_default') self.server.ssh.run('cat > /tmp/.lxc_default', stdin=config)
os.unlink(config.name)
script = _get_script_path('lxc-install.sh') self.server.ssh.run('/bin/sh', stdin=_get_script('lxc-install.sh'))
self.server.ssh.execute_script(script)
self.create_local_tunnels() self.create_local_tunnels()
self.create_remote_tunnels() self.create_remote_tunnels()
def create_local_tunnels(self): def create_local_tunnels(self):
"""Create tunel on lxc host side.""" """Create tunel on lxc host side."""
for tunnel_to in self.config['tunnel_to']: for tunnel_to in self.config['tunnel_to']:
script = _write_script_from_template('tunnel-local.sh', script = _get_script_from_template('tunnel-local.sh',
net=self.network, net=self.network,
local=self.server.host, local=self.server.host,
remote=tunnel_to) remote=tunnel_to)
self.server.ssh.execute_script(script) self.server.ssh.run('/bin/sh -e', stdin=script)
os.unlink(script)
def create_remote_tunnels(self): def create_remote_tunnels(self):
"""Create tunel on remote side.""" """Create tunel on remote side."""
for tunnel_to in self.config['tunnel_to']: for tunnel_to in self.config['tunnel_to']:
script = _write_script_from_template('tunnel-remote.sh', script = _get_script_from_template('tunnel-remote.sh',
net=self.network, net=self.network,
local=tunnel_to, local=tunnel_to,
remote=self.server.host) remote=self.server.host)
server = self._get_server_with_ip(tunnel_to) server = self._get_server_with_ip(tunnel_to)
server.ssh.execute_script(script) server.ssh.run('/bin/sh -e', stdin=script)
os.unlink(script)
def delete_tunnels(self): def delete_tunnels(self):
for tunnel_to in self.config['tunnel_to']: for tunnel_to in self.config['tunnel_to']:
@ -130,7 +124,9 @@ class LxcHost(object):
cmd = 'lxc-attach -n %s ip addr list dev eth0' % name cmd = 'lxc-attach -n %s ip addr list dev eth0' % name
for attempt in range(1, 16): for attempt in range(1, 16):
stdout = self.server.ssh.execute(cmd, get_stdout=True)[0] code, stdout = self.server.ssh.execute(cmd)[:2]
if code:
continue
for line in stdout.splitlines(): for line in stdout.splitlines():
m = INET_ADDR_RE.match(line) m = INET_ADDR_RE.match(line)
if m: if m:
@ -140,9 +136,10 @@ class LxcHost(object):
raise exceptions.TimeoutException(msg) raise exceptions.TimeoutException(msg)
def create_container(self, name, distribution): def create_container(self, name, distribution):
self.server.ssh.execute('lxc-create', '-B', self.backingstore, args = {'backingstore': self.backingstore,
'-n', name, 'name': name, 'distribution': distribution}
'-t', distribution) self.server.ssh.run('lxc-create -B %(backingstore)s -n %(name)s'
' -t %(distribution)s' % args)
self.configure_container(name) self.configure_container(name)
self.containers.append(name) self.containers.append(name)
@ -152,28 +149,27 @@ class LxcHost(object):
if self.backingstore == 'btrfs': if self.backingstore == 'btrfs':
cmd.append('--snapshot') cmd.append('--snapshot')
cmd.extend(['-o', source, '-n', name]) cmd.extend(['-o', source, '-n', name])
self.server.ssh.execute(*cmd) self.server.ssh.execute(' '.join(cmd))
self.configure_container(name) self.configure_container(name)
self.containers.append(name) self.containers.append(name)
def configure_container(self, name): def configure_container(self, name):
path = os.path.join(self.path, name, 'rootfs') path = os.path.join(self.path, name, 'rootfs')
configure_script = _get_script_path('configure_container.sh') conf_script = _get_script('configure_container.sh')
self.server.ssh.upload(configure_script, '/tmp/.rally_cont_conf.sh') self.server.ssh.run('/bin/sh -e -s %s' % path, stdin=conf_script)
self.server.ssh.execute('/bin/sh', '/tmp/.rally_cont_conf.sh', path)
def start_containers(self): def start_containers(self):
for name in self.containers: for name in self.containers:
self.server.ssh.execute('lxc-start -d -n %s' % name) self.server.ssh.run('lxc-start -d -n %s' % name)
def stop_containers(self): def stop_containers(self):
for name in self.containers: for name in self.containers:
self.server.ssh.execute('lxc-stop -n %s' % name) self.server.ssh.run('lxc-stop -n %s' % name)
def destroy_containers(self): def destroy_containers(self):
for name in self.containers: for name in self.containers:
self.server.ssh.execute('lxc-stop -n %s' % name) self.server.ssh.run('lxc-stop -n %s' % name)
self.server.ssh.execute('lxc-destroy -n %s' % name) self.server.ssh.run('lxc-destroy -n %s' % name)
def get_server_object(self, name, wait=True): def get_server_object(self, name, wait=True):
"""Create Server object for container.""" """Create Server object for container."""
@ -257,6 +253,7 @@ class LxcProvider(provider.ProviderFactory):
host.prepare() host.prepare()
ip = str(network.ip).replace('.', '-') if network else '0' ip = str(network.ip).replace('.', '-') if network else '0'
first_name = '%s-000-%s' % (name_prefix, ip) first_name = '%s-000-%s' % (name_prefix, ip)
host.create_container(first_name, distribution) host.create_container(first_name, distribution)
for i in range(1, self.config.get('containers_per_host', 1)): for i in range(1, self.config.get('containers_per_host', 1)):
name = '%s-%03d-%s' % (name_prefix, i, ip) name = '%s-%03d-%s' % (name_prefix, i, ip)

View File

@ -1,3 +1,3 @@
ip tun add t{net.ip} mode ipip local {local} remote {remote} || true ip tun add t{net.ip} mode ipip local {local} remote {remote}
ip link set t{net.ip} up ip link set t{net.ip} up
ip route add {net} dev t{net.ip} src {local} || true ip route add {net} dev t{net.ip} src {local}

View File

@ -13,160 +13,235 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import os
"""High level ssh library.
Usage examples:
Execute command and get output:
ssh = sshclient.SSH('root', 'example.com', port=33)
status, stdout, stderr = ssh.execute('ps ax')
if status:
raise Exception('Command failed with non-zero status.')
print stdout.splitlines()
Execute command with huge output:
class PseudoFile(object):
def write(chunk):
if 'error' in chunk:
email_admin(chunk)
ssh = sshclient.SSH('root', 'example.com')
ssh.run('tail -f /var/log/syslog', stdout=PseudoFile(), timeout=False)
Execute local script on remote side:
ssh = sshclient.SSH('user', 'example.com')
status, out, err = ssh.execute('/bin/sh -s arg1 arg2',
stdin=open('~/myscript.sh', 'r'))
Upload file:
ssh = sshclient.SSH('user', 'example.com')
ssh.run('cat > ~/upload/file.gz', stdin=open('/store/file.gz', 'rb'))
Eventlet:
eventlet.monkey_patch(select=True, time=True)
or
eventlet.monkey_patch()
or
sshclient = eventlet.import_patched("opentstack.common.sshclient")
"""
import paramiko import paramiko
import random
import select import select
import socket import socket
import string
import StringIO import StringIO
import time import time
from rally import exceptions
from rally.openstack.common.gettextutils import _ from rally.openstack.common.gettextutils import _
from rally.openstack.common import log as logging from rally.openstack.common import log as logging
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class SSHError(Exception):
pass
class SSHTimeout(SSHError):
pass
class SSH(object): class SSH(object):
"""SSH common functions.""" """Represent ssh connection."""
STDOUT_INDEX = 0
STDERR_INDEX = 1
def __init__(self, ip, user, port=22, key=None, key_type="file", def __init__(self, user, host, port=22, pkey=None,
timeout=1800): key_filename=None, password=None):
"""Initialize SSH client with ip, username and the default values. """Initialize SSH client.
:param user: ssh username
:param host: hostname or ip address of remote ssh server
:param port: remote ssh port
:param pkey: RSA or DSS private key string or file object
:param key_filename: private key filename
:param password: password
timeout - the timeout for execution of the command
key - path to private key file, or string containing actual key
key_type - "file" for key path, "string" for actual key
""" """
self.ip = ip
self.port = port
self.user = user self.user = user
self.timeout = timeout self.host = host
self.client = None self.port = port
self.key = key self.pkey = self._get_pkey(pkey) if pkey else None
self.key_type = key_type self.password = password
if not self.key: self.key_filename = key_filename
#Guess location of user's private key if no key is specified. self._client = False
self.key = os.path.expanduser('~/.ssh/id_rsa')
def _get_ssh_connection(self): def _get_pkey(self, key):
self.client = paramiko.SSHClient() if isinstance(key, basestring):
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) key = StringIO.StringIO(key)
connect_params = { for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
'hostname': self.ip, try:
'port': self.port, return key_class.from_private_key(key)
'username': self.user except paramiko.SSHException:
} pass
raise SSHError('Invalid pkey')
# NOTE(hughsaunders): Set correct paramiko parameter names for each def _get_client(self):
# method of supplying a key. if self._client:
if self.key_type == 'file': return self._client
connect_params['key_filename'] = self.key try:
else: self._client = paramiko.SSHClient()
connect_params['pkey'] = paramiko.RSAKey( self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
file_obj=StringIO.StringIO(self.key)) self._client.connect(self.host, username=self.user,
port=self.port, pkey=self.pkey,
key_filename=self.key_filename,
password=self.password)
return self._client
except paramiko.SSHException as e:
message = _("Paramiko exception %(exception_type)s was raised "
"during connect. Exception value is: %(exception)r")
raise SSHError(message % {'exception': e,
'exception_type': type(e)})
self.client.connect(**connect_params) def close(self):
self._client.close()
self._client = False
def _is_timed_out(self, start_time, timeout=None): def run(self, cmd, stdin=None, stdout=None, stderr=None,
timeout = timeout if timeout else self.timeout raise_on_error=True, timeout=3600):
return (time.time() - timeout) > start_time """Execute specified command on the server.
def execute(self, *cmd, **kwargs):
"""Execute the specified command on the server.
Return tuple (stdout, stderr).
:param *cmd: Command and arguments to be executed.
:param get_stdout: Collect stdout data. Boolean.
:param get_stderr: Collect stderr data. Boolean.
:param cmd: Command to be executed.
:param stdin: Open file or string to pass to stdin.
:param stdout: Open file to connect to stdout.
:param stderr: Open file to connect to stderr.
:param raise_on_error: If False then exit code will be return. If True
then exception will be raized if non-zero code.
:param timeout: Timeout in seconds for command execution.
Default 1 hour. No timeout if set to 0.
""" """
get_stdout = kwargs.get("get_stdout", False)
get_stderr = kwargs.get("get_stderr", False)
stdout = ''
stderr = ''
for chunk in self.execute_generator(*cmd, get_stdout=get_stdout,
get_stderr=get_stderr):
if chunk[0] == 1:
stdout += chunk[1]
elif chunk[0] == 2:
stderr += chunk[1]
return (stdout, stderr)
def execute_generator(self, *cmd, **kwargs): client = self._get_client()
"""Execute the specified command on the server.
Return generator. Stdout and stderr data can be collected by chunks. if isinstance(stdin, basestring):
stdin = StringIO.StringIO(stdin)
:param *cmd: Command and arguments to be executed. return self._run(client, cmd, stdin=stdin, stdout=stdout,
:param get_stdout: Collect stdout data. Boolean. stderr=stderr, raise_on_error=raise_on_error,
:param get_stderr: Collect stderr data. Boolean. timeout=timeout)
""" def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
get_stdout = kwargs.get("get_stdout", True) raise_on_error=True, timeout=3600):
get_stderr = kwargs.get("get_stderr", True)
self._get_ssh_connection() transport = client.get_transport()
cmd = ' '.join(cmd)
transport = self.client.get_transport()
session = transport.open_session() session = transport.open_session()
session.exec_command(cmd) session.exec_command(cmd)
start_time = time.time() start_time = time.time()
data_to_send = ''
stderr_data = None
# If we have data to be sent to stdin then `select' should also
# check for stdin availability.
if stdin and not stdin.closed:
writes = [session]
else:
writes = []
while True: while True:
errors = select.select([session], [], [], 4)[2] # Block until data can be read/write.
r, w, e = select.select([session], writes, [session], 1)
if session.recv_ready(): if session.recv_ready():
data = session.recv(4096) data = session.recv(4096)
LOG.debug(data) LOG.debug(_('stdout: %r') % data)
if get_stdout: if stdout is not None:
yield (1, data) stdout.write(data)
continue continue
if session.recv_stderr_ready(): if session.recv_stderr_ready():
data = session.recv_stderr(4096) stderr_data = session.recv_stderr(4096)
LOG.debug(data) LOG.debug(_('stderr: %r') % stderr_data)
if get_stderr: if stderr is not None:
yield (2, data) stderr.write(stderr_data)
continue continue
if errors or session.exit_status_ready(): if session.send_ready():
if stdin is not None and not stdin.closed:
if not data_to_send:
data_to_send = stdin.read(4096)
if not data_to_send:
stdin.close()
session.shutdown_write()
writes = []
continue
sent_bytes = session.send(data_to_send)
data_to_send = data_to_send[sent_bytes:]
if session.exit_status_ready():
break break
if self._is_timed_out(start_time): if timeout and (time.time() - timeout) > start_time:
raise exceptions.TimeoutException('SSH Timeout') args = {'cmd': cmd, 'host': self.host}
raise SSHTimeout(_('Timeout executing command '
'"%(cmd)s" on host %(host)s') % args)
if e:
raise SSHError('Socket error.')
exit_status = session.recv_exit_status() exit_status = session.recv_exit_status()
if 0 != exit_status: if 0 != exit_status and raise_on_error:
raise exceptions.SSHError( fmt = _('Command "%(cmd)s" failed with exit_status %(status)d.')
'SSHExecCommandFailed with exit_status %s' details = fmt % {'cmd': cmd, 'status': exit_status}
% exit_status) if stderr_data:
self.client.close() details += _(' Last stderr data: "%s".') % stderr_data
raise SSHError(details)
return exit_status
def upload(self, source, destination): def execute(self, cmd, stdin=None, timeout=3600):
"""Upload the specified file to the server.""" """Execute the specified command on the server.
if destination.startswith('~'):
destination = '/home/' + self.user + destination[1:]
self._get_ssh_connection()
ftp = self.client.open_sftp()
ftp.put(os.path.expanduser(source), destination)
ftp.close()
def execute_script(self, script, interpreter='/bin/sh', :param cmd: Command to be executed.
get_stdout=False, get_stderr=False): :param stdin: Open file to be sent on process stdin.
"""Execute the specified local script on the remote server.""" :param timeout: Timeout for execution of the command.
destination = '/tmp/' + ''.join(
random.choice(string.lowercase) for i in range(16))
self.upload(script, destination) Return tuple (exit_status, stdout, stderr)
streams = self.execute('%s %s' % (interpreter, destination),
get_stdout=get_stdout, get_stderr=get_stderr) """
self.execute('rm %s' % destination) stdout = StringIO.StringIO()
return streams stderr = StringIO.StringIO()
exit_status = self.run(cmd, stderr=stderr,
stdout=stdout, stdin=stdin,
timeout=timeout, raise_on_error=False)
stdout.seek(0)
stderr.seek(0)
return (exit_status, stdout.read(), stderr.read())
def wait(self, timeout=120, interval=1): def wait(self, timeout=120, interval=1):
"""Wait for the host will be available via ssh.""" """Wait for the host will be available via ssh."""
@ -174,10 +249,8 @@ class SSH(object):
while True: while True:
try: try:
return self.execute('uname') return self.execute('uname')
except (socket.error, exceptions.SSHError) as e: except (socket.error, SSHError) as e:
LOG.debug( LOG.debug(_('Ssh is still unavailable: %r') % e)
_('Ssh is still unavailable. (Exception was: %s)') % e)
time.sleep(interval) time.sleep(interval)
if self._is_timed_out(start_time, timeout): if time.time() > (start_time + timeout):
raise exceptions.TimeoutException( raise SSHTimeout(_('Timeout waiting for "%s"') % self.host)
_('SSH Timeout waiting for "%s"') % self.ip)

View File

@ -29,10 +29,15 @@ class NovaServersTestCase(test.TestCase):
@mock.patch("json.loads") @mock.patch("json.loads")
@mock.patch("rally.benchmark.base.Scenario.clients") @mock.patch("rally.benchmark.base.Scenario.clients")
@mock.patch("rally.sshutils.SSH.execute_script") @mock.patch("rally.sshutils.SSH.execute")
def _verify_boot_runcommand_delete_server(self, mock_ssh_execute_script, @mock.patch("rally.sshutils.SSH.wait")
@mock.patch("rally.sshutils.SSH._get_pkey")
@mock.patch("rally.benchmark.scenarios.nova.servers.open", create=True)
def _verify_boot_runcommand_delete_server(self, mock_open, mock__get_pkey,
mock_wait, mock_execute,
mock_base_clients, mock_base_clients,
mock_json_loads): mock_json_loads):
mock_open.return_value = "fake_script"
fake_server = fakes.FakeServer() fake_server = fakes.FakeServer()
fake_server.addresses = dict( fake_server.addresses = dict(
private=[dict( private=[dict(
@ -40,12 +45,13 @@ class NovaServersTestCase(test.TestCase):
addr="1.2.3.4" addr="1.2.3.4"
)] )]
) )
scenario = servers.NovaServers() scenario = servers.NovaServers()
scenario._boot_server = mock.MagicMock(return_value=fake_server) scenario._boot_server = mock.MagicMock(return_value=fake_server)
scenario._generate_random_name = mock.MagicMock(return_value="name") scenario._generate_random_name = mock.MagicMock(return_value="name")
scenario._delete_server = mock.MagicMock() scenario._delete_server = mock.MagicMock()
mock_ssh_execute_script.return_value = ('stdout', 'stderr') mock_execute.return_value = (0, 'stdout', 'stderr')
mock_base_clients.return_value = dict(private='private-key-string') mock_base_clients.return_value = dict(private='private-key-string')
scenario.boot_runcommand_delete_server("img", 0, "script_path", scenario.boot_runcommand_delete_server("img", 0, "script_path",
@ -54,10 +60,8 @@ class NovaServersTestCase(test.TestCase):
scenario._boot_server.assert_called_once_with("name", "img", 0, scenario._boot_server.assert_called_once_with("name", "img", 0,
fakearg="f", fakearg="f",
key_name='rally_ssh_key') key_name='rally_ssh_key')
mock_ssh_execute_script.assert_called_once_with( mock_execute.assert_called_once_with("/bin/bash", stdin="fake_script")
script="script_path", interpreter="/bin/bash", get_stdout=True, mock_open.assert_called_once_with("script_path", "rb")
get_stderr=True)
mock_json_loads.assert_called_once_with('stdout') mock_json_loads.assert_called_once_with('stdout')
scenario._delete_server.assert_called_once_with(fake_server) scenario._delete_server.assert_called_once_with(fake_server)

View File

@ -54,12 +54,16 @@ class DevstackEngineTestCase(test.BaseTestCase):
def test_construct(self): def test_construct(self):
self.assertEqual(self.engine.localrc['ADMIN_PASSWORD'], 'secret') self.assertEqual(self.engine.localrc['ADMIN_PASSWORD'], 'secret')
def test_prepare_server(self): @mock.patch('rally.deploy.engines.devstack.open', create=True)
def test_prepare_server(self, m_open):
m_open.return_value = 'fake_file'
server = mock.Mock() server = mock.Mock()
self.engine.prepare_server(server) self.engine.prepare_server(server)
filename = server.ssh.execute_script.mock_calls[0][1][0] server.ssh.run.assert_called_once_with('/bin/sh -e', stdin='fake_file')
filename = m_open.mock_calls[0][1][0]
self.assertTrue(filename.endswith('rally/deploy/engines/' self.assertTrue(filename.endswith('rally/deploy/engines/'
'devstack/install.sh')) 'devstack/install.sh'))
self.assertEqual([mock.call(filename, 'rb')], m_open.mock_calls)
@mock.patch('rally.deploy.engines.devstack.open', create=True) @mock.patch('rally.deploy.engines.devstack.open', create=True)
@mock.patch('rally.serverprovider.provider.Server') @mock.patch('rally.serverprovider.provider.Server')
@ -87,34 +91,26 @@ class DevstackEngineTestCase(test.BaseTestCase):
'tenant_name': 'admin', 'tenant_name': 'admin',
}) })
@mock.patch('rally.deploy.engines.devstack.os') @mock.patch('rally.deploy.engines.devstack.StringIO.StringIO')
@mock.patch('rally.deploy.engines.devstack.tempfile') def test_configure_devstack(self, m_sio):
@mock.patch('rally.deploy.engines.devstack.open', create=True) m_sio.return_value = fake_localrc = mock.Mock()
def test_configure_devstack(self, m_open, m_tmpf, m_os):
m_tmpf.mkstemp.return_value = (42, 'tmpnam')
fake_file = mock.Mock()
m_open.return_value = fake_file
server = mock.Mock() server = mock.Mock()
self.engine.localrc = {'k1': 'v1', 'k2': 'v2'} self.engine.localrc = {'k1': 'v1', 'k2': 'v2'}
self.engine.configure_devstack(server) self.engine.configure_devstack(server)
calls = [ calls = [
mock.call.ssh.execute('git', 'clone', DEVSTACK_REPO), mock.call.ssh.run('git clone https://github.com/'
mock.call.ssh.upload('tmpnam', '~/devstack/localrc'), 'openstack-dev/devstack.git'),
mock.call.ssh.run('cat > ~/devstack/localrc', stdin=fake_localrc)
] ]
self.assertEqual(calls, server.mock_calls) self.assertEqual(calls, server.mock_calls)
fake_file.asser_has_calls([ fake_localrc.asser_has_calls([
mock.call.write('k1=v1\n'), mock.call.write('k1=v1\n'),
mock.call.write('k2=v2\n'), mock.call.write('k2=v2\n'),
]) ])
os_calls = [
mock.call.close(42),
mock.call.unlink('tmpnam'),
]
self.assertEqual(os_calls, m_os.mock_calls)
def test_start_devstack(self): def test_start_devstack(self):
server = mock.Mock() server = mock.Mock()
self.assertTrue(self.engine.start_devstack(server)) self.assertTrue(self.engine.start_devstack(server))
server.ssh.execute.assert_called_once_with('~/devstack/stack.sh') server.ssh.run.assert_called_once_with('~/devstack/stack.sh')

View File

@ -27,29 +27,25 @@ MOD_NAME = 'rally.serverprovider.providers.lxc.'
class HelperFunctionsTestCase(test.BaseTestCase): class HelperFunctionsTestCase(test.BaseTestCase):
def test__get_script_path(self): @mock.patch(MOD_NAME + 'open', create=True, return_value='fake_script')
full_path = lxc._get_script_path('script.sh') def test__get_script(self, m_open):
self.assertTrue(full_path.endswith('rally/serverprovider/' script = lxc._get_script('script.sh')
'providers/lxc/script.sh')) self.assertEqual('fake_script', script)
path = m_open.mock_calls[0][1][0]
mode = m_open.mock_calls[0][1][1]
self.assertTrue(path.endswith('rally/serverprovider/providers'
'/lxc/script.sh'))
self.assertEqual('rb', mode)
@mock.patch(MOD_NAME + '_get_script_path', return_value='fake_path') @mock.patch(MOD_NAME + '_get_script', return_value='fake_script')
@mock.patch(MOD_NAME + 'tempfile') @mock.patch(MOD_NAME + 'StringIO.StringIO')
@mock.patch(MOD_NAME + 'open', create=True) def test__get_script_from_template(self, m_sio, m_gs):
def test__write_script_from_template(self, m_open, m_tempfile, m_gsp): m_gs.return_value = fake_script = mock.Mock()
fake_tempfile = mock.Mock() fake_script.read.return_value = 'fake_data {k1} {k2}'
m_tempfile.NamedTemporaryFile.return_value = fake_tempfile m_sio.return_value = 'fake_formatted_script'
fake_file = mock.Mock() script = lxc._get_script_from_template('fake_tpl', k1='v1', k2='v2')
fake_data = mock.Mock() self.assertEqual('fake_formatted_script', script)
fake_data.format.return_value = 'fake_formatted_data' m_sio.assert_called_once_with('fake_data v1 v2')
fake_file.read.return_value = fake_data
m_open.return_value = fake_file
retval = lxc._write_script_from_template('script', key='value')
m_gsp.assert_called_once_with('script')
m_open.assert_called_once_with('fake_path')
m_tempfile.NamedTemporaryFile.assert_called_once_with(delete=False)
fake_data.format.assert_called_once_with(key='value')
fake_tempfile.write.assert_called_once_with('fake_formatted_data')
self.assertEqual(fake_tempfile.name, retval)
class LxcHostTestCase(test.BaseTestCase): class LxcHostTestCase(test.BaseTestCase):
@ -81,14 +77,12 @@ class LxcHostTestCase(test.BaseTestCase):
self.server.ssh.execute.side_effect = exceptions.SSHError() self.server.ssh.execute.side_effect = exceptions.SSHError()
self.assertEqual('dir', self.host.backingstore) self.assertEqual('dir', self.host.backingstore)
@mock.patch(MOD_NAME + '_get_script_path', return_value='fake_sp') @mock.patch(MOD_NAME + 'StringIO.StringIO')
@mock.patch(MOD_NAME + 'os.unlink') @mock.patch(MOD_NAME + '_get_script', return_value='fake_script')
@mock.patch(MOD_NAME + 'tempfile') def test_prepare(self, m_gs, m_sio):
def test_prepare(self, m_tempfile, m_unlink, m_gsp): m_sio.return_value = fake_conf = mock.Mock()
self.host.create_local_tunnels = mock.Mock() self.host.create_local_tunnels = mock.Mock()
self.host.create_remote_tunnels = mock.Mock() self.host.create_remote_tunnels = mock.Mock()
fake_tempfile = mock.Mock()
m_tempfile.NamedTemporaryFile.return_value = fake_tempfile
self.host.prepare() self.host.prepare()
@ -102,49 +96,38 @@ class LxcHostTestCase(test.BaseTestCase):
mock.call('USE_LXC_BRIDGE="true"\n') mock.call('USE_LXC_BRIDGE="true"\n')
] ]
for call in write_calls: for call in write_calls:
fake_tempfile.write.assert_has_calls(call) fake_conf.write.assert_has_calls(call)
self.server.ssh.upload.assert_called_once_with(fake_tempfile.name, ssh_calls = [mock.call.run('cat > /tmp/.lxc_default', stdin=fake_conf),
'/tmp/.lxc_default') mock.call.run('/bin/sh', stdin='fake_script')]
self.server.ssh.execute_script.assert_called_once_with('fake_sp') self.assertEqual(ssh_calls, self.server.ssh.mock_calls)
m_unlink.assert_called_once_with(fake_tempfile.name)
self.host.create_local_tunnels.assert_called_once() self.host.create_local_tunnels.assert_called_once()
self.host.create_remote_tunnels.assert_called_once() self.host.create_remote_tunnels.assert_called_once()
@mock.patch(MOD_NAME + 'os.unlink') @mock.patch(MOD_NAME + 'os.unlink')
@mock.patch(MOD_NAME + '_write_script_from_template') @mock.patch(MOD_NAME + '_get_script_from_template')
def test_create_local_tunnels(self, m_ws, m_unlink): def test_create_local_tunnels(self, m_gs, m_unlink):
m_ws.side_effect = ['1', '2'] m_gs.side_effect = ['s1', 's2']
self.host.create_local_tunnels() self.host.create_local_tunnels()
ws_calls = [ gs_calls = [
mock.call('tunnel-local.sh', local='fake_server_ip', mock.call('tunnel-local.sh', local='fake_server_ip',
net=netaddr.IPNetwork('10.1.1.0/24'), remote='1.1.1.1'), net=netaddr.IPNetwork('10.1.1.0/24'), remote='1.1.1.1'),
mock.call('tunnel-local.sh', local='fake_server_ip', mock.call('tunnel-local.sh', local='fake_server_ip',
net=netaddr.IPNetwork('10.1.1.0/24'), remote='2.2.2.2'), net=netaddr.IPNetwork('10.1.1.0/24'), remote='2.2.2.2'),
] ]
self.assertEqual(ws_calls, m_ws.mock_calls) self.assertEqual(gs_calls, m_gs.mock_calls)
self.assertEqual([mock.call('1'), mock.call('2')], self.assertEqual([mock.call('/bin/sh -e', stdin='s1'),
self.server.ssh.execute_script.mock_calls) mock.call('/bin/sh -e', stdin='s2')],
self.server.ssh.run.mock_calls)
@mock.patch(MOD_NAME + 'os.unlink') @mock.patch(MOD_NAME + '_get_script_from_template')
@mock.patch(MOD_NAME + '_write_script_from_template') def test_create_remote_tunnels(self, m_get_script):
def test_create_remote_tunnels(self, m_ws, m_unlink): m_get_script.side_effect = ['s1', 's2']
m_ws.side_effect = ['1', '2']
fake_server = mock.Mock() fake_server = mock.Mock()
self.host._get_server_with_ip = mock.Mock(return_value=fake_server) self.host._get_server_with_ip = mock.Mock(return_value=fake_server)
self.host.create_remote_tunnels() self.host.create_remote_tunnels()
self.assertEqual([mock.call('/bin/sh -e', stdin='s1'),
ws_calls = [ mock.call('/bin/sh -e', stdin='s2')],
mock.call('tunnel-remote.sh', local='1.1.1.1', fake_server.ssh.run.mock_calls)
net=netaddr.IPNetwork('10.1.1.0/24'),
remote='fake_server_ip'),
mock.call('tunnel-remote.sh', local='2.2.2.2',
net=netaddr.IPNetwork('10.1.1.0/24'),
remote='fake_server_ip'),
]
self.assertEqual(ws_calls, m_ws.mock_calls)
self.assertEqual([mock.call('1'), mock.call('2')],
fake_server.ssh.execute_script.mock_calls)
def test_delete_tunnels(self): def test_delete_tunnels(self):
s1 = mock.Mock() s1 = mock.Mock()
@ -162,64 +145,58 @@ class LxcHostTestCase(test.BaseTestCase):
def test_get_ip(self, m_sleep): def test_get_ip(self, m_sleep):
s1 = 'link/ether fe:54:00:d3:f5:98 brd ff:ff:ff:ff:ff:ff' s1 = 'link/ether fe:54:00:d3:f5:98 brd ff:ff:ff:ff:ff:ff'
s2 = s1 + '\n inet 10.20.0.1/24 scope global br1' s2 = s1 + '\n inet 10.20.0.1/24 scope global br1'
self.host.server.ssh.execute.side_effect = [(s1, ''), (s2, '')] self.host.server.ssh.execute.side_effect = [(0, s1, ''), (0, s2, '')]
ip = self.host.get_ip('name') ip = self.host.get_ip('name')
self.assertEqual('10.20.0.1', ip) self.assertEqual('10.20.0.1', ip)
self.assertEqual([mock.call('lxc-attach -n name ip addr list dev eth0', self.assertEqual([mock.call('lxc-attach -n name ip'
get_stdout=True)] * 2, ' addr list dev eth0')] * 2,
self.host.server.ssh.execute.mock_calls) self.host.server.ssh.execute.mock_calls)
def test_create_container(self): def test_create_container(self):
self.host.configure_container = mock.Mock() self.host.configure_container = mock.Mock()
self.host._backingstore = 'btrfs' self.host._backingstore = 'btrfs'
self.host.create_container('name', 'dist') self.host.create_container('name', 'dist')
self.server.ssh.execute.assert_called_once_with( self.server.ssh.run.assert_called_once_with(
'lxc-create', '-B', 'btrfs', '-n', 'name', '-t', 'dist') 'lxc-create -B btrfs -n name -t dist')
self.assertEqual(['name'], self.host.containers) self.assertEqual(['name'], self.host.containers)
self.host.configure_container.assert_called_once_with('name') self.host.configure_container.assert_called_once_with('name')
#check with no btrfs #check with no btrfs
self.host._backingstore = 'dir' self.host._backingstore = 'dir'
self.host.create_container('name', 'dist') self.host.create_container('name', 'dist')
self.assertEqual(mock.call('lxc-create', '-B', 'dir', '-n', self.assertEqual(mock.call('lxc-create -B dir -n name -t dist'),
'name', '-t', 'dist'), self.server.ssh.run.mock_calls[1])
self.server.ssh.execute.mock_calls[1])
def test_create_clone(self): def test_create_clone(self):
self.host._backingstore = 'btrfs' self.host._backingstore = 'btrfs'
self.host.configure_container = mock.Mock() self.host.configure_container = mock.Mock()
self.host.create_clone('name', 'src') self.host.create_clone('name', 'src')
self.server.ssh.execute.assert_called_once_with('lxc-clone', self.server.ssh.execute.assert_called_once_with('lxc-clone --snapshot'
'--snapshot', ' -o src -n name')
'-o', 'src',
'-n', 'name')
self.assertEqual(['name'], self.host.containers) self.assertEqual(['name'], self.host.containers)
#check with no btrfs #check with no btrfs
self.host._backingstore = 'dir' self.host._backingstore = 'dir'
self.host.create_clone('name', 'src') self.host.create_clone('name', 'src')
self.assertEqual(mock.call('lxc-clone', '-o', 'src', '-n', 'name'), self.assertEqual(mock.call('lxc-clone -o src -n name'),
self.server.ssh.execute.mock_calls[1]) self.server.ssh.execute.mock_calls[1])
@mock.patch(MOD_NAME + 'os.path.join') @mock.patch(MOD_NAME + 'os.path.join')
@mock.patch(MOD_NAME + '_get_script_path') @mock.patch(MOD_NAME + '_get_script')
def test_configure_container(self, m_gsp, m_join): def test_configure_container(self, m_gs, m_join):
m_gsp.return_value = 'fake_script' m_gs.return_value = 'fake_script'
m_join.return_value = 'fake_path' m_join.return_value = 'fake_path'
self.server.ssh.execute.return_value = 0, '', ''
self.host.configure_container('name') self.host.configure_container('name')
calls = [ self.server.ssh.run.assert_called_once_with(
mock.call.upload('fake_script', '/tmp/.rally_cont_conf.sh'), '/bin/sh -e -s fake_path', stdin='fake_script')
mock.call.execute('/bin/sh', '/tmp/.rally_cont_conf.sh',
'fake_path'),
]
self.assertEqual(calls, self.server.ssh.mock_calls)
def test_start_containers(self): def test_start_containers(self):
self.host.containers = ['c1', 'c2'] self.host.containers = ['c1', 'c2']
self.host.start_containers() self.host.start_containers()
calls = [mock.call('lxc-start -d -n c1'), calls = [mock.call('lxc-start -d -n c1'),
mock.call('lxc-start -d -n c2')] mock.call('lxc-start -d -n c2')]
self.assertEqual(calls, self.server.ssh.execute.mock_calls) self.assertEqual(calls, self.server.ssh.run.mock_calls)
def test_stop_containers(self): def test_stop_containers(self):
self.host.containers = ['c1', 'c2'] self.host.containers = ['c1', 'c2']
@ -228,7 +205,7 @@ class LxcHostTestCase(test.BaseTestCase):
mock.call('lxc-stop -n c1'), mock.call('lxc-stop -n c1'),
mock.call('lxc-stop -n c2'), mock.call('lxc-stop -n c2'),
] ]
self.assertEqual(calls, self.server.ssh.execute.mock_calls) self.assertEqual(calls, self.server.ssh.run.mock_calls)
def test_destroy_containers(self): def test_destroy_containers(self):
self.host.containers = ['c1', 'c2'] self.host.containers = ['c1', 'c2']
@ -237,7 +214,7 @@ class LxcHostTestCase(test.BaseTestCase):
mock.call('lxc-stop -n c1'), mock.call('lxc-destroy -n c1'), mock.call('lxc-stop -n c1'), mock.call('lxc-destroy -n c1'),
mock.call('lxc-stop -n c2'), mock.call('lxc-destroy -n c2'), mock.call('lxc-stop -n c2'), mock.call('lxc-destroy -n c2'),
] ]
self.assertEqual(calls, self.server.ssh.execute.mock_calls) self.assertEqual(calls, self.server.ssh.run.mock_calls)
@mock.patch(MOD_NAME + 'provider.Server.from_credentials') @mock.patch(MOD_NAME + 'provider.Server.from_credentials')
def test_get_server_object(self, m_fc): def test_get_server_object(self, m_fc):

View File

@ -14,101 +14,251 @@
# under the License. # under the License.
import mock import mock
import os
from rally import exceptions
from rally import sshutils from rally import sshutils
from tests import test from tests import test
class FakeParamikoException(Exception):
pass
class SSHTestCase(test.TestCase): class SSHTestCase(test.TestCase):
"""Test all small SSH methods."""
def setUp(self): def setUp(self):
super(SSHTestCase, self).setUp() super(SSHTestCase, self).setUp()
self.ssh = sshutils.SSH('example.net', 'root') self.ssh = sshutils.SSH('root', 'example.net')
self.channel = mock.Mock()
self.channel.recv.return_value = 'ok'
self.channel.recv_stderr.return_value = 'error'
self.channel.recv_exit_status.return_value = 0
self.transport = mock.Mock()
self.transport.open_session = mock.MagicMock(return_value=self.channel)
self.policy = mock.Mock()
self.client = mock.Mock()
self.client.get_transport = mock.MagicMock(return_value=self.transport)
self.channel.exit_status_ready.return_value = True @mock.patch('rally.sshutils.SSH._get_pkey')
self.channel.recv_ready.side_effect = [True, False, False] def test_construct(self, m_pkey):
self.channel.recv_stderr_ready.side_effect = [True, False, False] m_pkey.return_value = 'pkey'
ssh = sshutils.SSH('root', 'example.net', port=33, pkey='key',
key_filename='kf', password='secret')
m_pkey.assert_called_once_with('key')
self.assertEqual('root', ssh.user)
self.assertEqual('example.net', ssh.host)
self.assertEqual(33, ssh.port)
self.assertEqual('pkey', ssh.pkey)
self.assertEqual('kf', ssh.key_filename)
self.assertEqual('secret', ssh.password)
def test_construct_default(self):
self.assertEqual('root', self.ssh.user)
self.assertEqual('example.net', self.ssh.host)
self.assertEqual(22, self.ssh.port)
self.assertIsNone(self.ssh.pkey)
self.assertIsNone(self.ssh.key_filename)
self.assertIsNone(self.ssh.password)
@mock.patch('rally.sshutils.paramiko') @mock.patch('rally.sshutils.paramiko')
@mock.patch('rally.sshutils.select') def test__get_pkey_invalid(self, m_paramiko):
def test_generator(self, st, pk): m_paramiko.SSHException = FakeParamikoException
pk.SSHClient.return_value = self.client rsa = m_paramiko.rsakey.RSAKey
st.select.return_value = ([], [], []) dss = m_paramiko.dsskey.DSSKey
rsa.from_private_key.side_effect = m_paramiko.SSHException
chunks = list(self.ssh.execute_generator('ps ax')) dss.from_private_key.side_effect = m_paramiko.SSHException
self.assertEqual([(1, 'ok'), (2, 'error')], chunks) self.assertRaises(sshutils.SSHError, self.ssh._get_pkey, 'key')
@mock.patch('rally.sshutils.StringIO')
@mock.patch('rally.sshutils.paramiko') @mock.patch('rally.sshutils.paramiko')
@mock.patch('rally.sshutils.select') def test__get_pkey_dss(self, m_paramiko, m_stringio):
def test_execute(self, st, pk): m_paramiko.SSHException = FakeParamikoException
pk.SSHClient.return_value = self.client m_stringio.StringIO.return_value = 'string_key'
st.select.return_value = ([], [], []) m_paramiko.dsskey.DSSKey.from_private_key.return_value = 'dss_key'
stdout, stderr = self.ssh.execute('uname') rsa = m_paramiko.rsakey.RSAKey
rsa.from_private_key.side_effect = m_paramiko.SSHException
self.assertEqual('', stdout) key = self.ssh._get_pkey('key')
self.assertEqual('', stderr) dss_calls = m_paramiko.dsskey.DSSKey.from_private_key.mock_calls
expected = [mock.call.exec_command('uname'), self.assertEqual([mock.call('string_key')], dss_calls)
mock.call.recv_ready(), self.assertEqual(key, 'dss_key')
mock.call.recv(4096), m_stringio.StringIO.assert_called_once_with('key')
mock.call.recv_ready(),
mock.call.recv_stderr_ready(),
mock.call.recv_stderr(4096),
mock.call.recv_ready(),
mock.call.recv_stderr_ready(),
mock.call.exit_status_ready(),
mock.call.recv_exit_status()]
self.assertEqual(expected, self.channel.mock_calls)
@mock.patch('rally.sshutils.StringIO')
@mock.patch('rally.sshutils.paramiko') @mock.patch('rally.sshutils.paramiko')
def test_upload_file(self, pk): def test__get_pkey_rsa(self, m_paramiko, m_stringio):
pk.AutoAddPolicy.return_value = self.policy m_paramiko.SSHException = FakeParamikoException
self.ssh.upload('/tmp/s', '/tmp/d') m_stringio.StringIO.return_value = 'string_key'
m_paramiko.rsakey.RSAKey.from_private_key.return_value = 'rsa_key'
dss = m_paramiko.dsskey.DSSKey
dss.from_private_key.side_effect = m_paramiko.SSHException
key = self.ssh._get_pkey('key')
rsa_calls = m_paramiko.rsakey.RSAKey.from_private_key.mock_calls
self.assertEqual([mock.call('string_key')], rsa_calls)
self.assertEqual(key, 'rsa_key')
m_stringio.StringIO.assert_called_once_with('key')
expected = [mock.call.set_missing_host_key_policy(self.policy), @mock.patch('rally.sshutils.SSH._get_pkey')
mock.call.connect(hostname='example.net', username='root', @mock.patch('rally.sshutils.paramiko')
key_filename=os.path.expanduser( def test__get_client(self, m_paramiko, m_pkey):
'~/.ssh/id_rsa'), port=22), m_pkey.return_value = 'key'
mock.call.open_sftp(), fake_client = mock.Mock()
mock.call.open_sftp().put('/tmp/s', '/tmp/d'), m_paramiko.SSHClient.return_value = fake_client
mock.call.open_sftp().close()] m_paramiko.AutoAddPolicy.return_value = 'autoadd'
self.assertEqual(pk.SSHClient().mock_calls, expected) ssh = sshutils.SSH('admin', 'example.net', pkey='key')
client = ssh._get_client()
@mock.patch('rally.sshutils.SSH.execute') self.assertEqual(fake_client, client)
@mock.patch('rally.sshutils.SSH.upload') client_calls = [
@mock.patch('rally.sshutils.random.choice') mock.call.set_missing_host_key_policy('autoadd'),
def test_execute_script_new(self, rc, up, ex): mock.call.connect('example.net', username='admin',
rc.return_value = 'a' port=22, pkey='key', key_filename=None,
self.ssh.execute_script('/bin/script') password=None),
]
self.assertEqual(client_calls, client.mock_calls)
up.assert_called_once_with('/bin/script', '/tmp/aaaaaaaaaaaaaaaa') def test_close(self):
ex.assert_has_calls([ with mock.patch.object(self.ssh, '_client') as m_client:
mock.call('/bin/sh /tmp/aaaaaaaaaaaaaaaa', self.ssh.close()
get_stderr=False, get_stdout=False), m_client.close.assert_called_once()
mock.call('rm /tmp/aaaaaaaaaaaaaaaa') self.assertFalse(self.ssh._client)
])
@mock.patch('rally.sshutils.SSH.execute') @mock.patch('rally.sshutils.StringIO')
def test_wait(self, ex): def test_execute(self, m_stringio):
self.ssh.wait() m_stringio.StringIO.side_effect = stdio = [mock.Mock(), mock.Mock()]
stdio[0].read.return_value = 'stdout fake data'
stdio[1].read.return_value = 'stderr fake data'
with mock.patch.object(self.ssh, 'run', return_value=0) as m_run:
status, stdout, stderr = self.ssh.execute('cmd',
stdin='fake_stdin',
timeout=43)
m_run.assert_called_once_with('cmd', stdin='fake_stdin',
stdout=stdio[0],
stderr=stdio[1], timeout=43,
raise_on_error=False)
self.assertEqual(0, status)
self.assertEqual('stdout fake data', stdout)
self.assertEqual('stderr fake data', stderr)
@mock.patch('rally.sshutils.time') @mock.patch('rally.sshutils.time')
@mock.patch('rally.sshutils.SSH.execute') def test_wait_timeout(self, m_time):
def test_wait_timeout(self, ex, mock_time): m_time.time.side_effect = [1, 50, 150]
mock_time.time.side_effect = [1, 10] self.ssh.execute = mock.Mock(side_effect=[sshutils.SSHError,
ex.side_effect = exceptions.SSHError sshutils.SSHError,
self.assertRaises(exceptions.TimeoutException, 0])
self.ssh.wait, 1, 1) self.assertRaises(sshutils.SSHTimeout, self.ssh.wait)
mock_time.sleep.assert_called_once_with(1) self.assertEqual([mock.call('uname')] * 2, self.ssh.execute.mock_calls)
@mock.patch('rally.sshutils.time')
def test_wait(self, m_time):
m_time.time.side_effect = [1, 50, 100]
self.ssh.execute = mock.Mock(side_effect=[sshutils.SSHError,
sshutils.SSHError,
0])
self.ssh.wait()
self.assertEqual([mock.call('uname')] * 3, self.ssh.execute.mock_calls)
class SSHRunTestCase(test.TestCase):
"""Test SSH.run method in different aspects.
Also tested method 'execute'.
"""
def setUp(self):
super(SSHRunTestCase, self).setUp()
self.fake_client = mock.Mock()
self.fake_session = mock.Mock()
self.fake_transport = mock.Mock()
self.fake_transport.open_session.return_value = self.fake_session
self.fake_client.get_transport.return_value = self.fake_transport
self.fake_session.recv_ready.return_value = False
self.fake_session.recv_stderr_ready.return_value = False
self.fake_session.send_ready.return_value = False
self.fake_session.exit_status_ready.return_value = True
self.fake_session.recv_exit_status.return_value = 0
self.ssh = sshutils.SSH('admin', 'example.net')
self.ssh._get_client = mock.Mock(return_value=self.fake_client)
@mock.patch('rally.sshutils.select')
def test_execute(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_ready.side_effect = [1, 0, 0]
self.fake_session.recv_stderr_ready.side_effect = [1, 0]
self.fake_session.recv.return_value = 'ok'
self.fake_session.recv_stderr.return_value = 'error'
self.fake_session.exit_status_ready.return_value = 1
self.fake_session.recv_exit_status.return_value = 127
self.assertEqual((127, 'ok', 'error'), self.ssh.execute('cmd'))
self.fake_session.exec_command.assert_called_once_with('cmd')
@mock.patch('rally.sshutils.select')
def test_run(self, m_select):
m_select.select.return_value = ([], [], [])
self.assertEqual(0, self.ssh.run('cmd'))
@mock.patch('rally.sshutils.select')
def test_run_nonzero_status(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_exit_status.return_value = 1
self.assertRaises(sshutils.SSHError, self.ssh.run, 'cmd')
self.assertEqual(1, self.ssh.run('cmd', raise_on_error=False))
@mock.patch('rally.sshutils.select')
def test_run_stdout(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_ready.side_effect = [True, True, False]
self.fake_session.recv.side_effect = ['ok1', 'ok2']
stdout = mock.Mock()
self.ssh.run('cmd', stdout=stdout)
self.assertEqual([mock.call('ok1'), mock.call('ok2')],
stdout.write.mock_calls)
@mock.patch('rally.sshutils.select')
def test_run_stderr(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_stderr_ready.side_effect = [True, False]
self.fake_session.recv_stderr.return_value = 'error'
stderr = mock.Mock()
self.ssh.run('cmd', stderr=stderr)
stderr.write.assert_called_once_with('error')
@mock.patch('rally.sshutils.select')
def test_run_stdin(self, m_select):
"""Test run method with stdin.
Third send call was called with 'e2' because only 3 bytes was sent
by second call. So remainig 2 bytes of 'line2' was sent by third call.
"""
m_select.select.return_value = ([], [], [])
self.fake_session.exit_status_ready.side_effect = [0, 0, 0, True]
self.fake_session.send_ready.return_value = True
self.fake_session.send.side_effect = [5, 3, 2]
fake_stdin = mock.Mock()
fake_stdin.read.side_effect = ['line1', 'line2', '']
fake_stdin.closed = False
def close():
fake_stdin.closed = True
fake_stdin.close = mock.Mock(side_effect=close)
self.ssh.run('cmd', stdin=fake_stdin)
call = mock.call
send_calls = [call('line1'), call('line2'), call('e2')]
self.assertEqual(send_calls, self.fake_session.send.mock_calls)
@mock.patch('rally.sshutils.select')
def test_run_select_error(self, m_select):
self.fake_session.exit_status_ready.return_value = False
m_select.select.return_value = ([], [], [True])
self.assertRaises(sshutils.SSHError, self.ssh.run, 'cmd')
@mock.patch('rally.sshutils.time')
@mock.patch('rally.sshutils.select')
def test_run_timemout(self, m_select, m_time):
m_time.time.side_effect = [1, 3700]
m_select.select.return_value = ([], [], [])
self.fake_session.exit_status_ready.return_value = False
self.assertRaises(sshutils.SSHTimeout, self.ssh.run, 'cmd')
@mock.patch('rally.sshutils.select')
def test__run_client_closed_on_error(self, m_select):
m_select.select.return_value = ([], [], [])
self.fake_session.recv_ready.return_value = True
self.fake_session.recv.side_effect = IOError
self.assertRaises(IOError, self.ssh._run, self.fake_client, 'cmd')
self.fake_client.close.assert_called_once()