From aae58e25dacca47b01947b384a84464a48a091e1 Mon Sep 17 00:00:00 2001 From: Alexey Stepanov Date: Fri, 27 May 2016 19:44:50 +0300 Subject: [PATCH] SSHClient rework for SSH Manager integration 1. Implemented SSHAuth 2. Old initialization API has been marked as deprecated 3. SFTP is started on demand with 3 retries 4. Reworked unit test to cover 100% 5. Added docstrings 6. Remove cyclic SSH session initialization in helper 7. Code is ready for adopted memorize pattern blueprint: sshmanager-integration Change-Id: I49d0aa635ba3f3125ab17531c0790a0106b87fea --- .coveragerc | 2 +- devops/helpers/helpers.py | 16 +- devops/helpers/ssh_client.py | 380 ++++++++--- devops/models/environment.py | 29 +- devops/models/node.py | 6 +- devops/tests/helpers/test_helpers.py | 12 +- devops/tests/helpers/test_ssh_client.py | 850 +++++++++++++++++++----- 7 files changed, 1044 insertions(+), 251 deletions(-) diff --git a/.coveragerc b/.coveragerc index 60dcd04e..22748a24 100644 --- a/.coveragerc +++ b/.coveragerc @@ -6,4 +6,4 @@ omit = devops/migrations/* devops/driver/dummy/* devops/settings.py - devops/test_settings.py + devops/test_settings.py \ No newline at end of file diff --git a/devops/helpers/helpers.py b/devops/helpers/helpers.py index 283fca92..74d70374 100644 --- a/devops/helpers/helpers.py +++ b/devops/helpers/helpers.py @@ -37,6 +37,7 @@ from six.moves import xmlrpc_client from devops.error import AuthenticationError from devops.error import DevopsError from devops.error import TimeoutError +from devops.helpers.ssh_client import SSHAuth from devops.helpers.ssh_client import SSHClient from devops import logger from devops.settings import KEYSTONE_CREDS @@ -151,8 +152,9 @@ def wait_ssh_cmd( password=SSH_CREDENTIALS['password'], timeout=0): ssh_client = SSHClient(host=host, port=port, - username=username, - password=password) + auth=SSHAuth( + username=username, + password=password)) wait(lambda: not ssh_client.execute(check_cmd)['exit_code'], timeout=timeout) @@ -198,10 +200,12 @@ def get_node_remote(env, node_name, login=SSH_SLAVE_CREDENTIALS['login'], name=node_name).interfaces[0].mac_address) wait(lambda: tcp_ping(ip, 22), timeout=180, timeout_msg="Node {ip} is not accessible by SSH.".format(ip=ip)) - return SSHClient(ip, - username=login, - password=password, - private_keys=get_private_keys(env)) + return SSHClient( + ip, + auth=SSHAuth( + username=login, + password=password, + keys=get_private_keys(env))) def get_admin_ip(env): diff --git a/devops/helpers/ssh_client.py b/devops/helpers/ssh_client.py index bf1e14dc..d977f3dc 100644 --- a/devops/helpers/ssh_client.py +++ b/devops/helpers/ssh_client.py @@ -15,6 +15,7 @@ import os import posixpath import stat +from warnings import warn import paramiko import six @@ -24,7 +25,163 @@ from devops.helpers.retry import retry from devops import logger +class SSHAuth(object): + __slots__ = ['__username', '__password', '__key', '__keys'] + + def __init__( + self, + username=None, password=None, key=None, keys=None): + """SSH authorisation object + + Used to authorize SSHClient. + Single SSHAuth object is associated with single host:port. + Password and key is private, other data is read-only. + + :type username: str + :type password: str + :type key: paramiko.RSAKey + :type keys: list + """ + self.__username = username + self.__password = password + self.__key = key + self.__keys = [None] + if key is not None: + self.__keys.append(key) + if keys is not None: + for key in keys: + if key not in self.__keys: + self.__keys.append(key) + + @property + def username(self): + """Username for auth + + :rtype: str + """ + return self.__username + + @staticmethod + def __get_public_key(key): + """Internal method for get public key from private + + :type key: paramiko.RSAKey + """ + if key is None: + return None + return '{0} {1}'.format(key.get_name(), key.get_base64()) + + @property + def public_key(self): + """public key for stored private key if presents else None + + :rtype: str + """ + return self.__get_public_key(self.__key) + + def enter_password(self, tgt): + """Enter password to STDIN + + Note: required for 'sudo' call + + :type tgt: file + :rtype: str + """ + return tgt.write('{}\n'.format(self.__password)) + + def connect(self, client, hostname=None, port=22, log=True): + """Connect SSH client object using credentials + + :type client: + paramiko.client.SSHClient + paramiko.transport.Transport + :type log: bool + :raises paramiko.AuthenticationException + """ + kwargs = { + 'username': self.username, + 'password': self.__password} + if hostname is not None: + kwargs['hostname'] = hostname + kwargs['port'] = port + + keys = [self.__key] + keys.extend([k for k in self.__keys if k != self.__key]) + + for key in keys: + kwargs['pkey'] = key + try: + client.connect(**kwargs) + if self.__key != key: + self.__key = key + logger.debug( + 'Main key has been updated, public key is: \n' + '{}'.format(self.public_key)) + return + except paramiko.PasswordRequiredException: + if self.__password is None: + logger.exception('No password has been set!') + raise + else: + logger.critical( + 'Unexpected PasswordRequiredException, ' + 'when password is set!') + raise + except paramiko.AuthenticationException: + continue + msg = 'Connection using stored authentication info failed!' + if log: + logger.exception( + 'Connection using stored authentication info failed!') + raise paramiko.AuthenticationException(msg) + + def __hash__(self): + return hash(( + self.__class__, + self.username, + self.__password, + tuple(self.__keys) + )) + + def __eq__(self, other): + return hash(self) == hash(other) + + def __repr__(self): + _key = ( + None if self.__key is None else + ''.format(self.public_key) + ) + _keys = [] + for k in self.__keys: + if k == self.__key: + continue + _keys.append( + ''.format( + self.__get_public_key(key=k)) if k is not None else None) + + return ( + '{cls}(username={username}, ' + 'password=<*masked*>, key={key}, keys={keys})'.format( + cls=self.__class__.__name__, + username=self.username, + key=_key, + keys=_keys) + ) + + def __str__(self): + return ( + '{cls} for {username}'.format( + cls=self.__class__.__name__, + username=self.username, + ) + ) + + class SSHClient(object): + __slots__ = [ + '__hostname', '__port', '__auth', '__ssh', '__sftp', 'sudo_mode' + ] + class get_sudo(object): """Context manager for call commands with sudo""" def __init__(self, ssh): @@ -36,42 +193,128 @@ class SSHClient(object): def __exit__(self, exc_type, exc_val, exc_tb): self.ssh.sudo_mode = False - def __init__(self, host, port=22, username=None, password=None, - private_keys=None): - self.host = str(host) - self.port = int(port) - self.username = username - self.__password = password - if not private_keys: - private_keys = [] - self.__private_keys = private_keys - self.__actual_pkey = None + def __hash__(self): + return hash(( + self.__class__, + self.hostname, + self.port, + self.auth)) + + def __init__( + self, + host, port=22, + username=None, password=None, private_keys=None, + auth=None + ): + """SSHClient helper + + :type host: str + :type port: int + :type username: str + :type password: str + :type private_keys: list + :type auth: SSHAuth + """ + self.__hostname = host + self.__port = port self.sudo_mode = False - self.sudo = self.get_sudo(self) - self._ssh = None + self.__ssh = paramiko.SSHClient() + self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.__sftp = None - self.reconnect() + self.__auth = auth + + if auth is None: + msg = ( + 'SSHClient(host={host}, port={port}, username={username}): ' + 'initialization by username/password/private_keys ' + 'is deprecated in favor of SSHAuth usage. ' + 'Please update your code'.format( + host=host, port=port, username=username + )) + warn(msg, DeprecationWarning) + logger.debug(msg) + + self.__auth = SSHAuth( + username=username, + password=password, + keys=private_keys + ) + + self.__connect() + if auth is None: + logger.info( + '{0}:{1}> SSHAuth was made from old style creds: ' + '{2}'.format(self.hostname, self.port, self.auth)) @property - def password(self): - return self.__password + def auth(self): + """Internal authorisation object + + Attention: this public property is mainly for inheritance, + debug and information purposes. + Calls outside SSHClient and child classes is sign of incorrect design. + Change is completely disallowed. + + :rtype: SSHAuth + """ + return self.__auth @property - def private_keys(self): - return self.__private_keys + def hostname(self): + """Connected remote host name + + :rtype: str + """ + return self.__hostname @property - def private_key(self): - return self.__actual_pkey + def port(self): + """Connected remote port number + + :rtype: int + """ + return self.__port + + def __repr__(self): + return '{cls}(host={host}, port={port}, auth={auth!r})'.format( + cls=self.__class__.__name__, host=self.hostname, port=self.port, + auth=self.auth + ) + + def __str__(self): + return '{cls}(host={host}, port={port}) for user {user}'.format( + cls=self.__class__.__name__, host=self.hostname, port=self.port, + user=self.auth.username + ) @property - def public_key(self): - if self.private_key is None: - return None - key = self.private_key - return '{0} {1}'.format(key.get_name(), key.get_base64()) + def _ssh(self): + """ssh client object getter for inheritance support only + + Attention: ssh client object creation and change + is allowed only by __init__ and reconnect call. + + :rtype: paramiko.SSHClient + """ + return self.__ssh + + @retry(count=3, delay=3) + def __connect(self): + """Main method for connection open""" + self.auth.connect( + client=self.__ssh, + hostname=self.hostname, port=self.port, + log=True) + + @retry(3, delay=0) + def __connect_sftp(self): + """SFTP connection opener""" + try: + self.__sftp = self.__ssh.open_sftp() + except paramiko.SSHException: + logger.warning('SFTP enable failed! SSH only is accessible.') @property def _sftp(self): @@ -81,25 +324,28 @@ class SSHClient(object): """ if self.__sftp is not None: return self.__sftp - logger.warning('SFTP is not connected, try to reconnect') - self._connect_sftp() + logger.debug('SFTP is not connected, try to connect...') + self.__connect_sftp() if self.__sftp is not None: return self.__sftp raise paramiko.SSHException('SFTP connection failed') def clear(self): - if self.__sftp is not None: - try: - self.__sftp.close() - except Exception: - logger.exception("Could not close sftp connection") + """Clear SSH and SFTP sessions""" try: - self._ssh.close() + self.__ssh.close() + self.__sftp = None except Exception: logger.exception("Could not close ssh connection") + if self.__sftp is not None: + try: + self.__sftp.close() + except Exception: + logger.exception("Could not close sftp connection") def __del__(self): - self.clear() + self.__ssh.close() + self.__sftp = None def __enter__(self): return self @@ -107,39 +353,14 @@ class SSHClient(object): def __exit__(self, exc_type, exc_val, exc_tb): self.clear() - @retry(count=3, delay=3) - def connect(self): - logger.debug( - "Connect to '{0}:{1}' as '{2}:{3}'".format( - self.host, self.port, self.username, self.password)) - for private_key in self.private_keys: - try: - self._ssh.connect( - self.host, port=self.port, username=self.username, - password=self.password, pkey=private_key) - self.__actual_pkey = private_key - return - except paramiko.AuthenticationException: - continue - if self.private_keys: - logger.error("Authentication with keys failed") - - self.__actual_pkey = None - self._ssh.connect( - self.host, port=self.port, username=self.username, - password=self.password) - - def _connect_sftp(self): - try: - self.__sftp = self._ssh.open_sftp() - except paramiko.SSHException: - logger.warning('SFTP enable failed! SSH only is accessible.') - def reconnect(self): - self._ssh = paramiko.SSHClient() - self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - self.connect() - self._connect_sftp() + """Reconnect SSH and SFTP session""" + self.clear() + + self.__ssh = paramiko.SSHClient() + self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + self.__connect() def check_call( self, @@ -232,7 +453,7 @@ class SSHClient(object): ret = chan.recv_exit_status() chan.close() if ret not in expected: - errors[remote.host] = ret + errors[remote.hostname] = ret if errors and raise_on_err: raise DevopsCalledProcessError(command, errors) @@ -301,7 +522,7 @@ class SSHClient(object): cmd = 'sudo -S bash -c "%s"' % cmd.replace('"', '\\"') chan.exec_command(cmd) if stdout.channel.closed is False: - stdin.write('%s\n' % self.password) + self.auth.enter_password(stdin) stdin.flush() else: chan.exec_command(cmd) @@ -311,23 +532,27 @@ class SSHClient(object): self, hostname, cmd, - username=None, - password=None, - key=None, + auth=None, target_port=22): - if username is None and password is None and key is None: - username = self.username - password = self.__password - key = self.private_key + """Execute command on remote host through currently connected host + + :type hostname: str + :type cmd: str + :type auth: SSHAuth + :type target_port: int + :rtype: dict + """ + if auth is None: + auth = self.auth intermediate_channel = self._ssh.get_transport().open_channel( kind='direct-tcpip', dest_addr=(hostname, target_port), - src_addr=(self.host, 0)) + src_addr=(self.hostname, 0)) transport = paramiko.Transport(sock=intermediate_channel) # start client and authenticate transport - transport.connect(username=username, password=password, pkey=key) + auth.connect(transport) # open ssh session channel = transport.open_session() @@ -336,7 +561,6 @@ class SSHClient(object): stdout = channel.makefile('rb') stderr = channel.makefile_stderr('rb') - logger.info("Executing command: {}".format(cmd)) channel.exec_command(cmd) # TODO(astepanov): make a logic for controlling channel state @@ -474,3 +698,5 @@ class SSHClient(object): return attrs.st_mode & stat.S_IFDIR != 0 except IOError: return False + +__all__ = ['SSHAuth', 'SSHClient'] diff --git a/devops/models/environment.py b/devops/models/environment.py index d19c5887..8901a2ef 100644 --- a/devops/models/environment.py +++ b/devops/models/environment.py @@ -26,6 +26,7 @@ from devops.error import DevopsEnvironmentError from devops.error import DevopsError from devops.error import DevopsObjNotFound from devops.helpers.network import IpNetworksPool +from devops.helpers.ssh_client import SSHAuth from devops.helpers.ssh_client import SSHClient from devops.helpers.templates import create_devops_config from devops.helpers.templates import get_devops_config @@ -362,30 +363,36 @@ class Environment(BaseModel): key=lambda node: node.name )[0] return admin.remote( - self.admin_net, - login=login, - password=password) + self.admin_net, auth=SSHAuth( + username=login, + password=password)) # LEGACY, for fuel-qa compatibility # @logwrap def get_ssh_to_remote(self, ip, login=settings.SSH_SLAVE_CREDENTIALS['login'], password=settings.SSH_SLAVE_CREDENTIALS['password']): + warn('LEGACY, for fuel-qa compatibility', DeprecationWarning) keys = [] + remote = self.get_admin_remote() for key_string in ['/root/.ssh/id_rsa', '/root/.ssh/bootstrap.rsa']: - if self.get_admin_remote().isfile(key_string): - with self.get_admin_remote().open(key_string) as f: + if remote.isfile(key_string): + with remote.open(key_string) as f: keys.append(RSAKey.from_private_key(f)) - return SSHClient(ip, - username=login, - password=password, - private_keys=keys) + return SSHClient( + ip, + auth=SSHAuth( + username=login, + password=password, + keys=keys)) # LEGACY, for fuel-qa compatibility # @logwrap - def get_ssh_to_remote_by_key(self, ip, keyfile): + @staticmethod + def get_ssh_to_remote_by_key(ip, keyfile): + warn('LEGACY, for fuel-qa compatibility', DeprecationWarning) try: with open(keyfile) as f: keys = [RSAKey.from_private_key(f)] @@ -393,7 +400,7 @@ class Environment(BaseModel): logger.warning('Loading of SSH key from file failed. Trying to use' ' SSH agent ...') keys = Agent().get_keys() - return SSHClient(ip, private_keys=keys) + return SSHClient(ip, auth=SSHAuth(keys=keys)) # LEGACY, TO REMOVE (for fuel-qa compatibility) def nodes(self): # migrated from EnvironmentModel.nodes() diff --git a/devops/models/node.py b/devops/models/node.py index 543f2cbd..f2dc2d22 100644 --- a/devops/models/node.py +++ b/devops/models/node.py @@ -270,7 +270,9 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)): interface = self.get_interface_by_nailgun_network_name(name) return interface.address_set.first().ip_address - def remote(self, network_name, login, password=None, private_keys=None): + def remote( + self, network_name, login=None, password=None, private_keys=None, + auth=None): """Create SSH-connection to the network :rtype : SSHClient @@ -278,7 +280,7 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)): return SSHClient( self.get_ip_address_by_network_name(network_name), username=login, - password=password, private_keys=private_keys) + password=password, private_keys=private_keys, auth=auth) def await(self, network_name, timeout=120, by_port=22): wait_pass( diff --git a/devops/tests/helpers/test_helpers.py b/devops/tests/helpers/test_helpers.py index f10283ba..e711f245 100644 --- a/devops/tests/helpers/test_helpers.py +++ b/devops/tests/helpers/test_helpers.py @@ -28,6 +28,7 @@ from six.moves import xrange from devops import error from devops.helpers import helpers +from devops.helpers.ssh_client import SSHAuth class TestHelpersHelpers(unittest.TestCase): @@ -198,7 +199,8 @@ class TestHelpersHelpers(unittest.TestCase): helpers.wait_ssh_cmd( host, port, check_cmd, username, password, timeout) ssh.assert_called_once_with( - host=host, port=port, username=username, password=password + host=host, port=port, + auth=SSHAuth(username=username, password=password) ) wait.assert_called_once() # Todo: cover ssh_client.execute @@ -292,13 +294,12 @@ class TestHelpersHelpers(unittest.TestCase): raise Exception() uri = 'http://127.0.0.1' - srv.return_value = Success() + srv.side_effect = [Success(), Success(), Fail()] result = helpers.xmlrpcmethod(uri, 'success') self.assertTrue(result) srv.assert_called_once_with(uri) srv.reset_mock() - srv.return_value = Success() self.assertRaises( AttributeError, helpers.xmlrpcmethod, @@ -307,7 +308,6 @@ class TestHelpersHelpers(unittest.TestCase): srv.assert_called_once_with(uri) srv.reset_mock() - srv.return_value = Fail() self.assertRaises( AttributeError, helpers.xmlrpcmethod, @@ -323,11 +323,15 @@ class TestHelpersHelpers(unittest.TestCase): rand.assert_called_once_with(5) def test_deepgetattr(self): + # pylint: disable=attribute-defined-outside-init class Tst(object): one = 1 + tst = Tst() tst2 = Tst() tst2.two = Tst() + # pylint: enable=attribute-defined-outside-init + result = helpers.deepgetattr(tst, 'one') self.assertEqual(result, 1) result = helpers.deepgetattr(tst2, 'two.one') diff --git a/devops/tests/helpers/test_ssh_client.py b/devops/tests/helpers/test_ssh_client.py index 5b22e0b2..0d05e296 100644 --- a/devops/tests/helpers/test_ssh_client.py +++ b/devops/tests/helpers/test_ssh_client.py @@ -14,6 +14,7 @@ # pylint: disable=no-self-use +from contextlib import closing from os.path import basename import posixpath import stat @@ -21,9 +22,12 @@ from unittest import TestCase import mock import paramiko +# noinspection PyUnresolvedReferences +from six.moves import cStringIO from six import PY2 from devops.error import DevopsCalledProcessError +from devops.helpers.ssh_client import SSHAuth from devops.helpers.ssh_client import SSHClient @@ -48,20 +52,136 @@ private_keys = [] command = 'ls ~ ' +class TestSSHAuth(TestCase): + def init_checks(self, username=None, password=None, key=None, keys=None): + """shared positive init checks + + :type username: str + :type password: str + :type key: paramiko.RSAKey + :type keys: list + """ + auth = SSHAuth( + username=username, + password=password, + key=key, + keys=keys + ) + + int_keys = [None] + if key is not None: + int_keys.append(key) + if keys is not None: + for k in keys: + if k not in int_keys: + int_keys.append(k) + + self.assertEqual(auth.username, username) + with closing(cStringIO()) as tgt: + auth.enter_password(tgt) + self.assertEqual(tgt.getvalue(), '{}\n'.format(password)) + self.assertEqual( + auth.public_key, + gen_public_key(key) if key is not None else None) + + _key = ( + None if auth.public_key is None else + ''.format(auth.public_key) + ) + _keys = [] + for k in int_keys: + if k == key: + continue + _keys.append( + ''.format( + gen_public_key(k)) if k is not None else None) + + self.assertEqual( + repr(auth), + "{cls}(" + "username={username}, " + "password=<*masked*>, " + "key={key}, " + "keys={keys})".format( + cls=SSHAuth.__name__, + username=auth.username, + key=_key, + keys=_keys + ) + ) + self.assertEqual( + str(auth), + '{cls} for {username}'.format( + cls=SSHAuth.__name__, + username=auth.username, + ) + ) + + def test_init_username_only(self): + self.init_checks( + username=username + ) + + def test_init_username_password(self): + self.init_checks( + username=username, + password=password + ) + + def test_init_username_key(self): + self.init_checks( + username=username, + key=gen_private_keys(1).pop() + ) + + def test_init_username_password_key(self): + self.init_checks( + username=username, + password=password, + key=gen_private_keys(1).pop() + ) + + def test_init_username_password_keys(self): + self.init_checks( + username=username, + password=password, + keys=gen_private_keys(2) + ) + + def test_init_username_password_key_keys(self): + self.init_checks( + username=username, + password=password, + key=gen_private_keys(1).pop(), + keys=gen_private_keys(2) + ) + + +@mock.patch('devops.helpers.retry.sleep', autospec=True) @mock.patch('devops.helpers.ssh_client.logger', autospec=True) @mock.patch( 'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy') @mock.patch('paramiko.SSHClient', autospec=True) -class TestSSHClient(TestCase): - def check_defaults( - self, obj, host, port, username, password, private_keys): - self.assertEqual(obj.host, host) - self.assertEqual(obj.port, port) - self.assertEqual(obj.username, username) - self.assertEqual(obj.password, password) - self.assertEqual(obj.private_keys, private_keys) +class TestSSHClientInit(TestCase): + def init_checks( + self, + client, policy, logger, + host=None, port=22, + username=None, password=None, private_keys=None, + auth=None + ): + """shared checks for positive cases - def test_init_passwd(self, client, policy, logger): + :type client: mock.Mock + :type policy: mock.Mock + :type logger: mock.Mock + :type host: str + :type port: int + :type username: str + :type password: str + :type private_keys: list + :type auth: SSHAuth + """ _ssh = mock.call() ssh = SSHClient( @@ -69,150 +189,572 @@ class TestSSHClient(TestCase): port=port, username=username, password=password, - private_keys=private_keys) - + private_keys=private_keys, + auth=auth + ) client.assert_called_once() policy.assert_called_once() - expected_calls = [ - _ssh, - _ssh.set_missing_host_key_policy('AutoAddPolicy'), - _ssh.connect( - host, password=password, - port=port, username=username), - _ssh.open_sftp() - ] + if auth is None: + if private_keys is None or len(private_keys) == 0: + logger.assert_has_calls(( + mock.call.debug( + 'SSHClient(' + 'host={host}, port={port}, username={username}): ' + 'initialization by username/password/private_keys ' + 'is deprecated in favor of SSHAuth usage. ' + 'Please update your code'.format( + host=host, port=port, username=username + )), + mock.call.info( + '{0}:{1}> SSHAuth was made from old style creds: ' + 'SSHAuth for {2}'.format(host, port, username)) + )) + else: + logger.assert_has_calls(( + mock.call.debug( + 'SSHClient(' + 'host={host}, port={port}, username={username}): ' + 'initialization by username/password/private_keys ' + 'is deprecated in favor of SSHAuth usage. ' + 'Please update your code'.format( + host=host, port=port, username=username + )), + mock.call.debug( + 'Main key has been updated, public key is: \n' + '{}'.format(ssh.auth.public_key)), + mock.call.info( + '{0}:{1}> SSHAuth was made from old style creds: ' + 'SSHAuth for {2}'.format(host, port, username)) + )) + else: + logger.assert_not_called() - self.assertIn(expected_calls, client.mock_calls) - - self.check_defaults(ssh, host, port, username, password, private_keys) - self.assertIsNone(ssh.private_key) - self.assertIsNone(ssh.public_key) - - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, password - )), - logger.mock_calls - ) - sftp = ssh._sftp - self.assertEqual(sftp, client().open_sftp()) - - def test_init_keys(self, client, policy, logger): - _ssh = mock.call() - - private_keys = gen_private_keys(1) - - ssh = SSHClient( - host=host, - port=port, - username=username, - password=password, - private_keys=private_keys) - - client.assert_called_once() - policy.assert_called_once() - - expected_calls = [ - _ssh, - _ssh.set_missing_host_key_policy('AutoAddPolicy'), - _ssh.connect( - host, password=password, pkey=private_keys[0], - port=port, username=username), - _ssh.open_sftp() - ] - - self.assertIn(expected_calls, client.mock_calls) - - self.check_defaults(ssh, host, port, username, password, private_keys) - self.assertEqual(ssh.private_key, private_keys[0]) - self.assertEqual(ssh.public_key, gen_public_key(private_keys[0])) - - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, password - )), - logger.mock_calls - ) - - def test_init_as_context(self, client, policy, logger): - _ssh = mock.call() - - private_keys = gen_private_keys(1) - - with SSHClient( - host=host, - port=port, - username=username, - password=password, - private_keys=private_keys) as ssh: - - client.assert_called_once() - policy.assert_called_once() - - expected_calls = [ - _ssh, - _ssh.set_missing_host_key_policy('AutoAddPolicy'), - _ssh.connect( - host, password=password, pkey=private_keys[0], - port=port, username=username), - _ssh.open_sftp() - ] + if auth is None: + if private_keys is None or len(private_keys) == 0: + pkey = None + expected_calls = [ + _ssh, + _ssh.set_missing_host_key_policy('AutoAddPolicy'), + _ssh.connect( + hostname=host, password=password, + pkey=pkey, + port=port, username=username), + ] + else: + pkey = private_keys[0] + expected_calls = [ + _ssh, + _ssh.set_missing_host_key_policy('AutoAddPolicy'), + _ssh.connect( + hostname=host, password=password, + pkey=None, + port=port, username=username), + _ssh.connect( + hostname=host, password=password, + pkey=pkey, + port=port, username=username), + ] self.assertIn(expected_calls, client.mock_calls) - self.check_defaults(ssh, host, port, username, password, - private_keys) + self.assertEqual( + ssh.auth, + SSHAuth( + username=username, + password=password, + keys=private_keys + ) + ) + else: + self.assertEqual(ssh.auth, auth) - self.assertIn( - mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format( - host, port, username, password - )), - logger.mock_calls + sftp = ssh._sftp + self.assertEqual(sftp, client().open_sftp()) + + self.assertEqual(ssh._ssh, client()) + + self.assertEqual(ssh.hostname, host) + self.assertEqual(ssh.port, port) + + self.assertEqual( + repr(ssh), + '{cls}(host={host}, port={port}, auth={auth!r})'.format( + cls=ssh.__class__.__name__, host=ssh.hostname, + port=ssh.port, + auth=ssh.auth + ) + ) + + def test_init_host(self, client, policy, logger, sleep): + """Test with host only set""" + self.init_checks( + client, policy, logger, + host=host) + + def test_init_alternate_port(self, client, policy, logger, sleep): + """Test with alternate port""" + self.init_checks( + client, policy, logger, + host=host, + port=2222 + ) + + def test_init_username(self, client, policy, logger, sleep): + """Test with username only set from creds""" + self.init_checks( + client, policy, logger, + host=host, + username=username + ) + + def test_init_username_password(self, client, policy, logger, sleep): + """Test with username and password set from creds""" + self.init_checks( + client, policy, logger, + host=host, + username=username, + password=password ) - def test_init_fail_sftp(self, client, policy, logger): - _ssh = mock.Mock() - client.return_value = _ssh - open_sftp = mock.Mock(parent=_ssh, side_effect=paramiko.SSHException) - _ssh.attach_mock(open_sftp, 'open_sftp') - warning = mock.Mock(parent=logger) - logger.attach_mock(warning, 'warning') - - ssh = SSHClient( + def test_init_username_password_empty_keys( + self, client, policy, logger, sleep): + """Test with username, password and empty keys set from creds""" + self.init_checks( + client, policy, logger, host=host, - port=port, username=username, password=password, - private_keys=private_keys) + private_keys=[] + ) + + def test_init_username_single_key(self, client, policy, logger, sleep): + """Test with username and single key set from creds""" + connect = mock.Mock( + side_effect=[ + paramiko.AuthenticationException, mock.Mock() + ]) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + self.init_checks( + client, policy, logger, + host=host, + username=username, + private_keys=gen_private_keys(1) + ) + + def test_init_username_password_single_key( + self, client, policy, logger, sleep): + """Test with username, password and single key set from creds""" + connect = mock.Mock( + side_effect=[ + paramiko.AuthenticationException, mock.Mock() + ]) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + self.init_checks( + client, policy, logger, + host=host, + username=username, + password=password, + private_keys=gen_private_keys(1) + ) + + def test_init_username_multiple_keys(self, client, policy, logger, sleep): + """Test with username and multiple keys set from creds""" + connect = mock.Mock( + side_effect=[ + paramiko.AuthenticationException, mock.Mock() + ]) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + self.init_checks( + client, policy, logger, + host=host, + username=username, + private_keys=gen_private_keys(2) + ) + + def test_init_username_password_multiple_keys( + self, client, policy, logger, sleep): + """Test with username, password and multiple keys set from creds""" + connect = mock.Mock( + side_effect=[ + paramiko.AuthenticationException, mock.Mock() + ]) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + connect = mock.Mock( + side_effect=[ + paramiko.AuthenticationException, mock.Mock() + ]) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + self.init_checks( + client, policy, logger, + host=host, + username=username, + password=password, + private_keys=gen_private_keys(2) + ) + + def test_init_auth( + self, client, policy, logger, sleep): + self.init_checks( + client, policy, logger, + host=host, + auth=SSHAuth( + username=username, + password=password, + key=gen_private_keys(1).pop() + ) + ) + + def test_init_auth_break( + self, client, policy, logger, sleep): + self.init_checks( + client, policy, logger, + host=host, + username='Invalid', + password='Invalid', + private_keys=gen_private_keys(1), + auth=SSHAuth( + username=username, + password=password, + key=gen_private_keys(1).pop() + ) + ) + + def test_init_context( + self, client, policy, logger, sleep): + with SSHClient(host=host, auth=SSHAuth()) as ssh: + client.assert_called_once() + policy.assert_called_once() + + logger.assert_not_called() + + self.assertEqual(ssh.auth, SSHAuth()) + + sftp = ssh._sftp + self.assertEqual(sftp, client().open_sftp()) + + self.assertEqual(ssh._ssh, client()) + + self.assertEqual(ssh.hostname, host) + self.assertEqual(ssh.port, port) + + def test_init_clear_failed( + self, client, policy, logger, sleep): + """Test reconnect + + :type client: mock.Mock + :type policy: mock.Mock + :type logger: mock.Mock + """ + _ssh = mock.Mock() + _ssh.attach_mock( + mock.Mock( + side_effect=[ + Exception('Mocked SSH close()'), + mock.Mock() + ]), + 'close') + _sftp = mock.Mock() + _sftp.attach_mock( + mock.Mock( + side_effect=[ + Exception('Mocked SFTP close()'), + mock.Mock() + ]), + 'close') + client.return_value = _ssh + _ssh.attach_mock(mock.Mock(return_value=_sftp), 'open_sftp') + + ssh = SSHClient(host=host, auth=SSHAuth()) + client.assert_called_once() + policy.assert_called_once() + + logger.assert_not_called() + + self.assertEqual(ssh.auth, SSHAuth()) + + sftp = ssh._sftp + self.assertEqual(sftp, _sftp) + + self.assertEqual(ssh._ssh, _ssh) + + self.assertEqual(ssh.hostname, host) + self.assertEqual(ssh.port, port) + + logger.reset_mock() + + ssh.clear() + + logger.assert_has_calls(( + mock.call.exception('Could not close ssh connection'), + mock.call.exception('Could not close sftp connection'), + )) + + def test_init_reconnect( + self, client, policy, logger, sleep): + """Test reconnect + + :type client: mock.Mock + :type policy: mock.Mock + :type logger: mock.Mock + """ + ssh = SSHClient(host=host, auth=SSHAuth()) + client.assert_called_once() + policy.assert_called_once() + + logger.assert_not_called() + + self.assertEqual(ssh.auth, SSHAuth()) + + sftp = ssh._sftp + self.assertEqual(sftp, client().open_sftp()) + + self.assertEqual(ssh._ssh, client()) + + client.reset_mock() + policy.reset_mock() + + self.assertEqual(ssh.hostname, host) + self.assertEqual(ssh.port, port) + + ssh.reconnect() + + _ssh = mock.call() + + expected_calls = [ + _ssh.close(), + _ssh, + _ssh.set_missing_host_key_policy('AutoAddPolicy'), + _ssh.connect( + hostname='127.0.0.1', + password=None, + pkey=None, + port=22, + username=None), + ] + self.assertIn( + expected_calls, + client.mock_calls + ) client.assert_called_once() policy.assert_called_once() - self.check_defaults(ssh, host, port, username, password, private_keys) + logger.assert_not_called() - warning.assert_called_once_with( - 'SFTP enable failed! SSH only is accessible.' + self.assertEqual(ssh.auth, SSHAuth()) + + sftp = ssh._sftp + self.assertEqual(sftp, client().open_sftp()) + + self.assertEqual(ssh._ssh, client()) + + def test_init_password_required( + self, client, policy, logger, sleep): + connect = mock.Mock(side_effect=paramiko.PasswordRequiredException) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + with self.assertRaises(paramiko.PasswordRequiredException): + SSHClient(host=host, auth=SSHAuth()) + logger.assert_has_calls(( + mock.call.exception('No password has been set!'), + )) + + def test_init_password_broken( + self, client, policy, logger, sleep): + connect = mock.Mock(side_effect=paramiko.PasswordRequiredException) + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + with self.assertRaises(paramiko.PasswordRequiredException): + SSHClient(host=host, auth=SSHAuth(password=password)) + + logger.assert_has_calls(( + mock.call.critical( + 'Unexpected PasswordRequiredException, ' + 'when password is set!' + ), + )) + + def test_init_auth_impossible_password( + self, client, policy, logger, sleep): + connect = mock.Mock(side_effect=paramiko.AuthenticationException) + + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + with self.assertRaises(paramiko.AuthenticationException): + SSHClient(host=host, auth=SSHAuth(password=password)) + + logger.assert_has_calls( + ( + mock.call.exception( + 'Connection using stored authentication info failed!'), + ) * 3 ) + def test_init_auth_impossible_key( + self, client, policy, logger, sleep): + connect = mock.Mock(side_effect=paramiko.AuthenticationException) + + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + with self.assertRaises(paramiko.AuthenticationException): + SSHClient( + host=host, + auth=SSHAuth(key=gen_private_keys(1).pop()) + ) + + logger.assert_has_calls( + ( + mock.call.exception( + 'Connection using stored authentication info failed!'), + ) * 3 + ) + + def test_init_auth_pass_no_key( + self, client, policy, logger, sleep): + connect = mock.Mock( + side_effect=[ + paramiko.AuthenticationException, + mock.Mock() + ]) + + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + key = gen_private_keys(1).pop() + + ssh = SSHClient( + host=host, + auth=SSHAuth( + username=username, + password=password, + key=key + ) + ) + + client.assert_called_once() + policy.assert_called_once() + + logger.assert_has_calls(( + mock.call.debug( + 'Main key has been updated, public key is: \nNone'), + )) + + self.assertEqual( + ssh.auth, + SSHAuth( + username=username, + password=password, + keys=[key] + ) + ) + + sftp = ssh._sftp + self.assertEqual(sftp, client().open_sftp()) + + self.assertEqual(ssh._ssh, client()) + + def test_init_auth_brute_impossible( + self, client, policy, logger, sleep): + connect = mock.Mock(side_effect=paramiko.AuthenticationException) + + _ssh = mock.Mock() + _ssh.attach_mock(connect, 'connect') + client.return_value = _ssh + + with self.assertRaises(paramiko.AuthenticationException): + SSHClient( + host=host, + username=username, + private_keys=gen_private_keys(2)) + + logger.assert_has_calls( + ( + mock.call.debug( + 'SSHClient(' + 'host={host}, port={port}, username={username}): ' + 'initialization by username/password/private_keys ' + 'is deprecated in favor of SSHAuth usage. ' + 'Please update your code'.format( + host=host, port=port, username=username + )), + ) + ( + mock.call.exception( + 'Connection using stored authentication info failed!'), + ) * 3 + ) + + def test_init_no_sftp( + self, client, policy, logger, sleep): + open_sftp = mock.Mock(side_effect=paramiko.SSHException) + + _ssh = mock.Mock() + _ssh.attach_mock(open_sftp, 'open_sftp') + client.return_value = _ssh + + ssh = SSHClient(host=host, auth=SSHAuth(password=password)) + + with self.assertRaises(paramiko.SSHException): + # pylint: disable=pointless-statement + # noinspection PyStatementEffect + ssh._sftp + # pylint: enable=pointless-statement + logger.assert_has_calls(( + mock.call.debug('SFTP is not connected, try to connect...'), + mock.call.warning( + 'SFTP enable failed! SSH only is accessible.'), + )) + + def test_init_sftp_repair( + self, client, policy, logger, sleep): + _sftp = mock.Mock() + open_sftp = mock.Mock( + side_effect=[ + paramiko.SSHException, + _sftp, _sftp]) + + _ssh = mock.Mock() + _ssh.attach_mock(open_sftp, 'open_sftp') + client.return_value = _ssh + + ssh = SSHClient(host=host, auth=SSHAuth(password=password)) + with self.assertRaises(paramiko.SSHException): # pylint: disable=pointless-statement # noinspection PyStatementEffect ssh._sftp # pylint: enable=pointless-statement - warning.assert_has_calls([ - mock.call('SFTP enable failed! SSH only is accessible.'), - mock.call('SFTP is not connected, try to reconnect'), - mock.call('SFTP enable failed! SSH only is accessible.')]) + logger.reset_mock() - # Unblock sftp connection - # (reset_mock is not possible to use in this case) - _sftp = mock.Mock() - open_sftp = mock.Mock(parent=_ssh, return_value=_sftp) - _ssh.attach_mock(open_sftp, 'open_sftp') sftp = ssh._sftp - self.assertEqual(sftp, _sftp) + self.assertEqual(sftp, open_sftp()) + logger.assert_has_calls(( + mock.call.debug('SFTP is not connected, try to connect...'), + )) @mock.patch('devops.helpers.ssh_client.logger', autospec=True) @@ -229,9 +771,10 @@ class TestExecute(TestCase): return SSHClient( host=host, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) def test_execute_async(self, client, policy, logger): chan = mock.Mock() @@ -325,8 +868,9 @@ class TestExecute(TestCase): logger.mock_calls ) + @mock.patch('devops.helpers.ssh_client.SSHAuth.enter_password') def test_execute_async_sudo_password( - self, client, policy, logger): + self, enter_password, client, policy, logger): stdin = mock.Mock(name='stdin') stdout = mock.Mock(name='stdout') stdout_channel = mock.Mock() @@ -350,6 +894,7 @@ class TestExecute(TestCase): get_transport.assert_called_once() open_session.assert_called_once() # raise ValueError(closed.mock_calls) + enter_password.assert_called_once_with(stdin) stdin.assert_has_calls((mock.call.flush(), )) self.assertIn(chan, result) @@ -365,8 +910,7 @@ class TestExecute(TestCase): logger.mock_calls ) - @staticmethod - def get_patched_execute_async_retval(ec=0, stderr_val=True): + def get_patched_execute_async_retval(self, ec=0, stderr_val=True): stderr = mock.Mock() stdout = mock.Mock() @@ -464,9 +1008,10 @@ class TestExecute(TestCase): ssh2 = SSHClient( host=host2, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) remotes = [ssh, ssh2] @@ -615,9 +1160,10 @@ class TestExecuteThrowHost(TestCase): ssh = SSHClient( host=host, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) result = ssh.execute_through_host(target, command) self.assertEqual(result, return_value) @@ -660,13 +1206,14 @@ class TestExecuteThrowHost(TestCase): ssh = SSHClient( host=host, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) result = ssh.execute_through_host( target, command, - username=_login, password=_password) + auth=SSHAuth(username=_login, password=_password)) self.assertEqual(result, return_value) get_transport.assert_called_once() open_channel.assert_called_once() @@ -701,9 +1248,10 @@ class TestSftp(TestCase): ssh = SSHClient( host=host, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) return ssh, _sftp def test_exists(self, client, policy, logger): @@ -793,9 +1341,10 @@ class TestSftp(TestCase): ssh = SSHClient( host=host, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) # Path not exists ssh.mkdir(path) @@ -817,9 +1366,10 @@ class TestSftp(TestCase): ssh = SSHClient( host=host, port=port, - username=username, - password=password - ) + auth=SSHAuth( + username=username, + password=password + )) # Path not exists ssh.rm_rf(path)