Merge "Provide alternative ssh exec_command for non-linux environments"
This commit is contained in:
commit
6d0b712e00
|
@ -96,6 +96,10 @@ class Client(object):
|
|||
def _is_timed_out(self, start_time):
|
||||
return (time.time() - self.timeout) > start_time
|
||||
|
||||
@staticmethod
|
||||
def _can_system_poll():
|
||||
return hasattr(select, 'poll')
|
||||
|
||||
def exec_command(self, cmd):
|
||||
"""Execute the specified command on the server
|
||||
|
||||
|
@ -114,37 +118,49 @@ class Client(object):
|
|||
channel.fileno() # Register event pipe
|
||||
channel.exec_command(cmd)
|
||||
channel.shutdown_write()
|
||||
out_data = []
|
||||
err_data = []
|
||||
poll = select.poll()
|
||||
poll.register(channel, select.POLLIN)
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
ready = poll.poll(self.channel_timeout)
|
||||
if not any(ready):
|
||||
if not self._is_timed_out(start_time):
|
||||
continue
|
||||
raise exceptions.TimeoutException(
|
||||
"Command: '{0}' executed on host '{1}'.".format(
|
||||
cmd, self.host))
|
||||
if not ready[0]: # If there is nothing to read.
|
||||
continue
|
||||
out_chunk = err_chunk = None
|
||||
if channel.recv_ready():
|
||||
out_chunk = channel.recv(self.buf_size)
|
||||
out_data += out_chunk,
|
||||
if channel.recv_stderr_ready():
|
||||
err_chunk = channel.recv_stderr(self.buf_size)
|
||||
err_data += err_chunk,
|
||||
if channel.closed and not err_chunk and not out_chunk:
|
||||
break
|
||||
exit_status = channel.recv_exit_status()
|
||||
|
||||
# If the executing host is linux-based, poll the channel
|
||||
if self._can_system_poll():
|
||||
out_data_chunks = []
|
||||
err_data_chunks = []
|
||||
poll = select.poll()
|
||||
poll.register(channel, select.POLLIN)
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
ready = poll.poll(self.channel_timeout)
|
||||
if not any(ready):
|
||||
if not self._is_timed_out(start_time):
|
||||
continue
|
||||
raise exceptions.TimeoutException(
|
||||
"Command: '{0}' executed on host '{1}'.".format(
|
||||
cmd, self.host))
|
||||
if not ready[0]: # If there is nothing to read.
|
||||
continue
|
||||
out_chunk = err_chunk = None
|
||||
if channel.recv_ready():
|
||||
out_chunk = channel.recv(self.buf_size)
|
||||
out_data_chunks += out_chunk,
|
||||
if channel.recv_stderr_ready():
|
||||
err_chunk = channel.recv_stderr(self.buf_size)
|
||||
err_data_chunks += err_chunk,
|
||||
if channel.closed and not err_chunk and not out_chunk:
|
||||
break
|
||||
out_data = ''.join(out_data_chunks)
|
||||
err_data = ''.join(err_data_chunks)
|
||||
# Just read from the channels
|
||||
else:
|
||||
out_file = channel.makefile('rb', self.buf_size)
|
||||
err_file = channel.makefile_stderr('rb', self.buf_size)
|
||||
out_data = out_file.read()
|
||||
err_data = err_file.read()
|
||||
|
||||
if 0 != exit_status:
|
||||
raise exceptions.SSHExecCommandFailed(
|
||||
command=cmd, exit_status=exit_status,
|
||||
stderr=err_data, stdout=out_data)
|
||||
return ''.join(out_data)
|
||||
return out_data
|
||||
|
||||
def test_connection_auth(self):
|
||||
"""Raises an exception when we can not connect to server via ssh."""
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from io import StringIO
|
||||
import socket
|
||||
import time
|
||||
|
||||
|
@ -25,6 +26,8 @@ from tempest_lib.tests import base
|
|||
|
||||
class TestSshClient(base.TestCase):
|
||||
|
||||
SELECT_POLLIN = 1
|
||||
|
||||
@mock.patch('paramiko.RSAKey.from_private_key')
|
||||
@mock.patch('six.StringIO')
|
||||
def test_pkey_calls_paramiko_RSAKey(self, cs_mock, rsa_mock):
|
||||
|
@ -109,12 +112,16 @@ class TestSshClient(base.TestCase):
|
|||
self.assertLess((end_time - start_time), 5)
|
||||
self.assertGreaterEqual((end_time - start_time), 2)
|
||||
|
||||
@mock.patch('select.POLLIN', SELECT_POLLIN, create=True)
|
||||
def test_exec_command(self):
|
||||
gsc_mock = self.patch('tempest_lib.common.ssh.Client.'
|
||||
'_get_ssh_connection')
|
||||
ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out')
|
||||
select_mock = self.patch('select.poll')
|
||||
csp_mock = self.patch(
|
||||
'tempest_lib.common.ssh.Client._can_system_poll')
|
||||
csp_mock.return_value = True
|
||||
|
||||
select_mock = self.patch('select.poll', create=True)
|
||||
client_mock = mock.MagicMock()
|
||||
tran_mock = mock.MagicMock()
|
||||
chan_mock = mock.MagicMock()
|
||||
|
@ -147,8 +154,8 @@ class TestSshClient(base.TestCase):
|
|||
chan_mock.exec_command.assert_called_once_with("test")
|
||||
chan_mock.shutdown_write.assert_called_once_with()
|
||||
|
||||
SELECT_POLLIN = 1
|
||||
poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN)
|
||||
poll_mock.register.assert_called_once_with(
|
||||
chan_mock, self.SELECT_POLLIN)
|
||||
poll_mock.poll.assert_called_once_with(10)
|
||||
|
||||
# Test for proper reading of STDOUT and STDERROR and closing
|
||||
|
@ -177,8 +184,9 @@ class TestSshClient(base.TestCase):
|
|||
chan_mock.exec_command.assert_called_once_with("test")
|
||||
chan_mock.shutdown_write.assert_called_once_with()
|
||||
|
||||
SELECT_POLLIN = 1
|
||||
poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN)
|
||||
select_mock.assert_called_once_with()
|
||||
poll_mock.register.assert_called_once_with(
|
||||
chan_mock, self.SELECT_POLLIN)
|
||||
poll_mock.poll.assert_called_once_with(10)
|
||||
chan_mock.recv_ready.assert_called_once_with()
|
||||
chan_mock.recv.assert_called_once_with(1024)
|
||||
|
@ -186,3 +194,36 @@ class TestSshClient(base.TestCase):
|
|||
chan_mock.recv_stderr.assert_called_once_with(1024)
|
||||
chan_mock.recv_exit_status.assert_called_once_with()
|
||||
closed_prop.assert_called_once_with()
|
||||
|
||||
def test_exec_command_no_select(self):
|
||||
gsc_mock = self.patch('tempest_lib.common.ssh.Client.'
|
||||
'_get_ssh_connection')
|
||||
csp_mock = self.patch(
|
||||
'tempest_lib.common.ssh.Client._can_system_poll')
|
||||
csp_mock.return_value = False
|
||||
|
||||
select_mock = self.patch('select.poll', create=True)
|
||||
client_mock = mock.MagicMock()
|
||||
tran_mock = mock.MagicMock()
|
||||
chan_mock = mock.MagicMock()
|
||||
|
||||
# Test for proper reading of STDOUT and STDERROR
|
||||
|
||||
gsc_mock.return_value = client_mock
|
||||
client_mock.get_transport.return_value = tran_mock
|
||||
tran_mock.open_session.return_value = chan_mock
|
||||
chan_mock.recv_exit_status.return_value = 0
|
||||
|
||||
std_out_mock = mock.MagicMock(StringIO)
|
||||
std_err_mock = mock.MagicMock(StringIO)
|
||||
chan_mock.makefile.return_value = std_out_mock
|
||||
chan_mock.makefile_stderr.return_value = std_err_mock
|
||||
|
||||
client = ssh.Client('localhost', 'root', timeout=2)
|
||||
client.exec_command("test")
|
||||
|
||||
chan_mock.makefile.assert_called_once_with('rb', 1024)
|
||||
chan_mock.makefile_stderr.assert_called_once_with('rb', 1024)
|
||||
std_out_mock.read.assert_called_once_with()
|
||||
std_err_mock.read.assert_called_once_with()
|
||||
self.assertFalse(select_mock.called)
|
||||
|
|
Loading…
Reference in New Issue