reimplement ssh proxying using ssh transport
Removes fragile, bulky code and provides a more robust solution, as well as negating the depedency on a local ssh program. Coincident updates: - renames keyword argument 'proxy' to 'gateway' and updates close() method to handle the closure of that connection as well - renames underlying RemoteShell client to _client attribute - uses this attribute to access methods Change-Id: Ie1ce9f19fbe5bb4341fb6721e3069c1d267be95f Implements: blueprint reimplement-ssh-proxy
This commit is contained in:
parent
c382a3e860
commit
dca1d8d28a
|
@ -111,7 +111,6 @@ class LocalShell(ShellMixin):
|
|||
self.user = user
|
||||
self.password = password
|
||||
self.interactive = interactive
|
||||
# TODO(samstav): Implement handle_password_prompt for popen
|
||||
|
||||
# properties
|
||||
self._platform_info = None
|
||||
|
@ -157,7 +156,10 @@ class RemoteShell(ShellMixin):
|
|||
|
||||
"""Execute shell commands on a remote machine over ssh."""
|
||||
|
||||
def __init__(self, address, **kwargs):
|
||||
def __init__(self, address, password=None, username=None,
|
||||
private_key=None, key_filename=None, port=None,
|
||||
timeout=None, gateway=None, options=None, interactive=False,
|
||||
**kwargs):
|
||||
"""An interface for executing shell commands on remote machines.
|
||||
|
||||
:param str host: The ip address or host name of the server
|
||||
|
@ -171,7 +173,7 @@ class RemoteShell(ShellMixin):
|
|||
:param port: tcp/ip port to use (defaults to 22)
|
||||
:param float timeout: an optional timeout (in seconds) for the
|
||||
TCP connection
|
||||
:param socket proxy: an existing SSH instance to use
|
||||
:param socket gateway: an existing SSH instance to use
|
||||
for proxying
|
||||
:param dict options: A dictionary used to set ssh options
|
||||
(when proxying).
|
||||
|
@ -183,16 +185,35 @@ class RemoteShell(ShellMixin):
|
|||
is equivalent.
|
||||
:keyword interactive: If true, prompt for password if missing.
|
||||
"""
|
||||
self.sshclient = ssh.connect(address, **kwargs)
|
||||
self.host = self.sshclient.host
|
||||
self.port = self.sshclient.port
|
||||
if kwargs:
|
||||
LOG.warning("Satori RemoteClient received unrecognized "
|
||||
"keyword arguments: %s", kwargs.keys())
|
||||
|
||||
self._client = ssh.connect(
|
||||
address, password=password, username=username,
|
||||
private_key=private_key, key_filename=key_filename, port=port,
|
||||
timeout=timeout, gateway=gateway, options=options,
|
||||
interactive=interactive)
|
||||
self.host = self._client.host
|
||||
self.port = self._client.port
|
||||
|
||||
@property
|
||||
def platform_info(self):
|
||||
"""Return distro, version, architecture."""
|
||||
return self.sshclient.platform_info
|
||||
return self._client.platform_info
|
||||
|
||||
def execute(self, command, wd=None, with_exit_code=None):
|
||||
def connect(self):
|
||||
"""Connect to the remote host."""
|
||||
return self._client.connect()
|
||||
|
||||
def test_connection(self):
|
||||
"""Test the connection to the remote host."""
|
||||
return self._client.test_connection()
|
||||
|
||||
def execute(self, command, **kwargs):
|
||||
"""Execute given command over ssh."""
|
||||
return self.sshclient.remote_execute(
|
||||
command, wd=wd, with_exit_code=with_exit_code)
|
||||
return self._client.remote_execute(command, **kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Close the connection to the remote host."""
|
||||
return self._client.close()
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
"""SSH Module for connecting to and automating remote commands.
|
||||
|
||||
Supports proxying, as in `ssh -A`
|
||||
Supports proxying through an ssh tunnel ('gateway' keyword argument.)
|
||||
|
||||
To control the behavior of the SSH client, use the specific connect_with_*
|
||||
calls. The .connect() call behaves like the ssh command and attempts a number
|
||||
|
@ -31,7 +31,6 @@ import getpass
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import paramiko
|
||||
|
@ -102,7 +101,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
# pylint: disable=R0913
|
||||
def __init__(self, host, password=None, username="root",
|
||||
private_key=None, key_filename=None, port=22,
|
||||
timeout=20, proxy=None, options=None, interactive=False):
|
||||
timeout=20, gateway=None, options=None, interactive=False):
|
||||
"""Create an instance of the SSH class.
|
||||
|
||||
:param str host: The ip address or host name of the server
|
||||
|
@ -116,7 +115,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
:param port: tcp/ip port to use (defaults to 22)
|
||||
:param float timeout: an optional timeout (in seconds) for the
|
||||
TCP connection
|
||||
:param socket proxy: an existing SSH instance to use
|
||||
:param socket gateway: an existing SSH instance to use
|
||||
for proxying
|
||||
:param dict options: A dictionary used to set ssh options
|
||||
(when proxying).
|
||||
|
@ -137,13 +136,13 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
self.timeout = timeout
|
||||
self._platform_info = None
|
||||
self.options = options or {}
|
||||
self.proxy = proxy
|
||||
self.gateway = gateway
|
||||
self.sock = None
|
||||
self.interactive = interactive
|
||||
|
||||
if self.proxy:
|
||||
if not isinstance(self.proxy, SSH):
|
||||
raise TypeError("'proxy' must be a satori.ssh.SSH instance. "
|
||||
if self.gateway:
|
||||
if not isinstance(self.gateway, SSH):
|
||||
raise TypeError("'gateway' must be a satori.ssh.SSH instance. "
|
||||
"( instances of this type are returned by "
|
||||
"satori.ssh.connect() )")
|
||||
|
||||
|
@ -224,13 +223,16 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
"""Set up client and connect to target."""
|
||||
self.load_system_host_keys()
|
||||
|
||||
if self.proxy:
|
||||
# lazy load
|
||||
self.sock = self._get_proxy_socket(self.proxy)
|
||||
|
||||
if self.options.get('StrictHostKeyChecking') in (False, "no"):
|
||||
self.set_missing_host_key_policy(AcceptMissingHostKey())
|
||||
|
||||
if self.gateway:
|
||||
# lazy load
|
||||
if not self.gateway.get_transport():
|
||||
self.gateway.connect()
|
||||
self.sock = self.gateway.get_transport().open_channel(
|
||||
'direct-tcpip', (self.host, self.port), ('', 0))
|
||||
|
||||
return super(SSH, self).connect(
|
||||
self.host,
|
||||
timeout=kwargs.pop('timeout', self.timeout),
|
||||
|
@ -250,6 +252,11 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
- username/password (will prompt if the password is not supplied and
|
||||
interactive is true)
|
||||
"""
|
||||
# idempotency
|
||||
if self.get_transport():
|
||||
if self.get_transport().is_active():
|
||||
return
|
||||
|
||||
if self.private_key:
|
||||
try:
|
||||
return self.connect_with_key()
|
||||
|
@ -305,6 +312,15 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the connection to the remote host.
|
||||
|
||||
If an ssh tunnel is being used, close that first.
|
||||
"""
|
||||
if self.gateway:
|
||||
self.gateway.close()
|
||||
return super(SSH, self).close()
|
||||
|
||||
def _handle_tty_required(self, results, get_pty):
|
||||
"""Determine whether the result implies a tty request."""
|
||||
if any(m in str(k) for m in TTY_REQUIRED for k in results.values()):
|
||||
|
@ -417,60 +433,6 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
|
|||
finally:
|
||||
self.close()
|
||||
|
||||
def _get_proxy_socket(self, proxy):
|
||||
"""Return a wrapped subprocess running ProxyCommand-driven programs.
|
||||
|
||||
Create a new CommandProxy instance.
|
||||
Can be created from an existing SSH instance.
|
||||
For proxy clients, please specify a private key filename.
|
||||
|
||||
To use an ssh proxy, you must use an SSH Key,
|
||||
since a ProxyCommand cannot be passed a password.
|
||||
"""
|
||||
if proxy.password:
|
||||
LOG.warning("Proxying through a client which is authorized by "
|
||||
"a password is not currently implemented. Please "
|
||||
"use an ssh key.")
|
||||
|
||||
proxy.load_system_host_keys()
|
||||
if proxy.options.get('StrictHostKeyChecking') in (False, "no"):
|
||||
proxy.set_missing_host_key_policy(AcceptMissingHostKey())
|
||||
|
||||
if proxy.private_key and not proxy.key_filename:
|
||||
tempkeyfile = tempfile.NamedTemporaryFile(
|
||||
mode='w+', prefix=TEMPFILE_PREFIX,
|
||||
dir=os.path.expanduser('~/'), delete=True)
|
||||
tempkeyfile.write(proxy.private_key)
|
||||
proxy.key_filename = tempkeyfile.name
|
||||
|
||||
pxd = {
|
||||
'bastion': proxy.host,
|
||||
'user': proxy.username,
|
||||
'port': '-p %s' % proxy.port,
|
||||
'options': ('-o ConnectTimeout=%s ' % proxy.timeout),
|
||||
'target_host': self.host,
|
||||
'target_port': self.port,
|
||||
}
|
||||
proxycommand = "ssh {options} -A {user}@{bastion} "
|
||||
|
||||
if proxy.key_filename:
|
||||
proxy.key_filename = os.path.expanduser(proxy.key_filename)
|
||||
proxy.key_filename = os.path.abspath(proxy.key_filename)
|
||||
pxd.update({'identity': '-i %s' % proxy.key_filename})
|
||||
proxycommand += "{identity} "
|
||||
|
||||
if proxy.options:
|
||||
for key, val in sorted(proxy.options.items()):
|
||||
if isinstance(val, bool):
|
||||
# turns booleans into `ssh -o` compat "yes" or "no"
|
||||
if val is True:
|
||||
val = "yes"
|
||||
if val is False:
|
||||
val = "no"
|
||||
pxd['options'] += '-o %s=%s ' % (key, val)
|
||||
|
||||
proxycommand += "nc {target_host} {target_port}"
|
||||
return paramiko.ProxyCommand(proxycommand.format(**pxd))
|
||||
|
||||
# Share SSH.__init__'s docstring
|
||||
connect.__doc__ = SSH.__init__.__doc__
|
||||
|
|
|
@ -125,7 +125,7 @@ class TestRemoteShell(TestBashModule):
|
|||
def test_execute(self):
|
||||
self.remoteshell.execute(self.testrun.command)
|
||||
self.mock_execute.assert_called_once_with(
|
||||
self.testrun.command, wd=None, with_exit_code=None)
|
||||
self.testrun.command)
|
||||
|
||||
def test_execute_resultdict(self):
|
||||
resultdict = self.remoteshell.execute(self.testrun.command)
|
||||
|
|
|
@ -550,54 +550,6 @@ class TestTestConnection(SSHTestBase):
|
|||
'ssh://%s@%s:%d is up.', 'test-user', self.host, 22)
|
||||
|
||||
|
||||
class TestGetProxySocket(SSHTestBase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestGetProxySocket, self).setUp()
|
||||
self.proxy_patcher = mock.patch.object(paramiko, "ProxyCommand")
|
||||
self.proxy_patcher.start()
|
||||
self.proxy = ssh.SSH('proxy.address', username='proxy-user')
|
||||
self.client = ssh.SSH('123.546.789.0', username='client-user')
|
||||
self.mutable = [True]
|
||||
|
||||
def tearDown(self):
|
||||
self.proxy_patcher.stop()
|
||||
super(TestGetProxySocket, self).tearDown()
|
||||
|
||||
def test_get_proxy_socket(self):
|
||||
self.client._get_proxy_socket(self.proxy)
|
||||
paramiko.ProxyCommand.assert_called_once_with(
|
||||
'ssh -o ConnectTimeout=20 '
|
||||
' -A proxy-user@proxy.address nc 123.546.789.0 22')
|
||||
|
||||
def test_get_proxy_socket_private_key(self):
|
||||
self.proxy.private_key = self.rsakey
|
||||
self.client._get_proxy_socket(self.proxy)
|
||||
self.assertTrue(paramiko.ProxyCommand.called)
|
||||
|
||||
def tempfile_spotted(self):
|
||||
home = os.path.expanduser('~/')
|
||||
filist = [k for k in os.listdir(home)
|
||||
if k.startswith(ssh.TEMPFILE_PREFIX)]
|
||||
while all(self.mutable):
|
||||
new = [k for k in os.listdir(home)
|
||||
if k.startswith(ssh.TEMPFILE_PREFIX)]
|
||||
if len(new) > len(filist):
|
||||
return set(new).difference(set(filist)).pop()
|
||||
return False
|
||||
|
||||
def test_get_proxy_file_unseen(self):
|
||||
|
||||
self.proxy.key_filename = "~/not/a/real/path"
|
||||
expanded_path = os.path.expanduser(self.proxy.key_filename)
|
||||
self.client._get_proxy_socket(self.proxy)
|
||||
|
||||
paramiko.ProxyCommand.assert_called_once_with(
|
||||
'ssh -o ConnectTimeout=20 '
|
||||
' -A proxy-user@proxy.address -i %s '
|
||||
'nc 123.546.789.0 22' % expanded_path)
|
||||
|
||||
|
||||
class TestRemoteExecute(SSHTestBase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -685,160 +637,29 @@ class TestRemoteExecute(SSHTestBase):
|
|||
self.assertEqual(expected_result, self.client.platform_info)
|
||||
|
||||
|
||||
class TestRemoteExecuteWithProxy(SSHTestBase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRemoteExecuteWithProxy, self).setUp()
|
||||
self.proxy_patcher = mock.patch.object(paramiko, "ProxyCommand")
|
||||
self.proxy_patcher.start()
|
||||
proxy = ssh.SSH('proxy-address', username='proxy-user')
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=proxy)
|
||||
self.client._handle_tty_required = mock.Mock(return_value=False)
|
||||
self.client._handle_password_prompt = mock.Mock(return_value=False)
|
||||
|
||||
self.mock_chan = mock.MagicMock()
|
||||
mock_transport = mock.MagicMock()
|
||||
mock_transport.open_session.return_value = self.mock_chan
|
||||
self.client.get_transport = mock.MagicMock(
|
||||
return_value=mock_transport)
|
||||
|
||||
self.mock_chan.exec_command = mock.MagicMock()
|
||||
self.mock_chan.makefile.side_effect = self.mkfile
|
||||
self.mock_chan.makefile_stderr.side_effect = (
|
||||
lambda x: self.mkfile(x, err=True))
|
||||
|
||||
self.example_command = 'echo hello'
|
||||
self.example_output = 'hello'
|
||||
|
||||
def tearDown(self):
|
||||
self.proxy_patcher.stop()
|
||||
super(TestRemoteExecuteWithProxy, self).tearDown()
|
||||
|
||||
def mkfile(self, arg, err=False):
|
||||
if arg == 'rb' and not err:
|
||||
stdout = mock.MagicMock()
|
||||
stdout.read.return_value = 'hello\n'
|
||||
return stdout
|
||||
if arg == 'wb' and not err:
|
||||
stdin = mock.MagicMock()
|
||||
stdin.read.return_value = ''
|
||||
return stdin
|
||||
if err is True:
|
||||
stderr = mock.MagicMock()
|
||||
stderr.read.return_value = ''
|
||||
return stderr
|
||||
|
||||
def test_remote_execute_with_proxy(self):
|
||||
commands = ['echo hello', 'uname -a', 'rev ~/.bash*']
|
||||
for cmd in commands:
|
||||
self.client.remote_execute(cmd)
|
||||
self.mock_chan.exec_command.assert_called_with(cmd)
|
||||
|
||||
|
||||
class TestProxy(SSHTestBase):
|
||||
|
||||
"""self.client in this class is instantiated with a proxy."""
|
||||
|
||||
def setUp(self):
|
||||
super(TestProxy, self).setUp()
|
||||
self.proxy_patcher = mock.patch.object(paramiko, "ProxyCommand")
|
||||
self.proxy_patcher.start()
|
||||
self.proxy = ssh.SSH('proxy.address', username='proxy-user')
|
||||
self.gateway = ssh.SSH('gateway.address', username='gateway-user')
|
||||
|
||||
def tearDown(self):
|
||||
self.proxy_patcher.stop()
|
||||
super(TestProxy, self).tearDown()
|
||||
|
||||
def test_test_connection(self):
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
self.assertTrue(self.client.test_connection())
|
||||
|
||||
def test_test_connection_valid_key(self):
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
self.client.private_key = self.dsakey
|
||||
self.assertTrue(self.client.test_connection())
|
||||
|
||||
def test_test_connection_fail_other(self):
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
'123.456.789.0', username='client-user', gateway=self.gateway)
|
||||
self.mock_connect.side_effect = Exception
|
||||
self.assertFalse(self.client.test_connection())
|
||||
|
||||
def test_connect_with_proxy_socket(self):
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
self.client.connect()
|
||||
paramiko.ProxyCommand.assert_called_with(
|
||||
'ssh -o ConnectTimeout=20 '
|
||||
' -A proxy-user@proxy.address nc 123.456.789.0 22')
|
||||
|
||||
def test_connect_with_proxy_socket_and_options(self):
|
||||
options = {
|
||||
'BatchMode': 'yes',
|
||||
'CheckHostIP': 'yes',
|
||||
'ChallengeResponseAuthentication': 'no',
|
||||
'Ciphers': 'cast128-cbc,aes256-ctr',
|
||||
}
|
||||
self.proxy = ssh.SSH(
|
||||
'proxy.address', username='proxy-user', options=options)
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
self.client.connect()
|
||||
paramiko.ProxyCommand.assert_called_with(
|
||||
'ssh -o ConnectTimeout=20 -o BatchMode=yes '
|
||||
'-o ChallengeResponseAuthentication=no '
|
||||
'-o CheckHostIP=yes -o Ciphers=cast128-cbc,aes256-ctr '
|
||||
'-A proxy-user@proxy.address nc 123.456.789.0 22')
|
||||
|
||||
def test_connect_with_proxy_socket_and_bool_options(self):
|
||||
"""Should create the same ProxyCommand call as bool_options."""
|
||||
options = {
|
||||
'BatchMode': True,
|
||||
'CheckHostIP': True,
|
||||
'ChallengeResponseAuthentication': False,
|
||||
'Ciphers': 'cast128-cbc,aes256-ctr',
|
||||
}
|
||||
self.proxy = ssh.SSH(
|
||||
'proxy.address', username='proxy-user', options=options)
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
self.client.connect()
|
||||
paramiko.ProxyCommand.assert_called_with(
|
||||
'ssh -o ConnectTimeout=20 -o BatchMode=yes '
|
||||
'-o ChallengeResponseAuthentication=no '
|
||||
'-o CheckHostIP=yes -o Ciphers=cast128-cbc,aes256-ctr '
|
||||
'-A proxy-user@proxy.address nc 123.456.789.0 22')
|
||||
|
||||
def test_connect_with_proxy_no_host_raises(self):
|
||||
proxy = {'this': 'is not a proxy'}
|
||||
gateway = {'this': 'is not a gateway'}
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
ssh.SSH, ('123.456.789.0',),
|
||||
username='client-user', proxy=proxy)
|
||||
|
||||
def test_connect_with_proxy_socket_private_key(self):
|
||||
"""Test when proxy inits with private key string.
|
||||
|
||||
Overrides self.client and self.proxy from setUp.
|
||||
"""
|
||||
self.proxy.private_key = self.rsakey
|
||||
self.client = ssh.SSH(
|
||||
'123.456.789.0', username='client-user', proxy=self.proxy)
|
||||
self.client.connect()
|
||||
self.assertTrue(paramiko.ProxyCommand.called)
|
||||
ident = ' -i '
|
||||
begin = paramiko.ProxyCommand.call_args[0][0].index(ident)
|
||||
end = paramiko.ProxyCommand.call_args[0][0].find(
|
||||
' ', begin + len(ident))
|
||||
tempfilepath = (paramiko.ProxyCommand.
|
||||
call_args[0][0][begin + len(ident):end])
|
||||
paramiko.ProxyCommand.assert_called_with(
|
||||
'ssh -o ConnectTimeout=20 '
|
||||
' -A proxy-user@proxy.address '
|
||||
'-i %s nc 123.456.789.0 22' % tempfilepath)
|
||||
username='client-user', gateway=gateway)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue