satori/satori/ssh.py

513 lines
20 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""SSH Module for connecting to and automating remote commands.
Supports proxying through an ssh tunnel ('gateway' keyword argument.)
To control the behavior of the SSH client, use the specific connect_with_*
calls. The .connect() call behaves like the ssh command and attempts a number
of connection methods, including using the curent user's ssh keys.
If interactive is set to true, the module will also prompt for a password if no
other connection methods succeeded.
Note that test_connection() calls connect(). To test a connection and control
the authentication methods used, just call connect_with_* and catch any
exceptions instead of using test_connect().
"""
import ast
import getpass
import logging
import os
import re
import time
import paramiko
import six
from satori import errors
from satori import utils
LOG = logging.getLogger(__name__)
MIN_PASSWORD_PROMPT_LEN = 8
MAX_PASSWORD_PROMPT_LEN = 64
TEMPFILE_PREFIX = ".satori.tmp.key."
TTY_REQUIRED = [
"you must have a tty to run sudo",
"is not a tty",
"no tty present",
"must be run from a terminal",
]
def shellquote(s):
r"""Quote a string for use on a command line.
This wraps the string in single-quotes and converts any existing
single-quotes to r"'\''". Here the first single-quote ends the
previous quoting, the escaped single-quote becomes a literal
single-quote, and the last single-quote quotes the next part of
the string.
"""
return "'%s'" % s.replace("'", r"'\''")
def make_pkey(private_key):
"""Return a paramiko.pkey.PKey from private key string."""
key_classes = [paramiko.rsakey.RSAKey,
paramiko.dsskey.DSSKey,
paramiko.ecdsakey.ECDSAKey, ]
keyfile = six.StringIO(private_key)
for cls in key_classes:
keyfile.seek(0)
try:
pkey = cls.from_private_key(keyfile)
except paramiko.SSHException:
continue
else:
keytype = cls
LOG.info("Valid SSH Key provided (%s)", keytype.__name__)
return pkey
raise paramiko.SSHException("Is not a valid private key")
def connect(*args, **kwargs):
"""Connect to a remote device over SSH."""
try:
return SSH.get_client(*args, **kwargs)
except TypeError as exc:
msg = "got an unexpected"
if msg in str(exc):
message = "%s " + str(exc)[str(exc).index(msg):]
raise exc.__class__(message % "connect()")
raise
class AcceptMissingHostKey(paramiko.client.MissingHostKeyPolicy):
"""Allow connections to hosts whose fingerprints are not on record."""
# pylint: disable=R0903
def missing_host_key(self, client, hostname, key):
"""Add missing host key."""
# pylint: disable=W0212
client._host_keys.add(hostname, key.get_name(), key)
class SSH(paramiko.SSHClient): # pylint: disable=R0902
"""Connects to devices via SSH to execute commands."""
# pylint: disable=R0913
def __init__(self, host, password=None, username="root", private_key=None,
root_password=None, key_filename=None, port=22,
timeout=20, gateway=None, options=None, interactive=False):
"""Create an instance of the SSH class.
:param str host: The ip address or host name of the server
to connect to
:param str password: A password to use for authentication
or for unlocking a private key
:param username: The username to authenticate as
:param root_password: root user password to be used if 'username'
is not root. This will use 'username' and
'password to login and then 'su' to root
using root_password
:param private_key: Private SSH Key string to use
(instead of using a filename)
:param key_filename: a private key filename (path)
:param port: tcp/ip port to use (defaults to 22)
:param float timeout: an optional timeout (in seconds) for the
TCP connection
:param socket gateway: an existing SSH instance to use
for proxying
:param dict options: A dictionary used to set ssh options
(when proxying).
e.g. for `ssh -o StrictHostKeyChecking=no`,
you would provide
(.., options={'StrictHostKeyChecking': 'no'})
Conversion of booleans is also supported,
(.., options={'StrictHostKeyChecking': False})
is equivalent.
:keyword interactive: If true, prompt for password if missing.
"""
self.password = password
self.host = host
self.username = username or 'root'
self.root_password = root_password
self.private_key = private_key
self.key_filename = key_filename
self.port = port or 22
self.timeout = timeout
self._platform_info = None
self.options = options or {}
self.gateway = gateway
self.sock = None
self.interactive = interactive
self.escalation_command = 'sudo -i %s'
if self.root_password:
self.escalation_command = "su -c '%s'"
if self.gateway:
if not isinstance(self.gateway, SSH):
raise TypeError("'gateway' must be a satori.ssh.SSH instance. "
"( instances of this type are returned by "
"satori.ssh.connect() )")
super(SSH, self).__init__()
def __del__(self):
"""Destructor to close the connection."""
self.close()
@classmethod
def get_client(cls, *args, **kwargs):
"""Return an ssh client object from this module."""
return cls(*args, **kwargs)
@property
def platform_info(self):
"""Return distro, version, architecture.
Requires >= Python 2.4 on remote system.
"""
if not self._platform_info:
platform_command = "import platform,sys\n"
platform_command += utils.get_source_definition(
utils.get_platform_info)
platform_command += ("\nsys.stdout.write(str("
"get_platform_info()))\n")
command = 'echo %s | python' % shellquote(platform_command)
output = self.remote_execute(command)
stdout = re.split('\n|\r\n', output['stdout'])[-1].strip()
if stdout:
try:
plat = ast.literal_eval(stdout)
except SyntaxError as exc:
plat = {'dist': 'unknown'}
LOG.warning("Error parsing response from host '%s': %s",
self.host, output, exc_info=exc)
else:
plat = {'dist': 'unknown'}
LOG.warning("Blank response from host '%s': %s",
self.host, output)
self._platform_info = plat
return self._platform_info
def connect_with_host_keys(self):
"""Try connecting with locally available keys (ex. ~/.ssh/id_rsa)."""
LOG.debug("Trying to connect with local host keys")
return self._connect(look_for_keys=True, allow_agent=False)
def connect_with_password(self):
"""Try connecting with password."""
LOG.debug("Trying to connect with password")
if self.interactive and not self.password:
LOG.debug("Prompting for password (interactive=%s)",
self.interactive)
try:
self.password = getpass.getpass("Enter password for %s:" %
self.username)
except KeyboardInterrupt:
LOG.debug("User cancelled at password prompt")
if not self.password:
raise paramiko.PasswordRequiredException("Password not provided")
return self._connect(
password=self.password,
look_for_keys=False,
allow_agent=False)
def connect_with_key_file(self):
"""Try connecting with key file."""
LOG.debug("Trying to connect with key file")
if not self.key_filename:
raise paramiko.AuthenticationException("No key file supplied")
return self._connect(
key_filename=os.path.expanduser(self.key_filename),
look_for_keys=False,
allow_agent=False)
def connect_with_key(self):
"""Try connecting with key string."""
LOG.debug("Trying to connect with private key string")
if not self.private_key:
raise paramiko.AuthenticationException("No key supplied")
pkey = make_pkey(self.private_key)
return self._connect(
pkey=pkey,
look_for_keys=False,
allow_agent=False)
def _connect(self, **kwargs):
"""Set up client and connect to target."""
self.load_system_host_keys()
if self.options.get('StrictHostKeyChecking') in (False, "no"):
self.set_missing_host_key_policy(AcceptMissingHostKey())
if self.gateway:
# lazy load
if not self.gateway.get_transport():
self.gateway.connect()
self.sock = self.gateway.get_transport().open_channel(
'direct-tcpip', (self.host, self.port), ('', 0))
return super(SSH, self).connect(
self.host,
timeout=kwargs.pop('timeout', self.timeout),
port=kwargs.pop('port', self.port),
username=kwargs.pop('username', self.username),
pkey=kwargs.pop('pkey', None),
sock=kwargs.pop('sock', self.sock),
**kwargs)
def connect(self): # pylint: disable=W0221
"""Attempt an SSH connection through paramiko.SSHClient.connect.
The order for authentication attempts is:
- private_key
- key_filename
- any key discoverable in ~/.ssh/
- username/password (will prompt if the password is not supplied and
interactive is true)
"""
# idempotency
if self.get_transport():
if self.get_transport().is_active():
return
if self.private_key:
try:
return self.connect_with_key()
except paramiko.SSHException:
pass # try next method
if self.key_filename:
try:
return self.connect_with_key_file()
except paramiko.SSHException:
pass # try next method
try:
return self.connect_with_host_keys()
except paramiko.SSHException:
pass # try next method
try:
return self.connect_with_password()
except paramiko.BadHostKeyException as exc:
msg = (
"ssh://%s@%s:%d failed: %s. You might have a bad key "
"entry on your server, but this is a security issue and "
"won't be handled automatically. To fix this you can remove "
"the host entry for this host from the /.ssh/known_hosts file")
LOG.info(msg, self.username, self.host, self.port, exc)
raise exc
except Exception as exc:
LOG.info('ssh://%s@%s:%d failed. %s',
self.username, self.host, self.port, exc)
raise exc
def test_connection(self):
"""Connect to an ssh server and verify that it responds.
The order for authentication attempts is:
(1) private_key
(2) key_filename
(3) any key discoverable in ~/.ssh/
(4) username/password
"""
LOG.debug("Checking for a response from ssh://%s@%s:%d.",
self.username, self.host, self.port)
try:
self.connect()
LOG.debug("ssh://%s@%s:%d is up.",
self.username, self.host, self.port)
return True
except Exception as exc:
LOG.info("ssh://%s@%s:%d failed. %s",
self.username, self.host, self.port, exc)
return False
finally:
self.close()
def close(self):
"""Close the connection to the remote host.
If an ssh tunnel is being used, close that first.
"""
if self.gateway:
self.gateway.close()
return super(SSH, self).close()
def _handle_tty_required(self, results, get_pty):
"""Determine whether the result implies a tty request."""
if any(m in str(k) for m in TTY_REQUIRED for k in results.values()):
LOG.info('%s requires TTY for sudo/su. Using TTY mode.',
self.host)
if get_pty is True: # if this is *already* True
raise errors.GetPTYRetryFailure(
"Running command with get_pty=True FAILED: %s@%s:%d"
% (self.username, self.host, self.port))
else:
return True
return False
def _handle_password_prompt(self, stdin, stdout, su_auth=False):
"""Determine whether the remote host is prompting for a password.
Respond to the prompt through stdin if applicable.
"""
if not stdout.channel.closed:
buflen = len(stdout.channel.in_buffer)
# min and max determined from max username length
# and a set of encountered linux password prompts
if MIN_PASSWORD_PROMPT_LEN < buflen < MAX_PASSWORD_PROMPT_LEN:
prompt = stdout.channel.recv(buflen)
if all(m in prompt.lower()
for m in ['password', ':']):
LOG.warning("%s@%s encountered prompt! of length "
" [%s] {%s}",
self.username, self.host, buflen, prompt)
if su_auth:
LOG.warning("Escalating using 'su -'.")
stdin.write("%s\n" % self.root_password)
else:
stdin.write("%s\n" % self.password)
stdin.flush()
return True
else:
LOG.warning("Nearly a False-Positive on "
"password prompt detection. [%s] {%s}",
buflen, prompt)
stdout.channel.send(prompt)
return False
def _command_is_already_running(self, command):
"""Check to see if the command is already running using ps & grep."""
# check plain 'command' w/o prefix or escalation
check_cmd = 'ps -ef |grep -v grep|grep -c "%s"' % command
result = self.remote_execute(check_cmd, keepalive=True,
allow_many=True)
if result['stdout'] != '0':
return True
else:
LOG.debug("Remote command %s IS NOT already running. "
"Continuing with remote_execute.", command)
def remote_execute(self, command, with_exit_code=False, # noqa
get_pty=False, cwd=None, keepalive=True,
escalate=False, allow_many=True, **kw):
"""Execute an ssh command on a remote host.
Tries cert auth first and falls back
to password auth if password provided.
:param command: Shell command to be executed by this function.
:param with_exit_code: Include the exit_code in the return body.
:param cwd: The child's current directory will be changed
to `cwd` before it is executed. Note that this
directory is not considered when searching the
executable, so you can't specify the program's
path relative to this argument
:param get_pty: Request a pseudo-terminal from the server.
:param allow_many: If False, do not run command if it is already
found running on remote client.
:returns: a dict with stdin, stdout,
and (optionally) the exit code of the call.
"""
if escalate and self.username != 'root':
run_command = self.escalation_command % command
else:
run_command = command
if cwd:
prefix = "cd %s && " % cwd
run_command = prefix + run_command
# _command_is_already_running wont be called if allow_many is True
# python is great :)
if not allow_many and self._command_is_already_running(command):
raise errors.SatoriDuplicateCommandException(
"Remote command %s is already running and allow_many was "
"set to False. Aborting remote_execute." % command)
try:
self.connect()
results = None
chan = self.get_transport().open_session()
su_auth = False
if 'su -' in run_command:
su_auth = True
get_pty = True
if get_pty:
chan.get_pty()
stdin = chan.makefile('wb')
stdout = chan.makefile('rb')
stderr = chan.makefile_stderr('rb')
LOG.debug("Executing '%s' on ssh://%s@%s:%s.",
run_command, self.username, self.host, self.port)
chan.exec_command(run_command)
LOG.debug('ssh://%s@%s:%d responded.', self.username, self.host,
self.port)
time.sleep(.25)
self._handle_password_prompt(stdin, stdout, su_auth=su_auth)
results = {
'stdout': stdout.read().strip(),
'stderr': stderr.read()
}
LOG.debug("STDOUT from ssh://%s@%s:%d: %.5000s ...",
self.username, self.host, self.port,
results['stdout'])
LOG.debug("STDERR from ssh://%s@%s:%d: %.5000s ...",
self.username, self.host, self.port,
results['stderr'])
exit_code = chan.recv_exit_status()
if with_exit_code:
results.update({'exit_code': exit_code})
if not keepalive:
chan.close()
if self._handle_tty_required(results, get_pty):
return self.remote_execute(
command, with_exit_code=with_exit_code, get_pty=True,
cwd=cwd, keepalive=keepalive, escalate=escalate,
allow_many=allow_many)
return results
except Exception as exc:
LOG.info("ssh://%s@%s:%d failed. | %s", self.username, self.host,
self.port, exc)
raise
finally:
if not keepalive:
self.close()
# Share SSH.__init__'s docstring
connect.__doc__ = SSH.__init__.__doc__
try:
SSH.__dict__['get_client'].__doc__ = SSH.__dict__['__init__'].__doc__
except AttributeError:
SSH.get_client.__func__.__doc__ = SSH.__init__.__doc__