Merge "Implement an SSH module"

This commit is contained in:
Jenkins 2014-03-14 17:26:59 +00:00 committed by Gerrit Code Review
commit b7c0e0abae
4 changed files with 1124 additions and 9 deletions

View File

@ -1,5 +1,7 @@
iso8601>=0.1.9
Jinja2
# python3 branch of paramiko
https://github.com/paramiko/paramiko/archive/python3.zip#egg=paramiko-1.13.0
pbr>=0.5.21,<1.0
python-novaclient==2.15.0
# pythonwhois with python 3.3 readiness patch

View File

@ -21,15 +21,6 @@ class SatoriException(Exception):
"""
def __init__(self, message):
"""Store error message."""
self.message = message
super(SatoriException, self).__init__()
def __str__(self):
"""Display error message."""
return repr(self.message)
class SatoriInvalidNetloc(SatoriException):
@ -39,3 +30,8 @@ class SatoriInvalidNetloc(SatoriException):
class SatoriShellException(SatoriException):
"""Invalid shell parameters."""
class GetPTYRetryFailure(SatoriException):
"""Tried to re-run command with get_pty to no avail."""

406
satori/ssh.py Normal file
View File

@ -0,0 +1,406 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# pylint: disable=R0902, R0913
"""SSH Module for connecting to and automating remote commands.
Supports proxying, as in `ssh -A`
"""
import ast
import logging
import os
import re
import tempfile
import time
import paramiko
import six
from satori import errors
LOG = logging.getLogger(__name__)
MIN_PASSWORD_PROMPT_LEN = 8
MAX_PASSWORD_PROMPT_LEN = 64
TEMPFILE_PREFIX = ".satori.tmp.key."
TTY_REQUIRED = [
"you must have a tty to run sudo",
"is not a tty",
"no tty present",
]
def connect(*args, **kwargs):
"""Connect to a remote device over SSH."""
try:
return SSH.get_client(*args, **kwargs)
except TypeError as exc:
msg = "got an unexpected"
if msg in str(exc):
message = "%s " + str(exc)[str(exc).index(msg):]
raise exc.__class__(message % "connect()")
raise
class AcceptMissingHostKey(paramiko.client.MissingHostKeyPolicy):
"""Allow connections to hosts whose fingerprints are not on record."""
# pylint: disable=R0903
def missing_host_key(self, client, hostname, key):
"""Add missing host key."""
# pylint: disable=W0212
client._host_keys.add(hostname, key.get_name(), key)
class SSH(paramiko.SSHClient):
"""Connects to devices via SSH to execute commands."""
def __init__(self, host, password=None, username="root",
private_key=None, key_filename=None, port=22,
timeout=20, proxy=None, options=None):
"""Create an instance of the SSH class.
:param str host: The ip address or host name of the server
to connect to
:param str password: A password to use for authentication
or for unlocking a private key
:param username: The username to authenticate as
:param private_key: Private SSH Key string to use
(instead of using a filename)
:param key_filename: a private key filename (path)
:param port: tcp/ip port to use (defaults to 22)
:param float timeout: an optional timeout (in seconds) for the
TCP connection
:param socket proxy: an existing SSH instance to use
for proxying
:param dict options: A dictionary used to set ssh options
(when proxying).
e.g. for `ssh -o StrictHostKeyChecking=no`,
you would provide
(.., options={'StrictHostKeyChecking': 'no'})
Conversion of booleans is also supported,
(.., options={'StrictHostKeyChecking': False})
is equivalent.
"""
self.password = password
self.host = host
self.username = username
self.private_key = private_key
self.key_filename = key_filename
self.port = port
self.timeout = timeout
self._platform_info = None
self.options = options or {}
self.proxy = proxy
self.sock = None
if self.proxy:
if not isinstance(self.proxy, SSH):
raise TypeError("'proxy' must be a satori.ssh.SSH instance. "
"( instances of this type are returned by "
"satori.ssh.connect() )")
super(SSH, self).__init__()
@staticmethod
def _get_pkey(private_key):
"""Return a paramiko.pkey.PKey from private key string."""
key_classes = [paramiko.rsakey.RSAKey,
paramiko.dsskey.DSSKey,
paramiko.ecdsakey.ECDSAKey, ]
keyfile = six.StringIO(private_key)
for cls in key_classes:
keyfile.seek(0)
try:
pkey = cls.from_private_key(keyfile)
except paramiko.SSHException:
continue
else:
keytype = cls
LOG.info("Valid SSH Key provided (%s)", keytype.__name__)
return pkey
raise paramiko.SSHException("Is not a valid private key")
@classmethod
def get_client(cls, *args, **kwargs):
"""Return an ssh client object from this module."""
return cls(*args, **kwargs)
@property
def platform_info(self):
"""Return distro, version, architecture."""
if not self._platform_info:
command = ('python -c '
'"""import sys,platform as p;'
'plat=list(p.dist()+(p.machine(),));'
'sys.stdout.write(str(plat))"""')
output = self.remote_execute(command)
stdout = re.split('\n|\r\n', output['stdout'])[-1].strip()
plat = ast.literal_eval(stdout)
self._platform_info = {'dist': plat[0].lower(), 'version': plat[1],
'arch': plat[3]}
LOG.debug("Remote platform info: %s", self._platform_info)
return self._platform_info
def connect(self, use_password=False): # pylint: disable=W0221
"""Attempt an SSH connection through paramiko.SSHClient.connect .
The order for authentication attempts is:
- private_key
- key_filename
- any key discoverable in ~/.ssh/
- username/password
:param use_password: Skip SSH keys when authenticating.
"""
self.load_system_host_keys()
if self.proxy:
# lazy load
self.sock = self._get_proxy_socket(self.proxy)
if self.options.get('StrictHostKeyChecking') in (False, "no"):
self.set_missing_host_key_policy(AcceptMissingHostKey())
try:
if self.private_key is not None and not use_password:
pkey = self._get_pkey(self.private_key)
LOG.debug("Trying supplied private key string")
return super(SSH, self).connect(
self.host,
timeout=self.timeout,
port=self.port,
username=self.username,
pkey=pkey,
sock=self.sock)
elif self.key_filename is not None and not use_password:
LOG.debug("Trying key file: %s",
os.path.expanduser(self.key_filename))
return super(SSH, self).connect(
self.host, timeout=self.timeout, port=self.port,
username=self.username,
key_filename=os.path.expanduser(self.key_filename),
sock=self.sock)
else:
return super(SSH, self).connect(
self.host, port=self.port,
username=self.username,
password=self.password,
sock=self.sock)
except paramiko.PasswordRequiredException as exc:
#Looks like we have cert issues, so try password auth if we can
if self.password and not use_password: # dont recurse twice
LOG.debug("Retrying with password credentials")
return self.connect(use_password=True)
else:
raise exc
except paramiko.BadHostKeyException as exc:
msg = (
"ssh://%s@%s:%d failed: %s. You might have a bad key "
"entry on your server, but this is a security issue and "
"won't be handled automatically. To fix this you can remove "
"the host entry for this host from the /.ssh/known_hosts file"
% (self.username, self.host, self.port, exc))
LOG.info(msg)
raise exc
except Exception as exc:
LOG.info('ssh://%s@%s:%d failed. %s',
self.username, self.host, self.port, exc)
raise exc
def test_connection(self):
"""Connect to an ssh server and verify that it responds.
The order for authentication attempts is:
(1) private_key
(2) key_filename
(3) any key discoverable in ~/.ssh/
(4) username/password
"""
LOG.debug("Checking for a response from ssh://%s@%s:%d.",
self.username, self.host, self.port)
try:
self.connect()
LOG.debug("ssh://%s@%s:%d is up.",
self.username, self.host, self.port)
return True
except Exception as exc:
LOG.info("ssh://%s@%s:%d failed. %s",
self.username, self.host, self.port, exc)
return False
finally:
self.close()
def _handle_tty_required(self, results, get_pty):
"""Determine whether the result implies a tty request."""
if any(m in str(k) for m in TTY_REQUIRED for k in results.values()):
LOG.info('%s requires TTY for sudo. Using TTY mode.',
self.host)
if get_pty is True: # if this is *already* True
raise errors.GetPTYRetryFailure(
"Running command with get_pty=True FAILED: %s@%s:%d"
% (self.username, self.host, self.port))
else:
return True
return False
def _handle_password_prompt(self, stdin, stdout):
"""Determine whether the remote host is prompting for a password.
Respond to the prompt through stdin if applicable.
"""
if not stdout.channel.closed:
buflen = len(stdout.channel.in_buffer)
# min and max determined from max username length
# and a set of encountered linux password prompts
if MIN_PASSWORD_PROMPT_LEN < buflen < MAX_PASSWORD_PROMPT_LEN:
prompt = stdout.channel.recv(buflen)
if all(m in prompt.lower()
for m in ['password', ':']):
LOG.warning("%s@%s encountered prompt! of length "
" [%s] {%s}",
self.username, self.host, buflen, prompt)
stdin.write("%s\n" % self.password)
stdin.flush()
return True
else:
LOG.warning("Nearly a False-Positive on "
"password prompt detection. [%s] {%s}",
buflen, prompt)
stdout.channel.send(prompt)
return False
def remote_execute(self, command, with_exit_code=False, get_pty=False):
"""Execute an ssh command on a remote host.
Tries cert auth first and falls back
to password auth if password provided.
:param command: Shell command to be executed by this function.
:param with_exit_code: Include the exit_code in the return body.
:param get_pty: Request a pseudo-terminal from the server.
:returns: a dict with stdin, stdout,
and (optionally) the exit code of the call.
"""
LOG.debug("Executing '%s' on ssh://%s@%s:%s.",
command, self.username, self.host, self.port)
try:
self.connect()
results = None
chan = self.get_transport().open_session()
if get_pty:
chan.get_pty()
stdin = chan.makefile('wb')
stdout = chan.makefile('rb')
stderr = chan.makefile_stderr('rb')
chan.exec_command(command)
LOG.debug('ssh://%s@%s:%d responded.', self.username, self.host,
self.port)
time.sleep(.25)
self._handle_password_prompt(stdin, stdout)
results = {
'stdout': stdout.read().strip(),
'stderr': stderr.read()
}
exit_code = chan.recv_exit_status()
if with_exit_code:
results.update({'exit_code': exit_code})
chan.close()
if self._handle_tty_required(results, get_pty):
return self.remote_execute(
command, with_exit_code=with_exit_code, get_pty=True)
return results
except Exception as exc:
LOG.info("ssh://%s@%s:%d failed. %s", self.username, self.host,
self.port, exc)
raise
finally:
self.close()
def _get_proxy_socket(self, proxy):
"""Return a wrapped subprocess running ProxyCommand-driven programs.
Create a new CommandProxy instance.
Can be created from an existing SSH instance.
For proxy clients, please specify a private key filename.
To use an ssh proxy, you must use an SSH Key,
since a ProxyCommand cannot be passed a password.
"""
if proxy.password:
LOG.warning("Proxying through a client which is authorized by "
"a password is not currently implemented. Please "
"use an ssh key.")
proxy.load_system_host_keys()
if proxy.options.get('StrictHostKeyChecking') in (False, "no"):
proxy.set_missing_host_key_policy(AcceptMissingHostKey())
if proxy.private_key and not proxy.key_filename:
tempkeyfile = tempfile.NamedTemporaryFile(
mode='w+', prefix=TEMPFILE_PREFIX,
dir=os.path.expanduser('~/'), delete=True)
tempkeyfile.write(proxy.private_key)
proxy.key_filename = tempkeyfile.name
pxd = {
'bastion': proxy.host,
'user': proxy.username,
'port': '-p %s' % proxy.port,
'options': ('-o ConnectTimeout=%s ' % proxy.timeout),
'target_host': self.host,
'target_port': self.port,
}
proxycommand = "ssh {options} -A {user}@{bastion} "
if proxy.key_filename:
proxy.key_filename = os.path.expanduser(proxy.key_filename)
proxy.key_filename = os.path.abspath(proxy.key_filename)
pxd.update({'identity': '-i %s' % proxy.key_filename})
proxycommand += "{identity} "
if proxy.options:
for key, val in sorted(proxy.options.items()):
if isinstance(val, bool):
# turns booleans into `ssh -o` compat "yes" or "no"
if val is True:
val = "yes"
if val is False:
val = "no"
pxd['options'] += '-o %s=%s ' % (key, val)
proxycommand += "nc {target_host} {target_port}"
return paramiko.ProxyCommand(proxycommand.format(**pxd))
# Share SSH.__init__'s docstring
connect.__doc__ = SSH.__init__.__doc__
try:
SSH.__dict__['get_client'].__doc__ = SSH.__dict__['__init__'].__doc__
except AttributeError:
SSH.get_client.__func__.__doc__ = SSH.__init__.__doc__

711
satori/tests/test_ssh.py Normal file
View File

@ -0,0 +1,711 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# pylint: disable=C0111, C0103, W0212, R0904
"""Satori SSH Module Tests."""
import os
import unittest
import mock
import paramiko
from satori import errors
from satori import ssh
from satori.tests import utils
class TestTTYRequired(utils.TestCase):
"""Test response to tty demand."""
def setUp(self):
super(TestTTYRequired, self).setUp()
self.client = ssh.SSH('123.456.789.0', password='test_password')
self.stdout = mock.MagicMock()
self.stdin = mock.MagicMock()
def test_valid_demand(self):
"""Ensure that anticipated requests for tty's return True."""
for substring in ssh.TTY_REQUIRED:
results = {'stdout': "xyz" + substring + "zyx"}
self.assertTrue(self.client._handle_tty_required(results, False))
def test_normal_response(self):
"""Ensure standard response returns False."""
examples = ["hello", "#75-Ubuntu SMP Tue Jun 18 17:59:38 UTC 2013",
("fatal: Not a git repository "
"(or any of the parent directories): .git")]
for substring in examples:
results = {'stderr': '', 'stdout': substring}
self.assertFalse(self.client._handle_tty_required(results, False))
def test_no_recurse(self):
"""Avoid infinte loop by raising GetPTYRetryFailure.
When retrying with get_pty in response to one of TTY_REQUIRED
"""
for substring in ssh.TTY_REQUIRED:
results = {'stdout': substring}
self.assertRaises(errors.GetPTYRetryFailure,
self.client._handle_tty_required,
results, True)
class TestConnectHelper(utils.TestCase):
def test_connect_helper(self):
self.assertIsInstance(ssh.connect("123.456.789.0"), ssh.SSH)
def test_throws_typeerror_well(self):
self.assertRaises(TypeError, ssh.connect,
("123.456.789.0",), invalidkey="bad")
def test_throws_typeerror_well_with_message(self):
try:
ssh.connect("123.456.789.0", invalidkey="bad")
except TypeError as exc:
self.assertEqual("connect() got an unexpected keyword "
"argument 'invalidkey'", str(exc))
def test_throws_error_no_host(self):
self.assertRaises(TypeError, ssh.connect)
class TestSSHKeys(utils.TestCase):
def setUp(self):
super(TestSSHKeys, self).setUp()
self.invalidkey = """
-----BEGIN RSA PRIVATE KEY-----
MJK7hkKYHUNJKDHNF)980BN456bjnkl_0Hj08,l$IRJSDLKjhkl/jFJVSLx2doRZ
-----END RSA PRIVATE KEY-----
"""
self.ecdsakey = """
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIiZdMfDf+lScOkujN1+zAKDJ9PQRquCVZoXfS+6hDlToAoGCCqGSM49
AwEHoUQDQgAE/qUj+vxnhIrkTR/ayYx9ZC/9JanJGyXkOe3Oe6WT/FJ9vBbfThTF
U9+i43I3TONq+nWbhFKBj8XR4NKReaYeBw==
-----END EC PRIVATE KEY-----
"""
self.rsakey = """
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAwbbaH5m0yLIVAi1i4aJ5uKprPM93x6b/KkH5N4QmZoXGOFId
v0G64Sanz1VZkCWXiyivgkT6/y0+M0Ok8UK24UO6YNBSFGKboan/OMNETTIqXzmV
liVYkQTf2zrBPWofjeDnzMndy7AD5iylJ6cNAksFM+sLt0MQcOeCmbOX8E6+AGZr
JLj8orJgGJKU9jN5tnMlgtDP9BVrrbi7wX0kqb42OMtM6AuMUBDtAM2QSpTJa0JL
mFOLfe6PYOLdQaJsnaoV+Wu4eBdY91h8COmhOKZv5VMYalOSDQnsKgngDW9iOoFs
Uou7W8Wk3FXusbDwAvakWKmQtDF8SIgMLqygTwIDAQABAoIBAQCe5FkuKmm7ZTcO
PiQpZ5fn/QFRM+vP/A64nrzI6MCGv5vDfreftU6Qd6CV1DBOqEcRgiHT/LjUrkui
yQ12R36yb1dlKfrpdaiqhkIuURypJUjUKuuj6KYo7ZKgxCTVN0MCoUQBGmOvO4U3
O8+MIt3sz5RI7bcCbyQBOCRL5p/uH3soWoG+6u2W17M4otLT0xJGX5eU0AoCYfOi
Vd9Ot3j687k6KtZajy2hZIccuGNRwFeKSIAN9U7FEy4fgxkIMrc/wqArKmZLNui1
SkVP3UHlbGVAI5ZDLzdcyxXPRWz1FBtJYiITtQCVKTv5LFCxFjlIWML2qJMB2GTW
0+t1WhEhAoGBAOFdh14qn0i5v7DztpkS665vQ9F8n7RN0b17yK5wNmfhd4gYK/ym
hCPUk0+JfPNQuhhhzoDXWICiCHRqNVT0ZzkyY0E2aTYLYxbeKkiCOccqJXxtxiI+
6KneRMV3mKaJXJLz8G0YepB2Qhv4JkNsR1yiA5EqIs0Cr9Jafg9tHQsrAoGBANwL
5lYjNHu51WVdjv2Db4oV26fRqAloc0//CBCl9IESM7m9A7zPTboMMijSqEuz3qXJ
Fd5++B/b1Rkt4EunJNcE+XRJ9cI7MKE1kYKz6oiSN4X4eHQSDmlpS9DBcAEjTJ8r
c+5DsPMSkz6qMxbG+FZB1SvVflFZe9dO8Ba7oR1tAoGAa+97keIf/5jW8k0HOzEQ
p66qcH6bjqNmvLW4W7Nqmz4lHY1WI98skmyRURqsOWyEdIEDgjmhLZptKjRj7phP
h9lWKDmDEltJzf4BilC0k2rgIUQCDQzMKe9GSL0K41gOemNS1y1OJjo9V1/2E3yc
gQUnaDMiD8Ylpz2n+oNr0ZkCgYBqDK4g+2yS6JgI91MvqQW7lhc7xRZoGlfgyPe5
FlJFVmFpdcf0WjCKptARzpzfhzuZyNTqW2T37bnBHdQIgfCGVFZpDjAMQPyJ5UhQ
pqc01Ms/nOVogz9A3Ed2v5NcaQfHemiv/x2ruFsQi3R92LzczXOQYZ80U50Uwm2B
d0IJ7QKBgD39jFiz7U4XEK/knRWUBUNq8QSGF5UuzO404z/+6J2KlFeNiDe+aH0c
cdi+/PhkDkMXfW6eQdvgFYs277uss4M+4F8fWb2KVvPTuZXmTf6qntFoZNuL1oIv
kn+fI2noF0ET7ktofoPEeD2/ya0B9/XecUqDJcVofoVO2pxMn12A
-----END RSA PRIVATE KEY-----
"""
self.dsakey = """
-----BEGIN DSA PRIVATE KEY-----
MIIBuwIBAAKBgQC+WvLRuPNDPVfZwKYqJYuD6XXjrUU4KIdLWmRO9qOtq0UR1kOQ
/4rhjgb2TyujW6RzPnqPc9eUv84Z3gKawAdZv5/vKbp6tpMn86Y42r0Ohy63DEgM
XyBfWxbZm0RBmLy3bCUefMOBngnODIhrTt2o+ip5ve5JMctDvjkWBVnZiQIVAMlh
6gd7IC68FwynC4f/p8+zpx9pAoGARjTQeKxBBDDfxySYDN0maXHMR21RF/gklecO
x6sH1MEDtOupQk0/uIPvolH0Jh+PK+NAv0GBZ96PDrF5z0S6MyQ5eHWGtwW4NFqk
ZGHTriy+8qc4OhtyS3dpXQu40Ad2o1ap1v806RwM8iw1OfBa94h/vreedO0ij2Fe
7aKEci4CgYAITw+ySCskHakn1GTG952MKxlMo7Mx++dYnCoFxsMwXFlwIrpzyhhC
Qk11sEgcAOZ2HiRVhwaz4BivNV5iuwUeIeKJc12W4+FU+Lh533hFOcSAYbBr1Crl
e+YpaOHRjLel0Nb5Cil4qEQaWQDmWvQb958IQQgzC9NhnR7NRNkfrgIVAKfMMZKz
57plimt3W9YoDAATyr6i
-----END DSA PRIVATE KEY-----
"""
def test_invalid_key_raises_sshexception(self):
self.assertRaises(
paramiko.SSHException, ssh.SSH._get_pkey, self.invalidkey)
def test_valid_ecdsa_returns_pkey_obj(self):
self.assertIsInstance(ssh.SSH._get_pkey(self.ecdsakey), paramiko.PKey)
def test_valid_rsa_returns_pkey_obj(self):
self.assertIsInstance(ssh.SSH._get_pkey(self.rsakey), paramiko.PKey)
def test_valid_ds_returns_pkey_obj(self):
self.assertIsInstance(ssh.SSH._get_pkey(self.dsakey), paramiko.PKey)
@mock.patch.object(ssh, 'LOG')
def test_valid_ecdsa_logs_key_class(self, mock_LOG):
ssh.SSH._get_pkey(self.ecdsakey)
mock_LOG.info.assert_called_with(
'Valid SSH Key provided (%s)', 'ECDSAKey')
@mock.patch.object(ssh, 'LOG')
def test_valid_rsa_logs_key_class(self, mock_LOG):
ssh.SSH._get_pkey(self.rsakey)
mock_LOG.info.assert_called_with(
'Valid SSH Key provided (%s)', 'RSAKey')
@mock.patch.object(ssh, 'LOG')
def test_valid_dsa_logs_key_class(self, mock_LOG):
ssh.SSH._get_pkey(self.dsakey)
mock_LOG.info.assert_called_with(
'Valid SSH Key provided (%s)', 'DSSKey')
class TestSSHConnect(TestSSHKeys):
def setUp(self):
super(TestSSHConnect, self).setUp()
self.host = '123.456.789.0'
self.client = ssh.SSH(self.host, username='test-user')
paramiko.SSHClient.connect = mock.MagicMock()
def test_connect_no_auth_attrs(self):
"""Test connect call without auth attributes."""
self.client.connect()
paramiko.SSHClient.connect.assert_called_once_with(
'123.456.789.0', username='test-user',
password=None, sock=None, port=22)
def test_connect_with_password(self):
self.client.password = 'test-password'
self.client.connect()
paramiko.SSHClient.connect.assert_called_once_with(
'123.456.789.0', username='test-user',
password='test-password', sock=None, port=22)
def test_connect_invalid_private_key_string(self):
self.client.private_key = self.invalidkey
self.assertRaises(paramiko.SSHException, self.client.connect)
def test_connect_valid_private_key_string(self):
validkeys = [self.rsakey, self.dsakey, self.ecdsakey]
for key in validkeys:
self.client.private_key = key
self.client.connect()
pkey_kwarg_value = (paramiko.SSHClient.
connect.call_args[1]['pkey'])
self.assertIsInstance(pkey_kwarg_value, paramiko.PKey)
paramiko.SSHClient.connect.assert_called_with(
'123.456.789.0', username='test-user',
pkey=pkey_kwarg_value, sock=None, port=22, timeout=20)
def test_key_filename(self):
self.client.key_filename = "~/not/a/real/path"
expanded_path = os.path.expanduser(self.client.key_filename)
self.client.connect()
paramiko.SSHClient.connect.assert_called_once_with(
'123.456.789.0', username='test-user',
key_filename=expanded_path,
sock=None, port=22, timeout=20)
def test_use_password_on_exc_negative(self):
"""Do this without self.password. """
paramiko.SSHClient.connect.side_effect = (
paramiko.PasswordRequiredException)
self.assertRaises(paramiko.PasswordRequiredException,
self.client.connect)
@mock.patch.object(ssh, 'LOG')
def test_logging_use_password_on_exc_positive(self, mock_LOG):
self.client.password = 'test-password'
paramiko.SSHClient.connect.side_effect = (
paramiko.PasswordRequiredException)
self.assertRaises(paramiko.PasswordRequiredException,
self.client.connect)
mock_LOG.debug.assert_called_with('Retrying with password credentials')
@mock.patch.object(ssh, 'LOG')
def test_logging_when_badhostkey(self, mock_LOG):
"""Test when raising BadHostKeyException."""
self.client.private_key = self.rsakey
paramiko.SSHClient.connect.side_effect = (
paramiko.BadHostKeyException(None, None, None))
self.assertRaises(paramiko.BadHostKeyException,
self.client.connect)
mock_LOG.info.assert_called_with(
"ssh://test-user@123.456.789.0:22 failed: "
"Host key for server None does not match!. "
"You might have a bad key entry on your server, "
"but this is a security issue and won't be handled "
"automatically. To fix this you can remove the "
"host entry for this host from the /.ssh/known_hosts file")
@mock.patch.object(ssh, 'LOG')
def test_logging_when_reraising_other_exc(self, mock_LOG):
self.client.private_key = self.rsakey
paramiko.SSHClient.connect.side_effect = Exception
self.assertRaises(Exception,
self.client.connect)
err = mock_LOG.info.call_args[0][-1]
mock_LOG.info.assert_called_with(
'ssh://%s@%s:%d failed. %s',
'test-user', '123.456.789.0', 22, err)
def test_reraising_other_exc(self):
self.client.private_key = self.rsakey
paramiko.SSHClient.connect.side_effect = (
paramiko.BadHostKeyException(None, None, None))
self.assertRaises(paramiko.BadHostKeyException,
self.client.connect)
def test_default_user_is_root(self):
self.client = ssh.SSH('123.456.789.0')
self.client.connect()
default = paramiko.SSHClient.connect.call_args[1]['username']
self.assertEqual(default, 'root')
def test_missing_host_key_policy(self):
client = ssh.connect(
"123.456.789.0", options={'StrictHostKeyChecking': 'no'})
client.connect()
self.assertIsInstance(
client._policy, ssh.AcceptMissingHostKey)
def test_adds_missing_host_key(self):
client = ssh.connect(
"123.456.789.0", options={'StrictHostKeyChecking': 'no'})
client.connect()
pkey = client._get_pkey(self.rsakey)
client._policy.missing_host_key(
client,
"123.456.789.0",
pkey)
expected = {'123.456.789.0': {
'ssh-rsa': pkey}}
self.assertEqual(expected, client._host_keys)
class TestTestConnection(TestSSHKeys):
def setUp(self):
super(TestTestConnection, self).setUp()
self.host = '123.456.789.0'
self.client = ssh.SSH(self.host, username='test-user')
paramiko.SSHClient.connect = mock.MagicMock()
def test_test_connection(self):
self.assertTrue(self.client.test_connection())
def test_test_connection_fail_invalid_key(self):
self.client.private_key = self.invalidkey
self.assertFalse(self.client.test_connection())
def test_test_connection_valid_key(self):
self.client.private_key = self.dsakey
self.assertTrue(self.client.test_connection())
def test_test_connection_fail_other(self):
paramiko.SSHClient.connect.side_effect = Exception
self.assertFalse(self.client.test_connection())
@mock.patch.object(ssh, 'LOG')
def test_test_connection_logging(self, mock_LOG):
self.client.test_connection()
mock_LOG.debug.assert_called_with(
'ssh://%s@%s:%d is up.', 'test-user', self.host, 22)
class TestGetProxySocket(TestSSHKeys):
def setUp(self):
super(TestGetProxySocket, self).setUp()
paramiko.ProxyCommand = mock.MagicMock()
paramiko.SSHClient.connect = mock.MagicMock()
self.proxy = ssh.SSH('proxy.address', username='proxy-user')
self.client = ssh.SSH('123.546.789.0', username='client-user')
self.mutable = [True]
def tearDown(self):
super(TestGetProxySocket, self).tearDown()
self.proxy.close()
self.client.close()
def test_get_proxy_socket(self):
self.client._get_proxy_socket(self.proxy)
paramiko.ProxyCommand.assert_called_once_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address nc 123.546.789.0 22')
def test_get_proxy_socket_private_key(self):
self.proxy.private_key = self.rsakey
self.client._get_proxy_socket(self.proxy)
self.assertTrue(paramiko.ProxyCommand.called)
def tempfile_spotted(self):
home = os.path.expanduser('~/')
filist = [k for k in os.listdir(home)
if k.startswith(ssh.TEMPFILE_PREFIX)]
while all(self.mutable):
new = [k for k in os.listdir(home)
if k.startswith(ssh.TEMPFILE_PREFIX)]
if len(new) > len(filist):
return set(new).difference(set(filist)).pop()
return False
def test_get_proxy_file_unseen(self):
self.proxy.key_filename = "~/not/a/real/path"
expanded_path = os.path.expanduser(self.proxy.key_filename)
self.client._get_proxy_socket(self.proxy)
paramiko.ProxyCommand.assert_called_once_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address -i %s '
'nc 123.546.789.0 22' % expanded_path)
class TestRemoteExecute(TestSSHKeys):
def setUp(self):
super(TestRemoteExecute, self).setUp()
paramiko.ProxyCommand = mock.MagicMock()
paramiko.SSHClient.connect = mock.MagicMock()
self.client = ssh.SSH('123.456.789.0', username='client-user')
self.client._handle_password_prompt = mock.Mock(return_value=False)
self.mock_chan = mock.MagicMock()
mock_transport = mock.MagicMock()
mock_transport.open_session.return_value = self.mock_chan
self.client.get_transport = mock.MagicMock(
return_value=mock_transport)
self.mock_chan.exec_command = mock.MagicMock()
self.mock_chan.makefile.side_effect = self.mkfile
self.mock_chan.makefile_stderr.side_effect = (
lambda x: self.mkfile(x, err=True))
self.example_command = 'echo hello'
self.example_output = 'hello'
def mkfile(self, arg, err=False, stdoutput=None):
if arg == 'rb' and not err:
stdout = mock.MagicMock()
stdout.read.return_value = stdoutput or self.example_output
stdout.read.return_value += "\n"
return stdout
if arg == 'wb' and not err:
stdin = mock.MagicMock()
stdin.read.return_value = ''
return stdin
if err is True:
stderr = mock.MagicMock()
stderr.read.return_value = ''
return stderr
def test_remote_execute_proper_primitive(self):
self.client._handle_tty_required = mock.Mock(return_value=False)
commands = ['echo hello', 'uname -a', 'rev ~/.bash*']
for cmd in commands:
self.client.remote_execute(cmd)
self.mock_chan.exec_command.assert_called_with(cmd)
def test_remote_execute_no_exit_code(self):
self.client._handle_tty_required = mock.Mock(return_value=False)
self.mock_chan.recv_exit_status.return_value = 0
actual_output = self.client.remote_execute(self.example_command)
expected_output = {'stdout': self.example_output,
'stderr': ''}
self.assertEqual(expected_output, actual_output)
def test_remote_execute_with_exit_code(self):
self.client._handle_tty_required = mock.Mock(return_value=False)
self.mock_chan.recv_exit_status.return_value = 0
actual_output = self.client.remote_execute(
self.example_command, with_exit_code=True)
expected_output = {'stdout': self.example_output,
'stderr': '',
'exit_code': 0}
self.assertEqual(expected_output, actual_output)
def test_remote_execute_tty_required(self):
for i, substring in enumerate(ssh.TTY_REQUIRED):
self.mock_chan.makefile.side_effect = lambda x: self.mkfile(
x, stdoutput="xyz" + substring + "zyx")
self.assertRaises(
errors.GetPTYRetryFailure,
self.client.remote_execute,
'sudo echo_hello')
self.assertEqual(i + 1, self.mock_chan.get_pty.call_count)
def test_get_platform_info(self):
platinfo = ['Ubuntu', '12.04', 'precise', 'x86_64']
fields = ['dist', 'version', 'remove', 'arch']
expected_result = dict(zip(fields, [v.lower() for v in platinfo]))
expected_result.pop('remove')
self.mock_chan.makefile.side_effect = lambda x: self.mkfile(
x, stdoutput=str(platinfo))
self.assertEqual(expected_result, self.client.platform_info)
self.assertEqual(expected_result, self.client.platform_info)
class TestRemoteExecuteWithProxy(TestSSHKeys):
def setUp(self):
super(TestRemoteExecuteWithProxy, self).setUp()
paramiko.ProxyCommand = mock.MagicMock()
paramiko.SSHClient.connect = mock.MagicMock()
proxy = ssh.SSH('proxy-address', username='proxy-user')
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=proxy)
self.client._handle_tty_required = mock.Mock(return_value=False)
self.client._handle_password_prompt = mock.Mock(return_value=False)
self.mock_chan = mock.MagicMock()
mock_transport = mock.MagicMock()
mock_transport.open_session.return_value = self.mock_chan
self.client.get_transport = mock.MagicMock(
return_value=mock_transport)
self.mock_chan.exec_command = mock.MagicMock()
self.mock_chan.makefile.side_effect = self.mkfile
self.mock_chan.makefile_stderr.side_effect = (
lambda x: self.mkfile(x, err=True))
self.example_command = 'echo hello'
self.example_output = 'hello'
def mkfile(self, arg, err=False):
if arg == 'rb' and not err:
stdout = mock.MagicMock()
stdout.read.return_value = 'hello\n'
return stdout
if arg == 'wb' and not err:
stdin = mock.MagicMock()
stdin.read.return_value = ''
return stdin
if err is True:
stderr = mock.MagicMock()
stderr.read.return_value = ''
return stderr
def test_remote_execute_with_proxy(self):
commands = ['echo hello', 'uname -a', 'rev ~/.bash*']
for cmd in commands:
self.client.remote_execute(cmd)
self.mock_chan.exec_command.assert_called_with(cmd)
class TestProxy(TestSSHKeys):
"""self.client in this class is instantiated with a proxy."""
def setUp(self):
super(TestProxy, self).setUp()
paramiko.ProxyCommand = mock.MagicMock()
paramiko.SSHClient.connect = mock.MagicMock()
self.proxy = ssh.SSH('proxy.address', username='proxy-user')
def tearDown(self):
super(TestProxy, self).tearDown()
self.proxy.close()
def test_test_connection(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.assertTrue(self.client.test_connection())
def test_test_connection_fail_invalid_key(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.private_key = self.invalidkey
self.assertFalse(self.client.test_connection())
def test_test_connection_valid_key(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.private_key = self.dsakey
self.assertTrue(self.client.test_connection())
def test_test_connection_fail_other(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
paramiko.SSHClient.connect.side_effect = Exception
self.assertFalse(self.client.test_connection())
def test_connect_with_proxy_socket(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address nc 123.456.789.0 22')
def test_connect_with_proxy_socket_and_options(self):
options = {
'BatchMode': 'yes',
'CheckHostIP': 'yes',
'ChallengeResponseAuthentication': 'no',
'Ciphers': 'cast128-cbc,aes256-ctr',
}
self.proxy = ssh.SSH(
'proxy.address', username='proxy-user', options=options)
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 -o BatchMode=yes '
'-o ChallengeResponseAuthentication=no '
'-o CheckHostIP=yes -o Ciphers=cast128-cbc,aes256-ctr '
'-A proxy-user@proxy.address nc 123.456.789.0 22')
def test_connect_with_proxy_socket_and_bool_options(self):
"""Should create the same ProxyCommand call as bool_options."""
options = {
'BatchMode': True,
'CheckHostIP': True,
'ChallengeResponseAuthentication': False,
'Ciphers': 'cast128-cbc,aes256-ctr',
}
self.proxy = ssh.SSH(
'proxy.address', username='proxy-user', options=options)
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 -o BatchMode=yes '
'-o ChallengeResponseAuthentication=no '
'-o CheckHostIP=yes -o Ciphers=cast128-cbc,aes256-ctr '
'-A proxy-user@proxy.address nc 123.456.789.0 22')
def test_connect_with_proxy_no_host_raises(self):
proxy = {'this': 'is not a proxy'}
self.assertRaises(
TypeError,
ssh.SSH, ('123.456.789.0',),
username='client-user', proxy=proxy)
def test_connect_with_proxy_socket_private_key(self):
"""Test when proxy inits with private key string.
Overrides self.client and self.proxy from setUp.
"""
self.proxy.private_key = self.rsakey
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
self.assertTrue(paramiko.ProxyCommand.called)
ident = ' -i '
begin = paramiko.ProxyCommand.call_args[0][0].index(ident)
end = paramiko.ProxyCommand.call_args[0][0].find(
' ', begin + len(ident))
tempfilepath = (paramiko.ProxyCommand.
call_args[0][0][begin + len(ident):end])
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address '
'-i %s nc 123.456.789.0 22' % tempfilepath)
class TestPasswordPrompt(utils.TestCase):
def setUp(self):
super(TestPasswordPrompt, self).setUp()
ssh.LOG = mock.MagicMock()
self.client = ssh.SSH('123.456.789.0', password='test_password')
self.stdout = mock.MagicMock()
self.stdin = mock.MagicMock()
def test_channel_closed(self):
"""If the channel is closed, there's no prompt."""
self.stdout.channel.closed = True
self.assertFalse(
self.client._handle_password_prompt(self.stdin, self.stdout))
def test_password_prompt_buflen_too_short(self):
"""Stdout chan buflen is too short to be a password prompt."""
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "a" * (ssh.MIN_PASSWORD_PROMPT_LEN - 1)
self.assertFalse(
self.client._handle_password_prompt(self.stdin, self.stdout))
def test_password_prompt_buflen_too_long(self):
"""Stdout chan buflen is too long to be a password prompt."""
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "a" * (ssh.MAX_PASSWORD_PROMPT_LEN + 1)
self.assertFalse(
self.client._handle_password_prompt(self.stdin, self.stdout))
def test_common_password_prompt(self):
"""Ensure that a couple commonly seen prompts have success."""
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "[sudo] password for user:"
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.assertTrue(
self.client._handle_password_prompt(self.stdin, self.stdout))
self.stdout.channel.in_buffer = "Password:"
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.assertTrue(
self.client._handle_password_prompt(self.stdin, self.stdout))
def test_password_prompt_other_prompt(self):
"""Pass buflen check, fail on substring check."""
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "Welcome to <hostname>:"
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.assertFalse(
self.client._handle_password_prompt(self.stdin, self.stdout))
def test_logging_encountered_prompt(self):
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "[sudo] password for user:"
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.client._handle_password_prompt(self.stdin, self.stdout)
ssh.LOG.warning.assert_called_with(
'%s@%s encountered prompt! of length [%s] {%s}', "root",
'123.456.789.0', 25, '[sudo] password for user:')
def test_logging_nearly_false_positive(self):
"""Assert that a close-call on a false-positive logs a warning."""
other_prompt = "Welcome to <hostname>:"
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = other_prompt
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.client._handle_password_prompt(self.stdin, self.stdout)
ssh.LOG.warning.assert_called_with(
'Nearly a False-Positive on password prompt detection. [%s] {%s}',
22, other_prompt)
def test_password_given_to_prompt(self):
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "[sudo] password for user:"
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.client._handle_password_prompt(self.stdin, self.stdout)
self.stdin.write.assert_called_with(self.client.password + '\n')
def test_password_given_returns_true(self):
self.stdout.channel.closed = False
self.stdout.channel.in_buffer = "[sudo] password for user:"
self.stdout.channel.recv.return_value = self.stdout.channel.in_buffer
self.assertTrue(
self.client._handle_password_prompt(self.stdin, self.stdout))
if __name__ == "__main__":
unittest.main()