Switch sshutils.SSH from subprocess to paramiko
Methods for executing commands/scripts and upload files to remote host via ssh of sshutils.SSH class were rewritten using the paramiko lib. Related: blueprint serverdto-ssh Change-Id: I1786d5c4d4b6e9c661ec84785dd9255ab39de4aa
This commit is contained in:
parent
b510482e83
commit
b4e7cec59b
@ -14,7 +14,13 @@
|
||||
# under the License.
|
||||
|
||||
import eventlet
|
||||
import subprocess
|
||||
import os
|
||||
import paramiko
|
||||
import random
|
||||
import select
|
||||
import socket
|
||||
import string
|
||||
import time
|
||||
|
||||
from rally import exceptions
|
||||
from rally.openstack.common.gettextutils import _ # noqa
|
||||
@ -26,43 +32,89 @@ LOG = logging.getLogger(__name__)
|
||||
class SSH(object):
|
||||
"""SSH common functions."""
|
||||
|
||||
OPTIONS = ['-o', 'StrictHostKeyChecking=no']
|
||||
def __init__(self, ip, user, port=22, key=None, timeout=1800):
|
||||
"""Initialize SSH client with ip, username and the default values.
|
||||
|
||||
def __init__(self, ip, user, port=22):
|
||||
timeout - the timeout for execution of the command
|
||||
"""
|
||||
self.ip = ip
|
||||
self.user = user
|
||||
self.timeout = timeout
|
||||
self.client = None
|
||||
if key:
|
||||
self.key = key
|
||||
else:
|
||||
self.key = os.path.expanduser('~/.ssh/id_rsa')
|
||||
|
||||
def _get_ssh_connection(self):
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
self.client.connect(self.ip, username=self.user, key_filename=self.key)
|
||||
|
||||
def _is_timed_out(self, start_time):
|
||||
return (time.time() - self.timeout) > start_time
|
||||
|
||||
def execute(self, *cmd):
|
||||
pipe = subprocess.Popen(['ssh'] + self.OPTIONS +
|
||||
['%s@%s' % (self.user, self.ip)] + list(cmd),
|
||||
stderr=subprocess.PIPE)
|
||||
(out, err) = pipe.communicate()
|
||||
if pipe.returncode:
|
||||
raise exceptions.SSHError(err)
|
||||
"""Execute the specified command on the server."""
|
||||
self._get_ssh_connection()
|
||||
cmd = ' '.join(cmd)
|
||||
transport = self.client.get_transport()
|
||||
channel = transport.open_session()
|
||||
channel.fileno()
|
||||
channel.exec_command(cmd)
|
||||
channel.shutdown_write()
|
||||
poll = select.poll()
|
||||
poll.register(channel, select.POLLIN)
|
||||
start_time = time.time()
|
||||
while True:
|
||||
ready = poll.poll(16)
|
||||
if not any(ready):
|
||||
if not self._is_timed_out(start_time):
|
||||
continue
|
||||
raise exceptions.TimeoutException('SSH Timeout')
|
||||
if not ready[0]:
|
||||
continue
|
||||
out_chunk = err_chunk = None
|
||||
if channel.recv_ready():
|
||||
out_chunk = channel.recv(4096)
|
||||
LOG.debug(out_chunk)
|
||||
if channel.recv_stderr_ready():
|
||||
err_chunk = channel.recv_stderr(4096)
|
||||
LOG.debug(err_chunk)
|
||||
if channel.closed and not err_chunk and not out_chunk:
|
||||
break
|
||||
exit_status = channel.recv_exit_status()
|
||||
if 0 != exit_status:
|
||||
raise exceptions.SSHError(
|
||||
'SSHExecCommandFailed with exit_status %s'
|
||||
% exit_status)
|
||||
self.client.close()
|
||||
|
||||
def upload(self, source, destination):
|
||||
"""Upload the specified file to the server."""
|
||||
if destination.startswith('~'):
|
||||
destination = '/home/' + self.user + destination[1:]
|
||||
self._get_ssh_connection()
|
||||
ftp = self.client.open_sftp()
|
||||
ftp.put(os.path.expanduser(source), destination)
|
||||
ftp.close()
|
||||
|
||||
def execute_script(self, script, enterpreter='/bin/sh'):
|
||||
cmd = ['ssh'] + self.OPTIONS + ['%s@%s' % (self.user, self.ip),
|
||||
enterpreter]
|
||||
pipe = subprocess.Popen(cmd, stdin=open(script, 'r'),
|
||||
stderr=subprocess.PIPE)
|
||||
(out, err) = pipe.communicate()
|
||||
if pipe.returncode:
|
||||
raise exceptions.SSHError(err)
|
||||
"""Execute the specified local script on the remote server."""
|
||||
destination = '/tmp/' + ''.join(
|
||||
random.choice(string.lowercase) for i in range(16))
|
||||
|
||||
def wait(self, timeout=15, interval=1):
|
||||
self.upload(script, destination)
|
||||
self.execute('%s %s' % (enterpreter, destination))
|
||||
self.execute('rm %s' % destination)
|
||||
|
||||
def wait(self, timeout=120, interval=1):
|
||||
"""Wait for the host will be available via ssh."""
|
||||
with eventlet.timeout.Timeout(timeout, exceptions.TimeoutException):
|
||||
while True:
|
||||
try:
|
||||
return self.execute('uname')
|
||||
except exceptions.SSHError as e:
|
||||
LOG.debug(_('Ssh is still unavailable. '
|
||||
'Exception is: ') + repr(e))
|
||||
except (socket.error, exceptions.SSHError) as e:
|
||||
LOG.debug(
|
||||
_('Ssh is still unavailable. (Exception was: %r)') % e)
|
||||
eventlet.sleep(interval)
|
||||
|
||||
def upload(self, source, destination):
|
||||
cmd = ['scp'] + self.OPTIONS + [
|
||||
source, '%s@%s:%s' % (self.user, self.ip, destination)]
|
||||
pipe = subprocess.Popen(cmd, stderr=subprocess.PIPE)
|
||||
(out, err) = pipe.communicate()
|
||||
if pipe.returncode:
|
||||
raise exceptions.SSHError(err)
|
||||
|
@ -13,8 +13,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import mock
|
||||
import os
|
||||
|
||||
from rally import exceptions
|
||||
from rally import sshutils
|
||||
@ -26,52 +26,69 @@ class SSHTestCase(test.TestCase):
|
||||
def setUp(self):
|
||||
super(SSHTestCase, self).setUp()
|
||||
self.ssh = sshutils.SSH('example.net', 'root')
|
||||
self.pipe = mock.MagicMock()
|
||||
self.pipe.communicate = mock.MagicMock(return_value=(mock.MagicMock(),
|
||||
mock.MagicMock()))
|
||||
self.pipe.returncode = 0
|
||||
self.sp = mock.MagicMock()
|
||||
self.sp.PIPE = self.pipe
|
||||
self.sp.Popen = mock.MagicMock(return_value=self.pipe)
|
||||
self.mod = 'rally.sshutils'
|
||||
self.channel = mock.Mock()
|
||||
self.channel.fileno.return_value = 15
|
||||
self.channel.recv.return_value = 0
|
||||
self.channel.recv_stderr.return_value = 0
|
||||
self.channel.recv_exit_status.return_value = 0
|
||||
self.transport = mock.Mock()
|
||||
self.transport.open_session = mock.MagicMock(return_value=self.channel)
|
||||
self.poll = mock.Mock()
|
||||
self.poll.poll.return_value = [(self.channel, 1)]
|
||||
self.policy = mock.Mock()
|
||||
self.client = mock.Mock()
|
||||
self.client.get_transport = mock.MagicMock(return_value=self.transport)
|
||||
|
||||
def test_execute(self):
|
||||
with mock.patch(self.mod + '.subprocess', new=self.sp) as sp:
|
||||
self.ssh.execute('ps', 'ax')
|
||||
expected = [
|
||||
mock.call.Popen(['ssh', '-o', 'StrictHostKeyChecking=no',
|
||||
'root@example.net', 'ps', 'ax'],
|
||||
stderr=self.pipe),
|
||||
mock.call.PIPE.communicate()]
|
||||
self.assertEqual(sp.mock_calls, expected)
|
||||
@mock.patch('rally.sshutils.paramiko')
|
||||
@mock.patch('rally.sshutils.select')
|
||||
def test_execute(self, st, pk):
|
||||
pk.SSHClient.return_value = self.client
|
||||
st.poll.return_value = self.poll
|
||||
self.ssh.execute('uname')
|
||||
|
||||
def test_execute_script(self):
|
||||
with mock.patch(self.mod + '.subprocess', new=self.sp) as sp:
|
||||
with mock.patch(self.mod + '.open', create=True) as op:
|
||||
self.ssh.execute_script('/bin/script')
|
||||
expected = [
|
||||
mock.call.Popen(['ssh', '-o', 'StrictHostKeyChecking=no',
|
||||
'root@example.net', '/bin/sh'],
|
||||
stdin=op(), stderr=self.pipe),
|
||||
mock.call.PIPE.communicate()]
|
||||
self.assertEqual(sp.mock_calls, expected)
|
||||
expected = [mock.call.fileno(),
|
||||
mock.call.exec_command('uname'),
|
||||
mock.call.shutdown_write(),
|
||||
mock.call.recv_ready(),
|
||||
mock.call.recv(4096),
|
||||
mock.call.recv_stderr_ready(),
|
||||
mock.call.recv_stderr(4096),
|
||||
mock.call.recv_exit_status()]
|
||||
|
||||
def test_upload_file(self):
|
||||
with mock.patch(self.mod + '.subprocess', new=self.sp) as sp:
|
||||
self.ssh.upload('/tmp/s', '/tmp/d')
|
||||
expected = [mock.call.Popen(['scp', '-o',
|
||||
'StrictHostKeyChecking=no',
|
||||
'/tmp/s', 'root@example.net:/tmp/d'],
|
||||
stderr=self.pipe),
|
||||
mock.call.PIPE.communicate()]
|
||||
self.assertEqual(sp.mock_calls, expected)
|
||||
self.assertEqual(self.channel.mock_calls, expected)
|
||||
|
||||
def test_wait(self):
|
||||
with mock.patch(self.mod + '.SSH.execute'):
|
||||
self.ssh.wait()
|
||||
@mock.patch('rally.sshutils.paramiko')
|
||||
def test_upload_file(self, pk):
|
||||
pk.AutoAddPolicy.return_value = self.policy
|
||||
self.ssh.upload('/tmp/s', '/tmp/d')
|
||||
|
||||
def test_wait_timeout(self):
|
||||
with mock.patch(self.mod + '.SSH.execute', new=mock.Mock(
|
||||
side_effect=exceptions.SSHError)):
|
||||
self.assertRaises(exceptions.TimeoutException,
|
||||
self.ssh.wait, 1, 1)
|
||||
expected = [mock.call.set_missing_host_key_policy(self.policy),
|
||||
mock.call.connect('example.net', username='root',
|
||||
key_filename=os.path.expanduser(
|
||||
'~/.ssh/id_rsa')),
|
||||
mock.call.open_sftp(),
|
||||
mock.call.open_sftp().put('/tmp/s', '/tmp/d'),
|
||||
mock.call.open_sftp().close()]
|
||||
|
||||
self.assertEqual(pk.SSHClient().mock_calls, expected)
|
||||
|
||||
@mock.patch('rally.sshutils.SSH.execute')
|
||||
@mock.patch('rally.sshutils.SSH.upload')
|
||||
@mock.patch('rally.sshutils.random.choice')
|
||||
def test_execute_script_new(self, rc, up, ex):
|
||||
rc.return_value = 'a'
|
||||
self.ssh.execute_script('/bin/script')
|
||||
|
||||
up.assert_called_once_with('/bin/script', '/tmp/aaaaaaaaaaaaaaaa')
|
||||
ex.assert_has_calls([mock.call('/bin/sh /tmp/aaaaaaaaaaaaaaaa'),
|
||||
mock.call('rm /tmp/aaaaaaaaaaaaaaaa')])
|
||||
|
||||
@mock.patch('rally.sshutils.SSH.execute')
|
||||
def test_wait(self, ex):
|
||||
self.ssh.wait()
|
||||
|
||||
@mock.patch('rally.sshutils.SSH.execute')
|
||||
def test_wait_timeout(self, ex):
|
||||
ex.side_effect = exceptions.SSHError
|
||||
self.assertRaises(exceptions.TimeoutException,
|
||||
self.ssh.wait, 1, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user