SSHClient rework for SSH Manager integration: part 1

1. Methods optimization
2. Added docstrings

Change-Id: Ia929e151cef1edefa5238a2019e57b365ce74da3
blueprint: sshmanager-integration
This commit is contained in:
Alexey Stepanov 2016-05-31 11:03:56 +03:00
parent 67e4f9d811
commit 255de881fa
4 changed files with 477 additions and 411 deletions

View File

@ -40,9 +40,9 @@ class DevopsCalledProcessError(DevopsError):
expected=self.expected expected=self.expected
)) ))
if self.stdout: if self.stdout:
message += "\n\tSTDOUT: {}".format(self.stdout) message += "\n\tSTDOUT:\n{}".format(self.stdout)
if self.stderr: if self.stderr:
message += "\n\tSTDERR: {}".format(self.stderr) message += "\n\tSTDERR:\n{}".format(self.stderr)
super(DevopsCalledProcessError, self).__init__(message) super(DevopsCalledProcessError, self).__init__(message)
@property @property

View File

@ -17,8 +17,7 @@ import posixpath
import stat import stat
import paramiko import paramiko
# noinspection PyUnresolvedReferences import six
from six.moves import cStringIO
from devops.error import DevopsCalledProcessError from devops.error import DevopsCalledProcessError
from devops.helpers.retry import retry from devops.helpers.retry import retry
@ -27,13 +26,14 @@ from devops import logger
class SSHClient(object): class SSHClient(object):
class get_sudo(object): class get_sudo(object):
"""Context manager for call commands with sudo"""
def __init__(self, ssh): def __init__(self, ssh):
self.ssh = ssh self.ssh = ssh
def __enter__(self): def __enter__(self):
self.ssh.sudo_mode = True self.ssh.sudo_mode = True
def __exit__(self, exc_type, value, traceback): def __exit__(self, exc_type, exc_val, exc_tb):
self.ssh.sudo_mode = False self.ssh.sudo_mode = False
def __init__(self, host, port=22, username=None, password=None, def __init__(self, host, port=22, username=None, password=None,
@ -58,25 +58,10 @@ class SSHClient(object):
def password(self): def password(self):
return self.__password return self.__password
@password.setter
def password(self, new_val):
self.__password = new_val
self.reconnect()
@property @property
def private_keys(self): def private_keys(self):
return self.__private_keys return self.__private_keys
@private_keys.setter
def private_keys(self, new_val):
self.__private_keys = new_val
self.reconnect()
@private_keys.deleter
def private_keys(self):
self.__private_keys = []
self.reconnect()
@property @property
def private_key(self): def private_key(self):
return self.__actual_pkey return self.__actual_pkey
@ -85,11 +70,15 @@ class SSHClient(object):
def public_key(self): def public_key(self):
if self.private_key is None: if self.private_key is None:
return None return None
key = paramiko.RSAKey(file_obj=cStringIO(self.private_key)) key = self.private_key
return '{0} {1}'.format(key.get_name(), key.get_base64()) return '{0} {1}'.format(key.get_name(), key.get_base64())
@property @property
def _sftp(self): def _sftp(self):
"""SFTP channel access for inheritance
:raises: paramiko.SSHException
"""
if self.__sftp is not None: if self.__sftp is not None:
return self.__sftp return self.__sftp
logger.warning('SFTP is not connected, try to reconnect') logger.warning('SFTP is not connected, try to reconnect')
@ -115,7 +104,7 @@ class SSHClient(object):
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, *err): def __exit__(self, exc_type, exc_val, exc_tb):
self.clear() self.clear()
@retry(count=3, delay=3) @retry(count=3, delay=3)
@ -152,67 +141,162 @@ class SSHClient(object):
self.connect() self.connect()
self._connect_sftp() self._connect_sftp()
def check_call(self, command, verbose=False, excpected=0): def check_call(
self,
command, verbose=False,
expected=None, raise_on_err=True):
"""Execute command and check for return code
:type command: str
:type verbose: bool
:type expected: list
:type raise_on_err: bool
:rtype: dict
:raises: DevopsCalledProcessError
"""
if expected is None:
expected = [0]
ret = self.execute(command, verbose) ret = self.execute(command, verbose)
if ret['exit_code'] != excpected: if ret['exit_code'] not in expected:
raise DevopsCalledProcessError( message = (
command, ret['exit_code'], "Command '{cmd}' returned exit code {code} while "
expected=excpected, "expected {expected}\n"
stdout=ret['stdout_str'], "\tSTDOUT:\n"
stderr=ret['stderr_str']) "{stdout}"
"\n\tSTDERR:\n"
"{stderr}".format(
cmd=command,
code=ret['exit_code'],
expected=expected,
stdout=ret['stdout_str'],
stderr=ret['stderr_str']
))
logger.error(message)
if raise_on_err:
raise DevopsCalledProcessError(
command, ret['exit_code'],
expected=expected,
stdout=ret['stdout_str'],
stderr=ret['stderr_str'])
return ret return ret
def check_stderr(self, command, verbose=False): def check_stderr(self, command, verbose=False, raise_on_err=True):
ret = self.check_call(command, verbose) """Execute command expecting return code 0 and empty STDERR
:type command: str
:type verbose: bool
:type raise_on_err: bool
:rtype: dict
:raises: DevopsCalledProcessError
"""
ret = self.check_call(command, verbose, raise_on_err=raise_on_err)
if ret['stderr']: if ret['stderr']:
raise DevopsCalledProcessError(command, ret['exit_code'], message = (
stdout=ret['stdout_str'], "Command '{cmd}' STDERR while not expected\n"
stderr=ret['stderr_str']) "\texit code: {code}\n"
"\tSTDOUT:\n"
"{stdout}"
"\n\tSTDERR:\n"
"{stderr}".format(
cmd=command,
code=ret['exit_code'],
stdout=ret['stdout_str'],
stderr=ret['stderr_str']
))
logger.error(message)
if raise_on_err:
raise DevopsCalledProcessError(command, ret['exit_code'],
stdout=ret['stdout_str'],
stderr=ret['stderr_str'])
return ret return ret
@classmethod @classmethod
def execute_together(cls, remotes, command): def execute_together(
cls, remotes, command, expected=None, raise_on_err=True):
"""Execute command on multiple remotes in async mode
:type remotes: list
:type command: str
:type expected: list
:type raise_on_err: bool
:raises: DevopsCalledProcessError
"""
if expected is None:
expected = [0]
futures = {} futures = {}
errors = {} errors = {}
for remote in remotes: for remote in set(remotes): # Use distinct remotes
chan, _, _, _ = remote.execute_async(command) chan, _, _, _ = remote.execute_async(command)
futures[remote] = chan futures[remote] = chan
for remote, chan in futures.items(): for remote, chan in futures.items():
ret = chan.recv_exit_status() ret = chan.recv_exit_status()
chan.close() chan.close()
if ret != 0: if ret not in expected:
errors[remote.host] = ret errors[remote.host] = ret
if errors: if errors and raise_on_err:
raise DevopsCalledProcessError(command, errors) raise DevopsCalledProcessError(command, errors)
def execute(self, command, verbose=False): def execute(self, command, verbose=False):
"""Execute command and wait for return code
:type command: str
:type verbose: bool
:rtype: dict
"""
chan, _, stderr, stdout = self.execute_async(command) chan, _, stderr, stdout = self.execute_async(command)
# noinspection PyDictCreation
result = { result = {
'stdout': [], 'exit_code': chan.recv_exit_status()
'stderr': [],
'exit_code': 0
} }
for line in stdout: result['stdout'] = stdout.readlines()
result['stdout'].append(line) result['stderr'] = stderr.readlines()
if verbose:
logger.info(line)
for line in stderr:
result['stderr'].append(line)
if verbose:
logger.info(line)
result['exit_code'] = chan.recv_exit_status()
chan.close() chan.close()
result['stdout_str'] = ''.join(result['stdout']).strip()
result['stderr_str'] = ''.join(result['stderr']).strip() result['stdout_str'] = self._get_str_from_list(result['stdout'])
result['stderr_str'] = self._get_str_from_list(result['stderr'])
if verbose:
logger.info(
'{cmd} execution results:\n'
'Exit code: {code}\n'
'STDOUT:\n'
'{stdout}\n'
'STDERR:\n'
'{stderr}'.format(
cmd=command,
code=result['exit_code'],
stdout=result['stdout_str'],
stderr=result['stderr_str']
))
return result return result
@staticmethod
def _get_str_from_list(src):
"""Join data in list to the string, with python 2&3 compatibility.
:type src: list
:rtype: str
"""
if six.PY2:
return b''.join(src).strip()
else:
return b''.join(src).strip().decode(encoding='utf-8')
def execute_async(self, command): def execute_async(self, command):
"""Execute command in async mode and return channel with IO objects
:type command: str
:rtype: tuple
"""
logger.debug("Executing command: '{}'".format(command.rstrip())) logger.debug("Executing command: '{}'".format(command.rstrip()))
chan = self._ssh.get_transport().open_session() chan = self._ssh.get_transport().open_session()
stdin = chan.makefile('wb') stdin = chan.makefile('wb')
stdout = chan.makefile('rb') stdout = chan.makefile('rb')
stderr = chan.makefile_stderr('rb') stderr = chan.makefile_stderr('rb')
cmd = "%s\n" % command cmd = "{}\n".format(command)
if self.sudo_mode: if self.sudo_mode:
cmd = 'sudo -S bash -c "%s"' % cmd.replace('"', '\\"') cmd = 'sudo -S bash -c "%s"' % cmd.replace('"', '\\"')
chan.exec_command(cmd) chan.exec_command(cmd)
@ -225,7 +309,7 @@ class SSHClient(object):
def execute_through_host( def execute_through_host(
self, self,
target_host, hostname,
cmd, cmd,
username=None, username=None,
password=None, password=None,
@ -237,21 +321,15 @@ class SSHClient(object):
key = self.private_key key = self.private_key
intermediate_channel = self._ssh.get_transport().open_channel( intermediate_channel = self._ssh.get_transport().open_channel(
'direct-tcpip', (target_host, target_port), (self.host, 0)) kind='direct-tcpip',
transport = paramiko.Transport(intermediate_channel) dest_addr=(hostname, target_port),
transport.start_client() src_addr=(self.host, 0))
logger.info("Passing authentication to: {}".format(target_host)) transport = paramiko.Transport(sock=intermediate_channel)
if password is None and key is None:
logger.debug('auth_none')
transport.auth_none(username=username)
elif key is not None:
logger.debug('auth_publickey')
transport.auth_publickey(username=username, key=key)
else:
logger.debug('auth_password')
transport.auth_password(username=username, password=password)
logger.debug("Opening session") # start client and authenticate transport
transport.connect(username=username, password=password, pkey=key)
# open ssh session
channel = transport.open_session() channel = transport.open_session()
# Make proxy objects for read # Make proxy objects for read
@ -266,29 +344,48 @@ class SSHClient(object):
result = {} result = {}
result['exit_code'] = channel.recv_exit_status() result['exit_code'] = channel.recv_exit_status()
result['stdout'] = stdout.read() result['stdout'] = stdout.readlines()
result['stderr'] = stderr.read() result['stderr'] = stderr.readlines()
channel.close() channel.close()
result['stdout_str'] = ''.join(result['stdout']).strip() result['stdout_str'] = self._get_str_from_list(result['stdout'])
result['stderr_str'] = ''.join(result['stderr']).strip() result['stderr_str'] = self._get_str_from_list(result['stderr'])
return result return result
def mkdir(self, path): def mkdir(self, path):
"""run 'mkdir -p path' on remote
:type path: str
"""
if self.exists(path): if self.exists(path):
return return
logger.debug("Creating directory: {}".format(path)) logger.debug("Creating directory: {}".format(path))
self.execute("mkdir -p {}\n".format(path)) self.execute("mkdir -p {}\n".format(path))
def rm_rf(self, path): def rm_rf(self, path):
"""run 'rm -rf path' on remote
:type path: str
"""
logger.debug("rm -rf {}".format(path)) logger.debug("rm -rf {}".format(path))
self.execute("rm -rf %s" % path) self.execute("rm -rf {}".format(path))
def open(self, path, mode='r'): def open(self, path, mode='r'):
"""Open file on remote using SFTP session
:type path: str
:type mode: str
:return: file.open() stream
"""
return self._sftp.open(path, mode) return self._sftp.open(path, mode)
def upload(self, source, target): def upload(self, source, target):
"""Upload file(s) from source to target using SFTP session
:type source: str
:type target: str
"""
logger.debug("Copying '%s' -> '%s'", source, target) logger.debug("Copying '%s' -> '%s'", source, target)
if self.isdir(target): if self.isdir(target):
@ -315,6 +412,12 @@ class SSHClient(object):
self._sftp.put(local_path, remote_path) self._sftp.put(local_path, remote_path)
def download(self, destination, target): def download(self, destination, target):
"""Download file(s) to target from destination
:type destination: str
:type target: str
:rtype: bool
"""
logger.debug( logger.debug(
"Copying '%s' -> '%s' from remote to local host", "Copying '%s' -> '%s' from remote to local host",
destination, target destination, target
@ -337,6 +440,11 @@ class SSHClient(object):
return os.path.exists(target) return os.path.exists(target)
def exists(self, path): def exists(self, path):
"""Check for file existence using SFTP session
:type path: str
:rtype: bool
"""
try: try:
self._sftp.lstat(path) self._sftp.lstat(path)
return True return True
@ -344,6 +452,11 @@ class SSHClient(object):
return False return False
def isfile(self, path): def isfile(self, path):
"""Check, that path is file using SFTP session
:type path: str
:rtype: bool
"""
try: try:
attrs = self._sftp.lstat(path) attrs = self._sftp.lstat(path)
return attrs.st_mode & stat.S_IFREG != 0 return attrs.st_mode & stat.S_IFREG != 0
@ -351,6 +464,11 @@ class SSHClient(object):
return False return False
def isdir(self, path): def isdir(self, path):
"""Check, that path is directory using SFTP session
:type path: str
:rtype: bool
"""
try: try:
attrs = self._sftp.lstat(path) attrs = self._sftp.lstat(path)
return attrs.st_mode & stat.S_IFDIR != 0 return attrs.st_mode & stat.S_IFDIR != 0

View File

@ -14,7 +14,6 @@
# pylint: disable=no-self-use # pylint: disable=no-self-use
from contextlib import closing
from os.path import basename from os.path import basename
import posixpath import posixpath
import stat import stat
@ -22,8 +21,7 @@ from unittest import TestCase
import mock import mock
import paramiko import paramiko
# noinspection PyUnresolvedReferences from six import PY2
from six.moves import cStringIO
from devops.error import DevopsCalledProcessError from devops.error import DevopsCalledProcessError
from devops.helpers.ssh_client import SSHClient from devops.helpers.ssh_client import SSHClient
@ -32,18 +30,14 @@ from devops.helpers.ssh_client import SSHClient
def gen_private_keys(amount=1): def gen_private_keys(amount=1):
keys = [] keys = []
for _ in range(amount): for _ in range(amount):
with closing(cStringIO()) as output: keys.append(paramiko.RSAKey.generate(1024))
paramiko.RSAKey.generate(1024).write_private_key(output)
keys.append(output.getvalue())
return keys return keys
def gen_public_key(private_key=None): def gen_public_key(private_key=None):
if private_key is None: if private_key is None:
key = paramiko.RSAKey.generate(1024) private_key = paramiko.RSAKey.generate(1024)
else: return '{0} {1}'.format(private_key.get_name(), private_key.get_base64())
key = paramiko.RSAKey(file_obj=cStringIO(private_key))
return '{0} {1}'.format(key.get_name(), key.get_base64())
host = '127.0.0.1' host = '127.0.0.1'
@ -104,56 +98,6 @@ class TestSSHClient(TestCase):
sftp = ssh._sftp sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp()) self.assertEqual(sftp, client().open_sftp())
def test_init_change_login_passwd(self, client, policy, logger):
_ssh = mock.call()
def check_expected_on_connect(pwd):
client.assert_called_once()
policy.assert_called_once()
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=pwd,
port=port, username=username),
_ssh.open_sftp()
]
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, pwd,
private_keys)
self.assertIsNone(ssh.private_key)
self.assertIsNone(ssh.public_key)
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, pwd
)),
logger.mock_calls
)
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys)
check_expected_on_connect(pwd=password)
new_password = 'password2'
client.reset_mock()
policy.reset_mock()
logger.reset_mock()
ssh.password = new_password
check_expected_on_connect(pwd=new_password)
def test_init_keys(self, client, policy, logger): def test_init_keys(self, client, policy, logger):
_ssh = mock.call() _ssh = mock.call()
@ -270,24 +214,24 @@ class TestSSHClient(TestCase):
sftp = ssh._sftp sftp = ssh._sftp
self.assertEqual(sftp, _sftp) self.assertEqual(sftp, _sftp)
def init_ssh(self, client, policy, logger):
ssh = SSHClient( @mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@mock.patch(
'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.SSHClient', autospec=True)
class TestExecute(TestCase):
@staticmethod
def get_ssh():
"""SSHClient object builder for execution tests
:rtype: SSHClient
"""
return SSHClient(
host=host, host=host,
port=port, port=port,
username=username, username=username,
password=password, password=password
private_keys=private_keys) )
client.assert_called_once()
policy.assert_called_once()
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
)
return ssh
def test_execute_async(self, client, policy, logger): def test_execute_async(self, client, policy, logger):
chan = mock.Mock() chan = mock.Mock()
@ -299,7 +243,7 @@ class TestSSHClient(TestCase):
_ssh.attach_mock(get_transport, 'get_transport') _ssh.attach_mock(get_transport, 'get_transport')
client.return_value = _ssh client.return_value = _ssh
ssh = self.init_ssh(client, policy, logger) ssh = self.get_ssh()
result = ssh.execute_async(command=command) result = ssh.execute_async(command=command)
get_transport.assert_called_once() get_transport.assert_called_once()
@ -328,7 +272,7 @@ class TestSSHClient(TestCase):
_ssh.attach_mock(get_transport, 'get_transport') _ssh.attach_mock(get_transport, 'get_transport')
client.return_value = _ssh client.return_value = _ssh
ssh = self.init_ssh(client, policy, logger) ssh = self.get_ssh()
ssh.sudo_mode = True ssh.sudo_mode = True
result = ssh.execute_async(command=command) result = ssh.execute_async(command=command)
@ -358,7 +302,7 @@ class TestSSHClient(TestCase):
_ssh.attach_mock(get_transport, 'get_transport') _ssh.attach_mock(get_transport, 'get_transport')
client.return_value = _ssh client.return_value = _ssh
ssh = self.init_ssh(client, policy, logger) ssh = self.get_ssh()
self.assertFalse(ssh.sudo_mode) self.assertFalse(ssh.sudo_mode)
with SSHClient.get_sudo(ssh): with SSHClient.get_sudo(ssh):
self.assertTrue(ssh.sudo_mode) self.assertTrue(ssh.sudo_mode)
@ -381,60 +325,150 @@ class TestSSHClient(TestCase):
logger.mock_calls logger.mock_calls
) )
@mock.patch( def test_execute_async_sudo_password(
'devops.helpers.ssh_client.SSHClient.execute_async') self, client, policy, logger):
def test_execute(self, execute_async, client, policy, logger): stdin = mock.Mock(name='stdin')
stderr = [' ', '0', '1', ' '] stdout = mock.Mock(name='stdout')
stdout = [' ', '2', '3', ' '] stdout_channel = mock.Mock()
exit_code = 0 stdout_channel.configure_mock(closed=False)
stdout.attach_mock(stdout_channel, 'channel')
makefile = mock.Mock(side_effect=[stdin, stdout])
chan = mock.Mock()
chan.attach_mock(makefile, 'makefile')
open_session = mock.Mock(return_value=chan)
transport = mock.Mock()
transport.attach_mock(open_session, 'open_session')
get_transport = mock.Mock(return_value=transport)
_ssh = mock.Mock()
_ssh.attach_mock(get_transport, 'get_transport')
client.return_value = _ssh
ssh = self.get_ssh()
ssh.sudo_mode = True
result = ssh.execute_async(command=command)
get_transport.assert_called_once()
open_session.assert_called_once()
# raise ValueError(closed.mock_calls)
stdin.assert_has_calls((mock.call.flush(), ))
self.assertIn(chan, result)
chan.assert_has_calls((
mock.call.makefile('wb'),
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('sudo -S bash -c "{}\n"'.format(command))
))
self.assertIn(
mock.call.debug(
"Executing command: '{}'".format(command.rstrip())),
logger.mock_calls
)
@staticmethod
def get_patched_execute_async_retval(ec=0, stderr_val=True):
stderr = mock.Mock()
stdout = mock.Mock()
stderr_readlines = mock.Mock(
return_value=[b' \n', b'0\n', b'1\n', b' \n'] if stderr_val else []
)
stdout_readlines = mock.Mock(
return_value=[b' \n', b'2\n', b'3\n', b' \n'])
stderr.attach_mock(stderr_readlines, 'readlines')
stdout.attach_mock(stdout_readlines, 'readlines')
exit_code = ec
chan = mock.Mock() chan = mock.Mock()
recv_exit_status = mock.Mock(return_value=exit_code) recv_exit_status = mock.Mock(return_value=exit_code)
chan.attach_mock(recv_exit_status, 'recv_exit_status') chan.attach_mock(recv_exit_status, 'recv_exit_status')
execute_async.return_value = chan, '', stderr, stdout return chan, '', stderr, stdout
ssh = self.init_ssh(client, policy, logger) @mock.patch(
'devops.helpers.ssh_client.SSHClient.execute_async')
def test_execute(self, execute_async, client, policy, logger):
chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
execute_async.return_value = chan, _stdin, stderr, stdout
stderr_lst = stderr.readlines()
stdout_lst = stdout.readlines()
expected = {
'exit_code': chan.recv_exit_status(),
'stderr': stderr_lst,
'stdout': stdout_lst}
if PY2:
expected['stderr_str'] = b''.join(stderr_lst).strip()
expected['stdout_str'] = b''.join(stdout_lst).strip()
else:
expected['stderr_str'] = b''.join(stderr_lst).strip().decode(
encoding='utf-8')
expected['stdout_str'] = b''.join(stdout_lst).strip().decode(
encoding='utf-8')
ssh = self.get_ssh()
logger.reset_mock()
result = ssh.execute(command=command, verbose=True)
result = ssh.execute(command=command)
self.assertEqual( self.assertEqual(
result, result,
{ expected
'stderr_str': ''.join(stderr).strip(),
'stdout_str': ''.join(stdout).strip(),
'exit_code': exit_code,
'stderr': stderr,
'stdout': stdout}
) )
execute_async.assert_called_once_with(command) execute_async.assert_called_once_with(command)
chan.assert_has_calls(( chan.assert_has_calls((
mock.call.recv_exit_status(), mock.call.recv_exit_status(),
mock.call.close())) mock.call.close()))
logger.assert_has_calls((
mock.call.info(
'{cmd} execution results:\n'
'Exit code: {code}\n'
'STDOUT:\n'
'{stdout}\n'
'STDERR:\n'
'{stderr}'.format(
cmd=command,
code=result['exit_code'],
stdout=result['stdout_str'],
stderr=result['stderr_str']
)),
))
@mock.patch( @mock.patch(
'devops.helpers.ssh_client.SSHClient.execute_async') 'devops.helpers.ssh_client.SSHClient.execute_async')
def test_execute_together(self, execute_async, client, policy, logger): def test_execute_together(self, execute_async, client, policy, logger):
stderr = [' ', '0', '1', ' '] chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval()
stdout = [' ', '2', '3', ' '] execute_async.return_value = chan, _stdin, stderr, stdout
exit_code = 0
chan = mock.Mock() stderr_lst = stderr.readlines()
recv_exit_status = mock.Mock(return_value=exit_code) stdout_lst = stdout.readlines()
chan.attach_mock(recv_exit_status, 'recv_exit_status')
execute_async.return_value = chan, '', stderr, stdout expected = {
'exit_code': chan.recv_exit_status(),
'stderr': stderr_lst,
'stdout': stdout_lst}
if PY2:
expected['stderr_str'] = b''.join(stderr_lst).strip()
expected['stdout_str'] = b''.join(stdout_lst).strip()
else:
expected['stderr_str'] = b''.join(stderr_lst).strip().decode(
encoding='utf-8')
expected['stdout_str'] = b''.join(stdout_lst).strip().decode(
encoding='utf-8')
host2 = '127.0.0.2' host2 = '127.0.0.2'
ssh = SSHClient( ssh = self.get_ssh()
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys)
ssh2 = SSHClient( ssh2 = SSHClient(
host=host2, host=host2,
port=port, port=port,
username=username, username=username,
password=password, password=password
private_keys=private_keys) )
remotes = ssh, ssh2 remotes = [ssh, ssh2]
SSHClient.execute_together( SSHClient.execute_together(
remotes=remotes, command=command) remotes=remotes, command=command)
@ -447,35 +481,35 @@ class TestSSHClient(TestCase):
mock.call.close() mock.call.close()
)) ))
SSHClient.execute_together(
remotes=remotes, command=command, expected=[1], raise_on_err=False)
with self.assertRaises(DevopsCalledProcessError):
SSHClient.execute_together(
remotes=remotes, command=command, expected=[1])
@mock.patch( @mock.patch(
'devops.helpers.ssh_client.SSHClient.execute') 'devops.helpers.ssh_client.SSHClient.execute')
def test_check_call(self, execute, client, policy, logger): def test_check_call(self, execute, client, policy, logger):
stderr = [' ', '0', '1', ' ']
stdout = [' ', '2', '3', ' ']
exit_code = 0 exit_code = 0
return_value = { return_value = {
'stderr_str': ''.join(stderr).strip(), 'stderr_str': '0\n1',
'stdout_str': ''.join(stdout).strip(), 'stdout_str': '2\n3',
'exit_code': exit_code, 'exit_code': exit_code,
'stderr': stderr, 'stderr': [b' \n', b'0\n', b'1\n', b' \n'],
'stdout': stdout} 'stdout': [b' \n', b'2\n', b'3\n', b' \n']}
execute.return_value = return_value execute.return_value = return_value
verbose = False verbose = False
ssh = self.init_ssh(client, policy, logger) ssh = self.get_ssh()
result = ssh.check_call(command=command, verbose=verbose) result = ssh.check_call(command=command, verbose=verbose)
execute.assert_called_once_with(command, verbose) execute.assert_called_once_with(command, verbose)
self.assertEqual(result, return_value) self.assertEqual(result, return_value)
exit_code = 1 exit_code = 1
return_value = { return_value['exit_code'] = exit_code
'stderr_str': ''.join(stderr).strip(),
'stdout_str': ''.join(stdout).strip(),
'exit_code': exit_code,
'stderr': stderr,
'stdout': stdout}
execute.reset_mock() execute.reset_mock()
execute.return_value = return_value execute.return_value = return_value
with self.assertRaises(DevopsCalledProcessError): with self.assertRaises(DevopsCalledProcessError):
@ -485,41 +519,47 @@ class TestSSHClient(TestCase):
@mock.patch( @mock.patch(
'devops.helpers.ssh_client.SSHClient.check_call') 'devops.helpers.ssh_client.SSHClient.check_call')
def test_check_stderr(self, check_call, client, policy, logger): def test_check_stderr(self, check_call, client, policy, logger):
stdout = [' ', '0', '1', ' ']
stderr = []
exit_code = 0
return_value = { return_value = {
'stderr_str': ''.join(stderr).strip(), 'stderr_str': '',
'stdout_str': ''.join(stdout).strip(), 'stdout_str': '2\n3',
'exit_code': exit_code, 'exit_code': 0,
'stderr': stderr, 'stderr': [],
'stdout': stdout} 'stdout': [b' \n', b'2\n', b'3\n', b' \n']}
check_call.return_value = return_value check_call.return_value = return_value
verbose = False verbose = False
raise_on_err = True
ssh = self.init_ssh(client, policy, logger) ssh = self.get_ssh()
result = ssh.check_stderr(command=command, verbose=verbose) result = ssh.check_stderr(
check_call.assert_called_once_with(command, verbose) command=command, verbose=verbose, raise_on_err=raise_on_err)
check_call.assert_called_once_with(
command, verbose, raise_on_err=raise_on_err)
self.assertEqual(result, return_value) self.assertEqual(result, return_value)
stderr = [' ', '2', '3', ' '] return_value['stderr_str'] = '0\n1'
return_value = { return_value['stderr'] = [b' \n', b'0\n', b'1\n', b' \n']
'stderr_str': ''.join(stderr).strip(),
'stdout_str': ''.join(stdout).strip(),
'exit_code': exit_code,
'stderr': stderr,
'stdout': stdout}
check_call.reset_mock() check_call.reset_mock()
check_call.return_value = return_value check_call.return_value = return_value
with self.assertRaises(DevopsCalledProcessError): with self.assertRaises(DevopsCalledProcessError):
ssh.check_stderr(command=command, verbose=verbose) ssh.check_stderr(
check_call.assert_called_once_with(command, verbose) command=command, verbose=verbose, raise_on_err=raise_on_err)
check_call.assert_called_once_with(
command, verbose, raise_on_err=raise_on_err)
def prepare_execute_through_host(
self, transp, client, policy, logger): @mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@mock.patch(
'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.SSHClient', autospec=True)
@mock.patch('paramiko.Transport', autospec=True)
class TestExecuteThrowHost(TestCase):
@staticmethod
def prepare_execute_through_host(transp, client, exit_code):
intermediate_channel = mock.Mock() intermediate_channel = mock.Mock()
open_channel = mock.Mock(return_value=intermediate_channel) open_channel = mock.Mock(return_value=intermediate_channel)
intermediate_transport = mock.Mock() intermediate_transport = mock.Mock()
intermediate_transport.attach_mock(open_channel, 'open_channel') intermediate_transport.attach_mock(open_channel, 'open_channel')
@ -532,22 +572,16 @@ class TestSSHClient(TestCase):
transport = mock.Mock() transport = mock.Mock()
transp.return_value = transport transp.return_value = transport
stderr = [' ', '0', '1', ' ']
stdout = [' ', '2', '3', ' ']
exit_code = 0
return_value = {
'stderr_str': ''.join(stderr).strip(),
'stdout_str': ''.join(stdout).strip(),
'exit_code': exit_code,
'stderr': stderr,
'stdout': stdout}
recv_exit_status = mock.Mock(return_value=exit_code) recv_exit_status = mock.Mock(return_value=exit_code)
makefile = mock.Mock() makefile = mock.Mock()
makefile.attach_mock(mock.Mock(return_value=stdout), 'read') makefile.attach_mock(mock.Mock(
return_value=[b' \n', b'2\n', b'3\n', b' \n']),
'readlines')
makefile_stderr = mock.Mock() makefile_stderr = mock.Mock()
makefile_stderr.attach_mock(mock.Mock(return_value=stderr), 'read') makefile_stderr.attach_mock(
mock.Mock(return_value=[b' \n', b'0\n', b'1\n', b' \n']),
'readlines')
channel = mock.Mock() channel = mock.Mock()
channel.attach_mock(mock.Mock(return_value=makefile), 'makefile') channel.attach_mock(mock.Mock(return_value=makefile), 'makefile')
channel.attach_mock(mock.Mock( channel.attach_mock(mock.Mock(
@ -556,100 +590,34 @@ class TestSSHClient(TestCase):
open_session = mock.Mock(return_value=channel) open_session = mock.Mock(return_value=channel)
transport.attach_mock(open_session, 'open_session') transport.attach_mock(open_session, 'open_session')
ssh = self.init_ssh(client, policy, logger)
return ( return (
ssh, return_value, open_session, transport, channel, get_transport, open_session, transport, channel, get_transport,
open_channel, intermediate_channel open_channel, intermediate_channel
) )
@mock.patch('paramiko.Transport', autospec=True)
def test_execute_through_host_no_creds( def test_execute_through_host_no_creds(
self, transp, client, policy, logger): self, transp, client, policy, logger):
target = '10.0.0.2' target = '127.0.0.2'
(
ssh, return_value, open_session, transport, channel, get_transport,
open_channel, intermediate_channel
) = self.prepare_execute_through_host(transp, client, policy, logger)
result = ssh.execute_through_host(target, command)
self.assertEqual(result, return_value)
get_transport.assert_called_once()
open_channel.assert_called_once()
transp.assert_called_once_with(intermediate_channel)
open_session.assert_called_once()
transport.assert_has_calls((
mock.call.start_client(),
mock.call.auth_password(username=username, password=password),
mock.call.open_session()
))
channel.assert_has_calls((
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('ls ~ '),
mock.call.recv_exit_status(),
mock.call.close()
))
@mock.patch('paramiko.Transport', autospec=True)
def test_execute_through_host_no_creds_key(
self, transp, client, policy, logger):
target = '10.0.0.2'
private_keys = gen_private_keys(1)
intermediate_channel = mock.Mock()
open_channel = mock.Mock(return_value=intermediate_channel)
intermediate_transport = mock.Mock()
intermediate_transport.attach_mock(open_channel, 'open_channel')
get_transport = mock.Mock(return_value=intermediate_transport)
_ssh = mock.Mock()
_ssh.attach_mock(get_transport, 'get_transport')
client.return_value = _ssh
transport = mock.Mock()
transp.return_value = transport
stderr = [' ', '0', '1', ' ']
stdout = [' ', '2', '3', ' ']
exit_code = 0 exit_code = 0
return_value = { return_value = {
'stderr_str': ''.join(stderr).strip(), 'stderr_str': '0\n1',
'stdout_str': ''.join(stdout).strip(), 'stdout_str': '2\n3',
'exit_code': exit_code, 'exit_code': exit_code,
'stderr': stderr, 'stderr': [b' \n', b'0\n', b'1\n', b' \n'],
'stdout': stdout} 'stdout': [b' \n', b'2\n', b'3\n', b' \n']}
recv_exit_status = mock.Mock(return_value=exit_code) (
open_session, transport, channel, get_transport,
makefile = mock.Mock() open_channel, intermediate_channel
makefile.attach_mock(mock.Mock(return_value=stdout), 'read') ) = self.prepare_execute_through_host(
makefile_stderr = mock.Mock() transp, client, exit_code=exit_code)
makefile_stderr.attach_mock(mock.Mock(return_value=stderr), 'read')
channel = mock.Mock()
channel.attach_mock(mock.Mock(return_value=makefile), 'makefile')
channel.attach_mock(mock.Mock(
return_value=makefile_stderr), 'makefile_stderr')
channel.attach_mock(recv_exit_status, 'recv_exit_status')
open_session = mock.Mock(return_value=channel)
transport.attach_mock(open_session, 'open_session')
ssh = SSHClient( ssh = SSHClient(
host=host, host=host,
port=port, port=port,
username=username, username=username,
password=password, password=password
private_keys=private_keys) )
client.assert_called_once()
policy.assert_called_once()
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
)
result = ssh.execute_through_host(target, command) result = ssh.execute_through_host(target, command)
self.assertEqual(result, return_value) self.assertEqual(result, return_value)
@ -658,9 +626,7 @@ class TestSSHClient(TestCase):
transp.assert_called_once_with(intermediate_channel) transp.assert_called_once_with(intermediate_channel)
open_session.assert_called_once() open_session.assert_called_once()
transport.assert_has_calls(( transport.assert_has_calls((
mock.call.start_client(), mock.call.connect(username=username, password=password, pkey=None),
mock.call.auth_publickey(
username=username, key=private_keys[0]),
mock.call.open_session() mock.call.open_session()
)) ))
channel.assert_has_calls(( channel.assert_has_calls((
@ -671,93 +637,43 @@ class TestSSHClient(TestCase):
mock.call.close() mock.call.close()
)) ))
@mock.patch('paramiko.Transport', autospec=True) def test_execute_through_host_auth(
def test_execute_through_host_password(
self, transp, client, policy, logger): self, transp, client, policy, logger):
target = '10.0.0.2'
_login = 'cirros' _login = 'cirros'
_password = 'cubswin:)' _password = 'cubswin:)'
( target = '127.0.0.2'
ssh, return_value, open_session, transport, channel, get_transport, exit_code = 0
open_channel, intermediate_channel return_value = {
) = self.prepare_execute_through_host(transp, client, policy, logger) 'stderr_str': '0\n1',
'stdout_str': '2\n3',
result = ssh.execute_through_host( 'exit_code': exit_code,
target, command, username=_login, password=_password) 'stderr': [b' \n', b'0\n', b'1\n', b' \n'],
self.assertEqual(result, return_value) 'stdout': [b' \n', b'2\n', b'3\n', b' \n']}
get_transport.assert_called_once()
open_channel.assert_called_once()
transp.assert_called_once_with(intermediate_channel)
open_session.assert_called_once()
transport.assert_has_calls((
mock.call.start_client(),
mock.call.auth_password(username=_login, password=_password),
mock.call.open_session()
))
channel.assert_has_calls((
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('ls ~ '),
mock.call.recv_exit_status(),
mock.call.close()
))
@mock.patch('paramiko.Transport', autospec=True)
def test_execute_through_host_no_auth(
self, transp, client, policy, logger):
target = '10.0.0.2'
_login = 'cirros'
( (
ssh, return_value, open_session, transport, channel, get_transport, open_session, transport, channel, get_transport,
open_channel, intermediate_channel open_channel, intermediate_channel
) = self.prepare_execute_through_host(transp, client, policy, logger) ) = self.prepare_execute_through_host(
transp, client, exit_code=exit_code)
result = ssh.execute_through_host( ssh = SSHClient(
target, command, username=_login) host=host,
self.assertEqual(result, return_value) port=port,
get_transport.assert_called_once() username=username,
open_channel.assert_called_once() password=password
transp.assert_called_once_with(intermediate_channel) )
open_session.assert_called_once()
transport.assert_has_calls((
mock.call.start_client(),
mock.call.auth_none(username=_login),
mock.call.open_session()
))
channel.assert_has_calls((
mock.call.makefile('rb'),
mock.call.makefile_stderr('rb'),
mock.call.exec_command('ls ~ '),
mock.call.recv_exit_status(),
mock.call.close()
))
@mock.patch('paramiko.Transport', autospec=True)
def test_execute_through_host_key(
self, transp, client, policy, logger):
target = '10.0.0.2'
_login = 'cirros'
_password = 'cubswin:)'
key = gen_private_keys(1)[0]
(
ssh, return_value, open_session, transport, channel, get_transport,
open_channel, intermediate_channel
) = self.prepare_execute_through_host(transp, client, policy, logger)
result = ssh.execute_through_host( result = ssh.execute_through_host(
target, command, target, command,
username=_login, password=_password, key=key) username=_login, password=_password)
self.assertEqual(result, return_value) self.assertEqual(result, return_value)
get_transport.assert_called_once() get_transport.assert_called_once()
open_channel.assert_called_once() open_channel.assert_called_once()
transp.assert_called_once_with(intermediate_channel) transp.assert_called_once_with(intermediate_channel)
open_session.assert_called_once() open_session.assert_called_once()
transport.assert_has_calls(( transport.assert_has_calls((
mock.call.start_client(), mock.call.connect(username=_login, password=_password, pkey=None),
mock.call.auth_publickey(username=_login, key=key),
mock.call.open_session() mock.call.open_session()
)) ))
channel.assert_has_calls(( channel.assert_has_calls((
@ -768,7 +684,14 @@ class TestSSHClient(TestCase):
mock.call.close() mock.call.close()
)) ))
def prepare_sftp_file_tests(self, client, policy, logger):
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@mock.patch(
'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.SSHClient', autospec=True)
class TestSftp(TestCase):
@staticmethod
def prepare_sftp_file_tests(client):
_ssh = mock.Mock() _ssh = mock.Mock()
client.return_value = _ssh client.return_value = _ssh
_sftp = mock.Mock() _sftp = mock.Mock()
@ -779,24 +702,12 @@ class TestSSHClient(TestCase):
host=host, host=host,
port=port, port=port,
username=username, username=username,
password=password, password=password
private_keys=private_keys) )
client.assert_called_once()
policy.assert_called_once()
self.check_defaults(ssh, host, port, username, password, private_keys)
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
)
return ssh, _sftp return ssh, _sftp
def test_exists(self, client, policy, logger): def test_exists(self, client, policy, logger):
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
lstat = mock.Mock() lstat = mock.Mock()
_sftp.attach_mock(lstat, 'lstat') _sftp.attach_mock(lstat, 'lstat')
path = '/etc' path = '/etc'
@ -818,7 +729,7 @@ class TestSSHClient(TestCase):
def __init__(self, mode): def __init__(self, mode):
self.st_mode = mode self.st_mode = mode
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
lstat = mock.Mock() lstat = mock.Mock()
_sftp.attach_mock(lstat, 'lstat') _sftp.attach_mock(lstat, 'lstat')
lstat.return_value = Attrs(stat.S_IFREG) lstat.return_value = Attrs(stat.S_IFREG)
@ -848,7 +759,7 @@ class TestSSHClient(TestCase):
def __init__(self, mode): def __init__(self, mode):
self.st_mode = mode self.st_mode = mode
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
lstat = mock.Mock() lstat = mock.Mock()
_sftp.attach_mock(lstat, 'lstat') _sftp.attach_mock(lstat, 'lstat')
lstat.return_value = Attrs(stat.S_IFDIR) lstat.return_value = Attrs(stat.S_IFDIR)
@ -875,11 +786,16 @@ class TestSSHClient(TestCase):
@mock.patch('devops.helpers.ssh_client.SSHClient.exists') @mock.patch('devops.helpers.ssh_client.SSHClient.exists')
@mock.patch('devops.helpers.ssh_client.SSHClient.execute') @mock.patch('devops.helpers.ssh_client.SSHClient.execute')
def test_mkdir(self, execute, exists, client, policy, logger): def test_mkdir(self, execute, exists, client, policy, logger):
exists.return_value = False exists.side_effect = [False, True]
path = '~/tst' path = '~/tst'
ssh = self.init_ssh(client, policy, logger) ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
# Path not exists # Path not exists
ssh.mkdir(path) ssh.mkdir(path)
@ -888,7 +804,6 @@ class TestSSHClient(TestCase):
# Path exists # Path exists
exists.reset_mock() exists.reset_mock()
exists.return_value = True
execute.reset_mock() execute.reset_mock()
ssh.mkdir(path) ssh.mkdir(path)
@ -899,14 +814,19 @@ class TestSSHClient(TestCase):
def test_rm_rf(self, execute, client, policy, logger): def test_rm_rf(self, execute, client, policy, logger):
path = '~/tst' path = '~/tst'
ssh = self.init_ssh(client, policy, logger) ssh = SSHClient(
host=host,
port=port,
username=username,
password=password
)
# Path not exists # Path not exists
ssh.rm_rf(path) ssh.rm_rf(path)
execute.assert_called_once_with("rm -rf {}".format(path)) execute.assert_called_once_with("rm -rf {}".format(path))
def test_open(self, client, policy, logger): def test_open(self, client, policy, logger):
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
fopen = mock.Mock(return_value=True) fopen = mock.Mock(return_value=True)
_sftp.attach_mock(fopen, 'open') _sftp.attach_mock(fopen, 'open')
@ -924,11 +844,11 @@ class TestSSHClient(TestCase):
self, self,
isdir, remote_isdir, exists, remote_exists, client, policy, logger isdir, remote_isdir, exists, remote_exists, client, policy, logger
): ):
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
isdir.return_value = True isdir.return_value = True
exists.return_value = True exists.side_effect = [True, False, False]
remote_isdir.return_value = False remote_isdir.side_effect = [False, False, True]
remote_exists.return_value = True remote_exists.side_effect = [True, False, False]
dst = '/etc/environment' dst = '/etc/environment'
target = '/tmp/environment' target = '/tmp/environment'
@ -942,12 +862,40 @@ class TestSSHClient(TestCase):
mock.call.get(dst, posixpath.join(target, basename(dst))), mock.call.get(dst, posixpath.join(target, basename(dst))),
)) ))
# Negative scenarios
logger.reset_mock()
result = ssh.download(destination=dst, target=target)
logger.assert_has_calls((
mock.call.debug(
"Copying '%s' -> '%s' from remote to local host",
'/etc/environment',
'/tmp/environment'),
mock.call.debug(
"Can't download %s because it doesn't exist",
'/etc/environment'
),
))
self.assertFalse(result)
logger.reset_mock()
result = ssh.download(destination=dst, target=target)
logger.assert_has_calls((
mock.call.debug(
"Copying '%s' -> '%s' from remote to local host",
'/etc/environment',
'/tmp/environment'),
mock.call.debug(
"Can't download %s because it is a directory",
'/etc/environment'
),
))
@mock.patch('devops.helpers.ssh_client.SSHClient.isdir') @mock.patch('devops.helpers.ssh_client.SSHClient.isdir')
@mock.patch('os.path.isdir', autospec=True) @mock.patch('os.path.isdir', autospec=True)
def test_upload_file( def test_upload_file(
self, isdir, remote_isdir, client, policy, logger self, isdir, remote_isdir, client, policy, logger
): ):
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
isdir.return_value = False isdir.return_value = False
remote_isdir.return_value = False remote_isdir.return_value = False
target = '/etc/environment' target = '/etc/environment'
@ -970,7 +918,7 @@ class TestSSHClient(TestCase):
isdir, remote_isdir, walk, mkdir, exists, isdir, remote_isdir, walk, mkdir, exists,
client, policy, logger client, policy, logger
): ):
ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) ssh, _sftp = self.prepare_sftp_file_tests(client)
isdir.return_value = True isdir.return_value = True
remote_isdir.return_value = True remote_isdir.return_value = True
exists.return_value = True exists.return_value = True

View File

@ -28,7 +28,7 @@ deps =
-r{toxinidir}/test-requirements.txt -r{toxinidir}/test-requirements.txt
commands = commands =
py.test --cov-config .coveragerc --cov-report html --cov=devops devops/tests py.test --cov-config .coveragerc --cov-report html --cov=devops devops/tests
coverage report --fail-under 70 coverage report --fail-under 73
[testenv:pep8] [testenv:pep8]