# Copyright 2012 OpenStack Foundation # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import select import socket import time import warnings from oslo_log import log as logging import six from tempest.lib import exceptions with warnings.catch_warnings(): warnings.simplefilter("ignore") import paramiko LOG = logging.getLogger(__name__) class Client(object): def __init__(self, host, username, password=None, timeout=300, pkey=None, channel_timeout=10, look_for_keys=False, key_filename=None, port=22, proxy_client=None): """SSH client. Many of parameters are just passed to the underlying implementation as it is. See the paramiko documentation for more details. http://docs.paramiko.org/en/2.1/api/client.html#paramiko.client.SSHClient.connect :param host: Host to login. :param username: SSH username. :param password: SSH password, or a password to unlock private key. :param timeout: Timeout in seconds, including retries. Default is 300 seconds. :param pkey: Private key. :param channel_timeout: Channel timeout in seconds, passed to the paramiko. Default is 10 seconds. :param look_for_keys: Whether or not to search for private keys in ``~/.ssh``. Default is False. :param key_filename: Filename for private key to use. :param port: SSH port number. :param proxy_client: Another SSH client to provide a transport for ssh-over-ssh. The default is None, which means not to use ssh-over-ssh. :type proxy_client: ``tempest.lib.common.ssh.Client`` object """ self.host = host self.username = username self.port = port self.password = password if isinstance(pkey, six.string_types): pkey = paramiko.RSAKey.from_private_key( six.StringIO(str(pkey))) self.pkey = pkey self.look_for_keys = look_for_keys self.key_filename = key_filename self.timeout = int(timeout) self.channel_timeout = float(channel_timeout) self.buf_size = 1024 self.proxy_client = proxy_client if (self.proxy_client and self.proxy_client.host == self.host and self.proxy_client.port == self.port and self.proxy_client.username == self.username): raise exceptions.SSHClientProxyClientLoop( host=self.host, port=self.port, username=self.username) self._proxy_conn = None def _get_ssh_connection(self, sleep=1.5, backoff=1): """Returns an ssh connection to the specified host.""" bsleep = sleep ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy( paramiko.AutoAddPolicy()) _start_time = time.time() if self.pkey is not None: LOG.info("Creating ssh connection to '%s:%d' as '%s'" " with public key authentication", self.host, self.port, self.username) else: LOG.info("Creating ssh connection to '%s:%d' as '%s'" " with password %s", self.host, self.port, self.username, str(self.password)) attempts = 0 while True: if self.proxy_client is not None: proxy_chan = self._get_proxy_channel() else: proxy_chan = None try: ssh.connect(self.host, port=self.port, username=self.username, password=self.password, look_for_keys=self.look_for_keys, key_filename=self.key_filename, timeout=self.channel_timeout, pkey=self.pkey, sock=proxy_chan) LOG.info("ssh connection to %s@%s successfully created", self.username, self.host) return ssh except (EOFError, socket.error, socket.timeout, paramiko.SSHException) as e: ssh.close() if self._is_timed_out(_start_time): LOG.exception("Failed to establish authenticated ssh" " connection to %s@%s after %d attempts. " "Proxy client: %s", self.username, self.host, attempts, self._get_proxy_client_info()) raise exceptions.SSHTimeout(host=self.host, user=self.username, password=self.password) bsleep += backoff attempts += 1 LOG.warning("Failed to establish authenticated ssh" " connection to %s@%s (%s). Number attempts: %s." " Retry after %d seconds.", self.username, self.host, e, attempts, bsleep) time.sleep(bsleep) def _is_timed_out(self, start_time): return (time.time() - self.timeout) > start_time @staticmethod def _can_system_poll(): return hasattr(select, 'poll') def exec_command(self, cmd, encoding="utf-8"): """Execute the specified command on the server Note that this method is reading whole command outputs to memory, thus shouldn't be used for large outputs. :param str cmd: Command to run at remote server. :param str encoding: Encoding for result from paramiko. Result will not be decoded if None. :returns: data read from standard output of the command. :raises: SSHExecCommandFailed if command returns nonzero status. The exception contains command status stderr content. :raises: TimeoutException if cmd doesn't end when timeout expires. """ ssh = self._get_ssh_connection() transport = ssh.get_transport() with transport.open_session() as channel: channel.fileno() # Register event pipe channel.exec_command(cmd) channel.shutdown_write() # If the executing host is linux-based, poll the channel if self._can_system_poll(): out_data_chunks = [] err_data_chunks = [] poll = select.poll() poll.register(channel, select.POLLIN) start_time = time.time() while True: ready = poll.poll(self.channel_timeout) if not any(ready): if not self._is_timed_out(start_time): continue raise exceptions.TimeoutException( "Command: '{0}' executed on host '{1}'.".format( cmd, self.host)) if not ready[0]: # If there is nothing to read. continue out_chunk = err_chunk = None if channel.recv_ready(): out_chunk = channel.recv(self.buf_size) out_data_chunks += out_chunk, if channel.recv_stderr_ready(): err_chunk = channel.recv_stderr(self.buf_size) err_data_chunks += err_chunk, if not err_chunk and not out_chunk: break out_data = b''.join(out_data_chunks) err_data = b''.join(err_data_chunks) # Just read from the channels else: out_file = channel.makefile('rb', self.buf_size) err_file = channel.makefile_stderr('rb', self.buf_size) out_data = out_file.read() err_data = err_file.read() if encoding: out_data = out_data.decode(encoding) err_data = err_data.decode(encoding) exit_status = channel.recv_exit_status() ssh.close() if 0 != exit_status: raise exceptions.SSHExecCommandFailed( command=cmd, exit_status=exit_status, stderr=err_data, stdout=out_data) return out_data def test_connection_auth(self): """Raises an exception when we can not connect to server via ssh.""" connection = self._get_ssh_connection() connection.close() def _get_proxy_channel(self): conn = self.proxy_client._get_ssh_connection() # Keep a reference to avoid g/c # https://github.com/paramiko/paramiko/issues/440 self._proxy_conn = conn transport = conn.get_transport() chan = transport.open_session() cmd = 'nc %s %s' % (self.host, self.port) chan.exec_command(cmd) return chan def _get_proxy_client_info(self): if not self.proxy_client: return 'no proxy client' nested_pclient = self.proxy_client._get_proxy_client_info() return ('%(username)s@%(host)s:%(port)s, nested proxy client: ' '%(nested_pclient)s' % {'username': self.proxy_client.username, 'host': self.proxy_client.host, 'port': self.proxy_client.port, 'nested_pclient': nested_pclient})