205 lines
6.2 KiB
Python
205 lines
6.2 KiB
Python
"""Cache for SSH tunnels."""
|
|
|
|
from __future__ import absolute_import
|
|
|
|
import collections
|
|
from collections import abc
|
|
import getpass
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
import weakref
|
|
from contextlib import suppress
|
|
|
|
import psutil
|
|
|
|
Context = collections.namedtuple('Context', (
|
|
'uri',
|
|
'interface',
|
|
'local_socket',
|
|
'remote_socket',
|
|
'username',
|
|
'hostname',
|
|
'port',
|
|
'identity_file',
|
|
'ignore_hosts',
|
|
'known_hosts',
|
|
))
|
|
Context.__new__.__defaults__ = (None, ) * len(Context._fields) # type: ignore
|
|
|
|
|
|
class Portal(abc.MutableMapping):
|
|
"""Expiring container for tunnels."""
|
|
|
|
def __init__(self, sweap=25):
|
|
"""Construct portal, reap tunnels every sweap seconds."""
|
|
self.data = collections.OrderedDict()
|
|
self.sweap = sweap
|
|
self.ttl = sweap * 2
|
|
self.lock = threading.RLock()
|
|
self._schedule_reaper()
|
|
|
|
def __getitem__(self, key):
|
|
"""Given uri return tunnel and update TTL."""
|
|
with self.lock:
|
|
value, _ = self.data[key]
|
|
self.data[key] = (value, time.time() + self.ttl)
|
|
self.data.move_to_end(key)
|
|
return value
|
|
|
|
def __setitem__(self, key, value):
|
|
"""Store given tunnel keyed with uri."""
|
|
if not isinstance(value, Tunnel):
|
|
raise ValueError('Portals only support Tunnels.')
|
|
|
|
with self.lock:
|
|
self.data[key] = (value, time.time() + self.ttl)
|
|
self.data.move_to_end(key)
|
|
|
|
def __delitem__(self, key):
|
|
"""Remove and close tunnel from portal."""
|
|
with self.lock:
|
|
value, _ = self.data[key]
|
|
del self.data[key]
|
|
value.close()
|
|
del value
|
|
|
|
def __iter__(self):
|
|
"""Iterate tunnels."""
|
|
with self.lock:
|
|
values = list(self.data.values())
|
|
|
|
for tunnel, _ in values:
|
|
yield tunnel
|
|
|
|
def __len__(self):
|
|
"""Return number of tunnels in portal."""
|
|
with self.lock:
|
|
return len(self.data)
|
|
|
|
def _schedule_reaper(self):
|
|
timer = threading.Timer(interval=self.sweap, function=self.reap)
|
|
timer.name = 'PortalReaper'
|
|
timer.daemon = True
|
|
timer.start()
|
|
|
|
def reap(self):
|
|
"""Remove tunnels who's TTL has expired."""
|
|
now = time.time()
|
|
with self.lock:
|
|
reaped_data = self.data.copy()
|
|
for entry in reaped_data.items():
|
|
if entry[1][1] < now:
|
|
del self.data[entry[0]]
|
|
else:
|
|
# StopIteration as soon as possible
|
|
break
|
|
self._schedule_reaper()
|
|
|
|
|
|
class Tunnel:
|
|
"""SSH tunnel."""
|
|
|
|
def __init__(self, context):
|
|
"""Construct Tunnel."""
|
|
self.context = context
|
|
self._closed = True
|
|
|
|
@property
|
|
def closed(self):
|
|
"""Is tunnel closed."""
|
|
return self._closed
|
|
|
|
def bore(self):
|
|
"""Create SSH tunnel from given context."""
|
|
cmd = ['ssh', '-fNT']
|
|
|
|
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
|
|
cmd.append('-v')
|
|
else:
|
|
cmd.append('-q')
|
|
|
|
if self.context.port:
|
|
cmd.extend(('-p', str(self.context.port)))
|
|
|
|
cmd.extend(('-L', '{}:{}'.format(self.context.local_socket,
|
|
self.context.remote_socket)))
|
|
|
|
if self.context.ignore_hosts:
|
|
cmd.extend(('-o', 'StrictHostKeyChecking=no',
|
|
'-o', 'UserKnownHostsFile=/dev/null'))
|
|
elif self.context.known_hosts:
|
|
cmd.extend(('-o', 'UserKnownHostsFile={known_hosts}'.format(
|
|
known_hosts=self.context.known_hosts)))
|
|
|
|
if self.context.identity_file:
|
|
cmd.extend(('-i', self.context.identity_file))
|
|
|
|
cmd.append('{}@{}'.format(self.context.username,
|
|
self.context.hostname))
|
|
|
|
logging.debug('Opening tunnel "%s", cmd "%s"', self.context.uri,
|
|
' '.join(cmd))
|
|
|
|
tunnel = subprocess.Popen(cmd, close_fds=True)
|
|
# The return value of Popen() has no long term value as that process
|
|
# has already exited by the time control is returned here. This is a
|
|
# side effect of the -f option. wait() will be called to clean up
|
|
# resources.
|
|
for _ in range(300):
|
|
# TODO: Make timeout configurable
|
|
if os.path.exists(self.context.local_socket) \
|
|
or tunnel.returncode is not None:
|
|
break
|
|
with suppress(subprocess.TimeoutExpired):
|
|
# waiting for either socket to be created
|
|
# or first child to exit
|
|
tunnel.wait(0.5)
|
|
else:
|
|
raise TimeoutError(
|
|
'Failed to create tunnel "{}", using: "{}"'.format(
|
|
self.context.uri, ' '.join(cmd)))
|
|
if tunnel.returncode is not None and tunnel.returncode != 0:
|
|
raise subprocess.CalledProcessError(tunnel.returncode,
|
|
' '.join(cmd))
|
|
tunnel.wait()
|
|
|
|
self._closed = False
|
|
weakref.finalize(self, self.close)
|
|
return self
|
|
|
|
def close(self):
|
|
"""Close SSH tunnel."""
|
|
logging.debug('Closing tunnel "%s"', self.context.uri)
|
|
|
|
if self._closed:
|
|
return
|
|
|
|
# Find all ssh instances for user with uri tunnel the hard way...
|
|
targets = [
|
|
p
|
|
for p in psutil.process_iter(attrs=['name', 'username', 'cmdline'])
|
|
if p.info['username'] == getpass.getuser()
|
|
and p.info['name'] == 'ssh'
|
|
and self.context.local_socket in ' '.join(p.info['cmdline'])
|
|
] # yapf: disable
|
|
|
|
# ask nicely for ssh to quit, reap results
|
|
for proc in targets:
|
|
proc.terminate()
|
|
_, alive = psutil.wait_procs(targets, timeout=300)
|
|
|
|
# kill off the uncooperative, then report any stragglers
|
|
for proc in alive:
|
|
proc.kill()
|
|
_, alive = psutil.wait_procs(targets, timeout=300)
|
|
|
|
for proc in alive:
|
|
logging.info('process %d survived SIGKILL, giving up.', proc.pid)
|
|
|
|
with suppress(OSError):
|
|
os.remove(self.context.local_socket)
|
|
self._closed = True
|