Sshutils refactoring
Paramiko sftp client is not eventlet friendly. Also it doesn't work with some servers, so upload/download methods has been removed. Downloading or uploading can be done in other way (see examples). Added ability to pass process stdin, which makes possible to do on remote host anything at all. Much more clean and usable API. Change-Id: If357a2878c2c60646a975c386a0fe2f1616aec95
This commit is contained in:
parent
3d420c286b
commit
936ba2efd8
@ -20,7 +20,6 @@ import random
|
||||
from rally.benchmark.scenarios.cinder import utils as cinder_utils
|
||||
from rally.benchmark.scenarios.nova import utils
|
||||
from rally.benchmark.scenarios import utils as scenario_utils
|
||||
from rally.benchmark import utils as benchmark_utils
|
||||
from rally.benchmark import validation
|
||||
from rally import exceptions as rally_exceptions
|
||||
from rally.openstack.common.gettextutils import _ # noqa
|
||||
@ -102,43 +101,23 @@ class NovaServers(utils.NovaScenario,
|
||||
)
|
||||
server_ip = [ip for ip in server.addresses[network] if
|
||||
ip['version'] == ip_version][0]['addr']
|
||||
ssh = sshutils.SSH(ip=server_ip, port=port, user=username,
|
||||
key=self.clients('ssh_key_pair')['private'],
|
||||
key_type='string')
|
||||
|
||||
for retry in range(retries):
|
||||
try:
|
||||
LOG.debug(_('Execute script on server attempt '
|
||||
'%(retry)i/%(retries)i') % dict(retry=retry,
|
||||
retries=retries))
|
||||
streams = list(ssh.execute_script(script=script,
|
||||
interpreter=interpreter,
|
||||
get_stdout=True,
|
||||
get_stderr=True))
|
||||
|
||||
#NOTE(hughsaunders): Decode JSON script output
|
||||
streams[sshutils.SSH.STDOUT_INDEX]\
|
||||
= json.loads(streams[sshutils.SSH.STDOUT_INDEX])
|
||||
break
|
||||
except (rally_exceptions.SSHError,
|
||||
rally_exceptions.TimeoutException, IOError) as e:
|
||||
LOG.debug(_('Error running script on instance via SSH. '
|
||||
'%(id)s/%(ip)s Attempt:%(retry)i, '
|
||||
'Error: %(error)s') % dict(
|
||||
id=server.id, ip=server_ip, retry=retry,
|
||||
error=benchmark_utils.format_exc(e)))
|
||||
self.sleep_between(5, 5)
|
||||
except ValueError:
|
||||
LOG.error(_('Script %(script)s did not output valid JSON. ')
|
||||
% dict(script=script))
|
||||
ssh = sshutils.SSH(username, server_ip, port=port,
|
||||
pkey=self.clients('ssh_key_pair')['private'])
|
||||
ssh.wait()
|
||||
code, out, err = ssh.execute(interpreter, stdin=open(script, 'rb'))
|
||||
if code:
|
||||
LOG.error(_('Error running script on instance via SSH. '
|
||||
'Error: %s') % err)
|
||||
try:
|
||||
out = json.loads(out)
|
||||
except ValueError:
|
||||
LOG.warning(_('Script %s did not output valid JSON. ') % script)
|
||||
|
||||
self._delete_server(server)
|
||||
LOG.debug(_('Output streams from in-instance script execution: '
|
||||
'stdout: %(stdout)s, stderr: $(stderr)s') % dict(
|
||||
stdout=str(streams[sshutils.SSH.STDOUT_INDEX]),
|
||||
stderr=str(streams[sshutils.SSH.STDERR_INDEX])))
|
||||
return dict(data=streams[sshutils.SSH.STDOUT_INDEX],
|
||||
errors=streams[sshutils.SSH.STDERR_INDEX])
|
||||
stdout=out, stderr=err))
|
||||
return {'data': out, 'errors': err}
|
||||
|
||||
@validation.add_validator(validation.flavor_exists("flavor_id"))
|
||||
@validation.add_validator(validation.image_exists("image_id"))
|
||||
|
@ -14,7 +14,7 @@
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import StringIO
|
||||
|
||||
from rally.deploy import engine
|
||||
from rally import objects
|
||||
@ -78,7 +78,7 @@ class DevstackEngine(engine.EngineFactory):
|
||||
def prepare_server(self, server):
|
||||
script_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'devstack', 'install.sh'))
|
||||
server.ssh.execute_script(script_path)
|
||||
server.ssh.run('/bin/sh -e', stdin=open(script_path, 'rb'))
|
||||
|
||||
@utils.log_deploy_wrapper(LOG.info, _("Deploy devstack"))
|
||||
def deploy(self):
|
||||
@ -103,18 +103,15 @@ class DevstackEngine(engine.EngineFactory):
|
||||
@utils.log_deploy_wrapper(LOG.info, _("Configure devstack"))
|
||||
def configure_devstack(self, server):
|
||||
devstack_repo = self.config.get('devstack_repo', DEVSTACK_REPO)
|
||||
server.ssh.execute('git', 'clone', devstack_repo)
|
||||
fd, config_path = tempfile.mkstemp()
|
||||
config_file = open(config_path, "w")
|
||||
server.ssh.run('git clone %s' % devstack_repo)
|
||||
localrc = StringIO.StringIO()
|
||||
for k, v in self.localrc.iteritems():
|
||||
config_file.write('%s=%s\n' % (k, v))
|
||||
config_file.close()
|
||||
os.close(fd)
|
||||
server.ssh.upload(config_path, "~/devstack/localrc")
|
||||
os.unlink(config_path)
|
||||
localrc.write('%s=%s\n' % (k, v))
|
||||
localrc.seek(0)
|
||||
server.ssh.run("cat > ~/devstack/localrc", stdin=localrc)
|
||||
return True
|
||||
|
||||
@utils.log_deploy_wrapper(LOG.info, _("Run devstack"))
|
||||
def start_devstack(self, server):
|
||||
server.ssh.execute('~/devstack/stack.sh')
|
||||
server.ssh.run('~/devstack/stack.sh')
|
||||
return True
|
||||
|
@ -16,6 +16,7 @@
|
||||
import abc
|
||||
import jsonschema
|
||||
|
||||
|
||||
from rally import exceptions
|
||||
from rally import sshutils
|
||||
from rally import utils
|
||||
@ -32,7 +33,8 @@ class Server(utils.ImmutableMixin):
|
||||
self.user = user
|
||||
self.key = key
|
||||
self.password = password
|
||||
self.ssh = sshutils.SSH(host, user, port, key)
|
||||
self.ssh = sshutils.SSH(user, host, key_filename=key, port=port,
|
||||
password=password)
|
||||
super(Server, self).__init__()
|
||||
|
||||
def get_credentials(self):
|
||||
|
@ -16,7 +16,7 @@
|
||||
import netaddr
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import StringIO
|
||||
import time
|
||||
|
||||
from rally import exceptions
|
||||
@ -30,17 +30,15 @@ LOG = logging.getLogger(__name__)
|
||||
INET_ADDR_RE = re.compile(r' *inet ((\d+\.){3}\d+)\/\d+ .*')
|
||||
|
||||
|
||||
def _get_script_path(filename):
|
||||
return os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
def _get_script(filename):
|
||||
path = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
'lxc', filename))
|
||||
return open(path, 'rb')
|
||||
|
||||
|
||||
def _write_script_from_template(template_filename, **kwargs):
|
||||
template = open(_get_script_path(template_filename)).read()
|
||||
new_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
new_file.write(template.format(**kwargs))
|
||||
new_file.close()
|
||||
return new_file.name
|
||||
def _get_script_from_template(template_filename, **kwargs):
|
||||
template = _get_script(template_filename).read()
|
||||
return StringIO.StringIO(template.format(**kwargs))
|
||||
|
||||
|
||||
class LxcHost(object):
|
||||
@ -85,39 +83,35 @@ class LxcHost(object):
|
||||
'LXC_DHCP_RANGE': dhcp_range,
|
||||
'LXC_DHCP_MAX': self.network.size - 3,
|
||||
}
|
||||
config = tempfile.NamedTemporaryFile(delete=False)
|
||||
config = StringIO.StringIO()
|
||||
for name, value in values.iteritems():
|
||||
config.write('%(name)s="%(value)s"\n' % {'name': name,
|
||||
'value': value})
|
||||
config.close()
|
||||
self.server.ssh.upload(config.name, '/tmp/.lxc_default')
|
||||
os.unlink(config.name)
|
||||
config.seek(0)
|
||||
self.server.ssh.run('cat > /tmp/.lxc_default', stdin=config)
|
||||
|
||||
script = _get_script_path('lxc-install.sh')
|
||||
self.server.ssh.execute_script(script)
|
||||
self.server.ssh.run('/bin/sh', stdin=_get_script('lxc-install.sh'))
|
||||
self.create_local_tunnels()
|
||||
self.create_remote_tunnels()
|
||||
|
||||
def create_local_tunnels(self):
|
||||
"""Create tunel on lxc host side."""
|
||||
for tunnel_to in self.config['tunnel_to']:
|
||||
script = _write_script_from_template('tunnel-local.sh',
|
||||
net=self.network,
|
||||
local=self.server.host,
|
||||
remote=tunnel_to)
|
||||
self.server.ssh.execute_script(script)
|
||||
os.unlink(script)
|
||||
script = _get_script_from_template('tunnel-local.sh',
|
||||
net=self.network,
|
||||
local=self.server.host,
|
||||
remote=tunnel_to)
|
||||
self.server.ssh.run('/bin/sh -e', stdin=script)
|
||||
|
||||
def create_remote_tunnels(self):
|
||||
"""Create tunel on remote side."""
|
||||
for tunnel_to in self.config['tunnel_to']:
|
||||
script = _write_script_from_template('tunnel-remote.sh',
|
||||
net=self.network,
|
||||
local=tunnel_to,
|
||||
remote=self.server.host)
|
||||
script = _get_script_from_template('tunnel-remote.sh',
|
||||
net=self.network,
|
||||
local=tunnel_to,
|
||||
remote=self.server.host)
|
||||
server = self._get_server_with_ip(tunnel_to)
|
||||
server.ssh.execute_script(script)
|
||||
os.unlink(script)
|
||||
server.ssh.run('/bin/sh -e', stdin=script)
|
||||
|
||||
def delete_tunnels(self):
|
||||
for tunnel_to in self.config['tunnel_to']:
|
||||
@ -130,7 +124,9 @@ class LxcHost(object):
|
||||
|
||||
cmd = 'lxc-attach -n %s ip addr list dev eth0' % name
|
||||
for attempt in range(1, 16):
|
||||
stdout = self.server.ssh.execute(cmd, get_stdout=True)[0]
|
||||
code, stdout = self.server.ssh.execute(cmd)[:2]
|
||||
if code:
|
||||
continue
|
||||
for line in stdout.splitlines():
|
||||
m = INET_ADDR_RE.match(line)
|
||||
if m:
|
||||
@ -140,9 +136,10 @@ class LxcHost(object):
|
||||
raise exceptions.TimeoutException(msg)
|
||||
|
||||
def create_container(self, name, distribution):
|
||||
self.server.ssh.execute('lxc-create', '-B', self.backingstore,
|
||||
'-n', name,
|
||||
'-t', distribution)
|
||||
args = {'backingstore': self.backingstore,
|
||||
'name': name, 'distribution': distribution}
|
||||
self.server.ssh.run('lxc-create -B %(backingstore)s -n %(name)s'
|
||||
' -t %(distribution)s' % args)
|
||||
self.configure_container(name)
|
||||
self.containers.append(name)
|
||||
|
||||
@ -152,28 +149,27 @@ class LxcHost(object):
|
||||
if self.backingstore == 'btrfs':
|
||||
cmd.append('--snapshot')
|
||||
cmd.extend(['-o', source, '-n', name])
|
||||
self.server.ssh.execute(*cmd)
|
||||
self.server.ssh.execute(' '.join(cmd))
|
||||
self.configure_container(name)
|
||||
self.containers.append(name)
|
||||
|
||||
def configure_container(self, name):
|
||||
path = os.path.join(self.path, name, 'rootfs')
|
||||
configure_script = _get_script_path('configure_container.sh')
|
||||
self.server.ssh.upload(configure_script, '/tmp/.rally_cont_conf.sh')
|
||||
self.server.ssh.execute('/bin/sh', '/tmp/.rally_cont_conf.sh', path)
|
||||
conf_script = _get_script('configure_container.sh')
|
||||
self.server.ssh.run('/bin/sh -e -s %s' % path, stdin=conf_script)
|
||||
|
||||
def start_containers(self):
|
||||
for name in self.containers:
|
||||
self.server.ssh.execute('lxc-start -d -n %s' % name)
|
||||
self.server.ssh.run('lxc-start -d -n %s' % name)
|
||||
|
||||
def stop_containers(self):
|
||||
for name in self.containers:
|
||||
self.server.ssh.execute('lxc-stop -n %s' % name)
|
||||
self.server.ssh.run('lxc-stop -n %s' % name)
|
||||
|
||||
def destroy_containers(self):
|
||||
for name in self.containers:
|
||||
self.server.ssh.execute('lxc-stop -n %s' % name)
|
||||
self.server.ssh.execute('lxc-destroy -n %s' % name)
|
||||
self.server.ssh.run('lxc-stop -n %s' % name)
|
||||
self.server.ssh.run('lxc-destroy -n %s' % name)
|
||||
|
||||
def get_server_object(self, name, wait=True):
|
||||
"""Create Server object for container."""
|
||||
@ -257,6 +253,7 @@ class LxcProvider(provider.ProviderFactory):
|
||||
host.prepare()
|
||||
ip = str(network.ip).replace('.', '-') if network else '0'
|
||||
first_name = '%s-000-%s' % (name_prefix, ip)
|
||||
|
||||
host.create_container(first_name, distribution)
|
||||
for i in range(1, self.config.get('containers_per_host', 1)):
|
||||
name = '%s-%03d-%s' % (name_prefix, i, ip)
|
||||
|
@ -1,3 +1,3 @@
|
||||
ip tun add t{net.ip} mode ipip local {local} remote {remote} || true
|
||||
ip tun add t{net.ip} mode ipip local {local} remote {remote}
|
||||
ip link set t{net.ip} up
|
||||
ip route add {net} dev t{net.ip} src {local} || true
|
||||
ip route add {net} dev t{net.ip} src {local}
|
||||
|
@ -13,160 +13,235 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
|
||||
"""High level ssh library.
|
||||
|
||||
Usage examples:
|
||||
|
||||
Execute command and get output:
|
||||
|
||||
ssh = sshclient.SSH('root', 'example.com', port=33)
|
||||
status, stdout, stderr = ssh.execute('ps ax')
|
||||
if status:
|
||||
raise Exception('Command failed with non-zero status.')
|
||||
print stdout.splitlines()
|
||||
|
||||
Execute command with huge output:
|
||||
|
||||
class PseudoFile(object):
|
||||
def write(chunk):
|
||||
if 'error' in chunk:
|
||||
email_admin(chunk)
|
||||
|
||||
ssh = sshclient.SSH('root', 'example.com')
|
||||
ssh.run('tail -f /var/log/syslog', stdout=PseudoFile(), timeout=False)
|
||||
|
||||
Execute local script on remote side:
|
||||
|
||||
ssh = sshclient.SSH('user', 'example.com')
|
||||
status, out, err = ssh.execute('/bin/sh -s arg1 arg2',
|
||||
stdin=open('~/myscript.sh', 'r'))
|
||||
|
||||
Upload file:
|
||||
|
||||
ssh = sshclient.SSH('user', 'example.com')
|
||||
ssh.run('cat > ~/upload/file.gz', stdin=open('/store/file.gz', 'rb'))
|
||||
|
||||
Eventlet:
|
||||
|
||||
eventlet.monkey_patch(select=True, time=True)
|
||||
or
|
||||
eventlet.monkey_patch()
|
||||
or
|
||||
sshclient = eventlet.import_patched("opentstack.common.sshclient")
|
||||
|
||||
"""
|
||||
|
||||
import paramiko
|
||||
import random
|
||||
import select
|
||||
import socket
|
||||
import string
|
||||
import StringIO
|
||||
import time
|
||||
|
||||
from rally import exceptions
|
||||
from rally.openstack.common.gettextutils import _
|
||||
from rally.openstack.common import log as logging
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SSHTimeout(SSHError):
|
||||
pass
|
||||
|
||||
|
||||
class SSH(object):
|
||||
"""SSH common functions."""
|
||||
STDOUT_INDEX = 0
|
||||
STDERR_INDEX = 1
|
||||
"""Represent ssh connection."""
|
||||
|
||||
def __init__(self, ip, user, port=22, key=None, key_type="file",
|
||||
timeout=1800):
|
||||
"""Initialize SSH client with ip, username and the default values.
|
||||
def __init__(self, user, host, port=22, pkey=None,
|
||||
key_filename=None, password=None):
|
||||
"""Initialize SSH client.
|
||||
|
||||
:param user: ssh username
|
||||
:param host: hostname or ip address of remote ssh server
|
||||
:param port: remote ssh port
|
||||
:param pkey: RSA or DSS private key string or file object
|
||||
:param key_filename: private key filename
|
||||
:param password: password
|
||||
|
||||
timeout - the timeout for execution of the command
|
||||
key - path to private key file, or string containing actual key
|
||||
key_type - "file" for key path, "string" for actual key
|
||||
"""
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
|
||||
self.user = user
|
||||
self.timeout = timeout
|
||||
self.client = None
|
||||
self.key = key
|
||||
self.key_type = key_type
|
||||
if not self.key:
|
||||
#Guess location of user's private key if no key is specified.
|
||||
self.key = os.path.expanduser('~/.ssh/id_rsa')
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.pkey = self._get_pkey(pkey) if pkey else None
|
||||
self.password = password
|
||||
self.key_filename = key_filename
|
||||
self._client = False
|
||||
|
||||
def _get_ssh_connection(self):
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
connect_params = {
|
||||
'hostname': self.ip,
|
||||
'port': self.port,
|
||||
'username': self.user
|
||||
}
|
||||
def _get_pkey(self, key):
|
||||
if isinstance(key, basestring):
|
||||
key = StringIO.StringIO(key)
|
||||
for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey):
|
||||
try:
|
||||
return key_class.from_private_key(key)
|
||||
except paramiko.SSHException:
|
||||
pass
|
||||
raise SSHError('Invalid pkey')
|
||||
|
||||
# NOTE(hughsaunders): Set correct paramiko parameter names for each
|
||||
# method of supplying a key.
|
||||
if self.key_type == 'file':
|
||||
connect_params['key_filename'] = self.key
|
||||
else:
|
||||
connect_params['pkey'] = paramiko.RSAKey(
|
||||
file_obj=StringIO.StringIO(self.key))
|
||||
def _get_client(self):
|
||||
if self._client:
|
||||
return self._client
|
||||
try:
|
||||
self._client = paramiko.SSHClient()
|
||||
self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
self._client.connect(self.host, username=self.user,
|
||||
port=self.port, pkey=self.pkey,
|
||||
key_filename=self.key_filename,
|
||||
password=self.password)
|
||||
return self._client
|
||||
except paramiko.SSHException as e:
|
||||
message = _("Paramiko exception %(exception_type)s was raised "
|
||||
"during connect. Exception value is: %(exception)r")
|
||||
raise SSHError(message % {'exception': e,
|
||||
'exception_type': type(e)})
|
||||
|
||||
self.client.connect(**connect_params)
|
||||
def close(self):
|
||||
self._client.close()
|
||||
self._client = False
|
||||
|
||||
def _is_timed_out(self, start_time, timeout=None):
|
||||
timeout = timeout if timeout else self.timeout
|
||||
return (time.time() - timeout) > start_time
|
||||
|
||||
def execute(self, *cmd, **kwargs):
|
||||
"""Execute the specified command on the server.
|
||||
|
||||
Return tuple (stdout, stderr).
|
||||
|
||||
:param *cmd: Command and arguments to be executed.
|
||||
:param get_stdout: Collect stdout data. Boolean.
|
||||
:param get_stderr: Collect stderr data. Boolean.
|
||||
def run(self, cmd, stdin=None, stdout=None, stderr=None,
|
||||
raise_on_error=True, timeout=3600):
|
||||
"""Execute specified command on the server.
|
||||
|
||||
:param cmd: Command to be executed.
|
||||
:param stdin: Open file or string to pass to stdin.
|
||||
:param stdout: Open file to connect to stdout.
|
||||
:param stderr: Open file to connect to stderr.
|
||||
:param raise_on_error: If False then exit code will be return. If True
|
||||
then exception will be raized if non-zero code.
|
||||
:param timeout: Timeout in seconds for command execution.
|
||||
Default 1 hour. No timeout if set to 0.
|
||||
"""
|
||||
get_stdout = kwargs.get("get_stdout", False)
|
||||
get_stderr = kwargs.get("get_stderr", False)
|
||||
stdout = ''
|
||||
stderr = ''
|
||||
for chunk in self.execute_generator(*cmd, get_stdout=get_stdout,
|
||||
get_stderr=get_stderr):
|
||||
if chunk[0] == 1:
|
||||
stdout += chunk[1]
|
||||
elif chunk[0] == 2:
|
||||
stderr += chunk[1]
|
||||
return (stdout, stderr)
|
||||
|
||||
def execute_generator(self, *cmd, **kwargs):
|
||||
"""Execute the specified command on the server.
|
||||
client = self._get_client()
|
||||
|
||||
Return generator. Stdout and stderr data can be collected by chunks.
|
||||
if isinstance(stdin, basestring):
|
||||
stdin = StringIO.StringIO(stdin)
|
||||
|
||||
:param *cmd: Command and arguments to be executed.
|
||||
:param get_stdout: Collect stdout data. Boolean.
|
||||
:param get_stderr: Collect stderr data. Boolean.
|
||||
return self._run(client, cmd, stdin=stdin, stdout=stdout,
|
||||
stderr=stderr, raise_on_error=raise_on_error,
|
||||
timeout=timeout)
|
||||
|
||||
"""
|
||||
get_stdout = kwargs.get("get_stdout", True)
|
||||
get_stderr = kwargs.get("get_stderr", True)
|
||||
self._get_ssh_connection()
|
||||
cmd = ' '.join(cmd)
|
||||
transport = self.client.get_transport()
|
||||
def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
|
||||
raise_on_error=True, timeout=3600):
|
||||
|
||||
transport = client.get_transport()
|
||||
session = transport.open_session()
|
||||
session.exec_command(cmd)
|
||||
start_time = time.time()
|
||||
|
||||
data_to_send = ''
|
||||
stderr_data = None
|
||||
|
||||
# If we have data to be sent to stdin then `select' should also
|
||||
# check for stdin availability.
|
||||
if stdin and not stdin.closed:
|
||||
writes = [session]
|
||||
else:
|
||||
writes = []
|
||||
|
||||
while True:
|
||||
errors = select.select([session], [], [], 4)[2]
|
||||
# Block until data can be read/write.
|
||||
r, w, e = select.select([session], writes, [session], 1)
|
||||
|
||||
if session.recv_ready():
|
||||
data = session.recv(4096)
|
||||
LOG.debug(data)
|
||||
if get_stdout:
|
||||
yield (1, data)
|
||||
LOG.debug(_('stdout: %r') % data)
|
||||
if stdout is not None:
|
||||
stdout.write(data)
|
||||
continue
|
||||
|
||||
if session.recv_stderr_ready():
|
||||
data = session.recv_stderr(4096)
|
||||
LOG.debug(data)
|
||||
if get_stderr:
|
||||
yield (2, data)
|
||||
stderr_data = session.recv_stderr(4096)
|
||||
LOG.debug(_('stderr: %r') % stderr_data)
|
||||
if stderr is not None:
|
||||
stderr.write(stderr_data)
|
||||
continue
|
||||
|
||||
if errors or session.exit_status_ready():
|
||||
if session.send_ready():
|
||||
if stdin is not None and not stdin.closed:
|
||||
if not data_to_send:
|
||||
data_to_send = stdin.read(4096)
|
||||
if not data_to_send:
|
||||
stdin.close()
|
||||
session.shutdown_write()
|
||||
writes = []
|
||||
continue
|
||||
sent_bytes = session.send(data_to_send)
|
||||
data_to_send = data_to_send[sent_bytes:]
|
||||
|
||||
if session.exit_status_ready():
|
||||
break
|
||||
|
||||
if self._is_timed_out(start_time):
|
||||
raise exceptions.TimeoutException('SSH Timeout')
|
||||
if timeout and (time.time() - timeout) > start_time:
|
||||
args = {'cmd': cmd, 'host': self.host}
|
||||
raise SSHTimeout(_('Timeout executing command '
|
||||
'"%(cmd)s" on host %(host)s') % args)
|
||||
if e:
|
||||
raise SSHError('Socket error.')
|
||||
|
||||
exit_status = session.recv_exit_status()
|
||||
if 0 != exit_status:
|
||||
raise exceptions.SSHError(
|
||||
'SSHExecCommandFailed with exit_status %s'
|
||||
% exit_status)
|
||||
self.client.close()
|
||||
if 0 != exit_status and raise_on_error:
|
||||
fmt = _('Command "%(cmd)s" failed with exit_status %(status)d.')
|
||||
details = fmt % {'cmd': cmd, 'status': exit_status}
|
||||
if stderr_data:
|
||||
details += _(' Last stderr data: "%s".') % stderr_data
|
||||
raise SSHError(details)
|
||||
return exit_status
|
||||
|
||||
def upload(self, source, destination):
|
||||
"""Upload the specified file to the server."""
|
||||
if destination.startswith('~'):
|
||||
destination = '/home/' + self.user + destination[1:]
|
||||
self._get_ssh_connection()
|
||||
ftp = self.client.open_sftp()
|
||||
ftp.put(os.path.expanduser(source), destination)
|
||||
ftp.close()
|
||||
def execute(self, cmd, stdin=None, timeout=3600):
|
||||
"""Execute the specified command on the server.
|
||||
|
||||
def execute_script(self, script, interpreter='/bin/sh',
|
||||
get_stdout=False, get_stderr=False):
|
||||
"""Execute the specified local script on the remote server."""
|
||||
destination = '/tmp/' + ''.join(
|
||||
random.choice(string.lowercase) for i in range(16))
|
||||
:param cmd: Command to be executed.
|
||||
:param stdin: Open file to be sent on process stdin.
|
||||
:param timeout: Timeout for execution of the command.
|
||||
|
||||
self.upload(script, destination)
|
||||
streams = self.execute('%s %s' % (interpreter, destination),
|
||||
get_stdout=get_stdout, get_stderr=get_stderr)
|
||||
self.execute('rm %s' % destination)
|
||||
return streams
|
||||
Return tuple (exit_status, stdout, stderr)
|
||||
|
||||
"""
|
||||
stdout = StringIO.StringIO()
|
||||
stderr = StringIO.StringIO()
|
||||
|
||||
exit_status = self.run(cmd, stderr=stderr,
|
||||
stdout=stdout, stdin=stdin,
|
||||
timeout=timeout, raise_on_error=False)
|
||||
stdout.seek(0)
|
||||
stderr.seek(0)
|
||||
return (exit_status, stdout.read(), stderr.read())
|
||||
|
||||
def wait(self, timeout=120, interval=1):
|
||||
"""Wait for the host will be available via ssh."""
|
||||
@ -174,10 +249,8 @@ class SSH(object):
|
||||
while True:
|
||||
try:
|
||||
return self.execute('uname')
|
||||
except (socket.error, exceptions.SSHError) as e:
|
||||
LOG.debug(
|
||||
_('Ssh is still unavailable. (Exception was: %s)') % e)
|
||||
except (socket.error, SSHError) as e:
|
||||
LOG.debug(_('Ssh is still unavailable: %r') % e)
|
||||
time.sleep(interval)
|
||||
if self._is_timed_out(start_time, timeout):
|
||||
raise exceptions.TimeoutException(
|
||||
_('SSH Timeout waiting for "%s"') % self.ip)
|
||||
if time.time() > (start_time + timeout):
|
||||
raise SSHTimeout(_('Timeout waiting for "%s"') % self.host)
|
||||
|
@ -29,10 +29,15 @@ class NovaServersTestCase(test.TestCase):
|
||||
|
||||
@mock.patch("json.loads")
|
||||
@mock.patch("rally.benchmark.base.Scenario.clients")
|
||||
@mock.patch("rally.sshutils.SSH.execute_script")
|
||||
def _verify_boot_runcommand_delete_server(self, mock_ssh_execute_script,
|
||||
@mock.patch("rally.sshutils.SSH.execute")
|
||||
@mock.patch("rally.sshutils.SSH.wait")
|
||||
@mock.patch("rally.sshutils.SSH._get_pkey")
|
||||
@mock.patch("rally.benchmark.scenarios.nova.servers.open", create=True)
|
||||
def _verify_boot_runcommand_delete_server(self, mock_open, mock__get_pkey,
|
||||
mock_wait, mock_execute,
|
||||
mock_base_clients,
|
||||
mock_json_loads):
|
||||
mock_open.return_value = "fake_script"
|
||||
fake_server = fakes.FakeServer()
|
||||
fake_server.addresses = dict(
|
||||
private=[dict(
|
||||
@ -40,12 +45,13 @@ class NovaServersTestCase(test.TestCase):
|
||||
addr="1.2.3.4"
|
||||
)]
|
||||
)
|
||||
|
||||
scenario = servers.NovaServers()
|
||||
|
||||
scenario._boot_server = mock.MagicMock(return_value=fake_server)
|
||||
scenario._generate_random_name = mock.MagicMock(return_value="name")
|
||||
scenario._delete_server = mock.MagicMock()
|
||||
mock_ssh_execute_script.return_value = ('stdout', 'stderr')
|
||||
mock_execute.return_value = (0, 'stdout', 'stderr')
|
||||
mock_base_clients.return_value = dict(private='private-key-string')
|
||||
|
||||
scenario.boot_runcommand_delete_server("img", 0, "script_path",
|
||||
@ -54,10 +60,8 @@ class NovaServersTestCase(test.TestCase):
|
||||
scenario._boot_server.assert_called_once_with("name", "img", 0,
|
||||
fakearg="f",
|
||||
key_name='rally_ssh_key')
|
||||
mock_ssh_execute_script.assert_called_once_with(
|
||||
script="script_path", interpreter="/bin/bash", get_stdout=True,
|
||||
get_stderr=True)
|
||||
|
||||
mock_execute.assert_called_once_with("/bin/bash", stdin="fake_script")
|
||||
mock_open.assert_called_once_with("script_path", "rb")
|
||||
mock_json_loads.assert_called_once_with('stdout')
|
||||
scenario._delete_server.assert_called_once_with(fake_server)
|
||||
|
||||
|
@ -54,12 +54,16 @@ class DevstackEngineTestCase(test.BaseTestCase):
|
||||
def test_construct(self):
|
||||
self.assertEqual(self.engine.localrc['ADMIN_PASSWORD'], 'secret')
|
||||
|
||||
def test_prepare_server(self):
|
||||
@mock.patch('rally.deploy.engines.devstack.open', create=True)
|
||||
def test_prepare_server(self, m_open):
|
||||
m_open.return_value = 'fake_file'
|
||||
server = mock.Mock()
|
||||
self.engine.prepare_server(server)
|
||||
filename = server.ssh.execute_script.mock_calls[0][1][0]
|
||||
server.ssh.run.assert_called_once_with('/bin/sh -e', stdin='fake_file')
|
||||
filename = m_open.mock_calls[0][1][0]
|
||||
self.assertTrue(filename.endswith('rally/deploy/engines/'
|
||||
'devstack/install.sh'))
|
||||
self.assertEqual([mock.call(filename, 'rb')], m_open.mock_calls)
|
||||
|
||||
@mock.patch('rally.deploy.engines.devstack.open', create=True)
|
||||
@mock.patch('rally.serverprovider.provider.Server')
|
||||
@ -87,34 +91,26 @@ class DevstackEngineTestCase(test.BaseTestCase):
|
||||
'tenant_name': 'admin',
|
||||
})
|
||||
|
||||
@mock.patch('rally.deploy.engines.devstack.os')
|
||||
@mock.patch('rally.deploy.engines.devstack.tempfile')
|
||||
@mock.patch('rally.deploy.engines.devstack.open', create=True)
|
||||
def test_configure_devstack(self, m_open, m_tmpf, m_os):
|
||||
m_tmpf.mkstemp.return_value = (42, 'tmpnam')
|
||||
fake_file = mock.Mock()
|
||||
m_open.return_value = fake_file
|
||||
@mock.patch('rally.deploy.engines.devstack.StringIO.StringIO')
|
||||
def test_configure_devstack(self, m_sio):
|
||||
m_sio.return_value = fake_localrc = mock.Mock()
|
||||
server = mock.Mock()
|
||||
self.engine.localrc = {'k1': 'v1', 'k2': 'v2'}
|
||||
|
||||
self.engine.configure_devstack(server)
|
||||
|
||||
calls = [
|
||||
mock.call.ssh.execute('git', 'clone', DEVSTACK_REPO),
|
||||
mock.call.ssh.upload('tmpnam', '~/devstack/localrc'),
|
||||
mock.call.ssh.run('git clone https://github.com/'
|
||||
'openstack-dev/devstack.git'),
|
||||
mock.call.ssh.run('cat > ~/devstack/localrc', stdin=fake_localrc)
|
||||
]
|
||||
self.assertEqual(calls, server.mock_calls)
|
||||
fake_file.asser_has_calls([
|
||||
fake_localrc.asser_has_calls([
|
||||
mock.call.write('k1=v1\n'),
|
||||
mock.call.write('k2=v2\n'),
|
||||
])
|
||||
os_calls = [
|
||||
mock.call.close(42),
|
||||
mock.call.unlink('tmpnam'),
|
||||
]
|
||||
self.assertEqual(os_calls, m_os.mock_calls)
|
||||
|
||||
def test_start_devstack(self):
|
||||
server = mock.Mock()
|
||||
self.assertTrue(self.engine.start_devstack(server))
|
||||
server.ssh.execute.assert_called_once_with('~/devstack/stack.sh')
|
||||
server.ssh.run.assert_called_once_with('~/devstack/stack.sh')
|
||||
|
@ -27,29 +27,25 @@ MOD_NAME = 'rally.serverprovider.providers.lxc.'
|
||||
|
||||
class HelperFunctionsTestCase(test.BaseTestCase):
|
||||
|
||||
def test__get_script_path(self):
|
||||
full_path = lxc._get_script_path('script.sh')
|
||||
self.assertTrue(full_path.endswith('rally/serverprovider/'
|
||||
'providers/lxc/script.sh'))
|
||||
@mock.patch(MOD_NAME + 'open', create=True, return_value='fake_script')
|
||||
def test__get_script(self, m_open):
|
||||
script = lxc._get_script('script.sh')
|
||||
self.assertEqual('fake_script', script)
|
||||
path = m_open.mock_calls[0][1][0]
|
||||
mode = m_open.mock_calls[0][1][1]
|
||||
self.assertTrue(path.endswith('rally/serverprovider/providers'
|
||||
'/lxc/script.sh'))
|
||||
self.assertEqual('rb', mode)
|
||||
|
||||
@mock.patch(MOD_NAME + '_get_script_path', return_value='fake_path')
|
||||
@mock.patch(MOD_NAME + 'tempfile')
|
||||
@mock.patch(MOD_NAME + 'open', create=True)
|
||||
def test__write_script_from_template(self, m_open, m_tempfile, m_gsp):
|
||||
fake_tempfile = mock.Mock()
|
||||
m_tempfile.NamedTemporaryFile.return_value = fake_tempfile
|
||||
fake_file = mock.Mock()
|
||||
fake_data = mock.Mock()
|
||||
fake_data.format.return_value = 'fake_formatted_data'
|
||||
fake_file.read.return_value = fake_data
|
||||
m_open.return_value = fake_file
|
||||
retval = lxc._write_script_from_template('script', key='value')
|
||||
m_gsp.assert_called_once_with('script')
|
||||
m_open.assert_called_once_with('fake_path')
|
||||
m_tempfile.NamedTemporaryFile.assert_called_once_with(delete=False)
|
||||
fake_data.format.assert_called_once_with(key='value')
|
||||
fake_tempfile.write.assert_called_once_with('fake_formatted_data')
|
||||
self.assertEqual(fake_tempfile.name, retval)
|
||||
@mock.patch(MOD_NAME + '_get_script', return_value='fake_script')
|
||||
@mock.patch(MOD_NAME + 'StringIO.StringIO')
|
||||
def test__get_script_from_template(self, m_sio, m_gs):
|
||||
m_gs.return_value = fake_script = mock.Mock()
|
||||
fake_script.read.return_value = 'fake_data {k1} {k2}'
|
||||
m_sio.return_value = 'fake_formatted_script'
|
||||
script = lxc._get_script_from_template('fake_tpl', k1='v1', k2='v2')
|
||||
self.assertEqual('fake_formatted_script', script)
|
||||
m_sio.assert_called_once_with('fake_data v1 v2')
|
||||
|
||||
|
||||
class LxcHostTestCase(test.BaseTestCase):
|
||||
@ -81,14 +77,12 @@ class LxcHostTestCase(test.BaseTestCase):
|
||||
self.server.ssh.execute.side_effect = exceptions.SSHError()
|
||||
self.assertEqual('dir', self.host.backingstore)
|
||||
|
||||
@mock.patch(MOD_NAME + '_get_script_path', return_value='fake_sp')
|
||||
@mock.patch(MOD_NAME + 'os.unlink')
|
||||
@mock.patch(MOD_NAME + 'tempfile')
|
||||
def test_prepare(self, m_tempfile, m_unlink, m_gsp):
|
||||
@mock.patch(MOD_NAME + 'StringIO.StringIO')
|
||||
@mock.patch(MOD_NAME + '_get_script', return_value='fake_script')
|
||||
def test_prepare(self, m_gs, m_sio):
|
||||
m_sio.return_value = fake_conf = mock.Mock()
|
||||
self.host.create_local_tunnels = mock.Mock()
|
||||
self.host.create_remote_tunnels = mock.Mock()
|
||||
fake_tempfile = mock.Mock()
|
||||
m_tempfile.NamedTemporaryFile.return_value = fake_tempfile
|
||||
|
||||
self.host.prepare()
|
||||
|
||||
@ -102,49 +96,38 @@ class LxcHostTestCase(test.BaseTestCase):
|
||||
mock.call('USE_LXC_BRIDGE="true"\n')
|
||||
]
|
||||
for call in write_calls:
|
||||
fake_tempfile.write.assert_has_calls(call)
|
||||
self.server.ssh.upload.assert_called_once_with(fake_tempfile.name,
|
||||
'/tmp/.lxc_default')
|
||||
self.server.ssh.execute_script.assert_called_once_with('fake_sp')
|
||||
m_unlink.assert_called_once_with(fake_tempfile.name)
|
||||
fake_conf.write.assert_has_calls(call)
|
||||
ssh_calls = [mock.call.run('cat > /tmp/.lxc_default', stdin=fake_conf),
|
||||
mock.call.run('/bin/sh', stdin='fake_script')]
|
||||
self.assertEqual(ssh_calls, self.server.ssh.mock_calls)
|
||||
self.host.create_local_tunnels.assert_called_once()
|
||||
self.host.create_remote_tunnels.assert_called_once()
|
||||
|
||||
@mock.patch(MOD_NAME + 'os.unlink')
|
||||
@mock.patch(MOD_NAME + '_write_script_from_template')
|
||||
def test_create_local_tunnels(self, m_ws, m_unlink):
|
||||
m_ws.side_effect = ['1', '2']
|
||||
@mock.patch(MOD_NAME + '_get_script_from_template')
|
||||
def test_create_local_tunnels(self, m_gs, m_unlink):
|
||||
m_gs.side_effect = ['s1', 's2']
|
||||
self.host.create_local_tunnels()
|
||||
ws_calls = [
|
||||
gs_calls = [
|
||||
mock.call('tunnel-local.sh', local='fake_server_ip',
|
||||
net=netaddr.IPNetwork('10.1.1.0/24'), remote='1.1.1.1'),
|
||||
mock.call('tunnel-local.sh', local='fake_server_ip',
|
||||
net=netaddr.IPNetwork('10.1.1.0/24'), remote='2.2.2.2'),
|
||||
]
|
||||
self.assertEqual(ws_calls, m_ws.mock_calls)
|
||||
self.assertEqual([mock.call('1'), mock.call('2')],
|
||||
self.server.ssh.execute_script.mock_calls)
|
||||
self.assertEqual(gs_calls, m_gs.mock_calls)
|
||||
self.assertEqual([mock.call('/bin/sh -e', stdin='s1'),
|
||||
mock.call('/bin/sh -e', stdin='s2')],
|
||||
self.server.ssh.run.mock_calls)
|
||||
|
||||
@mock.patch(MOD_NAME + 'os.unlink')
|
||||
@mock.patch(MOD_NAME + '_write_script_from_template')
|
||||
def test_create_remote_tunnels(self, m_ws, m_unlink):
|
||||
m_ws.side_effect = ['1', '2']
|
||||
@mock.patch(MOD_NAME + '_get_script_from_template')
|
||||
def test_create_remote_tunnels(self, m_get_script):
|
||||
m_get_script.side_effect = ['s1', 's2']
|
||||
fake_server = mock.Mock()
|
||||
self.host._get_server_with_ip = mock.Mock(return_value=fake_server)
|
||||
|
||||
self.host.create_remote_tunnels()
|
||||
|
||||
ws_calls = [
|
||||
mock.call('tunnel-remote.sh', local='1.1.1.1',
|
||||
net=netaddr.IPNetwork('10.1.1.0/24'),
|
||||
remote='fake_server_ip'),
|
||||
mock.call('tunnel-remote.sh', local='2.2.2.2',
|
||||
net=netaddr.IPNetwork('10.1.1.0/24'),
|
||||
remote='fake_server_ip'),
|
||||
]
|
||||
self.assertEqual(ws_calls, m_ws.mock_calls)
|
||||
self.assertEqual([mock.call('1'), mock.call('2')],
|
||||
fake_server.ssh.execute_script.mock_calls)
|
||||
self.assertEqual([mock.call('/bin/sh -e', stdin='s1'),
|
||||
mock.call('/bin/sh -e', stdin='s2')],
|
||||
fake_server.ssh.run.mock_calls)
|
||||
|
||||
def test_delete_tunnels(self):
|
||||
s1 = mock.Mock()
|
||||
@ -162,64 +145,58 @@ class LxcHostTestCase(test.BaseTestCase):
|
||||
def test_get_ip(self, m_sleep):
|
||||
s1 = 'link/ether fe:54:00:d3:f5:98 brd ff:ff:ff:ff:ff:ff'
|
||||
s2 = s1 + '\n inet 10.20.0.1/24 scope global br1'
|
||||
self.host.server.ssh.execute.side_effect = [(s1, ''), (s2, '')]
|
||||
self.host.server.ssh.execute.side_effect = [(0, s1, ''), (0, s2, '')]
|
||||
ip = self.host.get_ip('name')
|
||||
self.assertEqual('10.20.0.1', ip)
|
||||
self.assertEqual([mock.call('lxc-attach -n name ip addr list dev eth0',
|
||||
get_stdout=True)] * 2,
|
||||
self.assertEqual([mock.call('lxc-attach -n name ip'
|
||||
' addr list dev eth0')] * 2,
|
||||
self.host.server.ssh.execute.mock_calls)
|
||||
|
||||
def test_create_container(self):
|
||||
self.host.configure_container = mock.Mock()
|
||||
self.host._backingstore = 'btrfs'
|
||||
self.host.create_container('name', 'dist')
|
||||
self.server.ssh.execute.assert_called_once_with(
|
||||
'lxc-create', '-B', 'btrfs', '-n', 'name', '-t', 'dist')
|
||||
self.server.ssh.run.assert_called_once_with(
|
||||
'lxc-create -B btrfs -n name -t dist')
|
||||
self.assertEqual(['name'], self.host.containers)
|
||||
self.host.configure_container.assert_called_once_with('name')
|
||||
|
||||
#check with no btrfs
|
||||
self.host._backingstore = 'dir'
|
||||
self.host.create_container('name', 'dist')
|
||||
self.assertEqual(mock.call('lxc-create', '-B', 'dir', '-n',
|
||||
'name', '-t', 'dist'),
|
||||
self.server.ssh.execute.mock_calls[1])
|
||||
self.assertEqual(mock.call('lxc-create -B dir -n name -t dist'),
|
||||
self.server.ssh.run.mock_calls[1])
|
||||
|
||||
def test_create_clone(self):
|
||||
self.host._backingstore = 'btrfs'
|
||||
self.host.configure_container = mock.Mock()
|
||||
self.host.create_clone('name', 'src')
|
||||
self.server.ssh.execute.assert_called_once_with('lxc-clone',
|
||||
'--snapshot',
|
||||
'-o', 'src',
|
||||
'-n', 'name')
|
||||
self.server.ssh.execute.assert_called_once_with('lxc-clone --snapshot'
|
||||
' -o src -n name')
|
||||
self.assertEqual(['name'], self.host.containers)
|
||||
|
||||
#check with no btrfs
|
||||
self.host._backingstore = 'dir'
|
||||
self.host.create_clone('name', 'src')
|
||||
self.assertEqual(mock.call('lxc-clone', '-o', 'src', '-n', 'name'),
|
||||
self.assertEqual(mock.call('lxc-clone -o src -n name'),
|
||||
self.server.ssh.execute.mock_calls[1])
|
||||
|
||||
@mock.patch(MOD_NAME + 'os.path.join')
|
||||
@mock.patch(MOD_NAME + '_get_script_path')
|
||||
def test_configure_container(self, m_gsp, m_join):
|
||||
m_gsp.return_value = 'fake_script'
|
||||
@mock.patch(MOD_NAME + '_get_script')
|
||||
def test_configure_container(self, m_gs, m_join):
|
||||
m_gs.return_value = 'fake_script'
|
||||
m_join.return_value = 'fake_path'
|
||||
self.server.ssh.execute.return_value = 0, '', ''
|
||||
self.host.configure_container('name')
|
||||
calls = [
|
||||
mock.call.upload('fake_script', '/tmp/.rally_cont_conf.sh'),
|
||||
mock.call.execute('/bin/sh', '/tmp/.rally_cont_conf.sh',
|
||||
'fake_path'),
|
||||
]
|
||||
self.assertEqual(calls, self.server.ssh.mock_calls)
|
||||
self.server.ssh.run.assert_called_once_with(
|
||||
'/bin/sh -e -s fake_path', stdin='fake_script')
|
||||
|
||||
def test_start_containers(self):
|
||||
self.host.containers = ['c1', 'c2']
|
||||
self.host.start_containers()
|
||||
calls = [mock.call('lxc-start -d -n c1'),
|
||||
mock.call('lxc-start -d -n c2')]
|
||||
self.assertEqual(calls, self.server.ssh.execute.mock_calls)
|
||||
self.assertEqual(calls, self.server.ssh.run.mock_calls)
|
||||
|
||||
def test_stop_containers(self):
|
||||
self.host.containers = ['c1', 'c2']
|
||||
@ -228,7 +205,7 @@ class LxcHostTestCase(test.BaseTestCase):
|
||||
mock.call('lxc-stop -n c1'),
|
||||
mock.call('lxc-stop -n c2'),
|
||||
]
|
||||
self.assertEqual(calls, self.server.ssh.execute.mock_calls)
|
||||
self.assertEqual(calls, self.server.ssh.run.mock_calls)
|
||||
|
||||
def test_destroy_containers(self):
|
||||
self.host.containers = ['c1', 'c2']
|
||||
@ -237,7 +214,7 @@ class LxcHostTestCase(test.BaseTestCase):
|
||||
mock.call('lxc-stop -n c1'), mock.call('lxc-destroy -n c1'),
|
||||
mock.call('lxc-stop -n c2'), mock.call('lxc-destroy -n c2'),
|
||||
]
|
||||
self.assertEqual(calls, self.server.ssh.execute.mock_calls)
|
||||
self.assertEqual(calls, self.server.ssh.run.mock_calls)
|
||||
|
||||
@mock.patch(MOD_NAME + 'provider.Server.from_credentials')
|
||||
def test_get_server_object(self, m_fc):
|
||||
|
@ -14,101 +14,251 @@
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
import os
|
||||
|
||||
from rally import exceptions
|
||||
from rally import sshutils
|
||||
from tests import test
|
||||
|
||||
|
||||
class FakeParamikoException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SSHTestCase(test.TestCase):
|
||||
"""Test all small SSH methods."""
|
||||
|
||||
def setUp(self):
|
||||
super(SSHTestCase, self).setUp()
|
||||
self.ssh = sshutils.SSH('example.net', 'root')
|
||||
self.channel = mock.Mock()
|
||||
self.channel.recv.return_value = 'ok'
|
||||
self.channel.recv_stderr.return_value = 'error'
|
||||
self.channel.recv_exit_status.return_value = 0
|
||||
self.transport = mock.Mock()
|
||||
self.transport.open_session = mock.MagicMock(return_value=self.channel)
|
||||
self.policy = mock.Mock()
|
||||
self.client = mock.Mock()
|
||||
self.client.get_transport = mock.MagicMock(return_value=self.transport)
|
||||
self.ssh = sshutils.SSH('root', 'example.net')
|
||||
|
||||
self.channel.exit_status_ready.return_value = True
|
||||
self.channel.recv_ready.side_effect = [True, False, False]
|
||||
self.channel.recv_stderr_ready.side_effect = [True, False, False]
|
||||
@mock.patch('rally.sshutils.SSH._get_pkey')
|
||||
def test_construct(self, m_pkey):
|
||||
m_pkey.return_value = 'pkey'
|
||||
ssh = sshutils.SSH('root', 'example.net', port=33, pkey='key',
|
||||
key_filename='kf', password='secret')
|
||||
m_pkey.assert_called_once_with('key')
|
||||
self.assertEqual('root', ssh.user)
|
||||
self.assertEqual('example.net', ssh.host)
|
||||
self.assertEqual(33, ssh.port)
|
||||
self.assertEqual('pkey', ssh.pkey)
|
||||
self.assertEqual('kf', ssh.key_filename)
|
||||
self.assertEqual('secret', ssh.password)
|
||||
|
||||
def test_construct_default(self):
|
||||
self.assertEqual('root', self.ssh.user)
|
||||
self.assertEqual('example.net', self.ssh.host)
|
||||
self.assertEqual(22, self.ssh.port)
|
||||
self.assertIsNone(self.ssh.pkey)
|
||||
self.assertIsNone(self.ssh.key_filename)
|
||||
self.assertIsNone(self.ssh.password)
|
||||
|
||||
@mock.patch('rally.sshutils.paramiko')
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_generator(self, st, pk):
|
||||
pk.SSHClient.return_value = self.client
|
||||
st.select.return_value = ([], [], [])
|
||||
|
||||
chunks = list(self.ssh.execute_generator('ps ax'))
|
||||
self.assertEqual([(1, 'ok'), (2, 'error')], chunks)
|
||||
def test__get_pkey_invalid(self, m_paramiko):
|
||||
m_paramiko.SSHException = FakeParamikoException
|
||||
rsa = m_paramiko.rsakey.RSAKey
|
||||
dss = m_paramiko.dsskey.DSSKey
|
||||
rsa.from_private_key.side_effect = m_paramiko.SSHException
|
||||
dss.from_private_key.side_effect = m_paramiko.SSHException
|
||||
self.assertRaises(sshutils.SSHError, self.ssh._get_pkey, 'key')
|
||||
|
||||
@mock.patch('rally.sshutils.StringIO')
|
||||
@mock.patch('rally.sshutils.paramiko')
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_execute(self, st, pk):
|
||||
pk.SSHClient.return_value = self.client
|
||||
st.select.return_value = ([], [], [])
|
||||
stdout, stderr = self.ssh.execute('uname')
|
||||
|
||||
self.assertEqual('', stdout)
|
||||
self.assertEqual('', stderr)
|
||||
expected = [mock.call.exec_command('uname'),
|
||||
mock.call.recv_ready(),
|
||||
mock.call.recv(4096),
|
||||
mock.call.recv_ready(),
|
||||
mock.call.recv_stderr_ready(),
|
||||
mock.call.recv_stderr(4096),
|
||||
mock.call.recv_ready(),
|
||||
mock.call.recv_stderr_ready(),
|
||||
mock.call.exit_status_ready(),
|
||||
mock.call.recv_exit_status()]
|
||||
|
||||
self.assertEqual(expected, self.channel.mock_calls)
|
||||
def test__get_pkey_dss(self, m_paramiko, m_stringio):
|
||||
m_paramiko.SSHException = FakeParamikoException
|
||||
m_stringio.StringIO.return_value = 'string_key'
|
||||
m_paramiko.dsskey.DSSKey.from_private_key.return_value = 'dss_key'
|
||||
rsa = m_paramiko.rsakey.RSAKey
|
||||
rsa.from_private_key.side_effect = m_paramiko.SSHException
|
||||
key = self.ssh._get_pkey('key')
|
||||
dss_calls = m_paramiko.dsskey.DSSKey.from_private_key.mock_calls
|
||||
self.assertEqual([mock.call('string_key')], dss_calls)
|
||||
self.assertEqual(key, 'dss_key')
|
||||
m_stringio.StringIO.assert_called_once_with('key')
|
||||
|
||||
@mock.patch('rally.sshutils.StringIO')
|
||||
@mock.patch('rally.sshutils.paramiko')
|
||||
def test_upload_file(self, pk):
|
||||
pk.AutoAddPolicy.return_value = self.policy
|
||||
self.ssh.upload('/tmp/s', '/tmp/d')
|
||||
def test__get_pkey_rsa(self, m_paramiko, m_stringio):
|
||||
m_paramiko.SSHException = FakeParamikoException
|
||||
m_stringio.StringIO.return_value = 'string_key'
|
||||
m_paramiko.rsakey.RSAKey.from_private_key.return_value = 'rsa_key'
|
||||
dss = m_paramiko.dsskey.DSSKey
|
||||
dss.from_private_key.side_effect = m_paramiko.SSHException
|
||||
key = self.ssh._get_pkey('key')
|
||||
rsa_calls = m_paramiko.rsakey.RSAKey.from_private_key.mock_calls
|
||||
self.assertEqual([mock.call('string_key')], rsa_calls)
|
||||
self.assertEqual(key, 'rsa_key')
|
||||
m_stringio.StringIO.assert_called_once_with('key')
|
||||
|
||||
expected = [mock.call.set_missing_host_key_policy(self.policy),
|
||||
mock.call.connect(hostname='example.net', username='root',
|
||||
key_filename=os.path.expanduser(
|
||||
'~/.ssh/id_rsa'), port=22),
|
||||
mock.call.open_sftp(),
|
||||
mock.call.open_sftp().put('/tmp/s', '/tmp/d'),
|
||||
mock.call.open_sftp().close()]
|
||||
@mock.patch('rally.sshutils.SSH._get_pkey')
|
||||
@mock.patch('rally.sshutils.paramiko')
|
||||
def test__get_client(self, m_paramiko, m_pkey):
|
||||
m_pkey.return_value = 'key'
|
||||
fake_client = mock.Mock()
|
||||
m_paramiko.SSHClient.return_value = fake_client
|
||||
m_paramiko.AutoAddPolicy.return_value = 'autoadd'
|
||||
|
||||
self.assertEqual(pk.SSHClient().mock_calls, expected)
|
||||
ssh = sshutils.SSH('admin', 'example.net', pkey='key')
|
||||
client = ssh._get_client()
|
||||
|
||||
@mock.patch('rally.sshutils.SSH.execute')
|
||||
@mock.patch('rally.sshutils.SSH.upload')
|
||||
@mock.patch('rally.sshutils.random.choice')
|
||||
def test_execute_script_new(self, rc, up, ex):
|
||||
rc.return_value = 'a'
|
||||
self.ssh.execute_script('/bin/script')
|
||||
self.assertEqual(fake_client, client)
|
||||
client_calls = [
|
||||
mock.call.set_missing_host_key_policy('autoadd'),
|
||||
mock.call.connect('example.net', username='admin',
|
||||
port=22, pkey='key', key_filename=None,
|
||||
password=None),
|
||||
]
|
||||
self.assertEqual(client_calls, client.mock_calls)
|
||||
|
||||
up.assert_called_once_with('/bin/script', '/tmp/aaaaaaaaaaaaaaaa')
|
||||
ex.assert_has_calls([
|
||||
mock.call('/bin/sh /tmp/aaaaaaaaaaaaaaaa',
|
||||
get_stderr=False, get_stdout=False),
|
||||
mock.call('rm /tmp/aaaaaaaaaaaaaaaa')
|
||||
])
|
||||
def test_close(self):
|
||||
with mock.patch.object(self.ssh, '_client') as m_client:
|
||||
self.ssh.close()
|
||||
m_client.close.assert_called_once()
|
||||
self.assertFalse(self.ssh._client)
|
||||
|
||||
@mock.patch('rally.sshutils.SSH.execute')
|
||||
def test_wait(self, ex):
|
||||
self.ssh.wait()
|
||||
@mock.patch('rally.sshutils.StringIO')
|
||||
def test_execute(self, m_stringio):
|
||||
m_stringio.StringIO.side_effect = stdio = [mock.Mock(), mock.Mock()]
|
||||
stdio[0].read.return_value = 'stdout fake data'
|
||||
stdio[1].read.return_value = 'stderr fake data'
|
||||
with mock.patch.object(self.ssh, 'run', return_value=0) as m_run:
|
||||
status, stdout, stderr = self.ssh.execute('cmd',
|
||||
stdin='fake_stdin',
|
||||
timeout=43)
|
||||
m_run.assert_called_once_with('cmd', stdin='fake_stdin',
|
||||
stdout=stdio[0],
|
||||
stderr=stdio[1], timeout=43,
|
||||
raise_on_error=False)
|
||||
self.assertEqual(0, status)
|
||||
self.assertEqual('stdout fake data', stdout)
|
||||
self.assertEqual('stderr fake data', stderr)
|
||||
|
||||
@mock.patch('rally.sshutils.time')
|
||||
@mock.patch('rally.sshutils.SSH.execute')
|
||||
def test_wait_timeout(self, ex, mock_time):
|
||||
mock_time.time.side_effect = [1, 10]
|
||||
ex.side_effect = exceptions.SSHError
|
||||
self.assertRaises(exceptions.TimeoutException,
|
||||
self.ssh.wait, 1, 1)
|
||||
mock_time.sleep.assert_called_once_with(1)
|
||||
def test_wait_timeout(self, m_time):
|
||||
m_time.time.side_effect = [1, 50, 150]
|
||||
self.ssh.execute = mock.Mock(side_effect=[sshutils.SSHError,
|
||||
sshutils.SSHError,
|
||||
0])
|
||||
self.assertRaises(sshutils.SSHTimeout, self.ssh.wait)
|
||||
self.assertEqual([mock.call('uname')] * 2, self.ssh.execute.mock_calls)
|
||||
|
||||
@mock.patch('rally.sshutils.time')
|
||||
def test_wait(self, m_time):
|
||||
m_time.time.side_effect = [1, 50, 100]
|
||||
self.ssh.execute = mock.Mock(side_effect=[sshutils.SSHError,
|
||||
sshutils.SSHError,
|
||||
0])
|
||||
self.ssh.wait()
|
||||
self.assertEqual([mock.call('uname')] * 3, self.ssh.execute.mock_calls)
|
||||
|
||||
|
||||
class SSHRunTestCase(test.TestCase):
|
||||
"""Test SSH.run method in different aspects.
|
||||
|
||||
Also tested method 'execute'.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
super(SSHRunTestCase, self).setUp()
|
||||
|
||||
self.fake_client = mock.Mock()
|
||||
self.fake_session = mock.Mock()
|
||||
self.fake_transport = mock.Mock()
|
||||
|
||||
self.fake_transport.open_session.return_value = self.fake_session
|
||||
self.fake_client.get_transport.return_value = self.fake_transport
|
||||
|
||||
self.fake_session.recv_ready.return_value = False
|
||||
self.fake_session.recv_stderr_ready.return_value = False
|
||||
self.fake_session.send_ready.return_value = False
|
||||
self.fake_session.exit_status_ready.return_value = True
|
||||
self.fake_session.recv_exit_status.return_value = 0
|
||||
|
||||
self.ssh = sshutils.SSH('admin', 'example.net')
|
||||
self.ssh._get_client = mock.Mock(return_value=self.fake_client)
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_execute(self, m_select):
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.recv_ready.side_effect = [1, 0, 0]
|
||||
self.fake_session.recv_stderr_ready.side_effect = [1, 0]
|
||||
self.fake_session.recv.return_value = 'ok'
|
||||
self.fake_session.recv_stderr.return_value = 'error'
|
||||
self.fake_session.exit_status_ready.return_value = 1
|
||||
self.fake_session.recv_exit_status.return_value = 127
|
||||
self.assertEqual((127, 'ok', 'error'), self.ssh.execute('cmd'))
|
||||
self.fake_session.exec_command.assert_called_once_with('cmd')
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run(self, m_select):
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.assertEqual(0, self.ssh.run('cmd'))
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run_nonzero_status(self, m_select):
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.recv_exit_status.return_value = 1
|
||||
self.assertRaises(sshutils.SSHError, self.ssh.run, 'cmd')
|
||||
self.assertEqual(1, self.ssh.run('cmd', raise_on_error=False))
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run_stdout(self, m_select):
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.recv_ready.side_effect = [True, True, False]
|
||||
self.fake_session.recv.side_effect = ['ok1', 'ok2']
|
||||
stdout = mock.Mock()
|
||||
self.ssh.run('cmd', stdout=stdout)
|
||||
self.assertEqual([mock.call('ok1'), mock.call('ok2')],
|
||||
stdout.write.mock_calls)
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run_stderr(self, m_select):
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.recv_stderr_ready.side_effect = [True, False]
|
||||
self.fake_session.recv_stderr.return_value = 'error'
|
||||
stderr = mock.Mock()
|
||||
self.ssh.run('cmd', stderr=stderr)
|
||||
stderr.write.assert_called_once_with('error')
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run_stdin(self, m_select):
|
||||
"""Test run method with stdin.
|
||||
|
||||
Third send call was called with 'e2' because only 3 bytes was sent
|
||||
by second call. So remainig 2 bytes of 'line2' was sent by third call.
|
||||
"""
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.exit_status_ready.side_effect = [0, 0, 0, True]
|
||||
self.fake_session.send_ready.return_value = True
|
||||
self.fake_session.send.side_effect = [5, 3, 2]
|
||||
fake_stdin = mock.Mock()
|
||||
fake_stdin.read.side_effect = ['line1', 'line2', '']
|
||||
fake_stdin.closed = False
|
||||
|
||||
def close():
|
||||
fake_stdin.closed = True
|
||||
fake_stdin.close = mock.Mock(side_effect=close)
|
||||
self.ssh.run('cmd', stdin=fake_stdin)
|
||||
call = mock.call
|
||||
send_calls = [call('line1'), call('line2'), call('e2')]
|
||||
self.assertEqual(send_calls, self.fake_session.send.mock_calls)
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run_select_error(self, m_select):
|
||||
self.fake_session.exit_status_ready.return_value = False
|
||||
m_select.select.return_value = ([], [], [True])
|
||||
self.assertRaises(sshutils.SSHError, self.ssh.run, 'cmd')
|
||||
|
||||
@mock.patch('rally.sshutils.time')
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_run_timemout(self, m_select, m_time):
|
||||
m_time.time.side_effect = [1, 3700]
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.exit_status_ready.return_value = False
|
||||
self.assertRaises(sshutils.SSHTimeout, self.ssh.run, 'cmd')
|
||||
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test__run_client_closed_on_error(self, m_select):
|
||||
m_select.select.return_value = ([], [], [])
|
||||
self.fake_session.recv_ready.return_value = True
|
||||
self.fake_session.recv.side_effect = IOError
|
||||
self.assertRaises(IOError, self.ssh._run, self.fake_client, 'cmd')
|
||||
self.fake_client.close.assert_called_once()
|
||||
|
Loading…
Reference in New Issue
Block a user