From 255de881facfc62c7764331acc08034d783999c9 Mon Sep 17 00:00:00 2001 From: Alexey Stepanov Date: Tue, 31 May 2016 11:03:56 +0300 Subject: [PATCH] SSHClient rework for SSH Manager integration: part 1 1. Methods optimization 2. Added docstrings Change-Id: Ia929e151cef1edefa5238a2019e57b365ce74da3 blueprint: sshmanager-integration --- devops/error.py | 4 +- devops/helpers/ssh_client.py | 260 +++++++--- devops/tests/helpers/test_ssh_client.py | 622 +++++++++++------------- tox.ini | 2 +- 4 files changed, 477 insertions(+), 411 deletions(-) diff --git a/devops/error.py b/devops/error.py index d1a93fa0..d30aa157 100644 --- a/devops/error.py +++ b/devops/error.py @@ -40,9 +40,9 @@ class DevopsCalledProcessError(DevopsError): expected=self.expected )) if self.stdout: - message += "\n\tSTDOUT: {}".format(self.stdout) + message += "\n\tSTDOUT:\n{}".format(self.stdout) if self.stderr: - message += "\n\tSTDERR: {}".format(self.stderr) + message += "\n\tSTDERR:\n{}".format(self.stderr) super(DevopsCalledProcessError, self).__init__(message) @property diff --git a/devops/helpers/ssh_client.py b/devops/helpers/ssh_client.py index 751b3121..bf1e14dc 100644 --- a/devops/helpers/ssh_client.py +++ b/devops/helpers/ssh_client.py @@ -17,8 +17,7 @@ import posixpath import stat import paramiko -# noinspection PyUnresolvedReferences -from six.moves import cStringIO +import six from devops.error import DevopsCalledProcessError from devops.helpers.retry import retry @@ -27,13 +26,14 @@ from devops import logger class SSHClient(object): class get_sudo(object): + """Context manager for call commands with sudo""" def __init__(self, ssh): self.ssh = ssh def __enter__(self): self.ssh.sudo_mode = True - def __exit__(self, exc_type, value, traceback): + def __exit__(self, exc_type, exc_val, exc_tb): self.ssh.sudo_mode = False def __init__(self, host, port=22, username=None, password=None, @@ -58,25 +58,10 @@ class SSHClient(object): def password(self): return self.__password - @password.setter - def password(self, new_val): - self.__password = new_val - self.reconnect() - @property def private_keys(self): return self.__private_keys - @private_keys.setter - def private_keys(self, new_val): - self.__private_keys = new_val - self.reconnect() - - @private_keys.deleter - def private_keys(self): - self.__private_keys = [] - self.reconnect() - @property def private_key(self): return self.__actual_pkey @@ -85,11 +70,15 @@ class SSHClient(object): def public_key(self): if self.private_key is None: return None - key = paramiko.RSAKey(file_obj=cStringIO(self.private_key)) + key = self.private_key return '{0} {1}'.format(key.get_name(), key.get_base64()) @property def _sftp(self): + """SFTP channel access for inheritance + + :raises: paramiko.SSHException + """ if self.__sftp is not None: return self.__sftp logger.warning('SFTP is not connected, try to reconnect') @@ -115,7 +104,7 @@ class SSHClient(object): def __enter__(self): return self - def __exit__(self, *err): + def __exit__(self, exc_type, exc_val, exc_tb): self.clear() @retry(count=3, delay=3) @@ -152,67 +141,162 @@ class SSHClient(object): self.connect() self._connect_sftp() - def check_call(self, command, verbose=False, excpected=0): + def check_call( + self, + command, verbose=False, + expected=None, raise_on_err=True): + """Execute command and check for return code + + :type command: str + :type verbose: bool + :type expected: list + :type raise_on_err: bool + :rtype: dict + :raises: DevopsCalledProcessError + """ + if expected is None: + expected = [0] ret = self.execute(command, verbose) - if ret['exit_code'] != excpected: - raise DevopsCalledProcessError( - command, ret['exit_code'], - expected=excpected, - stdout=ret['stdout_str'], - stderr=ret['stderr_str']) + if ret['exit_code'] not in expected: + message = ( + "Command '{cmd}' returned exit code {code} while " + "expected {expected}\n" + "\tSTDOUT:\n" + "{stdout}" + "\n\tSTDERR:\n" + "{stderr}".format( + cmd=command, + code=ret['exit_code'], + expected=expected, + stdout=ret['stdout_str'], + stderr=ret['stderr_str'] + )) + logger.error(message) + if raise_on_err: + raise DevopsCalledProcessError( + command, ret['exit_code'], + expected=expected, + stdout=ret['stdout_str'], + stderr=ret['stderr_str']) return ret - def check_stderr(self, command, verbose=False): - ret = self.check_call(command, verbose) + def check_stderr(self, command, verbose=False, raise_on_err=True): + """Execute command expecting return code 0 and empty STDERR + + :type command: str + :type verbose: bool + :type raise_on_err: bool + :rtype: dict + :raises: DevopsCalledProcessError + """ + ret = self.check_call(command, verbose, raise_on_err=raise_on_err) if ret['stderr']: - raise DevopsCalledProcessError(command, ret['exit_code'], - stdout=ret['stdout_str'], - stderr=ret['stderr_str']) + message = ( + "Command '{cmd}' STDERR while not expected\n" + "\texit code: {code}\n" + "\tSTDOUT:\n" + "{stdout}" + "\n\tSTDERR:\n" + "{stderr}".format( + cmd=command, + code=ret['exit_code'], + stdout=ret['stdout_str'], + stderr=ret['stderr_str'] + )) + logger.error(message) + if raise_on_err: + raise DevopsCalledProcessError(command, ret['exit_code'], + stdout=ret['stdout_str'], + stderr=ret['stderr_str']) return ret @classmethod - def execute_together(cls, remotes, command): + def execute_together( + cls, remotes, command, expected=None, raise_on_err=True): + """Execute command on multiple remotes in async mode + + :type remotes: list + :type command: str + :type expected: list + :type raise_on_err: bool + :raises: DevopsCalledProcessError + """ + if expected is None: + expected = [0] futures = {} errors = {} - for remote in remotes: + for remote in set(remotes): # Use distinct remotes chan, _, _, _ = remote.execute_async(command) futures[remote] = chan for remote, chan in futures.items(): ret = chan.recv_exit_status() chan.close() - if ret != 0: + if ret not in expected: errors[remote.host] = ret - if errors: + if errors and raise_on_err: raise DevopsCalledProcessError(command, errors) def execute(self, command, verbose=False): + """Execute command and wait for return code + + :type command: str + :type verbose: bool + :rtype: dict + """ chan, _, stderr, stdout = self.execute_async(command) + + # noinspection PyDictCreation result = { - 'stdout': [], - 'stderr': [], - 'exit_code': 0 + 'exit_code': chan.recv_exit_status() } - for line in stdout: - result['stdout'].append(line) - if verbose: - logger.info(line) - for line in stderr: - result['stderr'].append(line) - if verbose: - logger.info(line) - result['exit_code'] = chan.recv_exit_status() + result['stdout'] = stdout.readlines() + result['stderr'] = stderr.readlines() + chan.close() - result['stdout_str'] = ''.join(result['stdout']).strip() - result['stderr_str'] = ''.join(result['stderr']).strip() + + result['stdout_str'] = self._get_str_from_list(result['stdout']) + result['stderr_str'] = self._get_str_from_list(result['stderr']) + + if verbose: + logger.info( + '{cmd} execution results:\n' + 'Exit code: {code}\n' + 'STDOUT:\n' + '{stdout}\n' + 'STDERR:\n' + '{stderr}'.format( + cmd=command, + code=result['exit_code'], + stdout=result['stdout_str'], + stderr=result['stderr_str'] + )) + return result + @staticmethod + def _get_str_from_list(src): + """Join data in list to the string, with python 2&3 compatibility. + + :type src: list + :rtype: str + """ + if six.PY2: + return b''.join(src).strip() + else: + return b''.join(src).strip().decode(encoding='utf-8') + def execute_async(self, command): + """Execute command in async mode and return channel with IO objects + + :type command: str + :rtype: tuple + """ logger.debug("Executing command: '{}'".format(command.rstrip())) chan = self._ssh.get_transport().open_session() stdin = chan.makefile('wb') stdout = chan.makefile('rb') stderr = chan.makefile_stderr('rb') - cmd = "%s\n" % command + cmd = "{}\n".format(command) if self.sudo_mode: cmd = 'sudo -S bash -c "%s"' % cmd.replace('"', '\\"') chan.exec_command(cmd) @@ -225,7 +309,7 @@ class SSHClient(object): def execute_through_host( self, - target_host, + hostname, cmd, username=None, password=None, @@ -237,21 +321,15 @@ class SSHClient(object): key = self.private_key intermediate_channel = self._ssh.get_transport().open_channel( - 'direct-tcpip', (target_host, target_port), (self.host, 0)) - transport = paramiko.Transport(intermediate_channel) - transport.start_client() - logger.info("Passing authentication to: {}".format(target_host)) - if password is None and key is None: - logger.debug('auth_none') - transport.auth_none(username=username) - elif key is not None: - logger.debug('auth_publickey') - transport.auth_publickey(username=username, key=key) - else: - logger.debug('auth_password') - transport.auth_password(username=username, password=password) + kind='direct-tcpip', + dest_addr=(hostname, target_port), + src_addr=(self.host, 0)) + transport = paramiko.Transport(sock=intermediate_channel) - logger.debug("Opening session") + # start client and authenticate transport + transport.connect(username=username, password=password, pkey=key) + + # open ssh session channel = transport.open_session() # Make proxy objects for read @@ -266,29 +344,48 @@ class SSHClient(object): result = {} result['exit_code'] = channel.recv_exit_status() - result['stdout'] = stdout.read() - result['stderr'] = stderr.read() + result['stdout'] = stdout.readlines() + result['stderr'] = stderr.readlines() channel.close() - result['stdout_str'] = ''.join(result['stdout']).strip() - result['stderr_str'] = ''.join(result['stderr']).strip() + result['stdout_str'] = self._get_str_from_list(result['stdout']) + result['stderr_str'] = self._get_str_from_list(result['stderr']) return result def mkdir(self, path): + """run 'mkdir -p path' on remote + + :type path: str + """ if self.exists(path): return logger.debug("Creating directory: {}".format(path)) self.execute("mkdir -p {}\n".format(path)) def rm_rf(self, path): + """run 'rm -rf path' on remote + + :type path: str + """ logger.debug("rm -rf {}".format(path)) - self.execute("rm -rf %s" % path) + self.execute("rm -rf {}".format(path)) def open(self, path, mode='r'): + """Open file on remote using SFTP session + + :type path: str + :type mode: str + :return: file.open() stream + """ return self._sftp.open(path, mode) def upload(self, source, target): + """Upload file(s) from source to target using SFTP session + + :type source: str + :type target: str + """ logger.debug("Copying '%s' -> '%s'", source, target) if self.isdir(target): @@ -315,6 +412,12 @@ class SSHClient(object): self._sftp.put(local_path, remote_path) def download(self, destination, target): + """Download file(s) to target from destination + + :type destination: str + :type target: str + :rtype: bool + """ logger.debug( "Copying '%s' -> '%s' from remote to local host", destination, target @@ -337,6 +440,11 @@ class SSHClient(object): return os.path.exists(target) def exists(self, path): + """Check for file existence using SFTP session + + :type path: str + :rtype: bool + """ try: self._sftp.lstat(path) return True @@ -344,6 +452,11 @@ class SSHClient(object): return False def isfile(self, path): + """Check, that path is file using SFTP session + + :type path: str + :rtype: bool + """ try: attrs = self._sftp.lstat(path) return attrs.st_mode & stat.S_IFREG != 0 @@ -351,6 +464,11 @@ class SSHClient(object): return False def isdir(self, path): + """Check, that path is directory using SFTP session + + :type path: str + :rtype: bool + """ try: attrs = self._sftp.lstat(path) return attrs.st_mode & stat.S_IFDIR != 0 diff --git a/devops/tests/helpers/test_ssh_client.py b/devops/tests/helpers/test_ssh_client.py index 1e7dfeb7..5b22e0b2 100644 --- a/devops/tests/helpers/test_ssh_client.py +++ b/devops/tests/helpers/test_ssh_client.py @@ -14,7 +14,6 @@ # pylint: disable=no-self-use -from contextlib import closing from os.path import basename import posixpath import stat @@ -22,8 +21,7 @@ from unittest import TestCase import mock import paramiko -# noinspection PyUnresolvedReferences -from six.moves import cStringIO +from six import PY2 from devops.error import DevopsCalledProcessError from devops.helpers.ssh_client import SSHClient @@ -32,18 +30,14 @@ from devops.helpers.ssh_client import SSHClient def gen_private_keys(amount=1): keys = [] for _ in range(amount): - with closing(cStringIO()) as output: - paramiko.RSAKey.generate(1024).write_private_key(output) - keys.append(output.getvalue()) + keys.append(paramiko.RSAKey.generate(1024)) return keys def gen_public_key(private_key=None): if private_key is None: - key = paramiko.RSAKey.generate(1024) - else: - key = paramiko.RSAKey(file_obj=cStringIO(private_key)) - return '{0} {1}'.format(key.get_name(), key.get_base64()) + private_key = paramiko.RSAKey.generate(1024) + return '{0} {1}'.format(private_key.get_name(), private_key.get_base64()) host = '127.0.0.1' @@ -104,56 +98,6 @@ class TestSSHClient(TestCase): sftp = ssh._sftp self.assertEqual(sftp, client().open_sftp()) - def test_init_change_login_passwd(self, client, policy, logger): - _ssh = mock.call() - - def check_expected_on_connect(pwd): - client.assert_called_once() - policy.assert_called_once() - - expected_calls = [ - _ssh, - _ssh.set_missing_host_key_policy('AutoAddPolicy'), - _ssh.connect( - host, password=pwd, - port=port, username=username), - _ssh.open_sftp() - ] - - self.assertIn(expected_calls, client.mock_calls) - - self.check_defaults(ssh, host, port, username, pwd, - private_keys) - self.assertIsNone(ssh.private_key) - self.assertIsNone(ssh.public_key) - - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, pwd - )), - logger.mock_calls - ) - sftp = ssh._sftp - self.assertEqual(sftp, client().open_sftp()) - - ssh = SSHClient( - host=host, - port=port, - username=username, - password=password, - private_keys=private_keys) - - check_expected_on_connect(pwd=password) - - new_password = 'password2' - - client.reset_mock() - policy.reset_mock() - logger.reset_mock() - ssh.password = new_password - - check_expected_on_connect(pwd=new_password) - def test_init_keys(self, client, policy, logger): _ssh = mock.call() @@ -270,24 +214,24 @@ class TestSSHClient(TestCase): sftp = ssh._sftp self.assertEqual(sftp, _sftp) - def init_ssh(self, client, policy, logger): - ssh = SSHClient( + +@mock.patch('devops.helpers.ssh_client.logger', autospec=True) +@mock.patch( + 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') +@mock.patch('paramiko.SSHClient', autospec=True) +class TestExecute(TestCase): + @staticmethod + def get_ssh(): + """SSHClient object builder for execution tests + + :rtype: SSHClient + """ + return SSHClient( host=host, port=port, username=username, - password=password, - private_keys=private_keys) - - client.assert_called_once() - policy.assert_called_once() - - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, password - )), - logger.mock_calls - ) - return ssh + password=password + ) def test_execute_async(self, client, policy, logger): chan = mock.Mock() @@ -299,7 +243,7 @@ class TestSSHClient(TestCase): _ssh.attach_mock(get_transport, 'get_transport') client.return_value = _ssh - ssh = self.init_ssh(client, policy, logger) + ssh = self.get_ssh() result = ssh.execute_async(command=command) get_transport.assert_called_once() @@ -328,7 +272,7 @@ class TestSSHClient(TestCase): _ssh.attach_mock(get_transport, 'get_transport') client.return_value = _ssh - ssh = self.init_ssh(client, policy, logger) + ssh = self.get_ssh() ssh.sudo_mode = True result = ssh.execute_async(command=command) @@ -358,7 +302,7 @@ class TestSSHClient(TestCase): _ssh.attach_mock(get_transport, 'get_transport') client.return_value = _ssh - ssh = self.init_ssh(client, policy, logger) + ssh = self.get_ssh() self.assertFalse(ssh.sudo_mode) with SSHClient.get_sudo(ssh): self.assertTrue(ssh.sudo_mode) @@ -381,60 +325,150 @@ class TestSSHClient(TestCase): logger.mock_calls ) - @mock.patch( - 'devops.helpers.ssh_client.SSHClient.execute_async') - def test_execute(self, execute_async, client, policy, logger): - stderr = [' ', '0', '1', ' '] - stdout = [' ', '2', '3', ' '] - exit_code = 0 + def test_execute_async_sudo_password( + self, client, policy, logger): + stdin = mock.Mock(name='stdin') + stdout = mock.Mock(name='stdout') + stdout_channel = mock.Mock() + stdout_channel.configure_mock(closed=False) + stdout.attach_mock(stdout_channel, 'channel') + makefile = mock.Mock(side_effect=[stdin, stdout]) + chan = mock.Mock() + chan.attach_mock(makefile, 'makefile') + open_session = mock.Mock(return_value=chan) + transport = mock.Mock() + transport.attach_mock(open_session, 'open_session') + get_transport = mock.Mock(return_value=transport) + _ssh = mock.Mock() + _ssh.attach_mock(get_transport, 'get_transport') + client.return_value = _ssh + + ssh = self.get_ssh() + ssh.sudo_mode = True + + result = ssh.execute_async(command=command) + get_transport.assert_called_once() + open_session.assert_called_once() + # raise ValueError(closed.mock_calls) + stdin.assert_has_calls((mock.call.flush(), )) + + self.assertIn(chan, result) + chan.assert_has_calls(( + mock.call.makefile('wb'), + mock.call.makefile('rb'), + mock.call.makefile_stderr('rb'), + mock.call.exec_command('sudo -S bash -c "{}\n"'.format(command)) + )) + self.assertIn( + mock.call.debug( + "Executing command: '{}'".format(command.rstrip())), + logger.mock_calls + ) + + @staticmethod + def get_patched_execute_async_retval(ec=0, stderr_val=True): + stderr = mock.Mock() + stdout = mock.Mock() + + stderr_readlines = mock.Mock( + return_value=[b' \n', b'0\n', b'1\n', b' \n'] if stderr_val else [] + ) + stdout_readlines = mock.Mock( + return_value=[b' \n', b'2\n', b'3\n', b' \n']) + + stderr.attach_mock(stderr_readlines, 'readlines') + stdout.attach_mock(stdout_readlines, 'readlines') + + exit_code = ec chan = mock.Mock() recv_exit_status = mock.Mock(return_value=exit_code) chan.attach_mock(recv_exit_status, 'recv_exit_status') - execute_async.return_value = chan, '', stderr, stdout + return chan, '', stderr, stdout - ssh = self.init_ssh(client, policy, logger) + @mock.patch( + 'devops.helpers.ssh_client.SSHClient.execute_async') + def test_execute(self, execute_async, client, policy, logger): + chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval() + execute_async.return_value = chan, _stdin, stderr, stdout + + stderr_lst = stderr.readlines() + stdout_lst = stdout.readlines() + + expected = { + 'exit_code': chan.recv_exit_status(), + 'stderr': stderr_lst, + 'stdout': stdout_lst} + if PY2: + expected['stderr_str'] = b''.join(stderr_lst).strip() + expected['stdout_str'] = b''.join(stdout_lst).strip() + else: + expected['stderr_str'] = b''.join(stderr_lst).strip().decode( + encoding='utf-8') + expected['stdout_str'] = b''.join(stdout_lst).strip().decode( + encoding='utf-8') + + ssh = self.get_ssh() + + logger.reset_mock() + + result = ssh.execute(command=command, verbose=True) - result = ssh.execute(command=command) self.assertEqual( result, - { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), - 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} + expected ) execute_async.assert_called_once_with(command) chan.assert_has_calls(( mock.call.recv_exit_status(), mock.call.close())) + logger.assert_has_calls(( + mock.call.info( + '{cmd} execution results:\n' + 'Exit code: {code}\n' + 'STDOUT:\n' + '{stdout}\n' + 'STDERR:\n' + '{stderr}'.format( + cmd=command, + code=result['exit_code'], + stdout=result['stdout_str'], + stderr=result['stderr_str'] + )), + )) @mock.patch( 'devops.helpers.ssh_client.SSHClient.execute_async') def test_execute_together(self, execute_async, client, policy, logger): - stderr = [' ', '0', '1', ' '] - stdout = [' ', '2', '3', ' '] - exit_code = 0 - chan = mock.Mock() - recv_exit_status = mock.Mock(return_value=exit_code) - chan.attach_mock(recv_exit_status, 'recv_exit_status') - execute_async.return_value = chan, '', stderr, stdout + chan, _stdin, stderr, stdout = self.get_patched_execute_async_retval() + execute_async.return_value = chan, _stdin, stderr, stdout + + stderr_lst = stderr.readlines() + stdout_lst = stdout.readlines() + + expected = { + 'exit_code': chan.recv_exit_status(), + 'stderr': stderr_lst, + 'stdout': stdout_lst} + if PY2: + expected['stderr_str'] = b''.join(stderr_lst).strip() + expected['stdout_str'] = b''.join(stdout_lst).strip() + else: + expected['stderr_str'] = b''.join(stderr_lst).strip().decode( + encoding='utf-8') + expected['stdout_str'] = b''.join(stdout_lst).strip().decode( + encoding='utf-8') + host2 = '127.0.0.2' - ssh = SSHClient( - host=host, - port=port, - username=username, - password=password, - private_keys=private_keys) + ssh = self.get_ssh() ssh2 = SSHClient( host=host2, port=port, username=username, - password=password, - private_keys=private_keys) + password=password + ) - remotes = ssh, ssh2 + remotes = [ssh, ssh2] SSHClient.execute_together( remotes=remotes, command=command) @@ -447,35 +481,35 @@ class TestSSHClient(TestCase): mock.call.close() )) + SSHClient.execute_together( + remotes=remotes, command=command, expected=[1], raise_on_err=False) + + with self.assertRaises(DevopsCalledProcessError): + SSHClient.execute_together( + remotes=remotes, command=command, expected=[1]) + @mock.patch( 'devops.helpers.ssh_client.SSHClient.execute') def test_check_call(self, execute, client, policy, logger): - stderr = [' ', '0', '1', ' '] - stdout = [' ', '2', '3', ' '] exit_code = 0 return_value = { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), + 'stderr_str': '0\n1', + 'stdout_str': '2\n3', 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} + 'stderr': [b' \n', b'0\n', b'1\n', b' \n'], + 'stdout': [b' \n', b'2\n', b'3\n', b' \n']} execute.return_value = return_value verbose = False - ssh = self.init_ssh(client, policy, logger) + ssh = self.get_ssh() result = ssh.check_call(command=command, verbose=verbose) execute.assert_called_once_with(command, verbose) self.assertEqual(result, return_value) exit_code = 1 - return_value = { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), - 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} + return_value['exit_code'] = exit_code execute.reset_mock() execute.return_value = return_value with self.assertRaises(DevopsCalledProcessError): @@ -485,41 +519,47 @@ class TestSSHClient(TestCase): @mock.patch( 'devops.helpers.ssh_client.SSHClient.check_call') def test_check_stderr(self, check_call, client, policy, logger): - stdout = [' ', '0', '1', ' '] - stderr = [] - exit_code = 0 return_value = { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), - 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} + 'stderr_str': '', + 'stdout_str': '2\n3', + 'exit_code': 0, + 'stderr': [], + 'stdout': [b' \n', b'2\n', b'3\n', b' \n']} check_call.return_value = return_value verbose = False + raise_on_err = True - ssh = self.init_ssh(client, policy, logger) + ssh = self.get_ssh() - result = ssh.check_stderr(command=command, verbose=verbose) - check_call.assert_called_once_with(command, verbose) + result = ssh.check_stderr( + command=command, verbose=verbose, raise_on_err=raise_on_err) + check_call.assert_called_once_with( + command, verbose, raise_on_err=raise_on_err) self.assertEqual(result, return_value) - stderr = [' ', '2', '3', ' '] - return_value = { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), - 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} + return_value['stderr_str'] = '0\n1' + return_value['stderr'] = [b' \n', b'0\n', b'1\n', b' \n'] + check_call.reset_mock() check_call.return_value = return_value with self.assertRaises(DevopsCalledProcessError): - ssh.check_stderr(command=command, verbose=verbose) - check_call.assert_called_once_with(command, verbose) + ssh.check_stderr( + command=command, verbose=verbose, raise_on_err=raise_on_err) + check_call.assert_called_once_with( + command, verbose, raise_on_err=raise_on_err) - def prepare_execute_through_host( - self, transp, client, policy, logger): + +@mock.patch('devops.helpers.ssh_client.logger', autospec=True) +@mock.patch( + 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') +@mock.patch('paramiko.SSHClient', autospec=True) +@mock.patch('paramiko.Transport', autospec=True) +class TestExecuteThrowHost(TestCase): + @staticmethod + def prepare_execute_through_host(transp, client, exit_code): intermediate_channel = mock.Mock() + open_channel = mock.Mock(return_value=intermediate_channel) intermediate_transport = mock.Mock() intermediate_transport.attach_mock(open_channel, 'open_channel') @@ -532,22 +572,16 @@ class TestSSHClient(TestCase): transport = mock.Mock() transp.return_value = transport - stderr = [' ', '0', '1', ' '] - stdout = [' ', '2', '3', ' '] - exit_code = 0 - return_value = { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), - 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} - recv_exit_status = mock.Mock(return_value=exit_code) makefile = mock.Mock() - makefile.attach_mock(mock.Mock(return_value=stdout), 'read') + makefile.attach_mock(mock.Mock( + return_value=[b' \n', b'2\n', b'3\n', b' \n']), + 'readlines') makefile_stderr = mock.Mock() - makefile_stderr.attach_mock(mock.Mock(return_value=stderr), 'read') + makefile_stderr.attach_mock( + mock.Mock(return_value=[b' \n', b'0\n', b'1\n', b' \n']), + 'readlines') channel = mock.Mock() channel.attach_mock(mock.Mock(return_value=makefile), 'makefile') channel.attach_mock(mock.Mock( @@ -556,100 +590,34 @@ class TestSSHClient(TestCase): open_session = mock.Mock(return_value=channel) transport.attach_mock(open_session, 'open_session') - ssh = self.init_ssh(client, policy, logger) return ( - ssh, return_value, open_session, transport, channel, get_transport, + open_session, transport, channel, get_transport, open_channel, intermediate_channel ) - @mock.patch('paramiko.Transport', autospec=True) def test_execute_through_host_no_creds( self, transp, client, policy, logger): - target = '10.0.0.2' - - ( - ssh, return_value, open_session, transport, channel, get_transport, - open_channel, intermediate_channel - ) = self.prepare_execute_through_host(transp, client, policy, logger) - - result = ssh.execute_through_host(target, command) - self.assertEqual(result, return_value) - get_transport.assert_called_once() - open_channel.assert_called_once() - transp.assert_called_once_with(intermediate_channel) - open_session.assert_called_once() - transport.assert_has_calls(( - mock.call.start_client(), - mock.call.auth_password(username=username, password=password), - mock.call.open_session() - )) - channel.assert_has_calls(( - mock.call.makefile('rb'), - mock.call.makefile_stderr('rb'), - mock.call.exec_command('ls ~ '), - mock.call.recv_exit_status(), - mock.call.close() - )) - - @mock.patch('paramiko.Transport', autospec=True) - def test_execute_through_host_no_creds_key( - self, transp, client, policy, logger): - target = '10.0.0.2' - private_keys = gen_private_keys(1) - - intermediate_channel = mock.Mock() - open_channel = mock.Mock(return_value=intermediate_channel) - intermediate_transport = mock.Mock() - intermediate_transport.attach_mock(open_channel, 'open_channel') - get_transport = mock.Mock(return_value=intermediate_transport) - - _ssh = mock.Mock() - _ssh.attach_mock(get_transport, 'get_transport') - client.return_value = _ssh - - transport = mock.Mock() - transp.return_value = transport - - stderr = [' ', '0', '1', ' '] - stdout = [' ', '2', '3', ' '] + target = '127.0.0.2' exit_code = 0 return_value = { - 'stderr_str': ''.join(stderr).strip(), - 'stdout_str': ''.join(stdout).strip(), + 'stderr_str': '0\n1', + 'stdout_str': '2\n3', 'exit_code': exit_code, - 'stderr': stderr, - 'stdout': stdout} + 'stderr': [b' \n', b'0\n', b'1\n', b' \n'], + 'stdout': [b' \n', b'2\n', b'3\n', b' \n']} - recv_exit_status = mock.Mock(return_value=exit_code) - - makefile = mock.Mock() - makefile.attach_mock(mock.Mock(return_value=stdout), 'read') - makefile_stderr = mock.Mock() - makefile_stderr.attach_mock(mock.Mock(return_value=stderr), 'read') - channel = mock.Mock() - channel.attach_mock(mock.Mock(return_value=makefile), 'makefile') - channel.attach_mock(mock.Mock( - return_value=makefile_stderr), 'makefile_stderr') - channel.attach_mock(recv_exit_status, 'recv_exit_status') - open_session = mock.Mock(return_value=channel) - transport.attach_mock(open_session, 'open_session') + ( + open_session, transport, channel, get_transport, + open_channel, intermediate_channel + ) = self.prepare_execute_through_host( + transp, client, exit_code=exit_code) ssh = SSHClient( host=host, port=port, username=username, - password=password, - private_keys=private_keys) - - client.assert_called_once() - policy.assert_called_once() - - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, password - )), - logger.mock_calls - ) + password=password + ) result = ssh.execute_through_host(target, command) self.assertEqual(result, return_value) @@ -658,9 +626,7 @@ class TestSSHClient(TestCase): transp.assert_called_once_with(intermediate_channel) open_session.assert_called_once() transport.assert_has_calls(( - mock.call.start_client(), - mock.call.auth_publickey( - username=username, key=private_keys[0]), + mock.call.connect(username=username, password=password, pkey=None), mock.call.open_session() )) channel.assert_has_calls(( @@ -671,93 +637,43 @@ class TestSSHClient(TestCase): mock.call.close() )) - @mock.patch('paramiko.Transport', autospec=True) - def test_execute_through_host_password( + def test_execute_through_host_auth( self, transp, client, policy, logger): - target = '10.0.0.2' _login = 'cirros' _password = 'cubswin:)' - ( - ssh, return_value, open_session, transport, channel, get_transport, - open_channel, intermediate_channel - ) = self.prepare_execute_through_host(transp, client, policy, logger) - - result = ssh.execute_through_host( - target, command, username=_login, password=_password) - self.assertEqual(result, return_value) - get_transport.assert_called_once() - open_channel.assert_called_once() - transp.assert_called_once_with(intermediate_channel) - open_session.assert_called_once() - transport.assert_has_calls(( - mock.call.start_client(), - mock.call.auth_password(username=_login, password=_password), - mock.call.open_session() - )) - channel.assert_has_calls(( - mock.call.makefile('rb'), - mock.call.makefile_stderr('rb'), - mock.call.exec_command('ls ~ '), - mock.call.recv_exit_status(), - mock.call.close() - )) - - @mock.patch('paramiko.Transport', autospec=True) - def test_execute_through_host_no_auth( - self, transp, client, policy, logger): - target = '10.0.0.2' - _login = 'cirros' + target = '127.0.0.2' + exit_code = 0 + return_value = { + 'stderr_str': '0\n1', + 'stdout_str': '2\n3', + 'exit_code': exit_code, + 'stderr': [b' \n', b'0\n', b'1\n', b' \n'], + 'stdout': [b' \n', b'2\n', b'3\n', b' \n']} ( - ssh, return_value, open_session, transport, channel, get_transport, + open_session, transport, channel, get_transport, open_channel, intermediate_channel - ) = self.prepare_execute_through_host(transp, client, policy, logger) + ) = self.prepare_execute_through_host( + transp, client, exit_code=exit_code) - result = ssh.execute_through_host( - target, command, username=_login) - self.assertEqual(result, return_value) - get_transport.assert_called_once() - open_channel.assert_called_once() - transp.assert_called_once_with(intermediate_channel) - open_session.assert_called_once() - transport.assert_has_calls(( - mock.call.start_client(), - mock.call.auth_none(username=_login), - mock.call.open_session() - )) - channel.assert_has_calls(( - mock.call.makefile('rb'), - mock.call.makefile_stderr('rb'), - mock.call.exec_command('ls ~ '), - mock.call.recv_exit_status(), - mock.call.close() - )) - - @mock.patch('paramiko.Transport', autospec=True) - def test_execute_through_host_key( - self, transp, client, policy, logger): - target = '10.0.0.2' - _login = 'cirros' - _password = 'cubswin:)' - key = gen_private_keys(1)[0] - - ( - ssh, return_value, open_session, transport, channel, get_transport, - open_channel, intermediate_channel - ) = self.prepare_execute_through_host(transp, client, policy, logger) + ssh = SSHClient( + host=host, + port=port, + username=username, + password=password + ) result = ssh.execute_through_host( target, command, - username=_login, password=_password, key=key) + username=_login, password=_password) self.assertEqual(result, return_value) get_transport.assert_called_once() open_channel.assert_called_once() transp.assert_called_once_with(intermediate_channel) open_session.assert_called_once() transport.assert_has_calls(( - mock.call.start_client(), - mock.call.auth_publickey(username=_login, key=key), + mock.call.connect(username=_login, password=_password, pkey=None), mock.call.open_session() )) channel.assert_has_calls(( @@ -768,7 +684,14 @@ class TestSSHClient(TestCase): mock.call.close() )) - def prepare_sftp_file_tests(self, client, policy, logger): + +@mock.patch('devops.helpers.ssh_client.logger', autospec=True) +@mock.patch( + 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') +@mock.patch('paramiko.SSHClient', autospec=True) +class TestSftp(TestCase): + @staticmethod + def prepare_sftp_file_tests(client): _ssh = mock.Mock() client.return_value = _ssh _sftp = mock.Mock() @@ -779,24 +702,12 @@ class TestSSHClient(TestCase): host=host, port=port, username=username, - password=password, - private_keys=private_keys) - - client.assert_called_once() - policy.assert_called_once() - - self.check_defaults(ssh, host, port, username, password, private_keys) - - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, password - )), - logger.mock_calls - ) + password=password + ) return ssh, _sftp def test_exists(self, client, policy, logger): - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) lstat = mock.Mock() _sftp.attach_mock(lstat, 'lstat') path = '/etc' @@ -818,7 +729,7 @@ class TestSSHClient(TestCase): def __init__(self, mode): self.st_mode = mode - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) lstat = mock.Mock() _sftp.attach_mock(lstat, 'lstat') lstat.return_value = Attrs(stat.S_IFREG) @@ -848,7 +759,7 @@ class TestSSHClient(TestCase): def __init__(self, mode): self.st_mode = mode - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) lstat = mock.Mock() _sftp.attach_mock(lstat, 'lstat') lstat.return_value = Attrs(stat.S_IFDIR) @@ -875,11 +786,16 @@ class TestSSHClient(TestCase): @mock.patch('devops.helpers.ssh_client.SSHClient.exists') @mock.patch('devops.helpers.ssh_client.SSHClient.execute') def test_mkdir(self, execute, exists, client, policy, logger): - exists.return_value = False + exists.side_effect = [False, True] path = '~/tst' - ssh = self.init_ssh(client, policy, logger) + ssh = SSHClient( + host=host, + port=port, + username=username, + password=password + ) # Path not exists ssh.mkdir(path) @@ -888,7 +804,6 @@ class TestSSHClient(TestCase): # Path exists exists.reset_mock() - exists.return_value = True execute.reset_mock() ssh.mkdir(path) @@ -899,14 +814,19 @@ class TestSSHClient(TestCase): def test_rm_rf(self, execute, client, policy, logger): path = '~/tst' - ssh = self.init_ssh(client, policy, logger) + ssh = SSHClient( + host=host, + port=port, + username=username, + password=password + ) # Path not exists ssh.rm_rf(path) execute.assert_called_once_with("rm -rf {}".format(path)) def test_open(self, client, policy, logger): - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) fopen = mock.Mock(return_value=True) _sftp.attach_mock(fopen, 'open') @@ -924,11 +844,11 @@ class TestSSHClient(TestCase): self, isdir, remote_isdir, exists, remote_exists, client, policy, logger ): - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) isdir.return_value = True - exists.return_value = True - remote_isdir.return_value = False - remote_exists.return_value = True + exists.side_effect = [True, False, False] + remote_isdir.side_effect = [False, False, True] + remote_exists.side_effect = [True, False, False] dst = '/etc/environment' target = '/tmp/environment' @@ -942,12 +862,40 @@ class TestSSHClient(TestCase): mock.call.get(dst, posixpath.join(target, basename(dst))), )) + # Negative scenarios + logger.reset_mock() + result = ssh.download(destination=dst, target=target) + logger.assert_has_calls(( + mock.call.debug( + "Copying '%s' -> '%s' from remote to local host", + '/etc/environment', + '/tmp/environment'), + mock.call.debug( + "Can't download %s because it doesn't exist", + '/etc/environment' + ), + )) + self.assertFalse(result) + + logger.reset_mock() + result = ssh.download(destination=dst, target=target) + logger.assert_has_calls(( + mock.call.debug( + "Copying '%s' -> '%s' from remote to local host", + '/etc/environment', + '/tmp/environment'), + mock.call.debug( + "Can't download %s because it is a directory", + '/etc/environment' + ), + )) + @mock.patch('devops.helpers.ssh_client.SSHClient.isdir') @mock.patch('os.path.isdir', autospec=True) def test_upload_file( self, isdir, remote_isdir, client, policy, logger ): - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) isdir.return_value = False remote_isdir.return_value = False target = '/etc/environment' @@ -970,7 +918,7 @@ class TestSSHClient(TestCase): isdir, remote_isdir, walk, mkdir, exists, client, policy, logger ): - ssh, _sftp = self.prepare_sftp_file_tests(client, policy, logger) + ssh, _sftp = self.prepare_sftp_file_tests(client) isdir.return_value = True remote_isdir.return_value = True exists.return_value = True diff --git a/tox.ini b/tox.ini index 5d820337..e583e83d 100644 --- a/tox.ini +++ b/tox.ini @@ -28,7 +28,7 @@ deps = -r{toxinidir}/test-requirements.txt commands = py.test --cov-config .coveragerc --cov-report html --cov=devops devops/tests - coverage report --fail-under 70 + coverage report --fail-under 73 [testenv:pep8]