Adds Windows Support to satori discovery

This change adds the capability to satori to do data plane
discovery on Windows devices. For that, a class has been
added that mirrors the functionality of satori/ssh.py and
utilizes a "3rd party" script (satori/contrib/psexec.py)
which is called via subprocess.Popen(). Further to that, an
SSH tunneling class has been put in place that uses paramiko
to establish a tunnel (similar to running ssh -L from a shell).
requirements.txt has been extended to include impacket which
satori/contrib/psexec.py imports.

Move support for PoSH-Ohai into its own provider module.
Raise UnsupportedPlatform exceptions in ohai_solo.py
when the client is non-linux, and raise the same
exception in posh_ohai.py when the client is non-windows.

Co-Authored-By: Nico Engelen <engelen.nico@googlemail.com>
Co-Authored-By: Samuel Stavinoha <samuel.stavinoha@rackspace.com>
Change-Id: I7a94eea9446bc7f57843407fb98880222f7af6af
Implements: blueprint windows-support
This commit is contained in:
Samuel Stavinoha 2014-07-23 15:04:43 +00:00
parent 078468a468
commit 6cabc1773d
12 changed files with 1386 additions and 129 deletions

View File

@ -9,3 +9,7 @@ no-docstring-rgx=((__.*__)|([tT]est.*)|setUp|tearDown)$
[Design]
min-public-methods=0
max-args=6
[Master]
#We try to keep contrib files unmodified
ignore=satori/contrib

View File

@ -1,10 +1,14 @@
# slightly improved version of impacket that allows you to use impacket.examples.serviceinstall
# to create services that are not randomly named
-e git://github.com/nick-o/impacket.git@e4fcac42975fd2f20d9ae6e8643c0fd9fab33c7a#egg=impacket
ipaddress>=1.0.6 # in stdlib as of python3.3
iso8601>=0.1.5
Jinja2>=2.7.1 # bug resolve @2.7.1
paramiko>=1.12.0 # ecdsa added
pbr>=0.5.21,<1.0
python-novaclient>=2.6.0.1 # breaks before
pythonwhois>=2.0.0
pythonwhois>=2.4.3
six>=1.4.0 # urllib introduced
tldextract>=1.2
argparse

View File

@ -14,6 +14,12 @@
__all__ = ('__version__')
try:
import eventlet
eventlet.monkey_patch()
except ImportError:
pass
import pbr.version
from satori import shell

View File

@ -23,6 +23,7 @@ import shlex
import subprocess
from satori import errors
from satori import smb
from satori import ssh
from satori import utils
@ -33,16 +34,21 @@ class ShellMixin(object):
"""Handle platform detection and define execute command."""
def execute(self, command, wd=None, with_exit_code=None):
def execute(self, command, **kwargs):
"""Execute a (shell) command on the target.
:param command: Shell command to be executed
:param with_exit_code: Include the exit_code in the return body.
:param wd: The child's current directory will be changed
to `wd` before it is executed. Note that this
:param cwd: The child's current directory will be changed
to `cwd` before it is executed. Note that this
directory is not considered when searching the
executable, so you can't specify the program's
path relative to this argument
:returns: a dict with stdin, stdout, and
(optionally), the exit_code of the call
See SSH.remote_execute(), SMB.remote_execute(), and
LocalShell.execute() for client-specific keyword arguments.
"""
pass
@ -88,6 +94,9 @@ class ShellMixin(object):
Uses the platform_info property.
"""
if hasattr(self, '_client'):
if isinstance(self._client, smb.SMBClient):
return True
if not self.platform_info['dist']:
raise errors.UndeterminedPlatform(
'Unable to determine whether the system is Windows based.')
@ -122,26 +131,28 @@ class LocalShell(ShellMixin):
self._platform_info = utils.get_platform_info()
return self._platform_info
def execute(self, command, wd=None, with_exit_code=None):
def execute(self, command, **kwargs):
"""Execute a command (containing no shell operators) locally.
:param command: Shell command to be executed.
:param with_exit_code: Include the exit_code in the return body.
Default is False.
:param wd: The child's current directory will be changed
to `wd` before it is executed. Note that this
:param cwd: The child's current directory will be changed
to `cwd` before it is executed. Note that this
directory is not considered when searching the
executable, so you can't specify the program's
path relative to this argument
:returns: A dict with stdin, stdout, and
(optionally) the exit code.
"""
cwd = kwargs.get('cwd')
with_exit_code = kwargs.get('with_exit_code')
spipe = subprocess.PIPE
cmd = shlex.split(command)
LOG.debug("Executing `%s` on local machine", command)
result = subprocess.Popen(
cmd, stdout=spipe, stderr=spipe, cwd=wd)
cmd, stdout=spipe, stderr=spipe, cwd=cwd)
out, err = result.communicate()
resultdict = {
'stdout': out.strip(),
@ -159,7 +170,7 @@ class RemoteShell(ShellMixin):
def __init__(self, address, password=None, username=None,
private_key=None, key_filename=None, port=None,
timeout=None, gateway=None, options=None, interactive=False,
**kwargs):
protocol='ssh', **kwargs):
"""An interface for executing shell commands on remote machines.
:param str host: The ip address or host name of the server
@ -189,11 +200,20 @@ class RemoteShell(ShellMixin):
LOG.warning("Satori RemoteClient received unrecognized "
"keyword arguments: %s", kwargs.keys())
self._client = ssh.connect(
address, password=password, username=username,
private_key=private_key, key_filename=key_filename, port=port,
timeout=timeout, gateway=gateway, options=options,
interactive=interactive)
if protocol == 'smb':
self._client = smb.connect(address, password=password,
username=username,
port=port, timeout=timeout,
gateway=gateway)
else:
self._client = ssh.connect(address, password=password,
username=username,
private_key=private_key,
key_filename=key_filename,
port=port, timeout=timeout,
gateway=gateway,
options=options,
interactive=interactive)
self.host = self._client.host
self.port = self._client.port

558
satori/contrib/psexec.py Normal file
View File

@ -0,0 +1,558 @@
#!/usr/bin/python
# Copyright (c) 2003-2012 CORE Security Technologies
#
# This software is provided under under a slightly modified version
# of the Apache Software License. See the accompanying LICENSE file
# for more information.
#
# $Id: psexec.py 712 2012-09-06 04:26:22Z bethus@gmail.com $
#
# PSEXEC like functionality example using
#RemComSvc (https://github.com/kavika13/RemCom)
#
# Author:
# beto (bethus@gmail.com)
#
# Reference for:
# DCE/RPC and SMB.
""".
OK
"""
import cmd
import os
import re
import sys
#from impacket.smbconnection import *
from impacket.dcerpc import dcerpc
from impacket.dcerpc import transport
from impacket.examples import remcomsvc
from impacket.examples import serviceinstall
from impacket import smbconnection
from impacket import structure as im_structure
from impacket import version
#from impacket.dcerpc import dcerpc_v4
#from impacket.dcerpc import srvsvc
#from impacket.dcerpc import svcctl
#from impacket.smbconnection import smb
#from impacket.smbconnection import SMB_DIALECT
#from impacket.smbconnection import SMBConnection
import argparse
import random
import string
import threading
import time
class RemComMessage(im_structure.Structure):
"""."""
structure = (
('Command', '4096s=""'),
('WorkingDir', '260s=""'),
('Priority', '<L=0x20'),
('ProcessID', '<L=0x01'),
('Machine', '260s=""'),
('NoWait', '<L=0'),
)
class RemComResponse(im_structure.Structure):
"""."""
structure = (
('ErrorCode', '<L=0'),
('ReturnCode', '<L=0'),
)
RemComSTDOUT = "RemCom_stdout"
RemComSTDIN = "RemCom_stdin"
RemComSTDERR = "RemCom_stderr"
lock = threading.Lock()
class PSEXEC:
"""."""
KNOWN_PROTOCOLS = {
'139/SMB': (r'ncacn_np:%s[\pipe\svcctl]', 139),
'445/SMB': (r'ncacn_np:%s[\pipe\svcctl]', 445),
}
def __init__(self, command, path, exeFile, protocols=None,
username='', password='', domain='', hashes=None):
"""."""
if not protocols:
protocols = PSEXEC.KNOWN_PROTOCOLS.keys()
self.__username = username
self.__password = password
self.__protocols = [protocols]
self.__command = command
self.__path = path
self.__domain = domain
self.__lmhash = ''
self.__nthash = ''
self.__exeFile = exeFile
if hashes is not None:
self.__lmhash, self.__nthash = hashes.split(':')
def run(self, addr):
"""."""
for protocol in self.__protocols:
protodef = PSEXEC.KNOWN_PROTOCOLS[protocol]
port = protodef[1]
print("Trying protocol %s...\n" % protocol)
stringbinding = protodef[0] % addr
rpctransport = transport.DCERPCTransportFactory(stringbinding)
rpctransport.set_dport(port)
if hasattr(rpctransport, 'preferred_dialect'):
rpctransport.preferred_dialect(smbconnection.SMB_DIALECT)
if hasattr(rpctransport, 'set_credentials'):
# This method exists only for selected protocol sequences.
rpctransport.set_credentials(self.__username, self.__password,
self.__domain, self.__lmhash,
self.__nthash)
self.doStuff(rpctransport)
def openPipe(self, s, tid, pipe, accessMask):
"""."""
pipeReady = False
tries = 50
while pipeReady is False and tries > 0:
try:
s.waitNamedPipe(tid, pipe)
pipeReady = True
except Exception:
tries -= 1
time.sleep(2)
pass
if tries == 0:
print('[!] Pipe not ready, aborting')
raise
fid = s.openFile(tid, pipe, accessMask, creationOption=0x40,
fileAttributes=0x80)
return fid
def doStuff(self, rpctransport):
"""."""
dce = dcerpc.DCERPC_v5(rpctransport)
try:
dce.connect()
except Exception as e:
print(e)
sys.exit(1)
global dialect
dialect = rpctransport.get_smb_connection().getDialect()
try:
unInstalled = False
s = rpctransport.get_smb_connection()
# We don't wanna deal with timeouts from now on.
s.setTimeout(100000)
svcName = "RackspaceSystemDiscovery"
executableName = "RackspaceSystemDiscovery.exe"
if self.__exeFile is None:
svc = remcomsvc.RemComSvc()
installService = serviceinstall.ServiceInstall(s, svc,
svcName,
executableName)
else:
try:
f = open(self.__exeFile)
except Exception as e:
print(e)
sys.exit(1)
installService = serviceinstall.ServiceInstall(s, f,
svcName,
executableName)
installService.install()
if self.__exeFile is not None:
f.close()
tid = s.connectTree('IPC$')
fid_main = self.openPipe(s, tid, '\RemCom_communicaton', 0x12019f)
packet = RemComMessage()
pid = os.getpid()
packet['Machine'] = ''.join([random.choice(string.letters)
for i in range(4)])
if self.__path is not None:
packet['WorkingDir'] = self.__path
packet['Command'] = self.__command
packet['ProcessID'] = pid
s.writeNamedPipe(tid, fid_main, str(packet))
# Here we'll store the command we type so we don't print it back ;)
# ( I know.. globals are nasty :P )
global LastDataSent
LastDataSent = ''
retCode = None
# Create the pipes threads
stdin_pipe = RemoteStdInPipe(rpctransport,
'\%s%s%d' % (RemComSTDIN,
packet['Machine'],
packet['ProcessID']),
smbconnection.smb.FILE_WRITE_DATA |
smbconnection.smb.FILE_APPEND_DATA,
installService.getShare())
stdin_pipe.start()
stdout_pipe = RemoteStdOutPipe(rpctransport,
'\%s%s%d' % (RemComSTDOUT,
packet['Machine'],
packet['ProcessID']),
smbconnection.smb.FILE_READ_DATA)
stdout_pipe.start()
stderr_pipe = RemoteStdErrPipe(rpctransport,
'\%s%s%d' % (RemComSTDERR,
packet['Machine'],
packet['ProcessID']),
smbconnection.smb.FILE_READ_DATA)
stderr_pipe.start()
# And we stay here till the end
ans = s.readNamedPipe(tid, fid_main, 8)
if len(ans):
retCode = RemComResponse(ans)
print("[*] Process %s finished with ErrorCode: %d, "
"ReturnCode: %d" % (self.__command, retCode['ErrorCode'],
retCode['ReturnCode']))
installService.uninstall()
unInstalled = True
sys.exit(retCode['ReturnCode'])
except Exception:
if unInstalled is False:
installService.uninstall()
sys.stdout.flush()
if retCode:
sys.exit(retCode['ReturnCode'])
else:
sys.exit(1)
class Pipes(threading.Thread):
"""."""
def __init__(self, transport, pipe, permissions, share=None):
"""."""
threading.Thread.__init__(self)
self.server = 0
self.transport = transport
self.credentials = transport.get_credentials()
self.tid = 0
self.fid = 0
self.share = share
self.port = transport.get_dport()
self.pipe = pipe
self.permissions = permissions
self.daemon = True
def connectPipe(self):
"""."""
try:
lock.acquire()
global dialect
remoteHost = self.transport.get_smb_connection().getRemoteHost()
#self.server = SMBConnection('*SMBSERVER',
#self.transport.get_smb_connection().getRemoteHost(),
#sess_port = self.port, preferredDialect = SMB_DIALECT)
self.server = smbconnection.SMBConnection('*SMBSERVER', remoteHost,
sess_port=self.port,
preferredDialect=dialect) # noqa
user, passwd, domain, lm, nt = self.credentials
self.server.login(user, passwd, domain, lm, nt)
lock.release()
self.tid = self.server.connectTree('IPC$')
self.server.waitNamedPipe(self.tid, self.pipe)
self.fid = self.server.openFile(self.tid, self.pipe,
self.permissions,
creationOption=0x40,
fileAttributes=0x80)
self.server.setTimeout(1000000)
except Exception:
message = ("[!] Something wen't wrong connecting the pipes(%s), "
"try again")
print(message % self.__class__)
class RemoteStdOutPipe(Pipes):
"""."""
def __init__(self, transport, pipe, permisssions):
"""."""
Pipes.__init__(self, transport, pipe, permisssions)
def run(self):
"""."""
self.connectPipe()
while True:
try:
ans = self.server.readFile(self.tid, self.fid, 0, 1024)
except Exception:
pass
else:
try:
global LastDataSent
if ans != LastDataSent: # noqa
sys.stdout.write(ans)
sys.stdout.flush()
else:
# Don't echo what I sent, and clear it up
LastDataSent = ''
# Just in case this got out of sync, i'm cleaning it
# up if there are more than 10 chars,
# it will give false positives tho.. we should find a
# better way to handle this.
if LastDataSent > 10:
LastDataSent = ''
except Exception:
pass
class RemoteStdErrPipe(Pipes):
"""."""
def __init__(self, transport, pipe, permisssions):
"""."""
Pipes.__init__(self, transport, pipe, permisssions)
def run(self):
"""."""
self.connectPipe()
while True:
try:
ans = self.server.readFile(self.tid, self.fid, 0, 1024)
except Exception:
pass
else:
try:
sys.stderr.write(str(ans))
sys.stderr.flush()
except Exception:
pass
class RemoteShell(cmd.Cmd):
"""."""
def __init__(self, server, port, credentials, tid, fid, share):
"""."""
cmd.Cmd.__init__(self, False)
self.prompt = '\x08'
self.server = server
self.transferClient = None
self.tid = tid
self.fid = fid
self.credentials = credentials
self.share = share
self.port = port
self.intro = '[!] Press help for extra shell commands'
def connect_transferClient(self):
"""."""
#self.transferClient = SMBConnection('*SMBSERVER',
#self.server.getRemoteHost(), sess_port = self.port,
#preferredDialect = SMB_DIALECT)
self.transferClient = smbconnection.SMBConnection('*SMBSERVER',
self.server.getRemoteHost(),
sess_port=self.port,
preferredDialect=dialect) # noqa
user, passwd, domain, lm, nt = self.credentials
self.transferClient.login(user, passwd, domain, lm, nt)
def do_help(self, line):
"""."""
print("""
lcd {path} - changes the current local directory to {path}
exit - terminates the server process (and this session)
put {src_file, dst_path} - uploads a local file to the dst_path RELATIVE to
the connected share (%s)
get {file} - downloads pathname RELATIVE to the connected
share (%s) to the current local dir
! {cmd} - executes a local shell cmd
""" % (self.share, self.share))
self.send_data('\r\n', False)
def do_shell(self, s):
"""."""
os.system(s)
self.send_data('\r\n')
def do_get(self, src_path):
"""."""
try:
if self.transferClient is None:
self.connect_transferClient()
import ntpath
filename = ntpath.basename(src_path)
fh = open(filename, 'wb')
print("[*] Downloading %s\%s" % (self.share, src_path))
self.transferClient.getFile(self.share, src_path, fh.write)
fh.close()
except Exception as e:
print(e)
pass
self.send_data('\r\n')
def do_put(self, s):
"""."""
try:
if self.transferClient is None:
self.connect_transferClient()
params = s.split(' ')
if len(params) > 1:
src_path = params[0]
dst_path = params[1]
elif len(params) == 1:
src_path = params[0]
dst_path = '/'
src_file = os.path.basename(src_path)
fh = open(src_path, 'rb')
f = dst_path + '/' + src_file
pathname = string.replace(f, '/', '\\')
print("[*] Uploading %s to %s\\%s" % (src_file, self.share,
dst_path))
self.transferClient.putFile(self.share, pathname, fh.read)
fh.close()
except Exception as e:
print(e)
pass
self.send_data('\r\n')
def do_lcd(self, s):
"""."""
if s == '':
print(os.getcwd())
else:
os.chdir(s)
self.send_data('\r\n')
def emptyline(self):
"""."""
self.send_data('\r\n')
return
def do_EOF(self, line):
"""."""
self.server.logoff()
def default(self, line):
"""."""
self.send_data(line+'\r\n')
def send_data(self, data, hideOutput=True):
"""."""
if hideOutput is True:
global LastDataSent
LastDataSent = data
else:
LastDataSent = ''
self.server.writeFile(self.tid, self.fid, data)
class RemoteStdInPipe(Pipes):
"""RemoteStdInPipe class.
Used to connect to RemComSTDIN named pipe on remote system
"""
def __init__(self, transport, pipe, permisssions, share=None):
"""Constructor."""
Pipes.__init__(self, transport, pipe, permisssions, share)
def run(self):
"""."""
self.connectPipe()
self.shell = RemoteShell(self.server, self.port, self.credentials,
self.tid, self.fid, self.share)
self.shell.cmdloop()
# Process command-line arguments.
if __name__ == '__main__':
print(version.BANNER)
parser = argparse.ArgumentParser()
parser.add_argument('target', action='store',
help='[domain/][username[:password]@]<address>')
parser.add_argument('command', action='store',
help='command to execute at the target (w/o path)')
parser.add_argument('-path', action='store',
help='path of the command to execute')
parser.add_argument(
'-file', action='store',
help="alternative RemCom binary (be sure it doesn't require CRT)")
parser.add_argument(
'-port', action='store',
help='alternative port to use, this will copy settings from 445/SMB')
parser.add_argument('protocol', choices=PSEXEC.KNOWN_PROTOCOLS.keys(),
nargs='?', default='445/SMB',
help='transport protocol (default 445/SMB)')
group = parser.add_argument_group('authentication')
group.add_argument('-hashes', action="store", metavar="LMHASH:NTHASH",
help='NTLM hashes, format is LMHASH:NTHASH')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
options = parser.parse_args()
domain, username, password, address = re.compile(
'(?:(?:([^/@:]*)/)?([^@:]*)(?::([^.]*))?@)?(.*)'
).match(options.target).groups('')
if domain is None:
domain = ''
if options.port:
options.protocol = "%s/SMB" % options.port
executer = PSEXEC(options.command, options.path, options.file,
options.protocol, username, password, domain,
options.hashes)
if options.protocol not in PSEXEC.KNOWN_PROTOCOLS:
connection_string = 'ncacn_np:%s[\\pipe\\svcctl]'
PSEXEC.KNOWN_PROTOCOLS[options.protocol] = (connection_string,
options.port)
executer.run(address)

View File

@ -56,9 +56,9 @@ def run(target, config=None, interactive=False):
found['hostname'] = hostname
ip_address = six.text_type(dns.resolve_hostname(hostname))
# TODO(sam): Use ipaddress.ip_address.is_global
# " .is_private
# " .is_unspecified
# " .is_multicast
# .is_private
# .is_unspecified
# .is_multicast
# To determine address "type"
if not ipaddress.ip_address(ip_address).is_loopback:
try:

316
satori/smb.py Normal file
View File

@ -0,0 +1,316 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Windows remote client module implemented using psexec.py."""
try:
import eventlet
eventlet.monkey_patch()
from eventlet.green import time
except ImportError:
import time
import ast
import base64
import logging
import os
import re
import shlex
import subprocess
import tempfile
from satori import ssh
from satori import tunnel
LOG = logging.getLogger(__name__)
def connect(*args, **kwargs):
"""Connect to a remote device using psexec.py."""
try:
return SMBClient.get_client(*args, **kwargs)
except Exception as exc:
LOG.error("ERROR: pse.py failed to connect: %s", str(exc))
def _posh_encode(command):
"""Encode a powershell command to base64.
This is using utf-16 encoding and disregarding the first two bytes
:param command: command to encode
"""
return base64.b64encode(command.encode('utf-16')[2:])
class SubprocessError(Exception):
"""Custom Exception.
This will be raised when the subprocess running psexec.py has exited.
"""
pass
class SMBClient(object): # pylint: disable=R0902
"""Connects to devices over SMB/psexec to execute commands."""
_prompt_pattern = re.compile(r'^[a-zA-Z]:\\.*>$', re.MULTILINE)
# pylint: disable=R0913
def __init__(self, host, password=None, username="Administrator",
port=445, timeout=10, gateway=None, **kwargs):
"""Create an instance of the PSE class.
:param str host: The ip address or host name of the server
to connect to
:param str password: A password to use for authentication
:param str username: The username to authenticate as (defaults to
Administrator)
:param int port: tcp/ip port to use (defaults to 445)
:param float timeout: an optional timeout (in seconds) for the
TCP connection
:param gateway: instance of satori.ssh.SSH to be used to set up
an SSH tunnel (equivalent to ssh -L)
"""
self.password = password
self.host = host
self.port = port or 445
self.username = username or 'Administrator'
self.timeout = timeout
self._connected = False
self._platform_info = None
self._process = None
self._orig_host = None
self._orig_port = None
self.ssh_tunnel = None
self._substituted_command = None
# creating temp file to talk to _process with
self._file_write = tempfile.NamedTemporaryFile()
self._file_read = open(self._file_write.name, 'r')
self._command = ("nice python %s/contrib/psexec.py -port %s %s:%s@%s "
"'c:\\Windows\\sysnative\\cmd'")
self._output = ''
self.gateway = gateway
if gateway:
if not isinstance(self.gateway, ssh.SSH):
raise TypeError("'gateway' must be a satori.ssh.SSH instance. "
"( instances of this type are returned by"
"satori.ssh.connect() )")
if kwargs:
LOG.debug("DEBUG: Following arguments passed into PSE constructor "
"not used: %s", kwargs.keys())
def __del__(self):
"""Destructor of the PSE class."""
try:
self.close()
except ValueError:
pass
@classmethod
def get_client(cls, *args, **kwargs):
"""Return a pse client object from this module."""
return cls(*args, **kwargs)
@property
def platform_info(self):
"""Return Windows edition, version and architecture.
requires Powershell version 3
"""
if not self._platform_info:
command = ('Get-WmiObject Win32_OperatingSystem |'
' select @{n="dist";e={$_.Caption.Trim()}},'
'@{n="version";e={$_.Version}},@{n="arch";'
'e={$_.OSArchitecture}} | '
' ConvertTo-Json -Compress')
stdout = self.remote_execute(command, retry=3)
self._platform_info = ast.literal_eval(stdout)
return self._platform_info
def create_tunnel(self):
"""Create an ssh tunnel via gateway.
This will tunnel a local ephemeral port to the host's port.
This will preserve the original host and port
"""
self.ssh_tunnel = tunnel.connect(self.host, self.port, self.gateway)
self._orig_host = self.host
self._orig_port = self.port
self.host, self.port = self.ssh_tunnel.address
self.ssh_tunnel.serve_forever(async=True)
def shutdown_tunnel(self):
"""Terminate the ssh tunnel. Restores original host and port."""
self.ssh_tunnel.shutdown()
self.host = self._orig_host
self.port = self._orig_port
def test_connection(self):
"""Connect to a Windows server and disconnect again.
Make sure the returncode is 0, otherwise return False
"""
self.connect()
self.close()
self._get_output()
if self._output.find('ErrorCode: 0, ReturnCode: 0') > -1:
return True
else:
return False
def connect(self):
"""Attempt a connection using psexec.py.
This will create a subprocess.Popen() instance and communicate with it
via _file_read/_file_write and _process.stdin
"""
try:
if self._connected and self._process:
if self._process.poll() is None:
return
else:
self._process.wait()
if self.gateway:
self.shutdown_tunnel()
if self.gateway:
self.create_tunnel()
self._substituted_command = self._command % (
os.path.dirname(__file__),
self.port,
self.username,
self.password,
self.host)
self._process = subprocess.Popen(
shlex.split(self._substituted_command),
stdout=self._file_write,
stderr=subprocess.STDOUT,
stdin=subprocess.PIPE,
close_fds=True,
bufsize=0)
output = ''
while not self._prompt_pattern.findall(output):
output += self._get_output()
self._connected = True
except Exception:
self.close()
raise
def close(self):
"""Close the psexec connection by sending 'exit' to the subprocess.
This will cleanly exit psexec (i.e. stop and uninstall the service and
delete the files)
This method will be called when an instance of this class is about to
being destroyed. It will try to close the connection (which will clean
up on the remote server) and catch the exception that is raised when
the connection has already been closed.
"""
try:
self._process.communicate('exit')
except Exception as exc:
LOG.warning("ERROR: Failed to close %s: %s", self, str(exc))
del exc
try:
if self.gateway:
self.shutdown_tunnel()
self.gateway.close()
except Exception as exc:
LOG.warning("ERROR: Failed to close gateway %s: %s", self.gateway,
str(exc))
del exc
finally:
if self._process:
LOG.warning("Killing process: %s", self._process)
subprocess.call(['pkill', '-STOP', '-P',
str(self._process.pid)])
def remote_execute(self, command, powershell=True, retry=0, **kwargs):
"""Execute a command on a remote host.
:param command: Command to be executed
:param powershell: If True, command will be interpreted as Powershell
command and therefore converted to base64 and
prepended with 'powershell -EncodedCommand
:param int retry: Number of retries when SubprocessError is thrown
by _get_output before giving up
"""
self.connect()
if powershell:
command = ('powershell -EncodedCommand %s' %
_posh_encode(command))
self._process.stdin.write('%s\n' % command)
try:
output = self._get_output()
output = "\n".join(output.splitlines()[:-1]).strip()
return output
except SubprocessError:
if not retry:
raise
else:
return self.remote_execute(command, powershell=powershell,
retry=retry - 1)
def _get_output(self, prompt_expected=True, wait=200):
"""Retrieve output from _process.
This method will wait until output is started to be received and then
wait until no further output is received within a defined period
:param prompt_expected: only return when regular expression defined
in _prompt_pattern is matched
:param wait: Time in milliseconds to wait in each of the
two loops that wait for (more) output.
"""
tmp_out = ''
while tmp_out == '':
self._file_read.seek(0, 1)
tmp_out += self._file_read.read()
# leave loop if underlying process has a return code
# obviously meaning that it has terminated
if self._process.poll() is not None:
import json
error = {"error": tmp_out}
raise SubprocessError("subprocess with pid: %s has terminated "
"unexpectedly with return code: %s\n%s"
% (self._process.pid,
self._process.poll(),
json.dumps(error)))
time.sleep(wait/1000)
stdout = tmp_out
while (not tmp_out == '' or
(not self._prompt_pattern.findall(stdout) and
prompt_expected)):
self._file_read.seek(0, 1)
tmp_out = self._file_read.read()
stdout += tmp_out
# leave loop if underlying process has a return code
# obviously meaning that it has terminated
if self._process.poll() is not None:
import json
error = {"error": stdout}
raise SubprocessError("subprocess with pid: %s has terminated "
"unexpectedly with return code: %s\n%s"
% (self._process.pid,
self._process.poll(),
json.dumps(error)))
time.sleep(wait/1000)
self._output += stdout
stdout = stdout.replace('\r', '').replace('\x08', '')
return stdout

View File

@ -129,10 +129,10 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
"""
self.password = password
self.host = host
self.username = username
self.username = username or 'root'
self.private_key = private_key
self.key_filename = key_filename
self.port = port
self.port = port or 22
self.timeout = timeout
self._platform_info = None
self.options = options or {}
@ -160,7 +160,6 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
Requires >= Python 2.4 on remote system.
"""
if not self._platform_info:
platform_command = "import platform,sys\n"
platform_command += utils.get_source_definition(
utils.get_platform_info)
@ -169,10 +168,18 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
command = 'echo -e """%s""" | python' % platform_command
output = self.remote_execute(command)
stdout = re.split('\n|\r\n', output['stdout'])[-1].strip()
plat = ast.literal_eval(stdout)
if stdout:
try:
plat = ast.literal_eval(stdout)
except SyntaxError as exc:
plat = {'dist': 'unknown'}
LOG.warning("Error parsing response from host '%s': %s",
self.host, output, exc_info=exc)
else:
plat = {'dist': 'unknown'}
LOG.warning("Blank response from host '%s': %s",
self.host, output)
self._platform_info = plat
LOG.debug("Remote platform info: %s", self._platform_info)
return self._platform_info
def connect_with_host_keys(self):
@ -362,7 +369,7 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
return False
def remote_execute(self, command, with_exit_code=False,
get_pty=False, wd=None):
get_pty=False, cwd=None, **kwargs):
"""Execute an ssh command on a remote host.
Tries cert auth first and falls back
@ -370,8 +377,8 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
:param command: Shell command to be executed by this function.
:param with_exit_code: Include the exit_code in the return body.
:param wd: The child's current directory will be changed
to `wd` before it is executed. Note that this
:param cwd: The child's current directory will be changed
to `cwd` before it is executed. Note that this
directory is not considered when searching the
executable, so you can't specify the program's
path relative to this argument
@ -380,8 +387,8 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
:returns: a dict with stdin, stdout,
and (optionally) the exit code of the call.
"""
if wd:
prefix = "cd %s && " % wd
if cwd:
prefix = "cd %s && " % cwd
command = prefix + command
LOG.debug("Executing '%s' on ssh://%s@%s:%s.",
@ -408,10 +415,10 @@ class SSH(paramiko.SSHClient): # pylint: disable=R0902
'stderr': stderr.read()
}
LOG.debug("STDOUT from ssh://%s@%s:%d: %s",
LOG.debug("STDOUT from ssh://%s@%s:%d: %.5000s ...",
self.username, self.host, self.port,
results['stdout'])
LOG.debug("STDERR from ssh://%s@%s:%d: %s",
LOG.debug("STDERR from ssh://%s@%s:%d: %.5000s ...",
self.username, self.host, self.port,
results['stderr'])
exit_code = chan.recv_exit_status()

View File

@ -24,10 +24,6 @@ from satori import errors
from satori import utils
LOG = logging.getLogger(__name__)
if six.PY3:
def unicode(text, errors=None): # noqa
"""A hacky Python 3 version of unicode() function."""
return str(text)
def get_systeminfo(ipaddress, config, interactive=False):
@ -38,7 +34,7 @@ def get_systeminfo(ipaddress, config, interactive=False):
:keyword interactive: whether to prompt the user for information.
"""
if (ipaddress in utils.get_local_ips() or
ipaddress_module.ip_address(unicode(ipaddress)).is_loopback):
ipaddress_module.ip_address(six.text_type(ipaddress)).is_loopback):
client = bash.LocalShell()
client.host = "localhost"
@ -66,46 +62,66 @@ def system_info(client):
SystemInfoNotJson if `ohai` does not return valid JSON.
SystemInfoMissingJson if `ohai` does not return any JSON.
"""
output = client.execute("sudo -i ohai-solo")
not_found_msgs = ["command not found", "Could not find ohai"]
if any(m in k for m in not_found_msgs
for k in list(output.values()) if isinstance(k, six.string_types)):
LOG.warning("SystemInfoCommandMissing on host: [%s]", client.host)
raise errors.SystemInfoCommandMissing("ohai-solo missing on %s",
client.host)
unicode_output = unicode(output['stdout'], errors='replace')
try:
results = json.loads(unicode_output)
except ValueError as exc:
if client.is_windows():
raise errors.UnsupportedPlatform(
"ohai-solo is a linux-only sytem info provider. "
"Target platform was %s", client.platform_info['dist'])
else:
output = client.execute("sudo -i ohai-solo")
not_found_msgs = ["command not found", "Could not find ohai"]
if any(m in k for m in not_found_msgs
for k in list(output.values()) if isinstance(k,
six.string_types)):
LOG.warning("SystemInfoCommandMissing on host: [%s]", client.host)
raise errors.SystemInfoCommandMissing("ohai-solo missing on %s" %
client.host)
# use string formatting to handle unicode
unicode_output = "%s" % output['stdout']
try:
clean_output = get_json(unicode_output)
results = json.loads(clean_output)
results = json.loads(unicode_output)
except ValueError as exc:
raise errors.SystemInfoNotJson(exc)
return results
try:
clean_output = get_json(unicode_output)
results = json.loads(clean_output)
except ValueError as exc:
raise errors.SystemInfoNotJson(exc)
return results
def install_remote(client):
"""Install ohai-solo on remote system."""
LOG.info("Installing (or updating) ohai-solo on device %s at %s:%d",
client.host, client.host, client.port)
# Download to host
command = "sudo wget -N http://ohai.rax.io/install.sh"
client.execute(command, wd='/tmp')
# Run install
command = "sudo bash install.sh"
output = client.execute(command, wd='/tmp', with_exit_code=True)
# Be a good citizen and clean up your tmp data
command = "sudo rm install.sh"
client.execute(command, wd='/tmp')
# Process install command output
if output['exit_code'] != 0:
raise errors.SystemInfoCommandInstallFailed(output['stderr'][:256])
# Check if it a windows box, but fail safely to Linux
is_windows = False
try:
is_windows = client.is_windows()
except Exception:
pass
if is_windows:
raise errors.UnsupportedPlatform(
"ohai-solo is a linux-only sytem info provider. "
"Target platform was %s", client.platform_info['dist'])
else:
return output
# Download to host
command = "sudo wget -N http://ohai.rax.io/install.sh"
client.execute(command, cwd='/tmp')
# Run install
command = "sudo bash install.sh"
output = client.execute(command, cwd='/tmp', with_exit_code=True)
# Be a good citizen and clean up your tmp data
command = "sudo rm install.sh"
client.execute(command, cwd='/tmp')
# Process install command output
if output['exit_code'] != 0:
raise errors.SystemInfoCommandInstallFailed(
output['stderr'][:256])
else:
return output
def remove_remote(client):
@ -117,24 +133,29 @@ def remove_remote(client):
- redhat [5.x, 6.x]
- centos [5.x, 6.x]
"""
platform_info = client.platform_info
if client.is_debian():
remove = "sudo dpkg --purge ohai-solo"
elif client.is_fedora():
remove = "sudo yum -y erase ohai-solo"
if client.is_windows():
raise errors.UnsupportedPlatform(
"ohai-solo is a linux-only sytem info provider. "
"Target platform was %s", client.platform_info['dist'])
else:
raise errors.UnsupportedPlatform("Unknown distro: %s" %
platform_info['dist'])
command = "%s" % remove
output = client.execute(command, wd='/tmp')
return output
platform_info = client.platform_info
if client.is_debian():
remove = "sudo dpkg --purge ohai-solo"
elif client.is_fedora():
remove = "sudo yum -y erase ohai-solo"
else:
raise errors.UnsupportedPlatform("Unknown distro: %s" %
platform_info['dist'])
command = "%s" % remove
output = client.execute(command, cwd='/tmp')
return output
def get_json(data):
"""Find the JSON string in data and return a string.
:param data: :string:
:returns: string -- JSON string striped of non-JSON data
:returns: string -- JSON string stripped of non-JSON data
:raises: SystemInfoMissingJson
SystemInfoMissingJson if `ohai` does not return any JSON.

149
satori/sysinfo/posh_ohai.py Normal file
View File

@ -0,0 +1,149 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# pylint: disable=W0622
"""PoSh-Ohai Data Plane Discovery Module."""
import json
import logging
import ipaddress as ipaddress_module
import six
from satori import bash
from satori import errors
from satori import utils
LOG = logging.getLogger(__name__)
def get_systeminfo(ipaddress, config, interactive=False):
"""Run data plane discovery using this module against a host.
:param ipaddress: address to the host to discover.
:param config: arguments and configuration suppplied to satori.
:keyword interactive: whether to prompt the user for information.
"""
if (ipaddress in utils.get_local_ips() or
ipaddress_module.ip_address(six.text_type(ipaddress)).is_loopback):
client = bash.LocalShell()
client.host = "localhost"
client.port = 0
else:
client = bash.RemoteShell(ipaddress, username=config['host_username'],
private_key=config['host_key'],
interactive=interactive)
install_remote(client)
return system_info(client)
def system_info(client):
"""Run Posh-Ohai on a remote system and gather the output.
:param client: :class:`smb.SMB` instance
:returns: dict -- system information from PoSh-Ohai
:raises: SystemInfoCommandMissing, SystemInfoCommandOld, SystemInfoNotJson
SystemInfoMissingJson
SystemInfoCommandMissing if `posh-ohai` is not installed.
SystemInfoCommandOld if `posh-ohai` is not the latest.
SystemInfoNotJson if `posh-ohai` does not return valid JSON.
SystemInfoMissingJson if `posh-ohai` does not return any JSON.
"""
if client.is_windows():
powershell_command = 'Get-ComputerConfiguration'
output = client.execute(powershell_command)
unicode_output = "%s" % output
try:
results = json.loads(unicode_output)
except ValueError:
try:
clean_output = get_json(unicode_output)
results = json.loads(clean_output)
except ValueError as err:
raise errors.SystemInfoNotJson(err)
return results
else:
raise errors.PlatformNotSupported(
"PoSh-Ohai is a Windows-only sytem info provider. "
"Target platform was %s", client.platform_info['dist'])
def install_remote(client):
"""Install PoSh-Ohai on remote system."""
LOG.info("Installing (or updating) PoSh-Ohai on device %s at %s:%d",
client.host, client.host, client.port)
# Check is it is a windows box, but fail safely to Linux
is_windows = False
try:
is_windows = client.is_windows()
except Exception:
pass
if is_windows:
powershell_command = ('[scriptblock]::Create((New-Object -TypeName '
'System.Net.WebClient).DownloadString('
'"http://ohai.rax.io/deploy.ps1"))'
'.Invoke()')
# check output to ensure that installation was successful
# if not, raise SystemInfoCommandInstallFailed
output = client.execute(powershell_command)
return output
else:
raise errors.PlatformNotSupported(
"PoSh-Ohai is a Windows-only sytem info provider. "
"Target platform was %s", client.platform_info['dist'])
def remove_remote(client):
"""Remove PoSh-Ohai from specifc remote system.
Currently supports:
- ubuntu [10.x, 12.x]
- debian [6.x, 7.x]
- redhat [5.x, 6.x]
- centos [5.x, 6.x]
"""
if client.is_windows():
powershell_command = ('Remove-Item -Path (Join-Path -Path '
'$($env:PSModulePath.Split(";") '
'| Where-Object { $_.StartsWith('
'$env:SystemRoot)}) -ChildPath '
'"PoSh-Ohai") -Recurse -Force -ErrorAction '
'SilentlyContinue')
output = client.execute(powershell_command)
return output
else:
raise errors.PlatformNotSupported(
"PoSh-Ohai is a Windows-only sytem info provider. "
"Target platform was %s", client.platform_info['dist'])
def get_json(data):
"""Find the JSON string in data and return a string.
:param data: :string:
:returns: string -- JSON string stripped of non-JSON data
:raises: SystemInfoMissingJson
SystemInfoMissingJson if no JSON is returned.
"""
try:
first = data.index('{')
last = data.rindex('}')
return data[first:last + 1]
except ValueError as exc:
context = {"ValueError": "%s" % exc}
raise errors.SystemInfoMissingJson(context)

View File

@ -48,119 +48,124 @@ class TestOhaiSolo(utils.TestCase):
class TestOhaiInstall(utils.TestCase):
def test_install_remote_fedora(self):
mock_ssh = mock.MagicMock()
response = {'exit_code': 0, 'foo': 'bar'}
mock_ssh.execute.return_value = response
result = ohai_solo.install_remote(mock_ssh)
self.assertEqual(result, response)
self.assertEqual(mock_ssh.execute.call_count, 3)
mock_ssh.execute.assert_has_calls([
mock.call('sudo wget -N http://ohai.rax.io/install.sh', wd='/tmp'),
mock.call('sudo bash install.sh', wd='/tmp', with_exit_code=True),
mock.call('sudo rm install.sh', wd='/tmp')])
def setUp(self):
super(TestOhaiInstall, self).setUp()
self.mock_remotesshclient = mock.MagicMock()
self.mock_remotesshclient.is_windows.return_value = False
def test_install_remote_failed(self):
mock_ssh = mock.MagicMock()
def test_install_remote_fedora(self):
response = {'exit_code': 0, 'foo': 'bar'}
self.mock_remotesshclient.execute.return_value = response
result = ohai_solo.install_remote(self.mock_remotesshclient)
self.assertEqual(result, response)
self.assertEqual(self.mock_remotesshclient.execute.call_count, 3)
self.mock_remotesshclient.execute.assert_has_calls([
mock.call('sudo wget -N http://ohai.rax.io/install.sh', cwd='/tmp'),
mock.call('sudo bash install.sh', cwd='/tmp', with_exit_code=True),
mock.call('sudo rm install.sh', cwd='/tmp')])
def test_install_linux_remote_failed(self):
response = {'exit_code': 1, 'stdout': "", "stderr": "FAIL"}
mock_ssh.execute.return_value = response
self.mock_remotesshclient.execute.return_value = response
self.assertRaises(errors.SystemInfoCommandInstallFailed,
ohai_solo.install_remote, mock_ssh)
ohai_solo.install_remote, self.mock_remotesshclient)
class TestOhaiRemove(utils.TestCase):
def setUp(self):
super(TestOhaiRemove, self).setUp()
self.mock_remotesshclient = mock.MagicMock()
self.mock_remotesshclient.is_windows.return_value = False
def test_remove_remote_fedora(self):
mock_ssh = mock.MagicMock()
mock_ssh.is_debian.return_value = False
mock_ssh.is_fedora.return_value = True
self.mock_remotesshclient.is_debian.return_value = False
self.mock_remotesshclient.is_fedora.return_value = True
response = {'exit_code': 0, 'foo': 'bar'}
mock_ssh.execute.return_value = response
result = ohai_solo.remove_remote(mock_ssh)
self.mock_remotesshclient.execute.return_value = response
result = ohai_solo.remove_remote(self.mock_remotesshclient)
self.assertEqual(result, response)
mock_ssh.execute.assert_called_once_with(
'sudo yum -y erase ohai-solo', wd='/tmp')
self.mock_remotesshclient.execute.assert_called_once_with(
'sudo yum -y erase ohai-solo', cwd='/tmp')
def test_remove_remote_debian(self):
mock_ssh = mock.MagicMock()
mock_ssh.is_debian.return_value = True
mock_ssh.is_fedora.return_value = False
self.mock_remotesshclient.is_debian.return_value = True
self.mock_remotesshclient.is_fedora.return_value = False
response = {'exit_code': 0, 'foo': 'bar'}
mock_ssh.execute.return_value = response
result = ohai_solo.remove_remote(mock_ssh)
self.mock_remotesshclient.execute.return_value = response
result = ohai_solo.remove_remote(self.mock_remotesshclient)
self.assertEqual(result, response)
mock_ssh.execute.assert_called_once_with(
'sudo dpkg --purge ohai-solo', wd='/tmp')
self.mock_remotesshclient.execute.assert_called_once_with(
'sudo dpkg --purge ohai-solo', cwd='/tmp')
def test_remove_remote_unsupported(self):
mock_ssh = mock.MagicMock()
mock_ssh.is_debian.return_value = False
mock_ssh.is_fedora.return_value = False
self.mock_remotesshclient.is_debian.return_value = False
self.mock_remotesshclient.is_fedora.return_value = False
self.assertRaises(errors.UnsupportedPlatform,
ohai_solo.remove_remote, mock_ssh)
ohai_solo.remove_remote, self.mock_remotesshclient)
class TestSystemInfo(utils.TestCase):
def setUp(self):
super(TestSystemInfo, self).setUp()
self.mock_remotesshclient = mock.MagicMock()
self.mock_remotesshclient.is_windows.return_value = False
def test_system_info(self):
mock_ssh = mock.MagicMock()
mock_ssh.execute.return_value = {
self.mock_remotesshclient.execute.return_value = {
'exit_code': 0,
'stdout': "{}",
'stderr': ""
}
ohai_solo.system_info(mock_ssh)
mock_ssh.execute.assert_called_with("sudo -i ohai-solo")
ohai_solo.system_info(self.mock_remotesshclient)
self.mock_remotesshclient.execute.assert_called_with(
"sudo -i ohai-solo")
def test_system_info_with_motd(self):
mock_ssh = mock.MagicMock()
mock_ssh.execute.return_value = {
self.mock_remotesshclient.execute.return_value = {
'exit_code': 0,
'stdout': "Hello world\n {}",
'stderr': ""
}
ohai_solo.system_info(mock_ssh)
mock_ssh.execute.assert_called_with("sudo -i ohai-solo")
ohai_solo.system_info(self.mock_remotesshclient)
self.mock_remotesshclient.execute.assert_called_with("sudo -i ohai-solo")
def test_system_info_bad_json(self):
mock_ssh = mock.MagicMock()
mock_ssh.execute.return_value = {
self.mock_remotesshclient.execute.return_value = {
'exit_code': 0,
'stdout': "{Not JSON!}",
'stderr': ""
}
self.assertRaises(errors.SystemInfoNotJson, ohai_solo.system_info,
mock_ssh)
self.mock_remotesshclient)
def test_system_info_missing_json(self):
mock_ssh = mock.MagicMock()
mock_ssh.execute.return_value = {
self.mock_remotesshclient.execute.return_value = {
'exit_code': 0,
'stdout': "No JSON!",
'stderr': ""
}
self.assertRaises(errors.SystemInfoMissingJson, ohai_solo.system_info,
mock_ssh)
self.mock_remotesshclient)
def test_system_info_command_not_found(self):
mock_ssh = mock.MagicMock()
mock_ssh.execute.return_value = {
self.mock_remotesshclient.execute.return_value = {
'exit_code': 1,
'stdout': "",
'stderr': "ohai-solo command not found"
}
self.assertRaises(errors.SystemInfoCommandMissing,
ohai_solo.system_info, mock_ssh)
ohai_solo.system_info, self.mock_remotesshclient)
def test_system_info_could_not_find(self):
mock_ssh = mock.MagicMock()
mock_ssh.execute.return_value = {
self.mock_remotesshclient.execute.return_value = {
'exit_code': 1,
'stdout': "",
'stderr': "Could not find ohai-solo."
}
self.assertRaises(errors.SystemInfoCommandMissing,
ohai_solo.system_info, mock_ssh)
ohai_solo.system_info, self.mock_remotesshclient)
if __name__ == "__main__":

167
satori/tunnel.py Normal file
View File

@ -0,0 +1,167 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""SSH tunneling module.
Set up a forward tunnel across an SSH server, using paramiko. A local port
(given with -p) is forwarded across an SSH session to an address:port from
the SSH server. This is similar to the openssh -L option.
"""
try:
import eventlet
eventlet.monkey_patch()
from eventlet.green import threading
from eventlet.green import time
except ImportError:
import threading
import time
pass
import logging
import select
import socket
try:
import SocketServer
except ImportError:
import socketserver as SocketServer
import paramiko
LOG = logging.getLogger(__name__)
class TunnelServer(SocketServer.ThreadingTCPServer):
"""Serve on a local ephemeral port.
Clients will connect to that port/server.
"""
daemon_threads = True
allow_reuse_address = True
class TunnelHandler(SocketServer.BaseRequestHandler):
"""Handle forwarding of packets."""
def handle(self):
"""Do all the work required to service a request.
The request is available as self.request, the client address as
self.client_address, and the server instance as self.server, in
case it needs to access per-server information.
This implementation will forward packets.
"""
try:
chan = self.ssh_transport.open_channel('direct-tcpip',
self.target_address,
self.request.getpeername())
except Exception as exc:
LOG.error('Incoming request to %s:%s failed',
self.target_address[0],
self.target_address[1],
exc_info=exc)
return
if chan is None:
LOG.error('Incoming request to %s:%s was rejected '
'by the SSH server.',
self.target_address[0],
self.target_address[1])
return
while True:
r, w, x = select.select([self.request, chan], [], [])
if self.request in r:
data = self.request.recv(1024)
if len(data) == 0:
break
chan.send(data)
if chan in r:
data = chan.recv(1024)
if len(data) == 0:
break
self.request.send(data)
try:
peername = None
peername = str(self.request.getpeername())
except socket.error as exc:
LOG.warning("Couldn't fetch peername.", exc_info=exc)
chan.close()
self.request.close()
LOG.info("Tunnel closed from '%s'", peername or 'unnamed peer')
class Tunnel(object): # pylint: disable=R0902
"""Create a TCP server which will use TunnelHandler."""
def __init__(self, target_host, target_port,
sshclient, tunnel_host='localhost',
tunnel_port=0):
"""Constructor."""
if not isinstance(sshclient, paramiko.SSHClient):
raise TypeError("'sshclient' must be an instance of "
"paramiko.SSHClient.")
self.target_host = target_host
self.target_port = target_port
self.target_address = (target_host, target_port)
self.address = (tunnel_host, tunnel_port)
self._tunnel = None
self._tunnel_thread = None
self.sshclient = sshclient
self._ssh_transport = self.get_sshclient_transport(
self.sshclient)
TunnelHandler.target_address = self.target_address
TunnelHandler.ssh_transport = self._ssh_transport
self._tunnel = TunnelServer(self.address, TunnelHandler)
# reset attribute to the port it has actually been set to
self.address = self._tunnel.server_address
tunnel_host, self.tunnel_port = self.address
def get_sshclient_transport(self, sshclient):
"""Get the sshclient's transport.
Connect the sshclient, that has been passed in and return its
transport.
"""
sshclient.connect()
return sshclient.get_transport()
def serve_forever(self, async=True):
"""Serve the tunnel forever.
if async is True, this will be done in a background thread
"""
if not async:
self._tunnel.serve_forever()
else:
self._tunnel_thread = threading.Thread(
target=self._tunnel.serve_forever)
self.start()
# cooperative yield
time.sleep(0)
def shutdown(self):
"""Stop serving the tunnel.
Also close the socket.
"""
self._tunnel.shutdown()
self._tunnel.socket.close()