diff --git a/tempest_lib/common/ssh.py b/tempest_lib/common/ssh.py index 031af4e..7d0a344 100644 --- a/tempest_lib/common/ssh.py +++ b/tempest_lib/common/ssh.py @@ -100,13 +100,15 @@ class Client(object): def _can_system_poll(): return hasattr(select, 'poll') - def exec_command(self, cmd): + def exec_command(self, cmd, encoding="utf-8"): """Execute the specified command on the server Note that this method is reading whole command outputs to memory, thus shouldn't be used for large outputs. :param str cmd: Command to run at remote server. + :param str encoding: Encoding for result from paramiko. + Result will not be decoded if None. :returns: data read from standard output of the command. :raises: SSHExecCommandFailed if command returns nonzero status. The exception contains command status stderr content. @@ -147,14 +149,17 @@ class Client(object): err_data_chunks += err_chunk, if channel.closed and not err_chunk and not out_chunk: break - out_data = ''.join(out_data_chunks) - err_data = ''.join(err_data_chunks) + out_data = b''.join(out_data_chunks) + err_data = b''.join(err_data_chunks) # Just read from the channels else: out_file = channel.makefile('rb', self.buf_size) err_file = channel.makefile_stderr('rb', self.buf_size) out_data = out_file.read() err_data = err_file.read() + if encoding: + out_data = out_data.decode(encoding) + err_data = err_data.decode(encoding) if 0 != exit_status: raise exceptions.SSHExecCommandFailed( diff --git a/tempest_lib/tests/test_ssh.py b/tempest_lib/tests/test_ssh.py index 51fa75b..140bdf0 100644 --- a/tempest_lib/tests/test_ssh.py +++ b/tempest_lib/tests/test_ssh.py @@ -17,6 +17,7 @@ import socket import time import mock +import six import testtools from tempest_lib.common import ssh @@ -113,37 +114,8 @@ class TestSshClient(base.TestCase): self.assertGreaterEqual((end_time - start_time), 2) @mock.patch('select.POLLIN', SELECT_POLLIN, create=True) - def test_exec_command(self): - gsc_mock = self.patch('tempest_lib.common.ssh.Client.' - '_get_ssh_connection') - ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out') - csp_mock = self.patch( - 'tempest_lib.common.ssh.Client._can_system_poll') - csp_mock.return_value = True - - select_mock = self.patch('select.poll', create=True) - client_mock = mock.MagicMock() - tran_mock = mock.MagicMock() - chan_mock = mock.MagicMock() - poll_mock = mock.MagicMock() - - def reset_mocks(): - gsc_mock.reset_mock() - ito_mock.reset_mock() - select_mock.reset_mock() - poll_mock.reset_mock() - client_mock.reset_mock() - tran_mock.reset_mock() - chan_mock.reset_mock() - - select_mock.return_value = poll_mock - gsc_mock.return_value = client_mock - ito_mock.return_value = True - client_mock.get_transport.return_value = tran_mock - tran_mock.open_session.return_value = chan_mock - poll_mock.poll.side_effect = [ - [0, 0, 0] - ] + def test_timeout_in_exec_command(self): + chan_mock, poll_mock, _ = self._set_mocks_for_select([0, 0, 0], True) # Test for a timeout condition immediately raised client = ssh.Client('localhost', 'root', timeout=2) @@ -158,24 +130,16 @@ class TestSshClient(base.TestCase): chan_mock, self.SELECT_POLLIN) poll_mock.poll.assert_called_once_with(10) - # Test for proper reading of STDOUT and STDERROR and closing - # of all file descriptors. - - reset_mocks() - - select_mock.return_value = poll_mock - gsc_mock.return_value = client_mock - ito_mock.return_value = False - client_mock.get_transport.return_value = tran_mock - tran_mock.open_session.return_value = chan_mock - poll_mock.poll.side_effect = [ - [1, 0, 0] - ] + @mock.patch('select.POLLIN', SELECT_POLLIN, create=True) + def test_exec_command(self): + chan_mock, poll_mock, select_mock = ( + self._set_mocks_for_select([[1, 0, 0]], True)) closed_prop = mock.PropertyMock(return_value=True) type(chan_mock).closed = closed_prop + chan_mock.recv_exit_status.return_value = 0 - chan_mock.recv.return_value = '' - chan_mock.recv_stderr.return_value = '' + chan_mock.recv.return_value = b'' + chan_mock.recv_stderr.return_value = b'' client = ssh.Client('localhost', 'root', timeout=2) client.exec_command("test") @@ -195,6 +159,66 @@ class TestSshClient(base.TestCase): chan_mock.recv_exit_status.assert_called_once_with() closed_prop.assert_called_once_with() + def _set_mocks_for_select(self, poll_data, ito_value=False): + gsc_mock = self.patch('tempest_lib.common.ssh.Client.' + '_get_ssh_connection') + ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out') + csp_mock = self.patch( + 'tempest_lib.common.ssh.Client._can_system_poll') + csp_mock.return_value = True + + select_mock = self.patch('select.poll', create=True) + client_mock = mock.MagicMock() + tran_mock = mock.MagicMock() + chan_mock = mock.MagicMock() + poll_mock = mock.MagicMock() + + select_mock.return_value = poll_mock + gsc_mock.return_value = client_mock + ito_mock.return_value = ito_value + client_mock.get_transport.return_value = tran_mock + tran_mock.open_session.return_value = chan_mock + if isinstance(poll_data[0], list): + poll_mock.poll.side_effect = poll_data + else: + poll_mock.poll.return_value = poll_data + + return chan_mock, poll_mock, select_mock + + _utf8_string = six.unichr(1071) + _utf8_bytes = _utf8_string.encode("utf-8") + + @mock.patch('select.POLLIN', SELECT_POLLIN, create=True) + def test_exec_good_command_output(self): + chan_mock, poll_mock, _ = self._set_mocks_for_select([1, 0, 0]) + closed_prop = mock.PropertyMock(return_value=True) + type(chan_mock).closed = closed_prop + + chan_mock.recv_exit_status.return_value = 0 + chan_mock.recv.side_effect = [self._utf8_bytes[0:1], + self._utf8_bytes[1:], b'R', b''] + chan_mock.recv_stderr.return_value = b'' + + client = ssh.Client('localhost', 'root', timeout=2) + out_data = client.exec_command("test") + self.assertEqual(self._utf8_string + 'R', out_data) + + @mock.patch('select.POLLIN', SELECT_POLLIN, create=True) + def test_exec_bad_command_output(self): + chan_mock, poll_mock, _ = self._set_mocks_for_select([1, 0, 0]) + closed_prop = mock.PropertyMock(return_value=True) + type(chan_mock).closed = closed_prop + + chan_mock.recv_exit_status.return_value = 1 + chan_mock.recv.return_value = b'' + chan_mock.recv_stderr.side_effect = [b'R', self._utf8_bytes[0:1], + self._utf8_bytes[1:], b''] + + client = ssh.Client('localhost', 'root', timeout=2) + exc = self.assertRaises(exceptions.SSHExecCommandFailed, + client.exec_command, "test") + self.assertIn('R' + self._utf8_string, six.text_type(exc)) + def test_exec_command_no_select(self): gsc_mock = self.patch('tempest_lib.common.ssh.Client.' '_get_ssh_connection')