Merge "Provide alternative ssh exec_command for non-linux environments"
This commit is contained in:
@@ -96,6 +96,10 @@ class Client(object):
|
|||||||
def _is_timed_out(self, start_time):
|
def _is_timed_out(self, start_time):
|
||||||
return (time.time() - self.timeout) > start_time
|
return (time.time() - self.timeout) > start_time
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _can_system_poll():
|
||||||
|
return hasattr(select, 'poll')
|
||||||
|
|
||||||
def exec_command(self, cmd):
|
def exec_command(self, cmd):
|
||||||
"""Execute the specified command on the server
|
"""Execute the specified command on the server
|
||||||
|
|
||||||
@@ -114,37 +118,49 @@ class Client(object):
|
|||||||
channel.fileno() # Register event pipe
|
channel.fileno() # Register event pipe
|
||||||
channel.exec_command(cmd)
|
channel.exec_command(cmd)
|
||||||
channel.shutdown_write()
|
channel.shutdown_write()
|
||||||
out_data = []
|
|
||||||
err_data = []
|
|
||||||
poll = select.poll()
|
|
||||||
poll.register(channel, select.POLLIN)
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
ready = poll.poll(self.channel_timeout)
|
|
||||||
if not any(ready):
|
|
||||||
if not self._is_timed_out(start_time):
|
|
||||||
continue
|
|
||||||
raise exceptions.TimeoutException(
|
|
||||||
"Command: '{0}' executed on host '{1}'.".format(
|
|
||||||
cmd, self.host))
|
|
||||||
if not ready[0]: # If there is nothing to read.
|
|
||||||
continue
|
|
||||||
out_chunk = err_chunk = None
|
|
||||||
if channel.recv_ready():
|
|
||||||
out_chunk = channel.recv(self.buf_size)
|
|
||||||
out_data += out_chunk,
|
|
||||||
if channel.recv_stderr_ready():
|
|
||||||
err_chunk = channel.recv_stderr(self.buf_size)
|
|
||||||
err_data += err_chunk,
|
|
||||||
if channel.closed and not err_chunk and not out_chunk:
|
|
||||||
break
|
|
||||||
exit_status = channel.recv_exit_status()
|
exit_status = channel.recv_exit_status()
|
||||||
|
|
||||||
|
# If the executing host is linux-based, poll the channel
|
||||||
|
if self._can_system_poll():
|
||||||
|
out_data_chunks = []
|
||||||
|
err_data_chunks = []
|
||||||
|
poll = select.poll()
|
||||||
|
poll.register(channel, select.POLLIN)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ready = poll.poll(self.channel_timeout)
|
||||||
|
if not any(ready):
|
||||||
|
if not self._is_timed_out(start_time):
|
||||||
|
continue
|
||||||
|
raise exceptions.TimeoutException(
|
||||||
|
"Command: '{0}' executed on host '{1}'.".format(
|
||||||
|
cmd, self.host))
|
||||||
|
if not ready[0]: # If there is nothing to read.
|
||||||
|
continue
|
||||||
|
out_chunk = err_chunk = None
|
||||||
|
if channel.recv_ready():
|
||||||
|
out_chunk = channel.recv(self.buf_size)
|
||||||
|
out_data_chunks += out_chunk,
|
||||||
|
if channel.recv_stderr_ready():
|
||||||
|
err_chunk = channel.recv_stderr(self.buf_size)
|
||||||
|
err_data_chunks += err_chunk,
|
||||||
|
if channel.closed and not err_chunk and not out_chunk:
|
||||||
|
break
|
||||||
|
out_data = ''.join(out_data_chunks)
|
||||||
|
err_data = ''.join(err_data_chunks)
|
||||||
|
# Just read from the channels
|
||||||
|
else:
|
||||||
|
out_file = channel.makefile('rb', self.buf_size)
|
||||||
|
err_file = channel.makefile_stderr('rb', self.buf_size)
|
||||||
|
out_data = out_file.read()
|
||||||
|
err_data = err_file.read()
|
||||||
|
|
||||||
if 0 != exit_status:
|
if 0 != exit_status:
|
||||||
raise exceptions.SSHExecCommandFailed(
|
raise exceptions.SSHExecCommandFailed(
|
||||||
command=cmd, exit_status=exit_status,
|
command=cmd, exit_status=exit_status,
|
||||||
stderr=err_data, stdout=out_data)
|
stderr=err_data, stdout=out_data)
|
||||||
return ''.join(out_data)
|
return out_data
|
||||||
|
|
||||||
def test_connection_auth(self):
|
def test_connection_auth(self):
|
||||||
"""Raises an exception when we can not connect to server via ssh."""
|
"""Raises an exception when we can not connect to server via ssh."""
|
||||||
|
@@ -12,6 +12,7 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from io import StringIO
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -25,6 +26,8 @@ from tempest_lib.tests import base
|
|||||||
|
|
||||||
class TestSshClient(base.TestCase):
|
class TestSshClient(base.TestCase):
|
||||||
|
|
||||||
|
SELECT_POLLIN = 1
|
||||||
|
|
||||||
@mock.patch('paramiko.RSAKey.from_private_key')
|
@mock.patch('paramiko.RSAKey.from_private_key')
|
||||||
@mock.patch('six.StringIO')
|
@mock.patch('six.StringIO')
|
||||||
def test_pkey_calls_paramiko_RSAKey(self, cs_mock, rsa_mock):
|
def test_pkey_calls_paramiko_RSAKey(self, cs_mock, rsa_mock):
|
||||||
@@ -109,12 +112,16 @@ class TestSshClient(base.TestCase):
|
|||||||
self.assertLess((end_time - start_time), 5)
|
self.assertLess((end_time - start_time), 5)
|
||||||
self.assertGreaterEqual((end_time - start_time), 2)
|
self.assertGreaterEqual((end_time - start_time), 2)
|
||||||
|
|
||||||
|
@mock.patch('select.POLLIN', SELECT_POLLIN, create=True)
|
||||||
def test_exec_command(self):
|
def test_exec_command(self):
|
||||||
gsc_mock = self.patch('tempest_lib.common.ssh.Client.'
|
gsc_mock = self.patch('tempest_lib.common.ssh.Client.'
|
||||||
'_get_ssh_connection')
|
'_get_ssh_connection')
|
||||||
ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out')
|
ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out')
|
||||||
select_mock = self.patch('select.poll')
|
csp_mock = self.patch(
|
||||||
|
'tempest_lib.common.ssh.Client._can_system_poll')
|
||||||
|
csp_mock.return_value = True
|
||||||
|
|
||||||
|
select_mock = self.patch('select.poll', create=True)
|
||||||
client_mock = mock.MagicMock()
|
client_mock = mock.MagicMock()
|
||||||
tran_mock = mock.MagicMock()
|
tran_mock = mock.MagicMock()
|
||||||
chan_mock = mock.MagicMock()
|
chan_mock = mock.MagicMock()
|
||||||
@@ -147,8 +154,8 @@ class TestSshClient(base.TestCase):
|
|||||||
chan_mock.exec_command.assert_called_once_with("test")
|
chan_mock.exec_command.assert_called_once_with("test")
|
||||||
chan_mock.shutdown_write.assert_called_once_with()
|
chan_mock.shutdown_write.assert_called_once_with()
|
||||||
|
|
||||||
SELECT_POLLIN = 1
|
poll_mock.register.assert_called_once_with(
|
||||||
poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN)
|
chan_mock, self.SELECT_POLLIN)
|
||||||
poll_mock.poll.assert_called_once_with(10)
|
poll_mock.poll.assert_called_once_with(10)
|
||||||
|
|
||||||
# Test for proper reading of STDOUT and STDERROR and closing
|
# Test for proper reading of STDOUT and STDERROR and closing
|
||||||
@@ -177,8 +184,9 @@ class TestSshClient(base.TestCase):
|
|||||||
chan_mock.exec_command.assert_called_once_with("test")
|
chan_mock.exec_command.assert_called_once_with("test")
|
||||||
chan_mock.shutdown_write.assert_called_once_with()
|
chan_mock.shutdown_write.assert_called_once_with()
|
||||||
|
|
||||||
SELECT_POLLIN = 1
|
select_mock.assert_called_once_with()
|
||||||
poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN)
|
poll_mock.register.assert_called_once_with(
|
||||||
|
chan_mock, self.SELECT_POLLIN)
|
||||||
poll_mock.poll.assert_called_once_with(10)
|
poll_mock.poll.assert_called_once_with(10)
|
||||||
chan_mock.recv_ready.assert_called_once_with()
|
chan_mock.recv_ready.assert_called_once_with()
|
||||||
chan_mock.recv.assert_called_once_with(1024)
|
chan_mock.recv.assert_called_once_with(1024)
|
||||||
@@ -186,3 +194,36 @@ class TestSshClient(base.TestCase):
|
|||||||
chan_mock.recv_stderr.assert_called_once_with(1024)
|
chan_mock.recv_stderr.assert_called_once_with(1024)
|
||||||
chan_mock.recv_exit_status.assert_called_once_with()
|
chan_mock.recv_exit_status.assert_called_once_with()
|
||||||
closed_prop.assert_called_once_with()
|
closed_prop.assert_called_once_with()
|
||||||
|
|
||||||
|
def test_exec_command_no_select(self):
|
||||||
|
gsc_mock = self.patch('tempest_lib.common.ssh.Client.'
|
||||||
|
'_get_ssh_connection')
|
||||||
|
csp_mock = self.patch(
|
||||||
|
'tempest_lib.common.ssh.Client._can_system_poll')
|
||||||
|
csp_mock.return_value = False
|
||||||
|
|
||||||
|
select_mock = self.patch('select.poll', create=True)
|
||||||
|
client_mock = mock.MagicMock()
|
||||||
|
tran_mock = mock.MagicMock()
|
||||||
|
chan_mock = mock.MagicMock()
|
||||||
|
|
||||||
|
# Test for proper reading of STDOUT and STDERROR
|
||||||
|
|
||||||
|
gsc_mock.return_value = client_mock
|
||||||
|
client_mock.get_transport.return_value = tran_mock
|
||||||
|
tran_mock.open_session.return_value = chan_mock
|
||||||
|
chan_mock.recv_exit_status.return_value = 0
|
||||||
|
|
||||||
|
std_out_mock = mock.MagicMock(StringIO)
|
||||||
|
std_err_mock = mock.MagicMock(StringIO)
|
||||||
|
chan_mock.makefile.return_value = std_out_mock
|
||||||
|
chan_mock.makefile_stderr.return_value = std_err_mock
|
||||||
|
|
||||||
|
client = ssh.Client('localhost', 'root', timeout=2)
|
||||||
|
client.exec_command("test")
|
||||||
|
|
||||||
|
chan_mock.makefile.assert_called_once_with('rb', 1024)
|
||||||
|
chan_mock.makefile_stderr.assert_called_once_with('rb', 1024)
|
||||||
|
std_out_mock.read.assert_called_once_with()
|
||||||
|
std_err_mock.read.assert_called_once_with()
|
||||||
|
self.assertFalse(select_mock.called)
|
||||||
|
Reference in New Issue
Block a user