diff --git a/tobiko/shell/sh/__init__.py b/tobiko/shell/sh/__init__.py index 6b1068f85..6640a285e 100644 --- a/tobiko/shell/sh/__init__.py +++ b/tobiko/shell/sh/__init__.py @@ -51,6 +51,7 @@ ShellExecuteResult = _execute.ShellExecuteResult HostNameError = _hostname.HostnameError get_hostname = _hostname.get_hostname +ssh_hostname = _hostname.ssh_hostname join_chunks = _io.join_chunks ShellStdout = _io.ShellStdout diff --git a/tobiko/shell/sh/_hostname.py b/tobiko/shell/sh/_hostname.py index f402ece1e..9fb9b6c3f 100644 --- a/tobiko/shell/sh/_hostname.py +++ b/tobiko/shell/sh/_hostname.py @@ -16,6 +16,8 @@ from __future__ import absolute_import import socket +import typing +import weakref import tobiko from tobiko.shell.sh import _exception @@ -27,13 +29,36 @@ class HostnameError(tobiko.TobikoException): message = "Unable to get hostname from host: {error}" +HOSTNAMES_CACHE: typing.MutableMapping[typing.Optional[ssh.SSHClientFixture], + str] = weakref.WeakKeyDictionary() + + def get_hostname(ssh_client: ssh.SSHClientType = None, + cached=True, **execute_params) -> str: - if ssh_client is False: + ssh_client = ssh.ssh_client_fixture(ssh_client) + if ssh_client is None: return socket.gethostname() - tobiko.check_valid_type(ssh_client, ssh.SSHClientFixture, - type(None)) + if cached: + try: + hostname = HOSTNAMES_CACHE[ssh_client] + except KeyError: + pass + else: + return hostname + + hostname = ssh_hostname(ssh_client=ssh_client, + **execute_params) + if cached: + HOSTNAMES_CACHE[ssh_client] = hostname + return hostname + + +def ssh_hostname(ssh_client: ssh.SSHClientFixture, + **execute_params) \ + -> str: + tobiko.check_valid_type(ssh_client, ssh.SSHClientFixture) try: result = _execute.execute('hostname', ssh_client=ssh_client, @@ -48,5 +73,4 @@ def get_hostname(ssh_client: ssh.SSHClientType = None, break else: raise HostnameError(error=f"Invalid result: '{result}'") - return hostname diff --git a/tobiko/tests/scenario/neutron/test_floating_ip.py b/tobiko/tests/scenario/neutron/test_floating_ip.py index 575d71dfe..65ae43f1b 100644 --- a/tobiko/tests/scenario/neutron/test_floating_ip.py +++ b/tobiko/tests/scenario/neutron/test_floating_ip.py @@ -44,7 +44,7 @@ class FloatingIPTest(testtools.TestCase): def test_ssh(self): """Test SSH connectivity to floating IP address""" - hostname = sh.get_hostname(ssh_client=self.stack.ssh_client) + hostname = sh.ssh_hostname(ssh_client=self.stack.ssh_client) self.assertEqual(self.stack.server_name.lower(), hostname) def test_ping(self): diff --git a/tobiko/tests/scenario/neutron/test_network.py b/tobiko/tests/scenario/neutron/test_network.py index 5af510e77..55bcee365 100644 --- a/tobiko/tests/scenario/neutron/test_network.py +++ b/tobiko/tests/scenario/neutron/test_network.py @@ -38,7 +38,7 @@ class NetworkTest(testtools.TestCase): def test_ssh(self): """Test SSH connectivity to floating IP address""" - hostname = sh.get_hostname(ssh_client=self.stack.ssh_client) + hostname = sh.ssh_hostname(ssh_client=self.stack.ssh_client) self.assertEqual(self.stack.server_name.lower(), hostname) def test_ping(self): diff --git a/tobiko/tests/unit/shell/sh/test_hostname.py b/tobiko/tests/unit/shell/sh/test_hostname.py new file mode 100644 index 000000000..bdfc4641c --- /dev/null +++ b/tobiko/tests/unit/shell/sh/test_hostname.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 Red Hat, Inc. +# +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import absolute_import + +import socket +from unittest import mock + +import paramiko + +from tobiko.shell import sh +from tobiko.shell import ssh +from tobiko.tests import unit + + +class HostnameTest(unit.TobikoUnitTest): + + def mock_ssh_client(self, + stdout='mocked-hostname\n', + stderr='', + exit_status=0) \ + -> ssh.SSHClientFixture: + channel_mock = mock.MagicMock(spec=paramiko.Channel, + exit_status=exit_status) + channel_mock.recv.side_effect = [bytes(stdout, 'utf-8'), + EOFError, + EOFError] * 10 + channel_mock.recv_stderr.side_effect = [bytes(stderr, 'utf-8'), + EOFError, + EOFError] * 10 + + client_mock = mock.MagicMock(spec=ssh.SSHClientFixture) + client_mock.connect().get_transport().open_session.return_value = \ + channel_mock + return client_mock + + def test_get_hostname_with_no_ssh_client(self): + hostname = sh.get_hostname(ssh_client=False) + self.assertEqual(socket.gethostname(), hostname) + + def test_get_hostname_with_ssh_client(self): + ssh_client = self.mock_ssh_client() + hostname = sh.get_hostname(ssh_client=ssh_client) + self.assertEqual('mocked-hostname', hostname) + self.assertIs(hostname, + sh.get_hostname(ssh_client=ssh_client)) + + def test_get_hostname_with_no_cached(self): + ssh_client = self.mock_ssh_client() + hostname = sh.get_hostname(ssh_client=ssh_client, + cached=False) + self.assertEqual('mocked-hostname', hostname) + self.assertIsNot(hostname, + sh.get_hostname(ssh_client=ssh_client, + cached=False)) + + def test_get_hostname_with_ssh_proxy(self): + ssh_client = self.mock_ssh_client() + self.patch(ssh, 'ssh_client_fixture', return_value=ssh_client) + hostname = sh.get_hostname(ssh_client=None) + self.assertEqual('mocked-hostname', hostname) + + def test_get_hostname_with_ssh_client_no_output(self): + ssh_client = self.mock_ssh_client(stdout='\n') + ex = self.assertRaises(sh.HostNameError, + sh.get_hostname, + ssh_client=ssh_client) + self.assertIn('Invalid result', str(ex)) + + def test_get_hostname_with_ssh_client_and_failure(self): + ssh_client = self.mock_ssh_client(exit_status=1, + stderr='command not found') + ex = self.assertRaises(sh.HostNameError, + sh.get_hostname, + ssh_client=ssh_client) + self.assertIn('command not found', str(ex))