Add caching mechanism to get_hostname function
Change-Id: I5c810740d030acd1141b4904cd478eb03df975e7
This commit is contained in:
parent
06038db65b
commit
c6aa9b7b06
|
@ -51,6 +51,7 @@ ShellExecuteResult = _execute.ShellExecuteResult
|
|||
|
||||
HostNameError = _hostname.HostnameError
|
||||
get_hostname = _hostname.get_hostname
|
||||
ssh_hostname = _hostname.ssh_hostname
|
||||
|
||||
join_chunks = _io.join_chunks
|
||||
ShellStdout = _io.ShellStdout
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import socket
|
||||
import typing
|
||||
import weakref
|
||||
|
||||
import tobiko
|
||||
from tobiko.shell.sh import _exception
|
||||
|
@ -27,13 +29,36 @@ class HostnameError(tobiko.TobikoException):
|
|||
message = "Unable to get hostname from host: {error}"
|
||||
|
||||
|
||||
HOSTNAMES_CACHE: typing.MutableMapping[typing.Optional[ssh.SSHClientFixture],
|
||||
str] = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
def get_hostname(ssh_client: ssh.SSHClientType = None,
|
||||
cached=True,
|
||||
**execute_params) -> str:
|
||||
if ssh_client is False:
|
||||
ssh_client = ssh.ssh_client_fixture(ssh_client)
|
||||
if ssh_client is None:
|
||||
return socket.gethostname()
|
||||
|
||||
tobiko.check_valid_type(ssh_client, ssh.SSHClientFixture,
|
||||
type(None))
|
||||
if cached:
|
||||
try:
|
||||
hostname = HOSTNAMES_CACHE[ssh_client]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return hostname
|
||||
|
||||
hostname = ssh_hostname(ssh_client=ssh_client,
|
||||
**execute_params)
|
||||
if cached:
|
||||
HOSTNAMES_CACHE[ssh_client] = hostname
|
||||
return hostname
|
||||
|
||||
|
||||
def ssh_hostname(ssh_client: ssh.SSHClientFixture,
|
||||
**execute_params) \
|
||||
-> str:
|
||||
tobiko.check_valid_type(ssh_client, ssh.SSHClientFixture)
|
||||
try:
|
||||
result = _execute.execute('hostname',
|
||||
ssh_client=ssh_client,
|
||||
|
@ -48,5 +73,4 @@ def get_hostname(ssh_client: ssh.SSHClientType = None,
|
|||
break
|
||||
else:
|
||||
raise HostnameError(error=f"Invalid result: '{result}'")
|
||||
|
||||
return hostname
|
||||
|
|
|
@ -44,7 +44,7 @@ class FloatingIPTest(testtools.TestCase):
|
|||
|
||||
def test_ssh(self):
|
||||
"""Test SSH connectivity to floating IP address"""
|
||||
hostname = sh.get_hostname(ssh_client=self.stack.ssh_client)
|
||||
hostname = sh.ssh_hostname(ssh_client=self.stack.ssh_client)
|
||||
self.assertEqual(self.stack.server_name.lower(), hostname)
|
||||
|
||||
def test_ping(self):
|
||||
|
|
|
@ -38,7 +38,7 @@ class NetworkTest(testtools.TestCase):
|
|||
|
||||
def test_ssh(self):
|
||||
"""Test SSH connectivity to floating IP address"""
|
||||
hostname = sh.get_hostname(ssh_client=self.stack.ssh_client)
|
||||
hostname = sh.ssh_hostname(ssh_client=self.stack.ssh_client)
|
||||
self.assertEqual(self.stack.server_name.lower(), hostname)
|
||||
|
||||
def test_ping(self):
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) 2021 Red Hat, Inc.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import absolute_import
|
||||
|
||||
import socket
|
||||
from unittest import mock
|
||||
|
||||
import paramiko
|
||||
|
||||
from tobiko.shell import sh
|
||||
from tobiko.shell import ssh
|
||||
from tobiko.tests import unit
|
||||
|
||||
|
||||
class HostnameTest(unit.TobikoUnitTest):
|
||||
|
||||
def mock_ssh_client(self,
|
||||
stdout='mocked-hostname\n',
|
||||
stderr='',
|
||||
exit_status=0) \
|
||||
-> ssh.SSHClientFixture:
|
||||
channel_mock = mock.MagicMock(spec=paramiko.Channel,
|
||||
exit_status=exit_status)
|
||||
channel_mock.recv.side_effect = [bytes(stdout, 'utf-8'),
|
||||
EOFError,
|
||||
EOFError] * 10
|
||||
channel_mock.recv_stderr.side_effect = [bytes(stderr, 'utf-8'),
|
||||
EOFError,
|
||||
EOFError] * 10
|
||||
|
||||
client_mock = mock.MagicMock(spec=ssh.SSHClientFixture)
|
||||
client_mock.connect().get_transport().open_session.return_value = \
|
||||
channel_mock
|
||||
return client_mock
|
||||
|
||||
def test_get_hostname_with_no_ssh_client(self):
|
||||
hostname = sh.get_hostname(ssh_client=False)
|
||||
self.assertEqual(socket.gethostname(), hostname)
|
||||
|
||||
def test_get_hostname_with_ssh_client(self):
|
||||
ssh_client = self.mock_ssh_client()
|
||||
hostname = sh.get_hostname(ssh_client=ssh_client)
|
||||
self.assertEqual('mocked-hostname', hostname)
|
||||
self.assertIs(hostname,
|
||||
sh.get_hostname(ssh_client=ssh_client))
|
||||
|
||||
def test_get_hostname_with_no_cached(self):
|
||||
ssh_client = self.mock_ssh_client()
|
||||
hostname = sh.get_hostname(ssh_client=ssh_client,
|
||||
cached=False)
|
||||
self.assertEqual('mocked-hostname', hostname)
|
||||
self.assertIsNot(hostname,
|
||||
sh.get_hostname(ssh_client=ssh_client,
|
||||
cached=False))
|
||||
|
||||
def test_get_hostname_with_ssh_proxy(self):
|
||||
ssh_client = self.mock_ssh_client()
|
||||
self.patch(ssh, 'ssh_client_fixture', return_value=ssh_client)
|
||||
hostname = sh.get_hostname(ssh_client=None)
|
||||
self.assertEqual('mocked-hostname', hostname)
|
||||
|
||||
def test_get_hostname_with_ssh_client_no_output(self):
|
||||
ssh_client = self.mock_ssh_client(stdout='\n')
|
||||
ex = self.assertRaises(sh.HostNameError,
|
||||
sh.get_hostname,
|
||||
ssh_client=ssh_client)
|
||||
self.assertIn('Invalid result', str(ex))
|
||||
|
||||
def test_get_hostname_with_ssh_client_and_failure(self):
|
||||
ssh_client = self.mock_ssh_client(exit_status=1,
|
||||
stderr='command not found')
|
||||
ex = self.assertRaises(sh.HostNameError,
|
||||
sh.get_hostname,
|
||||
ssh_client=ssh_client)
|
||||
self.assertIn('command not found', str(ex))
|
Loading…
Reference in New Issue