Implement SSH port forwarding on top of sshtunnel library

Change-Id: I6dbbc0996293602e43db6548afc35cfb34720604
changes/92/691292/3
Federico Ressi 3 years ago
parent fc8843058c
commit c77b681d4b

@ -10,6 +10,7 @@ neutron-lib>=1.25.0 # Apache-2.0
os-faults>=0.1.18 # Apache-2.0
oslo.config>=5.2.0 # Apache-2.0
oslo.log>=3.36.0 # Apache-2.0
pandas>=0.24.2 # BSD
paramiko>=2.4.0 # LGPLv2.1
pbr>=4.0.0 # Apache-2.0
python-heatclient>=1.5.0 # Apache-2.0
@ -20,5 +21,5 @@ python-octaviaclient>=1.9.0 # Apache-2.0
python-openstackclient>=3.0.0 # Apache-2.0
stestr>=2.0 # Apache-2.0
six>=1.10.0 # MIT
sshtunnel>=0.1.5 # MIT
testtools>=2.2.0 # MIT
pandas>=0.24.2 # BSD

@ -72,9 +72,9 @@ class KeystoneSessionFixture(tobiko.SharedFixture):
# api version parameter is not accepted
params.pop('api_version', None)
auth = loader.load_from_options(**params)
http_session = ssh.ssh_tunnel_http_session()
self.session = session = _session.Session(
auth=auth, verify=False, session=http_session)
auth=auth, verify=False)
ssh.setup_http_session_ssh_tunneling(session=session)
self.credentials = credentials

@ -31,4 +31,4 @@ ssh_proxy_client = _client.ssh_proxy_client
SSHConnectFailure = _client.SSHConnectFailure
gather_ssh_connect_parameters = _client.gather_ssh_connect_parameters
ssh_tunnel_http_session = _http.ssh_tunnel_http_session
setup_http_session_ssh_tunneling = _http.setup_http_session_ssh_tunneling

@ -236,6 +236,7 @@ class SSHClientFixture(tobiko.SharedFixture):
self.schema = schema = dict(schema or self.schema)
self._connect_parameters = gather_ssh_connect_parameters(
schema=schema, **kwargs)
self._forwarders = []
def setup_fixture(self):
self.setup_connect_parameters()
@ -301,9 +302,29 @@ class SSHClientFixture(tobiko.SharedFixture):
self.client, self.proxy_sock = ssh_connect(
proxy_client=self.proxy_client,
**self.connect_parameters)
self.addCleanup(self.client.close)
self.addCleanup(self.cleanup_ssh_client)
if self.proxy_sock:
self.addCleanup(self.proxy_sock.close)
self.addCleanup(self.cleanup_proxy_sock)
for forwarder in self._forwarders:
self.useFixture(forwarder)
def cleanup_ssh_client(self):
client = self.client
self.client = None
if client:
try:
client.close()
except Exception:
LOG.exception('Error closing client (%r)', self)
def cleanup_proxy_sock(self):
proxy_sock = self.proxy_sock
self.proxy_sock = None
if proxy_sock:
try:
proxy_sock.close()
except Exception:
LOG.exception('Error closing proxy socket (%r)', self)
def connect(self):
return tobiko.setup_fixture(self).client

@ -0,0 +1,160 @@
# Copyright (c) 2019 Red Hat, Inc.
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import collections
import contextlib
import socket
import netaddr
from oslo_log import log
import six
from six.moves import urllib
import sshtunnel
import tobiko
LOG = log.getLogger(__name__)
class SSHTunnelForwarderFixture(tobiko.SharedFixture):
forwarder = None
def __init__(self, ssh_client):
super(SSHTunnelForwarderFixture, self).__init__()
self.ssh_client = ssh_client
self._forwarding = collections.OrderedDict()
def put_forwarding(self, remote_address, remote_port=None,
local_address=None, local_port=None):
remote = AddressPair.create(remote_address, remote_port)
local = AddressPair.create(local_address, local_port)
return self._forwarding.setdefault(remote, local)
def get_forwarding(self, remote_address, remote_port=None):
remote = AddressPair.create(remote_address, remote_port)
return self._forwarding.get(remote)
def setup_fixture(self):
self.setup_forwarder()
def setup_forwarder(self):
forwarder = self.forwarder
if not forwarder:
remote_bind_addresses = list(self._forwarding.keys())
local_bind_addresses = list(self._forwarding.values())
self.forwarder = forwarder = SSHTunnelForwarder(
ssh_client=self.ssh_client,
local_bind_addresses=local_bind_addresses,
remote_bind_addresses=remote_bind_addresses)
self.addCleanup(self.cleanup_forwarder)
forwarder.start()
self.ssh_client.addCleanup(self)
return forwarder
def cleanup_forwarder(self):
forwarder = self.forwarder
if forwarder:
del self.forwarder
forwarder.stop()
class SSHTunnelForwarder(sshtunnel.SSHTunnelForwarder):
daemon_forward_servers = True #: flag tunnel threads in daemon mode
daemon_transport = True #: flag SSH transport thread in daemon mode
def __init__(self, ssh_client, **kwargs):
self.ssh_client = ssh_client
params = self._merge_parameters(self._get_connect_parameters(),
**kwargs)
super(SSHTunnelForwarder, self).__init__(**params)
def _merge_parameters(self, *dicts, **kwargs):
result = {}
for d in dicts + (kwargs,):
if d:
result.update((k, v) for k, v in d.items() if v is not None)
return result
@staticmethod
def _consolidate_auth(ssh_password=None,
ssh_pkey=None,
ssh_pkey_password=None,
allow_agent=True,
host_pkey_directories=None,
logger=None):
return None, None
def _get_connect_parameters(self):
parameters = self.ssh_client.setup_connect_parameters()
return dict(ssh_address_or_host=parameters['hostname'],
ssh_username=parameters.get('username'),
ssh_password=parameters.get('password'),
ssh_pkey=parameters.get('pkey'),
ssh_port=parameters.get('port'),
ssh_private_key_password=parameters.get('passphrase'),
compression=parameters.get('compress'),
allow_agent=parameters.get('allow_agent'))
def _connect_to_gateway(self):
# pylint: disable=attribute-defined-outside-init
self._transport = self._get_transport()
def _get_transport(self):
return self.ssh_client.connect().get_transport()
def _stop_transport(self):
if self.is_active:
del self._transport
assert not self.is_active
super(SSHTunnelForwarder, self)._stop_transport()
class AddressPair(collections.namedtuple('AddressPair', ['host', 'port'])):
@classmethod
def create(cls, address=None, port=None):
port = port and int(port) or None
address = address or '127.0.0.1'
if isinstance(address, netaddr.IPAddress):
if port is None:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
with contextlib.closing(sock):
sock.bind((str(address), 0))
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return cls(*sock.getsockname())
else:
return cls(str(address), port)
elif isinstance(address, urllib.parse.ParseResult):
return cls(address.hostname or address.path, address.port or None)
elif isinstance(address, six.string_types):
try:
return cls.create(netaddr.IPAddress(address), port)
except ValueError:
pass
if port is None:
return cls.create(urllib.parse.urlparse(address))
else:
return cls(address.lower(), port)
elif isinstance(address, collections.Sequence):
return cls.create(*address)
message = ("Invalid address pair parameters: "
"address={!r}, port={!r}").format(address, port)
raise TypeError(message)

@ -23,65 +23,37 @@ from urllib3 import connection
from urllib3 import connectionpool
from urllib3 import poolmanager
import tobiko
from tobiko.shell.ssh import _client
from tobiko.shell.ssh import _forward
def ssh_tunnel_http_session(ssh_client=None):
def setup_http_session_ssh_tunneling(session=None, ssh_client=None):
session = session or requests.Session()
ssh_client = ssh_client or _client.ssh_proxy_client()
if ssh_client is None:
return None
session = requests.Session()
mount_ssh_tunnel_http_adapter(session=session, ssh_client=ssh_client)
if ssh_client is not None:
for adapter in session.adapters.values():
manager = adapter.poolmanager
manager.pool_classes_by_scheme = pool_classes_by_scheme.copy()
manager.key_fn_by_scheme = key_fn_by_scheme.copy()
manager.connection_pool_kw['ssh_client'] = ssh_client
return session
def mount_ssh_tunnel_http_adapter(session, ssh_client):
adapter = SSHTunnelHttpAdapter(ssh_client=ssh_client)
for scheme in list(session.adapters):
session.mount(scheme, adapter)
class SSHTunnelHttpAdapter(requests.adapters.HTTPAdapter):
"""The custom adapter used to set tunnel HTTP connections over SSH tunnel
"""
def __init__(self, ssh_client, *args, **kwargs):
self.ssh_client = ssh_client
super(SSHTunnelHttpAdapter, self).__init__(*args, **kwargs)
def init_poolmanager(self, connections, maxsize,
block=requests.adapters.DEFAULT_POOLBLOCK,
**pool_kwargs):
# save these values for pickling
self._pool_connections = connections
self._pool_maxsize = maxsize
self._pool_block = block
self.poolmanager = SSHTunnelPoolManager(
num_pools=connections, maxsize=maxsize, block=block, strict=True,
ssh_client=self.ssh_client, **pool_kwargs)
class SSHTunnelPoolManager(poolmanager.PoolManager):
def __init__(self, *args, **kwargs):
super(SSHTunnelPoolManager, self).__init__(*args, **kwargs)
# Locally set the pool classes and keys so other PoolManagers can
# override them.
self.pool_classes_by_scheme = pool_classes_by_scheme
self.key_fn_by_scheme = key_fn_by_scheme.copy()
# pylint: disable=protected-access
# All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = poolmanager._key_fields + ('key_ssh_client',)
#: The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
SSHTunnelPoolKey = collections.namedtuple("SSHTunnelPoolKey", _key_fields)
class SSHTunnelPoolKey(
collections.namedtuple("SSHTunnelPoolKey", _key_fields)):
"""The namedtuple class used to construct keys for the connection pool.
All custom key schemes should include the fields in this key at a minimum.
"""
#: A dictionary that maps a scheme to a callable that creates a pool key.
#: This can be used to alter the way pool keys are constructed, if desired.
@ -99,20 +71,39 @@ key_fn_by_scheme = {
class SSHTunnelHTTPConnection(connection.HTTPConnection):
def __init__(self, *args, **kw):
self.ssh_client = kw.pop('ssh_client')
assert self.ssh_client is not None
super(SSHTunnelHTTPConnection, self).__init__(*args, **kw)
def __init__(self, local_address, *args, **kwargs):
super(SSHTunnelHTTPConnection, self).__init__(*args, **kwargs)
self.local_address = local_address
def _new_conn(self):
""" Establish a socket connection and set nodelay settings on it.
:return: New socket connection.
"""
return _client.ssh_proxy_sock(hostname=self._dns_host,
port=self.port,
source_address=self.source_address,
client=self.ssh_client)
extra_kw = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw["socket_options"] = self.socket_options
try:
conn = connection.connection.create_connection(
self.local_address, self.timeout, **extra_kw)
except connection.SocketTimeout:
raise connection.ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except connection.SocketError as e:
raise connection.NewConnectionError(
self, "Failed to establish a new connection: %s" % e
)
return conn
class SSHTunnelHTTPSConnection(SSHTunnelHTTPConnection,
@ -124,8 +115,17 @@ class SSHTunnelHTTPConnectionPool(connectionpool.HTTPConnectionPool):
ConnectionCls = SSHTunnelHTTPConnection
def __init__(self, host, port, ssh_client, **kwargs):
self.forwarder = forwarder = _forward.SSHTunnelForwarderFixture(
ssh_client=ssh_client)
local_address = forwarder.put_forwarding(host, port)
tobiko.setup_fixture(forwarder)
super(SSHTunnelHTTPConnectionPool, self).__init__(
host=host, port=port, local_address=local_address, **kwargs)
class SSHTunnelHTTPSConnectionPool(connectionpool.HTTPSConnectionPool):
class SSHTunnelHTTPSConnectionPool(SSHTunnelHTTPConnectionPool,
connectionpool.HTTPSConnectionPool):
ConnectionCls = SSHTunnelHTTPSConnection

Loading…
Cancel
Save