 083f56ddc9
			
		
	
	083f56ddc9
	
	
	
		
			
			paramiko returns chunk of binary type. So code must join all chunks with binary type and not as str. Also this patchset adds one more parameter 'encoding' for exec function to decode binary result to str. Closes-Bug: #1495181 Change-Id: I886c9d0c143d3553876ece7b6d1ac2df7e49d701
		
			
				
	
	
		
			174 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			174 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright 2012 OpenStack Foundation
 | |
| # All Rights Reserved.
 | |
| #
 | |
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | |
| #    not use this file except in compliance with the License. You may obtain
 | |
| #    a copy of the License at
 | |
| #
 | |
| #         http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #    Unless required by applicable law or agreed to in writing, software
 | |
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | |
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | |
| #    License for the specific language governing permissions and limitations
 | |
| #    under the License.
 | |
| 
 | |
| 
 | |
| import select
 | |
| import socket
 | |
| import time
 | |
| import warnings
 | |
| 
 | |
| from oslo_log import log as logging
 | |
| import six
 | |
| 
 | |
| from tempest_lib import exceptions
 | |
| 
 | |
| 
 | |
| with warnings.catch_warnings():
 | |
|     warnings.simplefilter("ignore")
 | |
|     import paramiko
 | |
| 
 | |
| 
 | |
| LOG = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class Client(object):
 | |
| 
 | |
|     def __init__(self, host, username, password=None, timeout=300, pkey=None,
 | |
|                  channel_timeout=10, look_for_keys=False, key_filename=None):
 | |
|         self.host = host
 | |
|         self.username = username
 | |
|         self.password = password
 | |
|         if isinstance(pkey, six.string_types):
 | |
|             pkey = paramiko.RSAKey.from_private_key(
 | |
|                 six.StringIO(str(pkey)))
 | |
|         self.pkey = pkey
 | |
|         self.look_for_keys = look_for_keys
 | |
|         self.key_filename = key_filename
 | |
|         self.timeout = int(timeout)
 | |
|         self.channel_timeout = float(channel_timeout)
 | |
|         self.buf_size = 1024
 | |
| 
 | |
|     def _get_ssh_connection(self, sleep=1.5, backoff=1):
 | |
|         """Returns an ssh connection to the specified host."""
 | |
|         bsleep = sleep
 | |
|         ssh = paramiko.SSHClient()
 | |
|         ssh.set_missing_host_key_policy(
 | |
|             paramiko.AutoAddPolicy())
 | |
|         _start_time = time.time()
 | |
|         if self.pkey is not None:
 | |
|             LOG.info("Creating ssh connection to '%s' as '%s'"
 | |
|                      " with public key authentication",
 | |
|                      self.host, self.username)
 | |
|         else:
 | |
|             LOG.info("Creating ssh connection to '%s' as '%s'"
 | |
|                      " with password %s",
 | |
|                      self.host, self.username, str(self.password))
 | |
|         attempts = 0
 | |
|         while True:
 | |
|             try:
 | |
|                 ssh.connect(self.host, username=self.username,
 | |
|                             password=self.password,
 | |
|                             look_for_keys=self.look_for_keys,
 | |
|                             key_filename=self.key_filename,
 | |
|                             timeout=self.channel_timeout, pkey=self.pkey)
 | |
|                 LOG.info("ssh connection to %s@%s successfuly created",
 | |
|                          self.username, self.host)
 | |
|                 return ssh
 | |
|             except (socket.error,
 | |
|                     paramiko.SSHException) as e:
 | |
|                 if self._is_timed_out(_start_time):
 | |
|                     LOG.exception("Failed to establish authenticated ssh"
 | |
|                                   " connection to %s@%s after %d attempts",
 | |
|                                   self.username, self.host, attempts)
 | |
|                     raise exceptions.SSHTimeout(host=self.host,
 | |
|                                                 user=self.username,
 | |
|                                                 password=self.password)
 | |
|                 bsleep += backoff
 | |
|                 attempts += 1
 | |
|                 LOG.warning("Failed to establish authenticated ssh"
 | |
|                             " connection to %s@%s (%s). Number attempts: %s."
 | |
|                             " Retry after %d seconds.",
 | |
|                             self.username, self.host, e, attempts, bsleep)
 | |
|                 time.sleep(bsleep)
 | |
| 
 | |
|     def _is_timed_out(self, start_time):
 | |
|         return (time.time() - self.timeout) > start_time
 | |
| 
 | |
|     @staticmethod
 | |
|     def _can_system_poll():
 | |
|         return hasattr(select, 'poll')
 | |
| 
 | |
|     def exec_command(self, cmd, encoding="utf-8"):
 | |
|         """Execute the specified command on the server
 | |
| 
 | |
|         Note that this method is reading whole command outputs to memory, thus
 | |
|         shouldn't be used for large outputs.
 | |
| 
 | |
|         :param str cmd: Command to run at remote server.
 | |
|         :param str encoding: Encoding for result from paramiko.
 | |
|                              Result will not be decoded if None.
 | |
|         :returns: data read from standard output of the command.
 | |
|         :raises: SSHExecCommandFailed if command returns nonzero
 | |
|                  status. The exception contains command status stderr content.
 | |
|         :raises: TimeoutException if cmd doesn't end when timeout expires.
 | |
|         """
 | |
|         ssh = self._get_ssh_connection()
 | |
|         transport = ssh.get_transport()
 | |
|         channel = transport.open_session()
 | |
|         channel.fileno()  # Register event pipe
 | |
|         channel.exec_command(cmd)
 | |
|         channel.shutdown_write()
 | |
|         exit_status = channel.recv_exit_status()
 | |
| 
 | |
|         # If the executing host is linux-based, poll the channel
 | |
|         if self._can_system_poll():
 | |
|             out_data_chunks = []
 | |
|             err_data_chunks = []
 | |
|             poll = select.poll()
 | |
|             poll.register(channel, select.POLLIN)
 | |
|             start_time = time.time()
 | |
| 
 | |
|             while True:
 | |
|                 ready = poll.poll(self.channel_timeout)
 | |
|                 if not any(ready):
 | |
|                     if not self._is_timed_out(start_time):
 | |
|                         continue
 | |
|                     raise exceptions.TimeoutException(
 | |
|                         "Command: '{0}' executed on host '{1}'.".format(
 | |
|                             cmd, self.host))
 | |
|                 if not ready[0]:  # If there is nothing to read.
 | |
|                     continue
 | |
|                 out_chunk = err_chunk = None
 | |
|                 if channel.recv_ready():
 | |
|                     out_chunk = channel.recv(self.buf_size)
 | |
|                     out_data_chunks += out_chunk,
 | |
|                 if channel.recv_stderr_ready():
 | |
|                     err_chunk = channel.recv_stderr(self.buf_size)
 | |
|                     err_data_chunks += err_chunk,
 | |
|                 if channel.closed and not err_chunk and not out_chunk:
 | |
|                     break
 | |
|             out_data = b''.join(out_data_chunks)
 | |
|             err_data = b''.join(err_data_chunks)
 | |
|         # Just read from the channels
 | |
|         else:
 | |
|             out_file = channel.makefile('rb', self.buf_size)
 | |
|             err_file = channel.makefile_stderr('rb', self.buf_size)
 | |
|             out_data = out_file.read()
 | |
|             err_data = err_file.read()
 | |
|         if encoding:
 | |
|             out_data = out_data.decode(encoding)
 | |
|             err_data = err_data.decode(encoding)
 | |
| 
 | |
|         if 0 != exit_status:
 | |
|             raise exceptions.SSHExecCommandFailed(
 | |
|                 command=cmd, exit_status=exit_status,
 | |
|                 stderr=err_data, stdout=out_data)
 | |
|         return out_data
 | |
| 
 | |
|     def test_connection_auth(self):
 | |
|         """Raises an exception when we can not connect to server via ssh."""
 | |
|         connection = self._get_ssh_connection()
 | |
|         connection.close()
 |