reimplement ssh proxying using ssh transport

Removes fragile, bulky code and provides a more robust solution,
as well as negating the depedency on a local ssh program.

Coincident updates:
- renames keyword argument 'proxy' to 'gateway' and updates
  close() method to handle the closure of that connection as well
- renames underlying RemoteShell client to _client attribute
    - uses this attribute to access methods

Change-Id: Ie1ce9f19fbe5bb4341fb6721e3069c1d267be95f
Implements: blueprint reimplement-ssh-proxy
This commit is contained in:
Samuel Stavinoha 2014-05-23 22:14:06 +00:00
parent c382a3e860
commit dca1d8d28a
4 changed files with 64 additions and 260 deletions

View File

@ -111,7 +111,6 @@ class LocalShell(ShellMixin):
self.user = user
self.password = password
self.interactive = interactive
# TODO(samstav): Implement handle_password_prompt for popen
# properties
self._platform_info = None
@ -157,7 +156,10 @@ class RemoteShell(ShellMixin):
"""Execute shell commands on a remote machine over ssh."""
def __init__(self, address, **kwargs):
def __init__(self, address, password=None, username=None,
private_key=None, key_filename=None, port=None,
timeout=None, gateway=None, options=None, interactive=False,
**kwargs):
"""An interface for executing shell commands on remote machines.
:param str host: The ip address or host name of the server
@ -171,7 +173,7 @@ class RemoteShell(ShellMixin):
:param port: tcp/ip port to use (defaults to 22)
:param float timeout: an optional timeout (in seconds) for the
TCP connection
:param socket proxy: an existing SSH instance to use
:param socket gateway: an existing SSH instance to use
for proxying
:param dict options: A dictionary used to set ssh options
(when proxying).
@ -183,16 +185,35 @@ class RemoteShell(ShellMixin):
is equivalent.
:keyword interactive: If true, prompt for password if missing.
"""
self.sshclient = ssh.connect(address, **kwargs)
self.host = self.sshclient.host
self.port = self.sshclient.port
if kwargs:
LOG.warning("Satori RemoteClient received unrecognized "
"keyword arguments: %s", kwargs.keys())
self._client = ssh.connect(
address, password=password, username=username,
private_key=private_key, key_filename=key_filename, port=port,
timeout=timeout, gateway=gateway, options=options,
interactive=interactive)
self.host = self._client.host
self.port = self._client.port
@property
def platform_info(self):
"""Return distro, version, architecture."""
return self.sshclient.platform_info
return self._client.platform_info
def execute(self, command, wd=None, with_exit_code=None):
def connect(self):
"""Connect to the remote host."""
return self._client.connect()
def test_connection(self):
"""Test the connection to the remote host."""
return self._client.test_connection()
def execute(self, command, **kwargs):
"""Execute given command over ssh."""
return self.sshclient.remote_execute(
command, wd=wd, with_exit_code=with_exit_code)
return self._client.remote_execute(command, **kwargs)
def close(self):
"""Close the connection to the remote host."""
return self._client.close()

View File

@ -12,7 +12,7 @@
"""SSH Module for connecting to and automating remote commands.
Supports proxying, as in `ssh -A`
Supports proxying through an ssh tunnel ('gateway' keyword argument.)
To control the behavior of the SSH client, use the specific connect_with_*
calls. The .connect() call behaves like the ssh command and attempts a number
@ -31,7 +31,6 @@ import getpass
import logging
import os
import re
import tempfile
import time
import paramiko
@ -102,7 +101,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
# pylint: disable=R0913
def __init__(self, host, password=None, username="root",
private_key=None, key_filename=None, port=22,
timeout=20, proxy=None, options=None, interactive=False):
timeout=20, gateway=None, options=None, interactive=False):
"""Create an instance of the SSH class.
:param str host: The ip address or host name of the server
@ -116,7 +115,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
:param port: tcp/ip port to use (defaults to 22)
:param float timeout: an optional timeout (in seconds) for the
TCP connection
:param socket proxy: an existing SSH instance to use
:param socket gateway: an existing SSH instance to use
for proxying
:param dict options: A dictionary used to set ssh options
(when proxying).
@ -137,13 +136,13 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
self.timeout = timeout
self._platform_info = None
self.options = options or {}
self.proxy = proxy
self.gateway = gateway
self.sock = None
self.interactive = interactive
if self.proxy:
if not isinstance(self.proxy, SSH):
raise TypeError("'proxy' must be a satori.ssh.SSH instance. "
if self.gateway:
if not isinstance(self.gateway, SSH):
raise TypeError("'gateway' must be a satori.ssh.SSH instance. "
"( instances of this type are returned by "
"satori.ssh.connect() )")
@ -224,13 +223,16 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
"""Set up client and connect to target."""
self.load_system_host_keys()
if self.proxy:
# lazy load
self.sock = self._get_proxy_socket(self.proxy)
if self.options.get('StrictHostKeyChecking') in (False, "no"):
self.set_missing_host_key_policy(AcceptMissingHostKey())
if self.gateway:
# lazy load
if not self.gateway.get_transport():
self.gateway.connect()
self.sock = self.gateway.get_transport().open_channel(
'direct-tcpip', (self.host, self.port), ('', 0))
return super(SSH, self).connect(
self.host,
timeout=kwargs.pop('timeout', self.timeout),
@ -250,6 +252,11 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
- username/password (will prompt if the password is not supplied and
interactive is true)
"""
# idempotency
if self.get_transport():
if self.get_transport().is_active():
return
if self.private_key:
try:
return self.connect_with_key()
@ -305,6 +312,15 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
finally:
self.close()
def close(self):
"""Close the connection to the remote host.
If an ssh tunnel is being used, close that first.
"""
if self.gateway:
self.gateway.close()
return super(SSH, self).close()
def _handle_tty_required(self, results, get_pty):
"""Determine whether the result implies a tty request."""
if any(m in str(k) for m in TTY_REQUIRED for k in results.values()):
@ -417,60 +433,6 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
finally:
self.close()
def _get_proxy_socket(self, proxy):
"""Return a wrapped subprocess running ProxyCommand-driven programs.
Create a new CommandProxy instance.
Can be created from an existing SSH instance.
For proxy clients, please specify a private key filename.
To use an ssh proxy, you must use an SSH Key,
since a ProxyCommand cannot be passed a password.
"""
if proxy.password:
LOG.warning("Proxying through a client which is authorized by "
"a password is not currently implemented. Please "
"use an ssh key.")
proxy.load_system_host_keys()
if proxy.options.get('StrictHostKeyChecking') in (False, "no"):
proxy.set_missing_host_key_policy(AcceptMissingHostKey())
if proxy.private_key and not proxy.key_filename:
tempkeyfile = tempfile.NamedTemporaryFile(
mode='w+', prefix=TEMPFILE_PREFIX,
dir=os.path.expanduser('~/'), delete=True)
tempkeyfile.write(proxy.private_key)
proxy.key_filename = tempkeyfile.name
pxd = {
'bastion': proxy.host,
'user': proxy.username,
'port': '-p %s' % proxy.port,
'options': ('-o ConnectTimeout=%s ' % proxy.timeout),
'target_host': self.host,
'target_port': self.port,
}
proxycommand = "ssh {options} -A {user}@{bastion} "
if proxy.key_filename:
proxy.key_filename = os.path.expanduser(proxy.key_filename)
proxy.key_filename = os.path.abspath(proxy.key_filename)
pxd.update({'identity': '-i %s' % proxy.key_filename})
proxycommand += "{identity} "
if proxy.options:
for key, val in sorted(proxy.options.items()):
if isinstance(val, bool):
# turns booleans into `ssh -o` compat "yes" or "no"
if val is True:
val = "yes"
if val is False:
val = "no"
pxd['options'] += '-o %s=%s ' % (key, val)
proxycommand += "nc {target_host} {target_port}"
return paramiko.ProxyCommand(proxycommand.format(**pxd))
# Share SSH.__init__'s docstring
connect.__doc__ = SSH.__init__.__doc__

View File

@ -125,7 +125,7 @@ class TestRemoteShell(TestBashModule):
def test_execute(self):
self.remoteshell.execute(self.testrun.command)
self.mock_execute.assert_called_once_with(
self.testrun.command, wd=None, with_exit_code=None)
self.testrun.command)
def test_execute_resultdict(self):
resultdict = self.remoteshell.execute(self.testrun.command)

View File

@ -550,54 +550,6 @@ class TestTestConnection(SSHTestBase):
'ssh://%s@%s:%d is up.', 'test-user', self.host, 22)
class TestGetProxySocket(SSHTestBase):
def setUp(self):
super(TestGetProxySocket, self).setUp()
self.proxy_patcher = mock.patch.object(paramiko, "ProxyCommand")
self.proxy_patcher.start()
self.proxy = ssh.SSH('proxy.address', username='proxy-user')
self.client = ssh.SSH('123.546.789.0', username='client-user')
self.mutable = [True]
def tearDown(self):
self.proxy_patcher.stop()
super(TestGetProxySocket, self).tearDown()
def test_get_proxy_socket(self):
self.client._get_proxy_socket(self.proxy)
paramiko.ProxyCommand.assert_called_once_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address nc 123.546.789.0 22')
def test_get_proxy_socket_private_key(self):
self.proxy.private_key = self.rsakey
self.client._get_proxy_socket(self.proxy)
self.assertTrue(paramiko.ProxyCommand.called)
def tempfile_spotted(self):
home = os.path.expanduser('~/')
filist = [k for k in os.listdir(home)
if k.startswith(ssh.TEMPFILE_PREFIX)]
while all(self.mutable):
new = [k for k in os.listdir(home)
if k.startswith(ssh.TEMPFILE_PREFIX)]
if len(new) > len(filist):
return set(new).difference(set(filist)).pop()
return False
def test_get_proxy_file_unseen(self):
self.proxy.key_filename = "~/not/a/real/path"
expanded_path = os.path.expanduser(self.proxy.key_filename)
self.client._get_proxy_socket(self.proxy)
paramiko.ProxyCommand.assert_called_once_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address -i %s '
'nc 123.546.789.0 22' % expanded_path)
class TestRemoteExecute(SSHTestBase):
def setUp(self):
@ -685,160 +637,29 @@ class TestRemoteExecute(SSHTestBase):
self.assertEqual(expected_result, self.client.platform_info)
class TestRemoteExecuteWithProxy(SSHTestBase):
def setUp(self):
super(TestRemoteExecuteWithProxy, self).setUp()
self.proxy_patcher = mock.patch.object(paramiko, "ProxyCommand")
self.proxy_patcher.start()
proxy = ssh.SSH('proxy-address', username='proxy-user')
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=proxy)
self.client._handle_tty_required = mock.Mock(return_value=False)
self.client._handle_password_prompt = mock.Mock(return_value=False)
self.mock_chan = mock.MagicMock()
mock_transport = mock.MagicMock()
mock_transport.open_session.return_value = self.mock_chan
self.client.get_transport = mock.MagicMock(
return_value=mock_transport)
self.mock_chan.exec_command = mock.MagicMock()
self.mock_chan.makefile.side_effect = self.mkfile
self.mock_chan.makefile_stderr.side_effect = (
lambda x: self.mkfile(x, err=True))
self.example_command = 'echo hello'
self.example_output = 'hello'
def tearDown(self):
self.proxy_patcher.stop()
super(TestRemoteExecuteWithProxy, self).tearDown()
def mkfile(self, arg, err=False):
if arg == 'rb' and not err:
stdout = mock.MagicMock()
stdout.read.return_value = 'hello\n'
return stdout
if arg == 'wb' and not err:
stdin = mock.MagicMock()
stdin.read.return_value = ''
return stdin
if err is True:
stderr = mock.MagicMock()
stderr.read.return_value = ''
return stderr
def test_remote_execute_with_proxy(self):
commands = ['echo hello', 'uname -a', 'rev ~/.bash*']
for cmd in commands:
self.client.remote_execute(cmd)
self.mock_chan.exec_command.assert_called_with(cmd)
class TestProxy(SSHTestBase):
"""self.client in this class is instantiated with a proxy."""
def setUp(self):
super(TestProxy, self).setUp()
self.proxy_patcher = mock.patch.object(paramiko, "ProxyCommand")
self.proxy_patcher.start()
self.proxy = ssh.SSH('proxy.address', username='proxy-user')
self.gateway = ssh.SSH('gateway.address', username='gateway-user')
def tearDown(self):
self.proxy_patcher.stop()
super(TestProxy, self).tearDown()
def test_test_connection(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.assertTrue(self.client.test_connection())
def test_test_connection_valid_key(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.private_key = self.dsakey
self.assertTrue(self.client.test_connection())
def test_test_connection_fail_other(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
'123.456.789.0', username='client-user', gateway=self.gateway)
self.mock_connect.side_effect = Exception
self.assertFalse(self.client.test_connection())
def test_connect_with_proxy_socket(self):
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address nc 123.456.789.0 22')
def test_connect_with_proxy_socket_and_options(self):
options = {
'BatchMode': 'yes',
'CheckHostIP': 'yes',
'ChallengeResponseAuthentication': 'no',
'Ciphers': 'cast128-cbc,aes256-ctr',
}
self.proxy = ssh.SSH(
'proxy.address', username='proxy-user', options=options)
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 -o BatchMode=yes '
'-o ChallengeResponseAuthentication=no '
'-o CheckHostIP=yes -o Ciphers=cast128-cbc,aes256-ctr '
'-A proxy-user@proxy.address nc 123.456.789.0 22')
def test_connect_with_proxy_socket_and_bool_options(self):
"""Should create the same ProxyCommand call as bool_options."""
options = {
'BatchMode': True,
'CheckHostIP': True,
'ChallengeResponseAuthentication': False,
'Ciphers': 'cast128-cbc,aes256-ctr',
}
self.proxy = ssh.SSH(
'proxy.address', username='proxy-user', options=options)
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 -o BatchMode=yes '
'-o ChallengeResponseAuthentication=no '
'-o CheckHostIP=yes -o Ciphers=cast128-cbc,aes256-ctr '
'-A proxy-user@proxy.address nc 123.456.789.0 22')
def test_connect_with_proxy_no_host_raises(self):
proxy = {'this': 'is not a proxy'}
gateway = {'this': 'is not a gateway'}
self.assertRaises(
TypeError,
ssh.SSH, ('123.456.789.0',),
username='client-user', proxy=proxy)
def test_connect_with_proxy_socket_private_key(self):
"""Test when proxy inits with private key string.
Overrides self.client and self.proxy from setUp.
"""
self.proxy.private_key = self.rsakey
self.client = ssh.SSH(
'123.456.789.0', username='client-user', proxy=self.proxy)
self.client.connect()
self.assertTrue(paramiko.ProxyCommand.called)
ident = ' -i '
begin = paramiko.ProxyCommand.call_args[0][0].index(ident)
end = paramiko.ProxyCommand.call_args[0][0].find(
' ', begin + len(ident))
tempfilepath = (paramiko.ProxyCommand.
call_args[0][0][begin + len(ident):end])
paramiko.ProxyCommand.assert_called_with(
'ssh -o ConnectTimeout=20 '
' -A proxy-user@proxy.address '
'-i %s nc 123.456.789.0 22' % tempfilepath)
username='client-user', gateway=gateway)
if __name__ == "__main__":