SSHClient rework for SSH Manager integration

1. Implemented SSHAuth
2. Old initialization API has been marked as deprecated
3. SFTP is started on demand with 3 retries
4. Reworked unit test to cover 100%
5. Added docstrings
6. Remove cyclic SSH session initialization in helper
7. Code is ready for adopted memorize pattern

blueprint: sshmanager-integration

Change-Id: I49d0aa635ba3f3125ab17531c0790a0106b87fea
This commit is contained in:
Alexey Stepanov 2016-05-27 19:44:50 +03:00
parent 621b1bdc5f
commit aae58e25da
7 changed files with 1044 additions and 251 deletions

View File

@ -37,6 +37,7 @@ from six.moves import xmlrpc_client
from devops.error import AuthenticationError
from devops.error import DevopsError
from devops.error import TimeoutError
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
from devops import logger
from devops.settings import KEYSTONE_CREDS
@ -151,8 +152,9 @@ def wait_ssh_cmd(
password=SSH_CREDENTIALS['password'],
timeout=0):
ssh_client = SSHClient(host=host, port=port,
auth=SSHAuth(
username=username,
password=password)
password=password))
wait(lambda: not ssh_client.execute(check_cmd)['exit_code'],
timeout=timeout)
@ -198,10 +200,12 @@ def get_node_remote(env, node_name, login=SSH_SLAVE_CREDENTIALS['login'],
name=node_name).interfaces[0].mac_address)
wait(lambda: tcp_ping(ip, 22), timeout=180,
timeout_msg="Node {ip} is not accessible by SSH.".format(ip=ip))
return SSHClient(ip,
return SSHClient(
ip,
auth=SSHAuth(
username=login,
password=password,
private_keys=get_private_keys(env))
keys=get_private_keys(env)))
def get_admin_ip(env):

View File

@ -15,6 +15,7 @@
import os
import posixpath
import stat
from warnings import warn
import paramiko
import six
@ -24,7 +25,163 @@ from devops.helpers.retry import retry
from devops import logger
class SSHAuth(object):
__slots__ = ['__username', '__password', '__key', '__keys']
def __init__(
self,
username=None, password=None, key=None, keys=None):
"""SSH authorisation object
Used to authorize SSHClient.
Single SSHAuth object is associated with single host:port.
Password and key is private, other data is read-only.
:type username: str
:type password: str
:type key: paramiko.RSAKey
:type keys: list
"""
self.__username = username
self.__password = password
self.__key = key
self.__keys = [None]
if key is not None:
self.__keys.append(key)
if keys is not None:
for key in keys:
if key not in self.__keys:
self.__keys.append(key)
@property
def username(self):
"""Username for auth
:rtype: str
"""
return self.__username
@staticmethod
def __get_public_key(key):
"""Internal method for get public key from private
:type key: paramiko.RSAKey
"""
if key is None:
return None
return '{0} {1}'.format(key.get_name(), key.get_base64())
@property
def public_key(self):
"""public key for stored private key if presents else None
:rtype: str
"""
return self.__get_public_key(self.__key)
def enter_password(self, tgt):
"""Enter password to STDIN
Note: required for 'sudo' call
:type tgt: file
:rtype: str
"""
return tgt.write('{}\n'.format(self.__password))
def connect(self, client, hostname=None, port=22, log=True):
"""Connect SSH client object using credentials
:type client:
paramiko.client.SSHClient
paramiko.transport.Transport
:type log: bool
:raises paramiko.AuthenticationException
"""
kwargs = {
'username': self.username,
'password': self.__password}
if hostname is not None:
kwargs['hostname'] = hostname
kwargs['port'] = port
keys = [self.__key]
keys.extend([k for k in self.__keys if k != self.__key])
for key in keys:
kwargs['pkey'] = key
try:
client.connect(**kwargs)
if self.__key != key:
self.__key = key
logger.debug(
'Main key has been updated, public key is: \n'
'{}'.format(self.public_key))
return
except paramiko.PasswordRequiredException:
if self.__password is None:
logger.exception('No password has been set!')
raise
else:
logger.critical(
'Unexpected PasswordRequiredException, '
'when password is set!')
raise
except paramiko.AuthenticationException:
continue
msg = 'Connection using stored authentication info failed!'
if log:
logger.exception(
'Connection using stored authentication info failed!')
raise paramiko.AuthenticationException(msg)
def __hash__(self):
return hash((
self.__class__,
self.username,
self.__password,
tuple(self.__keys)
))
def __eq__(self, other):
return hash(self) == hash(other)
def __repr__(self):
_key = (
None if self.__key is None else
'<private for pub: {}>'.format(self.public_key)
)
_keys = []
for k in self.__keys:
if k == self.__key:
continue
_keys.append(
'<private for pub: {}>'.format(
self.__get_public_key(key=k)) if k is not None else None)
return (
'{cls}(username={username}, '
'password=<*masked*>, key={key}, keys={keys})'.format(
cls=self.__class__.__name__,
username=self.username,
key=_key,
keys=_keys)
)
def __str__(self):
return (
'{cls} for {username}'.format(
cls=self.__class__.__name__,
username=self.username,
)
)
class SSHClient(object):
__slots__ = [
'__hostname', '__port', '__auth', '__ssh', '__sftp', 'sudo_mode'
]
class get_sudo(object):
"""Context manager for call commands with sudo"""
def __init__(self, ssh):
@ -36,42 +193,128 @@ class SSHClient(object):
def __exit__(self, exc_type, exc_val, exc_tb):
self.ssh.sudo_mode = False
def __init__(self, host, port=22, username=None, password=None,
private_keys=None):
self.host = str(host)
self.port = int(port)
self.username = username
self.__password = password
if not private_keys:
private_keys = []
self.__private_keys = private_keys
self.__actual_pkey = None
def __hash__(self):
return hash((
self.__class__,
self.hostname,
self.port,
self.auth))
def __init__(
self,
host, port=22,
username=None, password=None, private_keys=None,
auth=None
):
"""SSHClient helper
:type host: str
:type port: int
:type username: str
:type password: str
:type private_keys: list
:type auth: SSHAuth
"""
self.__hostname = host
self.__port = port
self.sudo_mode = False
self.sudo = self.get_sudo(self)
self._ssh = None
self.__ssh = paramiko.SSHClient()
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.__sftp = None
self.reconnect()
self.__auth = auth
if auth is None:
msg = (
'SSHClient(host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
))
warn(msg, DeprecationWarning)
logger.debug(msg)
self.__auth = SSHAuth(
username=username,
password=password,
keys=private_keys
)
self.__connect()
if auth is None:
logger.info(
'{0}:{1}> SSHAuth was made from old style creds: '
'{2}'.format(self.hostname, self.port, self.auth))
@property
def password(self):
return self.__password
def auth(self):
"""Internal authorisation object
Attention: this public property is mainly for inheritance,
debug and information purposes.
Calls outside SSHClient and child classes is sign of incorrect design.
Change is completely disallowed.
:rtype: SSHAuth
"""
return self.__auth
@property
def private_keys(self):
return self.__private_keys
def hostname(self):
"""Connected remote host name
:rtype: str
"""
return self.__hostname
@property
def private_key(self):
return self.__actual_pkey
def port(self):
"""Connected remote port number
:rtype: int
"""
return self.__port
def __repr__(self):
return '{cls}(host={host}, port={port}, auth={auth!r})'.format(
cls=self.__class__.__name__, host=self.hostname, port=self.port,
auth=self.auth
)
def __str__(self):
return '{cls}(host={host}, port={port}) for user {user}'.format(
cls=self.__class__.__name__, host=self.hostname, port=self.port,
user=self.auth.username
)
@property
def public_key(self):
if self.private_key is None:
return None
key = self.private_key
return '{0} {1}'.format(key.get_name(), key.get_base64())
def _ssh(self):
"""ssh client object getter for inheritance support only
Attention: ssh client object creation and change
is allowed only by __init__ and reconnect call.
:rtype: paramiko.SSHClient
"""
return self.__ssh
@retry(count=3, delay=3)
def __connect(self):
"""Main method for connection open"""
self.auth.connect(
client=self.__ssh,
hostname=self.hostname, port=self.port,
log=True)
@retry(3, delay=0)
def __connect_sftp(self):
"""SFTP connection opener"""
try:
self.__sftp = self.__ssh.open_sftp()
except paramiko.SSHException:
logger.warning('SFTP enable failed! SSH only is accessible.')
@property
def _sftp(self):
@ -81,25 +324,28 @@ class SSHClient(object):
"""
if self.__sftp is not None:
return self.__sftp
logger.warning('SFTP is not connected, try to reconnect')
self._connect_sftp()
logger.debug('SFTP is not connected, try to connect...')
self.__connect_sftp()
if self.__sftp is not None:
return self.__sftp
raise paramiko.SSHException('SFTP connection failed')
def clear(self):
"""Clear SSH and SFTP sessions"""
try:
self.__ssh.close()
self.__sftp = None
except Exception:
logger.exception("Could not close ssh connection")
if self.__sftp is not None:
try:
self.__sftp.close()
except Exception:
logger.exception("Could not close sftp connection")
try:
self._ssh.close()
except Exception:
logger.exception("Could not close ssh connection")
def __del__(self):
self.clear()
self.__ssh.close()
self.__sftp = None
def __enter__(self):
return self
@ -107,39 +353,14 @@ class SSHClient(object):
def __exit__(self, exc_type, exc_val, exc_tb):
self.clear()
@retry(count=3, delay=3)
def connect(self):
logger.debug(
"Connect to '{0}:{1}' as '{2}:{3}'".format(
self.host, self.port, self.username, self.password))
for private_key in self.private_keys:
try:
self._ssh.connect(
self.host, port=self.port, username=self.username,
password=self.password, pkey=private_key)
self.__actual_pkey = private_key
return
except paramiko.AuthenticationException:
continue
if self.private_keys:
logger.error("Authentication with keys failed")
self.__actual_pkey = None
self._ssh.connect(
self.host, port=self.port, username=self.username,
password=self.password)
def _connect_sftp(self):
try:
self.__sftp = self._ssh.open_sftp()
except paramiko.SSHException:
logger.warning('SFTP enable failed! SSH only is accessible.')
def reconnect(self):
self._ssh = paramiko.SSHClient()
self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.connect()
self._connect_sftp()
"""Reconnect SSH and SFTP session"""
self.clear()
self.__ssh = paramiko.SSHClient()
self.__ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.__connect()
def check_call(
self,
@ -232,7 +453,7 @@ class SSHClient(object):
ret = chan.recv_exit_status()
chan.close()
if ret not in expected:
errors[remote.host] = ret
errors[remote.hostname] = ret
if errors and raise_on_err:
raise DevopsCalledProcessError(command, errors)
@ -301,7 +522,7 @@ class SSHClient(object):
cmd = 'sudo -S bash -c "%s"' % cmd.replace('"', '\\"')
chan.exec_command(cmd)
if stdout.channel.closed is False:
stdin.write('%s\n' % self.password)
self.auth.enter_password(stdin)
stdin.flush()
else:
chan.exec_command(cmd)
@ -311,23 +532,27 @@ class SSHClient(object):
self,
hostname,
cmd,
username=None,
password=None,
key=None,
auth=None,
target_port=22):
if username is None and password is None and key is None:
username = self.username
password = self.__password
key = self.private_key
"""Execute command on remote host through currently connected host
:type hostname: str
:type cmd: str
:type auth: SSHAuth
:type target_port: int
:rtype: dict
"""
if auth is None:
auth = self.auth
intermediate_channel = self._ssh.get_transport().open_channel(
kind='direct-tcpip',
dest_addr=(hostname, target_port),
src_addr=(self.host, 0))
src_addr=(self.hostname, 0))
transport = paramiko.Transport(sock=intermediate_channel)
# start client and authenticate transport
transport.connect(username=username, password=password, pkey=key)
auth.connect(transport)
# open ssh session
channel = transport.open_session()
@ -336,7 +561,6 @@ class SSHClient(object):
stdout = channel.makefile('rb')
stderr = channel.makefile_stderr('rb')
logger.info("Executing command: {}".format(cmd))
channel.exec_command(cmd)
# TODO(astepanov): make a logic for controlling channel state
@ -474,3 +698,5 @@ class SSHClient(object):
return attrs.st_mode & stat.S_IFDIR != 0
except IOError:
return False
__all__ = ['SSHAuth', 'SSHClient']

View File

@ -26,6 +26,7 @@ from devops.error import DevopsEnvironmentError
from devops.error import DevopsError
from devops.error import DevopsObjNotFound
from devops.helpers.network import IpNetworksPool
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
from devops.helpers.templates import create_devops_config
from devops.helpers.templates import get_devops_config
@ -362,30 +363,36 @@ class Environment(BaseModel):
key=lambda node: node.name
)[0]
return admin.remote(
self.admin_net,
login=login,
password=password)
self.admin_net, auth=SSHAuth(
username=login,
password=password))
# LEGACY, for fuel-qa compatibility
# @logwrap
def get_ssh_to_remote(self, ip,
login=settings.SSH_SLAVE_CREDENTIALS['login'],
password=settings.SSH_SLAVE_CREDENTIALS['password']):
warn('LEGACY, for fuel-qa compatibility', DeprecationWarning)
keys = []
remote = self.get_admin_remote()
for key_string in ['/root/.ssh/id_rsa',
'/root/.ssh/bootstrap.rsa']:
if self.get_admin_remote().isfile(key_string):
with self.get_admin_remote().open(key_string) as f:
if remote.isfile(key_string):
with remote.open(key_string) as f:
keys.append(RSAKey.from_private_key(f))
return SSHClient(ip,
return SSHClient(
ip,
auth=SSHAuth(
username=login,
password=password,
private_keys=keys)
keys=keys))
# LEGACY, for fuel-qa compatibility
# @logwrap
def get_ssh_to_remote_by_key(self, ip, keyfile):
@staticmethod
def get_ssh_to_remote_by_key(ip, keyfile):
warn('LEGACY, for fuel-qa compatibility', DeprecationWarning)
try:
with open(keyfile) as f:
keys = [RSAKey.from_private_key(f)]
@ -393,7 +400,7 @@ class Environment(BaseModel):
logger.warning('Loading of SSH key from file failed. Trying to use'
' SSH agent ...')
keys = Agent().get_keys()
return SSHClient(ip, private_keys=keys)
return SSHClient(ip, auth=SSHAuth(keys=keys))
# LEGACY, TO REMOVE (for fuel-qa compatibility)
def nodes(self): # migrated from EnvironmentModel.nodes()

View File

@ -270,7 +270,9 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)):
interface = self.get_interface_by_nailgun_network_name(name)
return interface.address_set.first().ip_address
def remote(self, network_name, login, password=None, private_keys=None):
def remote(
self, network_name, login=None, password=None, private_keys=None,
auth=None):
"""Create SSH-connection to the network
:rtype : SSHClient
@ -278,7 +280,7 @@ class Node(six.with_metaclass(ExtendableNodeType, ParamedModel, BaseModel)):
return SSHClient(
self.get_ip_address_by_network_name(network_name),
username=login,
password=password, private_keys=private_keys)
password=password, private_keys=private_keys, auth=auth)
def await(self, network_name, timeout=120, by_port=22):
wait_pass(

View File

@ -28,6 +28,7 @@ from six.moves import xrange
from devops import error
from devops.helpers import helpers
from devops.helpers.ssh_client import SSHAuth
class TestHelpersHelpers(unittest.TestCase):
@ -198,7 +199,8 @@ class TestHelpersHelpers(unittest.TestCase):
helpers.wait_ssh_cmd(
host, port, check_cmd, username, password, timeout)
ssh.assert_called_once_with(
host=host, port=port, username=username, password=password
host=host, port=port,
auth=SSHAuth(username=username, password=password)
)
wait.assert_called_once()
# Todo: cover ssh_client.execute
@ -292,13 +294,12 @@ class TestHelpersHelpers(unittest.TestCase):
raise Exception()
uri = 'http://127.0.0.1'
srv.return_value = Success()
srv.side_effect = [Success(), Success(), Fail()]
result = helpers.xmlrpcmethod(uri, 'success')
self.assertTrue(result)
srv.assert_called_once_with(uri)
srv.reset_mock()
srv.return_value = Success()
self.assertRaises(
AttributeError,
helpers.xmlrpcmethod,
@ -307,7 +308,6 @@ class TestHelpersHelpers(unittest.TestCase):
srv.assert_called_once_with(uri)
srv.reset_mock()
srv.return_value = Fail()
self.assertRaises(
AttributeError,
helpers.xmlrpcmethod,
@ -323,11 +323,15 @@ class TestHelpersHelpers(unittest.TestCase):
rand.assert_called_once_with(5)
def test_deepgetattr(self):
# pylint: disable=attribute-defined-outside-init
class Tst(object):
one = 1
tst = Tst()
tst2 = Tst()
tst2.two = Tst()
# pylint: enable=attribute-defined-outside-init
result = helpers.deepgetattr(tst, 'one')
self.assertEqual(result, 1)
result = helpers.deepgetattr(tst2, 'two.one')

View File

@ -14,6 +14,7 @@
# pylint: disable=no-self-use
from contextlib import closing
from os.path import basename
import posixpath
import stat
@ -21,9 +22,12 @@ from unittest import TestCase
import mock
import paramiko
# noinspection PyUnresolvedReferences
from six.moves import cStringIO
from six import PY2
from devops.error import DevopsCalledProcessError
from devops.helpers.ssh_client import SSHAuth
from devops.helpers.ssh_client import SSHClient
@ -48,20 +52,136 @@ private_keys = []
command = 'ls ~ '
class TestSSHAuth(TestCase):
def init_checks(self, username=None, password=None, key=None, keys=None):
"""shared positive init checks
:type username: str
:type password: str
:type key: paramiko.RSAKey
:type keys: list
"""
auth = SSHAuth(
username=username,
password=password,
key=key,
keys=keys
)
int_keys = [None]
if key is not None:
int_keys.append(key)
if keys is not None:
for k in keys:
if k not in int_keys:
int_keys.append(k)
self.assertEqual(auth.username, username)
with closing(cStringIO()) as tgt:
auth.enter_password(tgt)
self.assertEqual(tgt.getvalue(), '{}\n'.format(password))
self.assertEqual(
auth.public_key,
gen_public_key(key) if key is not None else None)
_key = (
None if auth.public_key is None else
'<private for pub: {}>'.format(auth.public_key)
)
_keys = []
for k in int_keys:
if k == key:
continue
_keys.append(
'<private for pub: {}>'.format(
gen_public_key(k)) if k is not None else None)
self.assertEqual(
repr(auth),
"{cls}("
"username={username}, "
"password=<*masked*>, "
"key={key}, "
"keys={keys})".format(
cls=SSHAuth.__name__,
username=auth.username,
key=_key,
keys=_keys
)
)
self.assertEqual(
str(auth),
'{cls} for {username}'.format(
cls=SSHAuth.__name__,
username=auth.username,
)
)
def test_init_username_only(self):
self.init_checks(
username=username
)
def test_init_username_password(self):
self.init_checks(
username=username,
password=password
)
def test_init_username_key(self):
self.init_checks(
username=username,
key=gen_private_keys(1).pop()
)
def test_init_username_password_key(self):
self.init_checks(
username=username,
password=password,
key=gen_private_keys(1).pop()
)
def test_init_username_password_keys(self):
self.init_checks(
username=username,
password=password,
keys=gen_private_keys(2)
)
def test_init_username_password_key_keys(self):
self.init_checks(
username=username,
password=password,
key=gen_private_keys(1).pop(),
keys=gen_private_keys(2)
)
@mock.patch('devops.helpers.retry.sleep', autospec=True)
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@mock.patch(
'paramiko.AutoAddPolicy', autospec=True, return_value='AutoAddPolicy')
@mock.patch('paramiko.SSHClient', autospec=True)
class TestSSHClient(TestCase):
def check_defaults(
self, obj, host, port, username, password, private_keys):
self.assertEqual(obj.host, host)
self.assertEqual(obj.port, port)
self.assertEqual(obj.username, username)
self.assertEqual(obj.password, password)
self.assertEqual(obj.private_keys, private_keys)
class TestSSHClientInit(TestCase):
def init_checks(
self,
client, policy, logger,
host=None, port=22,
username=None, password=None, private_keys=None,
auth=None
):
"""shared checks for positive cases
def test_init_passwd(self, client, policy, logger):
:type client: mock.Mock
:type policy: mock.Mock
:type logger: mock.Mock
:type host: str
:type port: int
:type username: str
:type password: str
:type private_keys: list
:type auth: SSHAuth
"""
_ssh = mock.call()
ssh = SSHClient(
@ -69,150 +189,572 @@ class TestSSHClient(TestCase):
port=port,
username=username,
password=password,
private_keys=private_keys)
private_keys=private_keys,
auth=auth
)
client.assert_called_once()
policy.assert_called_once()
if auth is None:
if private_keys is None or len(private_keys) == 0:
logger.assert_has_calls((
mock.call.debug(
'SSHClient('
'host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
)),
mock.call.info(
'{0}:{1}> SSHAuth was made from old style creds: '
'SSHAuth for {2}'.format(host, port, username))
))
else:
logger.assert_has_calls((
mock.call.debug(
'SSHClient('
'host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
)),
mock.call.debug(
'Main key has been updated, public key is: \n'
'{}'.format(ssh.auth.public_key)),
mock.call.info(
'{0}:{1}> SSHAuth was made from old style creds: '
'SSHAuth for {2}'.format(host, port, username))
))
else:
logger.assert_not_called()
if auth is None:
if private_keys is None or len(private_keys) == 0:
pkey = None
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=password,
hostname=host, password=password,
pkey=pkey,
port=port, username=username),
]
else:
pkey = private_keys[0]
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
hostname=host, password=password,
pkey=None,
port=port, username=username),
_ssh.connect(
hostname=host, password=password,
pkey=pkey,
port=port, username=username),
_ssh.open_sftp()
]
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, password, private_keys)
self.assertIsNone(ssh.private_key)
self.assertIsNone(ssh.public_key)
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
self.assertEqual(
ssh.auth,
SSHAuth(
username=username,
password=password,
keys=private_keys
)
)
else:
self.assertEqual(ssh.auth, auth)
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
def test_init_keys(self, client, policy, logger):
_ssh = mock.call()
self.assertEqual(ssh._ssh, client())
private_keys = gen_private_keys(1)
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
ssh = SSHClient(
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys)
client.assert_called_once()
policy.assert_called_once()
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=password, pkey=private_keys[0],
port=port, username=username),
_ssh.open_sftp()
]
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, password, private_keys)
self.assertEqual(ssh.private_key, private_keys[0])
self.assertEqual(ssh.public_key, gen_public_key(private_keys[0]))
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
self.assertEqual(
repr(ssh),
'{cls}(host={host}, port={port}, auth={auth!r})'.format(
cls=ssh.__class__.__name__, host=ssh.hostname,
port=ssh.port,
auth=ssh.auth
)
)
def test_init_as_context(self, client, policy, logger):
_ssh = mock.call()
def test_init_host(self, client, policy, logger, sleep):
"""Test with host only set"""
self.init_checks(
client, policy, logger,
host=host)
private_keys = gen_private_keys(1)
with SSHClient(
def test_init_alternate_port(self, client, policy, logger, sleep):
"""Test with alternate port"""
self.init_checks(
client, policy, logger,
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys) as ssh:
client.assert_called_once()
policy.assert_called_once()
expected_calls = [
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
host, password=password, pkey=private_keys[0],
port=port, username=username),
_ssh.open_sftp()
]
self.assertIn(expected_calls, client.mock_calls)
self.check_defaults(ssh, host, port, username, password,
private_keys)
self.assertIn(
mock.call.debug("Connect to '{0}:{1}' as '{2}:{3}'".format(
host, port, username, password
)),
logger.mock_calls
port=2222
)
def test_init_fail_sftp(self, client, policy, logger):
def test_init_username(self, client, policy, logger, sleep):
"""Test with username only set from creds"""
self.init_checks(
client, policy, logger,
host=host,
username=username
)
def test_init_username_password(self, client, policy, logger, sleep):
"""Test with username and password set from creds"""
self.init_checks(
client, policy, logger,
host=host,
username=username,
password=password
)
def test_init_username_password_empty_keys(
self, client, policy, logger, sleep):
"""Test with username, password and empty keys set from creds"""
self.init_checks(
client, policy, logger,
host=host,
username=username,
password=password,
private_keys=[]
)
def test_init_username_single_key(self, client, policy, logger, sleep):
"""Test with username and single key set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
open_sftp = mock.Mock(parent=_ssh, side_effect=paramiko.SSHException)
_ssh.attach_mock(open_sftp, 'open_sftp')
warning = mock.Mock(parent=logger)
logger.attach_mock(warning, 'warning')
ssh = SSHClient(
self.init_checks(
client, policy, logger,
host=host,
username=username,
private_keys=gen_private_keys(1)
)
def test_init_username_password_single_key(
self, client, policy, logger, sleep):
"""Test with username, password and single key set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
port=port,
username=username,
password=password,
private_keys=private_keys)
private_keys=gen_private_keys(1)
)
def test_init_username_multiple_keys(self, client, policy, logger, sleep):
"""Test with username and multiple keys set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
username=username,
private_keys=gen_private_keys(2)
)
def test_init_username_password_multiple_keys(
self, client, policy, logger, sleep):
"""Test with username, password and multiple keys set from creds"""
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException, mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
self.init_checks(
client, policy, logger,
host=host,
username=username,
password=password,
private_keys=gen_private_keys(2)
)
def test_init_auth(
self, client, policy, logger, sleep):
self.init_checks(
client, policy, logger,
host=host,
auth=SSHAuth(
username=username,
password=password,
key=gen_private_keys(1).pop()
)
)
def test_init_auth_break(
self, client, policy, logger, sleep):
self.init_checks(
client, policy, logger,
host=host,
username='Invalid',
password='Invalid',
private_keys=gen_private_keys(1),
auth=SSHAuth(
username=username,
password=password,
key=gen_private_keys(1).pop()
)
)
def test_init_context(
self, client, policy, logger, sleep):
with SSHClient(host=host, auth=SSHAuth()) as ssh:
client.assert_called_once()
policy.assert_called_once()
logger.assert_not_called()
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
def test_init_clear_failed(
self, client, policy, logger, sleep):
"""Test reconnect
:type client: mock.Mock
:type policy: mock.Mock
:type logger: mock.Mock
"""
_ssh = mock.Mock()
_ssh.attach_mock(
mock.Mock(
side_effect=[
Exception('Mocked SSH close()'),
mock.Mock()
]),
'close')
_sftp = mock.Mock()
_sftp.attach_mock(
mock.Mock(
side_effect=[
Exception('Mocked SFTP close()'),
mock.Mock()
]),
'close')
client.return_value = _ssh
_ssh.attach_mock(mock.Mock(return_value=_sftp), 'open_sftp')
ssh = SSHClient(host=host, auth=SSHAuth())
client.assert_called_once()
policy.assert_called_once()
logger.assert_not_called()
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, _sftp)
self.assertEqual(ssh._ssh, _ssh)
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
logger.reset_mock()
ssh.clear()
logger.assert_has_calls((
mock.call.exception('Could not close ssh connection'),
mock.call.exception('Could not close sftp connection'),
))
def test_init_reconnect(
self, client, policy, logger, sleep):
"""Test reconnect
:type client: mock.Mock
:type policy: mock.Mock
:type logger: mock.Mock
"""
ssh = SSHClient(host=host, auth=SSHAuth())
client.assert_called_once()
policy.assert_called_once()
logger.assert_not_called()
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
client.reset_mock()
policy.reset_mock()
self.assertEqual(ssh.hostname, host)
self.assertEqual(ssh.port, port)
ssh.reconnect()
_ssh = mock.call()
expected_calls = [
_ssh.close(),
_ssh,
_ssh.set_missing_host_key_policy('AutoAddPolicy'),
_ssh.connect(
hostname='127.0.0.1',
password=None,
pkey=None,
port=22,
username=None),
]
self.assertIn(
expected_calls,
client.mock_calls
)
client.assert_called_once()
policy.assert_called_once()
self.check_defaults(ssh, host, port, username, password, private_keys)
logger.assert_not_called()
warning.assert_called_once_with(
'SFTP enable failed! SSH only is accessible.'
self.assertEqual(ssh.auth, SSHAuth())
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
def test_init_password_required(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.PasswordRequiredException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.PasswordRequiredException):
SSHClient(host=host, auth=SSHAuth())
logger.assert_has_calls((
mock.call.exception('No password has been set!'),
))
def test_init_password_broken(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.PasswordRequiredException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.PasswordRequiredException):
SSHClient(host=host, auth=SSHAuth(password=password))
logger.assert_has_calls((
mock.call.critical(
'Unexpected PasswordRequiredException, '
'when password is set!'
),
))
def test_init_auth_impossible_password(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.AuthenticationException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.AuthenticationException):
SSHClient(host=host, auth=SSHAuth(password=password))
logger.assert_has_calls(
(
mock.call.exception(
'Connection using stored authentication info failed!'),
) * 3
)
def test_init_auth_impossible_key(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.AuthenticationException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.AuthenticationException):
SSHClient(
host=host,
auth=SSHAuth(key=gen_private_keys(1).pop())
)
logger.assert_has_calls(
(
mock.call.exception(
'Connection using stored authentication info failed!'),
) * 3
)
def test_init_auth_pass_no_key(
self, client, policy, logger, sleep):
connect = mock.Mock(
side_effect=[
paramiko.AuthenticationException,
mock.Mock()
])
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
key = gen_private_keys(1).pop()
ssh = SSHClient(
host=host,
auth=SSHAuth(
username=username,
password=password,
key=key
)
)
client.assert_called_once()
policy.assert_called_once()
logger.assert_has_calls((
mock.call.debug(
'Main key has been updated, public key is: \nNone'),
))
self.assertEqual(
ssh.auth,
SSHAuth(
username=username,
password=password,
keys=[key]
)
)
sftp = ssh._sftp
self.assertEqual(sftp, client().open_sftp())
self.assertEqual(ssh._ssh, client())
def test_init_auth_brute_impossible(
self, client, policy, logger, sleep):
connect = mock.Mock(side_effect=paramiko.AuthenticationException)
_ssh = mock.Mock()
_ssh.attach_mock(connect, 'connect')
client.return_value = _ssh
with self.assertRaises(paramiko.AuthenticationException):
SSHClient(
host=host,
username=username,
private_keys=gen_private_keys(2))
logger.assert_has_calls(
(
mock.call.debug(
'SSHClient('
'host={host}, port={port}, username={username}): '
'initialization by username/password/private_keys '
'is deprecated in favor of SSHAuth usage. '
'Please update your code'.format(
host=host, port=port, username=username
)),
) + (
mock.call.exception(
'Connection using stored authentication info failed!'),
) * 3
)
def test_init_no_sftp(
self, client, policy, logger, sleep):
open_sftp = mock.Mock(side_effect=paramiko.SSHException)
_ssh = mock.Mock()
_ssh.attach_mock(open_sftp, 'open_sftp')
client.return_value = _ssh
ssh = SSHClient(host=host, auth=SSHAuth(password=password))
with self.assertRaises(paramiko.SSHException):
# pylint: disable=pointless-statement
# noinspection PyStatementEffect
ssh._sftp
# pylint: enable=pointless-statement
logger.assert_has_calls((
mock.call.debug('SFTP is not connected, try to connect...'),
mock.call.warning(
'SFTP enable failed! SSH only is accessible.'),
))
def test_init_sftp_repair(
self, client, policy, logger, sleep):
_sftp = mock.Mock()
open_sftp = mock.Mock(
side_effect=[
paramiko.SSHException,
_sftp, _sftp])
_ssh = mock.Mock()
_ssh.attach_mock(open_sftp, 'open_sftp')
client.return_value = _ssh
ssh = SSHClient(host=host, auth=SSHAuth(password=password))
with self.assertRaises(paramiko.SSHException):
# pylint: disable=pointless-statement
# noinspection PyStatementEffect
ssh._sftp
# pylint: enable=pointless-statement
warning.assert_has_calls([
mock.call('SFTP enable failed! SSH only is accessible.'),
mock.call('SFTP is not connected, try to reconnect'),
mock.call('SFTP enable failed! SSH only is accessible.')])
logger.reset_mock()
# Unblock sftp connection
# (reset_mock is not possible to use in this case)
_sftp = mock.Mock()
open_sftp = mock.Mock(parent=_ssh, return_value=_sftp)
_ssh.attach_mock(open_sftp, 'open_sftp')
sftp = ssh._sftp
self.assertEqual(sftp, _sftp)
self.assertEqual(sftp, open_sftp())
logger.assert_has_calls((
mock.call.debug('SFTP is not connected, try to connect...'),
))
@mock.patch('devops.helpers.ssh_client.logger', autospec=True)
@ -229,9 +771,10 @@ class TestExecute(TestCase):
return SSHClient(
host=host,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
def test_execute_async(self, client, policy, logger):
chan = mock.Mock()
@ -325,8 +868,9 @@ class TestExecute(TestCase):
logger.mock_calls
)
@mock.patch('devops.helpers.ssh_client.SSHAuth.enter_password')
def test_execute_async_sudo_password(
self, client, policy, logger):
self, enter_password, client, policy, logger):
stdin = mock.Mock(name='stdin')
stdout = mock.Mock(name='stdout')
stdout_channel = mock.Mock()
@ -350,6 +894,7 @@ class TestExecute(TestCase):
get_transport.assert_called_once()
open_session.assert_called_once()
# raise ValueError(closed.mock_calls)
enter_password.assert_called_once_with(stdin)
stdin.assert_has_calls((mock.call.flush(), ))
self.assertIn(chan, result)
@ -365,8 +910,7 @@ class TestExecute(TestCase):
logger.mock_calls
)
@staticmethod
def get_patched_execute_async_retval(ec=0, stderr_val=True):
def get_patched_execute_async_retval(self, ec=0, stderr_val=True):
stderr = mock.Mock()
stdout = mock.Mock()
@ -464,9 +1008,10 @@ class TestExecute(TestCase):
ssh2 = SSHClient(
host=host2,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
remotes = [ssh, ssh2]
@ -615,9 +1160,10 @@ class TestExecuteThrowHost(TestCase):
ssh = SSHClient(
host=host,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
result = ssh.execute_through_host(target, command)
self.assertEqual(result, return_value)
@ -660,13 +1206,14 @@ class TestExecuteThrowHost(TestCase):
ssh = SSHClient(
host=host,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
result = ssh.execute_through_host(
target, command,
username=_login, password=_password)
auth=SSHAuth(username=_login, password=_password))
self.assertEqual(result, return_value)
get_transport.assert_called_once()
open_channel.assert_called_once()
@ -701,9 +1248,10 @@ class TestSftp(TestCase):
ssh = SSHClient(
host=host,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
return ssh, _sftp
def test_exists(self, client, policy, logger):
@ -793,9 +1341,10 @@ class TestSftp(TestCase):
ssh = SSHClient(
host=host,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
# Path not exists
ssh.mkdir(path)
@ -817,9 +1366,10 @@ class TestSftp(TestCase):
ssh = SSHClient(
host=host,
port=port,
auth=SSHAuth(
username=username,
password=password
)
))
# Path not exists
ssh.rm_rf(path)