Add support for remote unix socket forwarding

Change-Id: I04ca8083a9af5974d659698018f8a285d6c1173d
This commit is contained in:
Federico Ressi 2019-10-29 11:51:13 +01:00
parent 428b96e2da
commit 205e130f1b
5 changed files with 230 additions and 62 deletions

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from urllib3 import connection
from urllib3 import connectionpool
import tobiko
from tobiko.shell import ssh
@ -69,23 +68,11 @@ class HTTPConnectionPool(connectionpool.HTTPConnectionPool):
ConnectionCls = HTTPConnection
forwarder = None
ssh_client = None
def __init__(self, host, port, ssh_client=None, **kwargs):
if ssh_client is None:
ssh_client = ssh.ssh_proxy_client() or False
self.ssh_client = ssh_client
if ssh_client:
self.forwarder = forwarder = ssh.SSHTunnelForwarderFixture(
ssh_client=ssh_client)
forward_address = forwarder.put_forwarding(host, port)
tobiko.setup_fixture(forwarder)
kwargs['forward_address'] = forward_address
super(HTTPConnectionPool, self).__init__(host=host,
port=port,
**kwargs)
forward_address = ssh.get_forward_port_address(address=(host, port),
ssh_client=ssh_client)
super(HTTPConnectionPool, self).__init__(
host=host, port=port, forward_address=forward_address, **kwargs)
class HTTPSConnectionPool(HTTPConnectionPool,

View File

@ -31,4 +31,7 @@ ssh_proxy_client = _client.ssh_proxy_client
SSHConnectFailure = _client.SSHConnectFailure
gather_ssh_connect_parameters = _client.gather_ssh_connect_parameters
get_port_forward_url = _forward.get_forward_url
get_forward_port_address = _forward.get_forward_port_address
SSHTunnelForwarderFixture = _forward.SSHTunnelForwarderFixture
SSHTunnelForwarder = _forward.SSHTunnelForwarder

View File

@ -71,7 +71,7 @@ def positive_int(value):
return value
def key_filename(value):
def get_key_filename(value):
if isinstance(value, six.string_types):
value = [value]
return [os.path.realpath(os.path.expanduser(v)) for v in value]
@ -96,7 +96,7 @@ SSH_CONNECT_PARAMETERS = {
#: The filename, or list of filenames, of optional private key(s) and/or
#: certs to try for authentication
'key_filename': key_filename,
'key_filename': get_key_filename,
#: An optional timeout (in seconds) for the TCP connect
'timeout': positive_float,
@ -329,6 +329,34 @@ class SSHClientFixture(tobiko.SharedFixture):
def connect(self):
return tobiko.setup_fixture(self).client
def get_ssh_command(self, host=None, username=None, port=None,
command=None, config_files=None, host_config=None,
proxy_command=None, key_filename=None, **options):
connect_parameters = self.setup_connect_parameters()
host = host or connect_parameters.get('hostname')
username = username or connect_parameters.get('username')
port = port or connect_parameters.get('port')
config_files = config_files or connect_parameters.get('config_files')
if not host_config:
_host_config = self.setup_host_config()
if hasattr(_host_config, 'host_config'):
_host_config = host_config
key_filename = key_filename or connect_parameters.get('key_filename')
proxy_command = (proxy_command or
connect_parameters.get('proxy_command'))
if not proxy_command and self.proxy_client:
proxy_command = self.proxy_client.get_ssh_command()
return _command.ssh_command(host=host,
username=username,
port=port,
command=command,
config_files=config_files,
host_config=host_config,
proxy_command=proxy_command,
key_filename=key_filename,
**options)
UNDEFINED_CLIENT = 'UNDEFINED_CLIENT'
@ -425,7 +453,7 @@ def ssh_connect(hostname, username=None, port=None, connection_interval=None,
return client, proxy_sock
def ssh_proxy_sock(hostname, port=None, command=None, client=None,
def ssh_proxy_sock(hostname=None, port=None, command=None, client=None,
source_address=None):
if not command:
if client:
@ -441,7 +469,8 @@ def ssh_proxy_sock(hostname, port=None, command=None, client=None,
# Apply connect parameters to proxy command
if not isinstance(command, six.string_types):
command = subprocess.list2cmdline(command)
command = command.format(hostname=hostname, port=(port or 22))
if hostname:
command = command.format(hostname=hostname, port=(port or 22))
if client:
if isinstance(client, SSHClientFixture):
# Connect to proxy server

View File

@ -33,7 +33,7 @@ def ssh_login(hostname, username=None, port=None):
def ssh_command(host, username=None, port=None, command=None,
config_files=None, host_config=None, proxy_command=None,
**options):
key_filename=None, **options):
host_config = host_config or _config.ssh_host_config(
host=host, config_files=config_files)
@ -50,11 +50,15 @@ def ssh_command(host, username=None, port=None, command=None,
port = port or host_config.port
if port:
command += ['-p', port]
command += ['-p', str(port)]
if key_filename:
command += ['-i', key_filename]
if proxy_command:
if not isinstance(proxy_command, six.string_types):
proxy_command = subprocess.list2cmdline(proxy_command)
proxy_command = subprocess.list2cmdline([str(a)
for a in proxy_command])
options['ProxyCommand'] = proxy_command
for name, value in host_config.host_config.items():

View File

@ -19,17 +19,78 @@ import collections
import contextlib
import socket
import netaddr
from oslo_log import log
import six
from six.moves import urllib
import sshtunnel
import tobiko
from tobiko.shell.ssh import _client
LOG = log.getLogger(__name__)
def get_forward_port_address(address, ssh_client=None, manager=None):
if ssh_client is None:
ssh_client = _client.ssh_proxy_client()
manager = manager or DEFAULT_SSH_PORT_FORWARD_MANAGER
return manager.get_forward_port_address(address, ssh_client=ssh_client)
def get_forward_url(url, ssh_client=None, manager=None):
url = parse_url(url)
if ssh_client is None:
ssh_client = _client.ssh_proxy_client()
manager = manager or DEFAULT_SSH_PORT_FORWARD_MANAGER
address = binding_address(url)
forward_address = get_forward_port_address(address, ssh_client=ssh_client,
manager=manager)
return binding_url(forward_address)
class SSHPortForwardManager(object):
def __init__(self):
self.forward_addresses = {}
self.forwarders = {}
def get_forward_port_address(self, address, ssh_client):
try:
return self.forward_addresses[address, ssh_client]
except KeyError:
pass
forwarder = self.get_forwarder(address, ssh_client=ssh_client)
if forwarder:
forward_address = forwarder.get_forwarding(address)
else:
forward_address = address
self.forward_addresses[address, ssh_client] = forward_address
return forward_address
def get_forwarder(self, address, ssh_client):
try:
return self.forwarders[address, ssh_client]
except KeyError:
pass
if ssh_client:
tobiko.check_valid_type(ssh_client, _client.SSHClientFixture)
forwarder = SSHTunnelForwarderFixture(ssh_client=ssh_client)
forwarder.put_forwarding(address)
tobiko.setup_fixture(forwarder)
else:
forwarder = None
self.forwarders[address, ssh_client] = forwarder
return forwarder
DEFAULT_SSH_PORT_FORWARD_MANAGER = SSHPortForwardManager()
class SSHTunnelForwarderFixture(tobiko.SharedFixture):
forwarder = None
@ -39,14 +100,17 @@ class SSHTunnelForwarderFixture(tobiko.SharedFixture):
self.ssh_client = ssh_client
self._forwarding = collections.OrderedDict()
def put_forwarding(self, remote_address, remote_port=None,
local_address=None, local_port=None):
remote = AddressPair.create(remote_address, remote_port)
local = AddressPair.create(local_address, local_port)
def put_forwarding(self, remote, local=None):
if not local:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
with contextlib.closing(sock):
sock.bind(('127.0.0.1', 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
hostname, port = sock.getsockname()
local = hostname, port
return self._forwarding.setdefault(remote, local)
def get_forwarding(self, remote_address, remote_port=None):
remote = AddressPair.create(remote_address, remote_port)
def get_forwarding(self, remote):
return self._forwarding.get(remote)
def setup_fixture(self):
@ -64,7 +128,6 @@ class SSHTunnelForwarderFixture(tobiko.SharedFixture):
self.addCleanup(self.cleanup_forwarder)
forwarder.start()
self.ssh_client.addCleanup(self)
return forwarder
def cleanup_forwarder(self):
@ -74,6 +137,48 @@ class SSHTunnelForwarderFixture(tobiko.SharedFixture):
forwarder.stop()
# pylint: disable=protected-access
SSHForwardHandler = sshtunnel._ForwardHandler
# pylint: enable=protected-access
class SSHUnixForwardHandler(SSHForwardHandler):
transport = None
def handle(self):
uid = sshtunnel.get_connection_id()
self.info = '#{0} <-- {1}'.format(uid, self.client_address or
self.server.local_address)
remote_address = self.remote_address
assert isinstance(remote_address, six.string_types)
command = 'sudo nc -U "{}"'.format(remote_address)
chan = self.transport.open_session()
chan.exec_command(command)
self.logger.log(sshtunnel.TRACE_LEVEL,
'{0} connected'.format(self.info))
try:
self._redirect(chan)
except socket.error:
# Sometimes a RST is sent and a socket error is raised, treat this
# exception. It was seen that a 3way FIN is processed later on, so
# no need to make an ordered close of the connection here or raise
# the exception beyond this point...
self.logger.log(sshtunnel.TRACE_LEVEL,
'{0} sending RST'.format(self.info))
except Exception as e:
self.logger.log(sshtunnel.TRACE_LEVEL,
'{0} error: {1}'.format(self.info, repr(e)))
finally:
chan.close()
self.request.close()
self.logger.log(sshtunnel.TRACE_LEVEL,
'{0} connection closed.'.format(self.info))
class SSHTunnelForwarder(sshtunnel.SSHTunnelForwarder):
daemon_forward_servers = True #: flag tunnel threads in daemon mode
@ -125,36 +230,76 @@ class SSHTunnelForwarder(sshtunnel.SSHTunnelForwarder):
assert not self.is_active
super(SSHTunnelForwarder, self)._stop_transport()
@staticmethod
def _get_binds(bind_address, bind_addresses, is_remote=False):
addr_kind = 'remote' if is_remote else 'local'
class AddressPair(collections.namedtuple('AddressPair', ['host', 'port'])):
@classmethod
def create(cls, address=None, port=None):
port = port and int(port) or None
address = address or '127.0.0.1'
if isinstance(address, netaddr.IPAddress):
if port is None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
with contextlib.closing(sock):
sock.bind((str(address), 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return cls(*sock.getsockname())
if not bind_address and not bind_addresses:
if is_remote:
raise ValueError("No {0} bind addresses specified. Use "
"'{0}_bind_address' or '{0}_bind_addresses'"
" argument".format(addr_kind))
else:
return cls(str(address), port)
elif isinstance(address, urllib.parse.ParseResult):
return cls(address.hostname or address.path, address.port or None)
elif isinstance(address, six.string_types):
try:
return cls.create(netaddr.IPAddress(address), port)
except ValueError:
pass
if port is None:
return cls.create(urllib.parse.urlparse(address))
else:
return cls(address.lower(), port)
elif isinstance(address, collections.Sequence):
return cls.create(*address)
return []
elif bind_address and bind_addresses:
raise ValueError("You can't use both '{0}_bind_address' and "
"'{0}_bind_addresses' arguments. Use one of "
"them.".format(addr_kind))
if bind_address:
bind_addresses = [bind_address]
if not is_remote:
# Add random port if missing in local bind
for (i, local_bind) in enumerate(bind_addresses):
if isinstance(local_bind, tuple) and len(local_bind) == 1:
bind_addresses[i] = (local_bind[0], 0)
# check_addresses(bind_addresses, is_remote)
return bind_addresses
message = ("Invalid address pair parameters: "
"address={!r}, port={!r}").format(address, port)
raise TypeError(message)
def _make_ssh_forward_handler_class(self, remote_address_):
"""
Make SSH Handler class
"""
if isinstance(remote_address_, tuple):
return super(
SSHTunnelForwarder, self)._make_ssh_forward_handler_class(
remote_address_)
class Handler(SSHUnixForwardHandler):
transport = self._transport
remote_address = remote_address_
logger = self.logger
return Handler
def parse_url(url):
if isinstance(url, urllib.parse.ParseResult):
return url
else:
return urllib.parse.urlparse(url)
def binding_address(url):
url = parse_url(url)
if url.netloc:
# Retains only scheme and netloc
return (url.hostname, url.port)
elif url.path:
# Retains only scheme and path
return url.path
raise ValueError('Invalid URL: {!r}'.format(url))
def binding_url(address):
if isinstance(address, tuple):
try:
hostname, = address
except ValueError:
hostname, port = address
return 'tcp://{hostname}:{port}'.format(hostname=hostname,
port=port)
elif isinstance(address, six.string_types):
return 'unix://{path}'.format(path=address)
raise TypeError('Invalid address type: {!r}'.format(address))