Add support for remote unix socket forwarding
Change-Id: I04ca8083a9af5974d659698018f8a285d6c1173d
This commit is contained in:
parent
428b96e2da
commit
205e130f1b
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from urllib3 import connection
|
||||
from urllib3 import connectionpool
|
||||
|
||||
import tobiko
|
||||
from tobiko.shell import ssh
|
||||
|
||||
|
||||
@ -69,23 +68,11 @@ class HTTPConnectionPool(connectionpool.HTTPConnectionPool):
|
||||
|
||||
ConnectionCls = HTTPConnection
|
||||
|
||||
forwarder = None
|
||||
ssh_client = None
|
||||
|
||||
def __init__(self, host, port, ssh_client=None, **kwargs):
|
||||
if ssh_client is None:
|
||||
ssh_client = ssh.ssh_proxy_client() or False
|
||||
self.ssh_client = ssh_client
|
||||
if ssh_client:
|
||||
self.forwarder = forwarder = ssh.SSHTunnelForwarderFixture(
|
||||
ssh_client=ssh_client)
|
||||
forward_address = forwarder.put_forwarding(host, port)
|
||||
tobiko.setup_fixture(forwarder)
|
||||
kwargs['forward_address'] = forward_address
|
||||
|
||||
super(HTTPConnectionPool, self).__init__(host=host,
|
||||
port=port,
|
||||
**kwargs)
|
||||
forward_address = ssh.get_forward_port_address(address=(host, port),
|
||||
ssh_client=ssh_client)
|
||||
super(HTTPConnectionPool, self).__init__(
|
||||
host=host, port=port, forward_address=forward_address, **kwargs)
|
||||
|
||||
|
||||
class HTTPSConnectionPool(HTTPConnectionPool,
|
||||
|
@ -31,4 +31,7 @@ ssh_proxy_client = _client.ssh_proxy_client
|
||||
SSHConnectFailure = _client.SSHConnectFailure
|
||||
gather_ssh_connect_parameters = _client.gather_ssh_connect_parameters
|
||||
|
||||
get_port_forward_url = _forward.get_forward_url
|
||||
get_forward_port_address = _forward.get_forward_port_address
|
||||
SSHTunnelForwarderFixture = _forward.SSHTunnelForwarderFixture
|
||||
SSHTunnelForwarder = _forward.SSHTunnelForwarder
|
||||
|
@ -71,7 +71,7 @@ def positive_int(value):
|
||||
return value
|
||||
|
||||
|
||||
def key_filename(value):
|
||||
def get_key_filename(value):
|
||||
if isinstance(value, six.string_types):
|
||||
value = [value]
|
||||
return [os.path.realpath(os.path.expanduser(v)) for v in value]
|
||||
@ -96,7 +96,7 @@ SSH_CONNECT_PARAMETERS = {
|
||||
|
||||
#: The filename, or list of filenames, of optional private key(s) and/or
|
||||
#: certs to try for authentication
|
||||
'key_filename': key_filename,
|
||||
'key_filename': get_key_filename,
|
||||
|
||||
#: An optional timeout (in seconds) for the TCP connect
|
||||
'timeout': positive_float,
|
||||
@ -329,6 +329,34 @@ class SSHClientFixture(tobiko.SharedFixture):
|
||||
def connect(self):
|
||||
return tobiko.setup_fixture(self).client
|
||||
|
||||
def get_ssh_command(self, host=None, username=None, port=None,
|
||||
command=None, config_files=None, host_config=None,
|
||||
proxy_command=None, key_filename=None, **options):
|
||||
connect_parameters = self.setup_connect_parameters()
|
||||
host = host or connect_parameters.get('hostname')
|
||||
username = username or connect_parameters.get('username')
|
||||
port = port or connect_parameters.get('port')
|
||||
config_files = config_files or connect_parameters.get('config_files')
|
||||
if not host_config:
|
||||
_host_config = self.setup_host_config()
|
||||
if hasattr(_host_config, 'host_config'):
|
||||
_host_config = host_config
|
||||
key_filename = key_filename or connect_parameters.get('key_filename')
|
||||
proxy_command = (proxy_command or
|
||||
connect_parameters.get('proxy_command'))
|
||||
if not proxy_command and self.proxy_client:
|
||||
proxy_command = self.proxy_client.get_ssh_command()
|
||||
|
||||
return _command.ssh_command(host=host,
|
||||
username=username,
|
||||
port=port,
|
||||
command=command,
|
||||
config_files=config_files,
|
||||
host_config=host_config,
|
||||
proxy_command=proxy_command,
|
||||
key_filename=key_filename,
|
||||
**options)
|
||||
|
||||
|
||||
UNDEFINED_CLIENT = 'UNDEFINED_CLIENT'
|
||||
|
||||
@ -425,7 +453,7 @@ def ssh_connect(hostname, username=None, port=None, connection_interval=None,
|
||||
return client, proxy_sock
|
||||
|
||||
|
||||
def ssh_proxy_sock(hostname, port=None, command=None, client=None,
|
||||
def ssh_proxy_sock(hostname=None, port=None, command=None, client=None,
|
||||
source_address=None):
|
||||
if not command:
|
||||
if client:
|
||||
@ -441,7 +469,8 @@ def ssh_proxy_sock(hostname, port=None, command=None, client=None,
|
||||
# Apply connect parameters to proxy command
|
||||
if not isinstance(command, six.string_types):
|
||||
command = subprocess.list2cmdline(command)
|
||||
command = command.format(hostname=hostname, port=(port or 22))
|
||||
if hostname:
|
||||
command = command.format(hostname=hostname, port=(port or 22))
|
||||
if client:
|
||||
if isinstance(client, SSHClientFixture):
|
||||
# Connect to proxy server
|
||||
|
@ -33,7 +33,7 @@ def ssh_login(hostname, username=None, port=None):
|
||||
|
||||
def ssh_command(host, username=None, port=None, command=None,
|
||||
config_files=None, host_config=None, proxy_command=None,
|
||||
**options):
|
||||
key_filename=None, **options):
|
||||
host_config = host_config or _config.ssh_host_config(
|
||||
host=host, config_files=config_files)
|
||||
|
||||
@ -50,11 +50,15 @@ def ssh_command(host, username=None, port=None, command=None,
|
||||
|
||||
port = port or host_config.port
|
||||
if port:
|
||||
command += ['-p', port]
|
||||
command += ['-p', str(port)]
|
||||
|
||||
if key_filename:
|
||||
command += ['-i', key_filename]
|
||||
|
||||
if proxy_command:
|
||||
if not isinstance(proxy_command, six.string_types):
|
||||
proxy_command = subprocess.list2cmdline(proxy_command)
|
||||
proxy_command = subprocess.list2cmdline([str(a)
|
||||
for a in proxy_command])
|
||||
options['ProxyCommand'] = proxy_command
|
||||
|
||||
for name, value in host_config.host_config.items():
|
||||
|
@ -19,17 +19,78 @@ import collections
|
||||
import contextlib
|
||||
import socket
|
||||
|
||||
import netaddr
|
||||
from oslo_log import log
|
||||
import six
|
||||
from six.moves import urllib
|
||||
import sshtunnel
|
||||
|
||||
import tobiko
|
||||
from tobiko.shell.ssh import _client
|
||||
|
||||
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
|
||||
def get_forward_port_address(address, ssh_client=None, manager=None):
|
||||
if ssh_client is None:
|
||||
ssh_client = _client.ssh_proxy_client()
|
||||
manager = manager or DEFAULT_SSH_PORT_FORWARD_MANAGER
|
||||
return manager.get_forward_port_address(address, ssh_client=ssh_client)
|
||||
|
||||
|
||||
def get_forward_url(url, ssh_client=None, manager=None):
|
||||
url = parse_url(url)
|
||||
if ssh_client is None:
|
||||
ssh_client = _client.ssh_proxy_client()
|
||||
manager = manager or DEFAULT_SSH_PORT_FORWARD_MANAGER
|
||||
address = binding_address(url)
|
||||
forward_address = get_forward_port_address(address, ssh_client=ssh_client,
|
||||
manager=manager)
|
||||
return binding_url(forward_address)
|
||||
|
||||
|
||||
class SSHPortForwardManager(object):
|
||||
|
||||
def __init__(self):
|
||||
self.forward_addresses = {}
|
||||
self.forwarders = {}
|
||||
|
||||
def get_forward_port_address(self, address, ssh_client):
|
||||
try:
|
||||
return self.forward_addresses[address, ssh_client]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
forwarder = self.get_forwarder(address, ssh_client=ssh_client)
|
||||
if forwarder:
|
||||
forward_address = forwarder.get_forwarding(address)
|
||||
else:
|
||||
forward_address = address
|
||||
|
||||
self.forward_addresses[address, ssh_client] = forward_address
|
||||
return forward_address
|
||||
|
||||
def get_forwarder(self, address, ssh_client):
|
||||
try:
|
||||
return self.forwarders[address, ssh_client]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if ssh_client:
|
||||
tobiko.check_valid_type(ssh_client, _client.SSHClientFixture)
|
||||
forwarder = SSHTunnelForwarderFixture(ssh_client=ssh_client)
|
||||
forwarder.put_forwarding(address)
|
||||
tobiko.setup_fixture(forwarder)
|
||||
else:
|
||||
forwarder = None
|
||||
|
||||
self.forwarders[address, ssh_client] = forwarder
|
||||
return forwarder
|
||||
|
||||
|
||||
DEFAULT_SSH_PORT_FORWARD_MANAGER = SSHPortForwardManager()
|
||||
|
||||
|
||||
class SSHTunnelForwarderFixture(tobiko.SharedFixture):
|
||||
|
||||
forwarder = None
|
||||
@ -39,14 +100,17 @@ class SSHTunnelForwarderFixture(tobiko.SharedFixture):
|
||||
self.ssh_client = ssh_client
|
||||
self._forwarding = collections.OrderedDict()
|
||||
|
||||
def put_forwarding(self, remote_address, remote_port=None,
|
||||
local_address=None, local_port=None):
|
||||
remote = AddressPair.create(remote_address, remote_port)
|
||||
local = AddressPair.create(local_address, local_port)
|
||||
def put_forwarding(self, remote, local=None):
|
||||
if not local:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
with contextlib.closing(sock):
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
hostname, port = sock.getsockname()
|
||||
local = hostname, port
|
||||
return self._forwarding.setdefault(remote, local)
|
||||
|
||||
def get_forwarding(self, remote_address, remote_port=None):
|
||||
remote = AddressPair.create(remote_address, remote_port)
|
||||
def get_forwarding(self, remote):
|
||||
return self._forwarding.get(remote)
|
||||
|
||||
def setup_fixture(self):
|
||||
@ -64,7 +128,6 @@ class SSHTunnelForwarderFixture(tobiko.SharedFixture):
|
||||
self.addCleanup(self.cleanup_forwarder)
|
||||
forwarder.start()
|
||||
self.ssh_client.addCleanup(self)
|
||||
|
||||
return forwarder
|
||||
|
||||
def cleanup_forwarder(self):
|
||||
@ -74,6 +137,48 @@ class SSHTunnelForwarderFixture(tobiko.SharedFixture):
|
||||
forwarder.stop()
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
SSHForwardHandler = sshtunnel._ForwardHandler
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class SSHUnixForwardHandler(SSHForwardHandler):
|
||||
|
||||
transport = None
|
||||
|
||||
def handle(self):
|
||||
uid = sshtunnel.get_connection_id()
|
||||
self.info = '#{0} <-- {1}'.format(uid, self.client_address or
|
||||
self.server.local_address)
|
||||
|
||||
remote_address = self.remote_address
|
||||
assert isinstance(remote_address, six.string_types)
|
||||
command = 'sudo nc -U "{}"'.format(remote_address)
|
||||
|
||||
chan = self.transport.open_session()
|
||||
chan.exec_command(command)
|
||||
|
||||
self.logger.log(sshtunnel.TRACE_LEVEL,
|
||||
'{0} connected'.format(self.info))
|
||||
try:
|
||||
self._redirect(chan)
|
||||
except socket.error:
|
||||
# Sometimes a RST is sent and a socket error is raised, treat this
|
||||
# exception. It was seen that a 3way FIN is processed later on, so
|
||||
# no need to make an ordered close of the connection here or raise
|
||||
# the exception beyond this point...
|
||||
self.logger.log(sshtunnel.TRACE_LEVEL,
|
||||
'{0} sending RST'.format(self.info))
|
||||
except Exception as e:
|
||||
self.logger.log(sshtunnel.TRACE_LEVEL,
|
||||
'{0} error: {1}'.format(self.info, repr(e)))
|
||||
finally:
|
||||
chan.close()
|
||||
self.request.close()
|
||||
self.logger.log(sshtunnel.TRACE_LEVEL,
|
||||
'{0} connection closed.'.format(self.info))
|
||||
|
||||
|
||||
class SSHTunnelForwarder(sshtunnel.SSHTunnelForwarder):
|
||||
|
||||
daemon_forward_servers = True #: flag tunnel threads in daemon mode
|
||||
@ -125,36 +230,76 @@ class SSHTunnelForwarder(sshtunnel.SSHTunnelForwarder):
|
||||
assert not self.is_active
|
||||
super(SSHTunnelForwarder, self)._stop_transport()
|
||||
|
||||
@staticmethod
|
||||
def _get_binds(bind_address, bind_addresses, is_remote=False):
|
||||
addr_kind = 'remote' if is_remote else 'local'
|
||||
|
||||
class AddressPair(collections.namedtuple('AddressPair', ['host', 'port'])):
|
||||
|
||||
@classmethod
|
||||
def create(cls, address=None, port=None):
|
||||
port = port and int(port) or None
|
||||
address = address or '127.0.0.1'
|
||||
if isinstance(address, netaddr.IPAddress):
|
||||
if port is None:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
with contextlib.closing(sock):
|
||||
sock.bind((str(address), 0))
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return cls(*sock.getsockname())
|
||||
if not bind_address and not bind_addresses:
|
||||
if is_remote:
|
||||
raise ValueError("No {0} bind addresses specified. Use "
|
||||
"'{0}_bind_address' or '{0}_bind_addresses'"
|
||||
" argument".format(addr_kind))
|
||||
else:
|
||||
return cls(str(address), port)
|
||||
elif isinstance(address, urllib.parse.ParseResult):
|
||||
return cls(address.hostname or address.path, address.port or None)
|
||||
elif isinstance(address, six.string_types):
|
||||
try:
|
||||
return cls.create(netaddr.IPAddress(address), port)
|
||||
except ValueError:
|
||||
pass
|
||||
if port is None:
|
||||
return cls.create(urllib.parse.urlparse(address))
|
||||
else:
|
||||
return cls(address.lower(), port)
|
||||
elif isinstance(address, collections.Sequence):
|
||||
return cls.create(*address)
|
||||
return []
|
||||
elif bind_address and bind_addresses:
|
||||
raise ValueError("You can't use both '{0}_bind_address' and "
|
||||
"'{0}_bind_addresses' arguments. Use one of "
|
||||
"them.".format(addr_kind))
|
||||
if bind_address:
|
||||
bind_addresses = [bind_address]
|
||||
if not is_remote:
|
||||
# Add random port if missing in local bind
|
||||
for (i, local_bind) in enumerate(bind_addresses):
|
||||
if isinstance(local_bind, tuple) and len(local_bind) == 1:
|
||||
bind_addresses[i] = (local_bind[0], 0)
|
||||
# check_addresses(bind_addresses, is_remote)
|
||||
return bind_addresses
|
||||
|
||||
message = ("Invalid address pair parameters: "
|
||||
"address={!r}, port={!r}").format(address, port)
|
||||
raise TypeError(message)
|
||||
def _make_ssh_forward_handler_class(self, remote_address_):
|
||||
"""
|
||||
Make SSH Handler class
|
||||
"""
|
||||
if isinstance(remote_address_, tuple):
|
||||
return super(
|
||||
SSHTunnelForwarder, self)._make_ssh_forward_handler_class(
|
||||
remote_address_)
|
||||
|
||||
class Handler(SSHUnixForwardHandler):
|
||||
transport = self._transport
|
||||
remote_address = remote_address_
|
||||
logger = self.logger
|
||||
return Handler
|
||||
|
||||
|
||||
def parse_url(url):
|
||||
if isinstance(url, urllib.parse.ParseResult):
|
||||
return url
|
||||
else:
|
||||
return urllib.parse.urlparse(url)
|
||||
|
||||
|
||||
def binding_address(url):
|
||||
url = parse_url(url)
|
||||
if url.netloc:
|
||||
# Retains only scheme and netloc
|
||||
return (url.hostname, url.port)
|
||||
elif url.path:
|
||||
# Retains only scheme and path
|
||||
return url.path
|
||||
|
||||
raise ValueError('Invalid URL: {!r}'.format(url))
|
||||
|
||||
|
||||
def binding_url(address):
|
||||
if isinstance(address, tuple):
|
||||
try:
|
||||
hostname, = address
|
||||
except ValueError:
|
||||
hostname, port = address
|
||||
return 'tcp://{hostname}:{port}'.format(hostname=hostname,
|
||||
port=port)
|
||||
|
||||
elif isinstance(address, six.string_types):
|
||||
return 'unix://{path}'.format(path=address)
|
||||
|
||||
raise TypeError('Invalid address type: {!r}'.format(address))
|
||||
|
Loading…
Reference in New Issue
Block a user