diff --git a/tempest_lib/common/ssh.py b/tempest_lib/common/ssh.py index 1911f9e..031af4e 100644 --- a/tempest_lib/common/ssh.py +++ b/tempest_lib/common/ssh.py @@ -96,6 +96,10 @@ class Client(object): def _is_timed_out(self, start_time): return (time.time() - self.timeout) > start_time + @staticmethod + def _can_system_poll(): + return hasattr(select, 'poll') + def exec_command(self, cmd): """Execute the specified command on the server @@ -114,37 +118,49 @@ class Client(object): channel.fileno() # Register event pipe channel.exec_command(cmd) channel.shutdown_write() - out_data = [] - err_data = [] - poll = select.poll() - poll.register(channel, select.POLLIN) - start_time = time.time() - - while True: - ready = poll.poll(self.channel_timeout) - if not any(ready): - if not self._is_timed_out(start_time): - continue - raise exceptions.TimeoutException( - "Command: '{0}' executed on host '{1}'.".format( - cmd, self.host)) - if not ready[0]: # If there is nothing to read. - continue - out_chunk = err_chunk = None - if channel.recv_ready(): - out_chunk = channel.recv(self.buf_size) - out_data += out_chunk, - if channel.recv_stderr_ready(): - err_chunk = channel.recv_stderr(self.buf_size) - err_data += err_chunk, - if channel.closed and not err_chunk and not out_chunk: - break exit_status = channel.recv_exit_status() + + # If the executing host is linux-based, poll the channel + if self._can_system_poll(): + out_data_chunks = [] + err_data_chunks = [] + poll = select.poll() + poll.register(channel, select.POLLIN) + start_time = time.time() + + while True: + ready = poll.poll(self.channel_timeout) + if not any(ready): + if not self._is_timed_out(start_time): + continue + raise exceptions.TimeoutException( + "Command: '{0}' executed on host '{1}'.".format( + cmd, self.host)) + if not ready[0]: # If there is nothing to read. + continue + out_chunk = err_chunk = None + if channel.recv_ready(): + out_chunk = channel.recv(self.buf_size) + out_data_chunks += out_chunk, + if channel.recv_stderr_ready(): + err_chunk = channel.recv_stderr(self.buf_size) + err_data_chunks += err_chunk, + if channel.closed and not err_chunk and not out_chunk: + break + out_data = ''.join(out_data_chunks) + err_data = ''.join(err_data_chunks) + # Just read from the channels + else: + out_file = channel.makefile('rb', self.buf_size) + err_file = channel.makefile_stderr('rb', self.buf_size) + out_data = out_file.read() + err_data = err_file.read() + if 0 != exit_status: raise exceptions.SSHExecCommandFailed( command=cmd, exit_status=exit_status, stderr=err_data, stdout=out_data) - return ''.join(out_data) + return out_data def test_connection_auth(self): """Raises an exception when we can not connect to server via ssh.""" diff --git a/tempest_lib/tests/test_ssh.py b/tempest_lib/tests/test_ssh.py index ab0a198..51fa75b 100644 --- a/tempest_lib/tests/test_ssh.py +++ b/tempest_lib/tests/test_ssh.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +from io import StringIO import socket import time @@ -25,6 +26,8 @@ from tempest_lib.tests import base class TestSshClient(base.TestCase): + SELECT_POLLIN = 1 + @mock.patch('paramiko.RSAKey.from_private_key') @mock.patch('six.StringIO') def test_pkey_calls_paramiko_RSAKey(self, cs_mock, rsa_mock): @@ -109,12 +112,16 @@ class TestSshClient(base.TestCase): self.assertLess((end_time - start_time), 5) self.assertGreaterEqual((end_time - start_time), 2) + @mock.patch('select.POLLIN', SELECT_POLLIN, create=True) def test_exec_command(self): gsc_mock = self.patch('tempest_lib.common.ssh.Client.' '_get_ssh_connection') ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out') - select_mock = self.patch('select.poll') + csp_mock = self.patch( + 'tempest_lib.common.ssh.Client._can_system_poll') + csp_mock.return_value = True + select_mock = self.patch('select.poll', create=True) client_mock = mock.MagicMock() tran_mock = mock.MagicMock() chan_mock = mock.MagicMock() @@ -147,8 +154,8 @@ class TestSshClient(base.TestCase): chan_mock.exec_command.assert_called_once_with("test") chan_mock.shutdown_write.assert_called_once_with() - SELECT_POLLIN = 1 - poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN) + poll_mock.register.assert_called_once_with( + chan_mock, self.SELECT_POLLIN) poll_mock.poll.assert_called_once_with(10) # Test for proper reading of STDOUT and STDERROR and closing @@ -177,8 +184,9 @@ class TestSshClient(base.TestCase): chan_mock.exec_command.assert_called_once_with("test") chan_mock.shutdown_write.assert_called_once_with() - SELECT_POLLIN = 1 - poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN) + select_mock.assert_called_once_with() + poll_mock.register.assert_called_once_with( + chan_mock, self.SELECT_POLLIN) poll_mock.poll.assert_called_once_with(10) chan_mock.recv_ready.assert_called_once_with() chan_mock.recv.assert_called_once_with(1024) @@ -186,3 +194,36 @@ class TestSshClient(base.TestCase): chan_mock.recv_stderr.assert_called_once_with(1024) chan_mock.recv_exit_status.assert_called_once_with() closed_prop.assert_called_once_with() + + def test_exec_command_no_select(self): + gsc_mock = self.patch('tempest_lib.common.ssh.Client.' + '_get_ssh_connection') + csp_mock = self.patch( + 'tempest_lib.common.ssh.Client._can_system_poll') + csp_mock.return_value = False + + select_mock = self.patch('select.poll', create=True) + client_mock = mock.MagicMock() + tran_mock = mock.MagicMock() + chan_mock = mock.MagicMock() + + # Test for proper reading of STDOUT and STDERROR + + gsc_mock.return_value = client_mock + client_mock.get_transport.return_value = tran_mock + tran_mock.open_session.return_value = chan_mock + chan_mock.recv_exit_status.return_value = 0 + + std_out_mock = mock.MagicMock(StringIO) + std_err_mock = mock.MagicMock(StringIO) + chan_mock.makefile.return_value = std_out_mock + chan_mock.makefile_stderr.return_value = std_err_mock + + client = ssh.Client('localhost', 'root', timeout=2) + client.exec_command("test") + + chan_mock.makefile.assert_called_once_with('rb', 1024) + chan_mock.makefile_stderr.assert_called_once_with('rb', 1024) + std_out_mock.read.assert_called_once_with() + std_err_mock.read.assert_called_once_with() + self.assertFalse(select_mock.called)