Switch sshutils.SSH from subprocess to paramiko

Methods for executing commands/scripts and upload
files to remote host via ssh of sshutils.SSH class
were rewritten using the paramiko lib.

Related: blueprint serverdto-ssh

Change-Id: I1786d5c4d4b6e9c661ec84785dd9255ab39de4aa
This commit is contained in:
Maksym Iarmak 2013-11-12 16:55:26 +02:00
parent b510482e83
commit b4e7cec59b
2 changed files with 142 additions and 73 deletions

View File

@ -14,7 +14,13 @@
# under the License.
import eventlet
import subprocess
import os
import paramiko
import random
import select
import socket
import string
import time
from rally import exceptions
from rally.openstack.common.gettextutils import _ # noqa
@ -26,43 +32,89 @@ LOG = logging.getLogger(__name__)
class SSH(object):
"""SSH common functions."""
OPTIONS = ['-o', 'StrictHostKeyChecking=no']
def __init__(self, ip, user, port=22, key=None, timeout=1800):
"""Initialize SSH client with ip, username and the default values.
def __init__(self, ip, user, port=22):
timeout - the timeout for execution of the command
"""
self.ip = ip
self.user = user
self.timeout = timeout
self.client = None
if key:
self.key = key
else:
self.key = os.path.expanduser('~/.ssh/id_rsa')
def _get_ssh_connection(self):
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
self.client.connect(self.ip, username=self.user, key_filename=self.key)
def _is_timed_out(self, start_time):
return (time.time() - self.timeout) > start_time
def execute(self, *cmd):
pipe = subprocess.Popen(['ssh'] + self.OPTIONS +
['%s@%s' % (self.user, self.ip)] + list(cmd),
stderr=subprocess.PIPE)
(out, err) = pipe.communicate()
if pipe.returncode:
raise exceptions.SSHError(err)
"""Execute the specified command on the server."""
self._get_ssh_connection()
cmd = ' '.join(cmd)
transport = self.client.get_transport()
channel = transport.open_session()
channel.fileno()
channel.exec_command(cmd)
channel.shutdown_write()
poll = select.poll()
poll.register(channel, select.POLLIN)
start_time = time.time()
while True:
ready = poll.poll(16)
if not any(ready):
if not self._is_timed_out(start_time):
continue
raise exceptions.TimeoutException('SSH Timeout')
if not ready[0]:
continue
out_chunk = err_chunk = None
if channel.recv_ready():
out_chunk = channel.recv(4096)
LOG.debug(out_chunk)
if channel.recv_stderr_ready():
err_chunk = channel.recv_stderr(4096)
LOG.debug(err_chunk)
if channel.closed and not err_chunk and not out_chunk:
break
exit_status = channel.recv_exit_status()
if 0 != exit_status:
raise exceptions.SSHError(
'SSHExecCommandFailed with exit_status %s'
% exit_status)
self.client.close()
def upload(self, source, destination):
"""Upload the specified file to the server."""
if destination.startswith('~'):
destination = '/home/' + self.user + destination[1:]
self._get_ssh_connection()
ftp = self.client.open_sftp()
ftp.put(os.path.expanduser(source), destination)
ftp.close()
def execute_script(self, script, enterpreter='/bin/sh'):
cmd = ['ssh'] + self.OPTIONS + ['%s@%s' % (self.user, self.ip),
enterpreter]
pipe = subprocess.Popen(cmd, stdin=open(script, 'r'),
stderr=subprocess.PIPE)
(out, err) = pipe.communicate()
if pipe.returncode:
raise exceptions.SSHError(err)
"""Execute the specified local script on the remote server."""
destination = '/tmp/' + ''.join(
random.choice(string.lowercase) for i in range(16))
def wait(self, timeout=15, interval=1):
self.upload(script, destination)
self.execute('%s %s' % (enterpreter, destination))
self.execute('rm %s' % destination)
def wait(self, timeout=120, interval=1):
"""Wait for the host will be available via ssh."""
with eventlet.timeout.Timeout(timeout, exceptions.TimeoutException):
while True:
try:
return self.execute('uname')
except exceptions.SSHError as e:
LOG.debug(_('Ssh is still unavailable. '
'Exception is: ') + repr(e))
except (socket.error, exceptions.SSHError) as e:
LOG.debug(
_('Ssh is still unavailable. (Exception was: %r)') % e)
eventlet.sleep(interval)
def upload(self, source, destination):
cmd = ['scp'] + self.OPTIONS + [
source, '%s@%s:%s' % (self.user, self.ip, destination)]
pipe = subprocess.Popen(cmd, stderr=subprocess.PIPE)
(out, err) = pipe.communicate()
if pipe.returncode:
raise exceptions.SSHError(err)

View File

@ -13,8 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
import os
from rally import exceptions
from rally import sshutils
@ -26,52 +26,69 @@ class SSHTestCase(test.TestCase):
def setUp(self):
super(SSHTestCase, self).setUp()
self.ssh = sshutils.SSH('example.net', 'root')
self.pipe = mock.MagicMock()
self.pipe.communicate = mock.MagicMock(return_value=(mock.MagicMock(),
mock.MagicMock()))
self.pipe.returncode = 0
self.sp = mock.MagicMock()
self.sp.PIPE = self.pipe
self.sp.Popen = mock.MagicMock(return_value=self.pipe)
self.mod = 'rally.sshutils'
self.channel = mock.Mock()
self.channel.fileno.return_value = 15
self.channel.recv.return_value = 0
self.channel.recv_stderr.return_value = 0
self.channel.recv_exit_status.return_value = 0
self.transport = mock.Mock()
self.transport.open_session = mock.MagicMock(return_value=self.channel)
self.poll = mock.Mock()
self.poll.poll.return_value = [(self.channel, 1)]
self.policy = mock.Mock()
self.client = mock.Mock()
self.client.get_transport = mock.MagicMock(return_value=self.transport)
def test_execute(self):
with mock.patch(self.mod + '.subprocess', new=self.sp) as sp:
self.ssh.execute('ps', 'ax')
expected = [
mock.call.Popen(['ssh', '-o', 'StrictHostKeyChecking=no',
'root@example.net', 'ps', 'ax'],
stderr=self.pipe),
mock.call.PIPE.communicate()]
self.assertEqual(sp.mock_calls, expected)
@mock.patch('rally.sshutils.paramiko')
@mock.patch('rally.sshutils.select')
def test_execute(self, st, pk):
pk.SSHClient.return_value = self.client
st.poll.return_value = self.poll
self.ssh.execute('uname')
def test_execute_script(self):
with mock.patch(self.mod + '.subprocess', new=self.sp) as sp:
with mock.patch(self.mod + '.open', create=True) as op:
self.ssh.execute_script('/bin/script')
expected = [
mock.call.Popen(['ssh', '-o', 'StrictHostKeyChecking=no',
'root@example.net', '/bin/sh'],
stdin=op(), stderr=self.pipe),
mock.call.PIPE.communicate()]
self.assertEqual(sp.mock_calls, expected)
expected = [mock.call.fileno(),
mock.call.exec_command('uname'),
mock.call.shutdown_write(),
mock.call.recv_ready(),
mock.call.recv(4096),
mock.call.recv_stderr_ready(),
mock.call.recv_stderr(4096),
mock.call.recv_exit_status()]
def test_upload_file(self):
with mock.patch(self.mod + '.subprocess', new=self.sp) as sp:
self.ssh.upload('/tmp/s', '/tmp/d')
expected = [mock.call.Popen(['scp', '-o',
'StrictHostKeyChecking=no',
'/tmp/s', 'root@example.net:/tmp/d'],
stderr=self.pipe),
mock.call.PIPE.communicate()]
self.assertEqual(sp.mock_calls, expected)
self.assertEqual(self.channel.mock_calls, expected)
def test_wait(self):
with mock.patch(self.mod + '.SSH.execute'):
self.ssh.wait()
@mock.patch('rally.sshutils.paramiko')
def test_upload_file(self, pk):
pk.AutoAddPolicy.return_value = self.policy
self.ssh.upload('/tmp/s', '/tmp/d')
def test_wait_timeout(self):
with mock.patch(self.mod + '.SSH.execute', new=mock.Mock(
side_effect=exceptions.SSHError)):
self.assertRaises(exceptions.TimeoutException,
self.ssh.wait, 1, 1)
expected = [mock.call.set_missing_host_key_policy(self.policy),
mock.call.connect('example.net', username='root',
key_filename=os.path.expanduser(
'~/.ssh/id_rsa')),
mock.call.open_sftp(),
mock.call.open_sftp().put('/tmp/s', '/tmp/d'),
mock.call.open_sftp().close()]
self.assertEqual(pk.SSHClient().mock_calls, expected)
@mock.patch('rally.sshutils.SSH.execute')
@mock.patch('rally.sshutils.SSH.upload')
@mock.patch('rally.sshutils.random.choice')
def test_execute_script_new(self, rc, up, ex):
rc.return_value = 'a'
self.ssh.execute_script('/bin/script')
up.assert_called_once_with('/bin/script', '/tmp/aaaaaaaaaaaaaaaa')
ex.assert_has_calls([mock.call('/bin/sh /tmp/aaaaaaaaaaaaaaaa'),
mock.call('rm /tmp/aaaaaaaaaaaaaaaa')])
@mock.patch('rally.sshutils.SSH.execute')
def test_wait(self, ex):
self.ssh.wait()
@mock.patch('rally.sshutils.SSH.execute')
def test_wait_timeout(self, ex):
ex.side_effect = exceptions.SSHError
self.assertRaises(exceptions.TimeoutException,
self.ssh.wait, 1, 1)