Add caching mechanism to get_hostname function

Change-Id: I5c810740d030acd1141b4904cd478eb03df975e7
This commit is contained in:
Federico Ressi 2021-07-06 12:49:03 +02:00
parent 06038db65b
commit c6aa9b7b06
5 changed files with 119 additions and 6 deletions

View File

@ -51,6 +51,7 @@ ShellExecuteResult = _execute.ShellExecuteResult
HostNameError = _hostname.HostnameError
get_hostname = _hostname.get_hostname
ssh_hostname = _hostname.ssh_hostname
join_chunks = _io.join_chunks
ShellStdout = _io.ShellStdout

View File

@ -16,6 +16,8 @@
from __future__ import absolute_import
import socket
import typing
import weakref
import tobiko
from tobiko.shell.sh import _exception
@ -27,13 +29,36 @@ class HostnameError(tobiko.TobikoException):
message = "Unable to get hostname from host: {error}"
HOSTNAMES_CACHE: typing.MutableMapping[typing.Optional[ssh.SSHClientFixture],
str] = weakref.WeakKeyDictionary()
def get_hostname(ssh_client: ssh.SSHClientType = None,
cached=True,
**execute_params) -> str:
if ssh_client is False:
ssh_client = ssh.ssh_client_fixture(ssh_client)
if ssh_client is None:
return socket.gethostname()
tobiko.check_valid_type(ssh_client, ssh.SSHClientFixture,
type(None))
if cached:
try:
hostname = HOSTNAMES_CACHE[ssh_client]
except KeyError:
pass
else:
return hostname
hostname = ssh_hostname(ssh_client=ssh_client,
**execute_params)
if cached:
HOSTNAMES_CACHE[ssh_client] = hostname
return hostname
def ssh_hostname(ssh_client: ssh.SSHClientFixture,
**execute_params) \
-> str:
tobiko.check_valid_type(ssh_client, ssh.SSHClientFixture)
try:
result = _execute.execute('hostname',
ssh_client=ssh_client,
@ -48,5 +73,4 @@ def get_hostname(ssh_client: ssh.SSHClientType = None,
break
else:
raise HostnameError(error=f"Invalid result: '{result}'")
return hostname

View File

@ -44,7 +44,7 @@ class FloatingIPTest(testtools.TestCase):
def test_ssh(self):
"""Test SSH connectivity to floating IP address"""
hostname = sh.get_hostname(ssh_client=self.stack.ssh_client)
hostname = sh.ssh_hostname(ssh_client=self.stack.ssh_client)
self.assertEqual(self.stack.server_name.lower(), hostname)
def test_ping(self):

View File

@ -38,7 +38,7 @@ class NetworkTest(testtools.TestCase):
def test_ssh(self):
"""Test SSH connectivity to floating IP address"""
hostname = sh.get_hostname(ssh_client=self.stack.ssh_client)
hostname = sh.ssh_hostname(ssh_client=self.stack.ssh_client)
self.assertEqual(self.stack.server_name.lower(), hostname)
def test_ping(self):

View File

@ -0,0 +1,88 @@
# Copyright (c) 2021 Red Hat, Inc.
#
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import socket
from unittest import mock
import paramiko
from tobiko.shell import sh
from tobiko.shell import ssh
from tobiko.tests import unit
class HostnameTest(unit.TobikoUnitTest):
def mock_ssh_client(self,
stdout='mocked-hostname\n',
stderr='',
exit_status=0) \
-> ssh.SSHClientFixture:
channel_mock = mock.MagicMock(spec=paramiko.Channel,
exit_status=exit_status)
channel_mock.recv.side_effect = [bytes(stdout, 'utf-8'),
EOFError,
EOFError] * 10
channel_mock.recv_stderr.side_effect = [bytes(stderr, 'utf-8'),
EOFError,
EOFError] * 10
client_mock = mock.MagicMock(spec=ssh.SSHClientFixture)
client_mock.connect().get_transport().open_session.return_value = \
channel_mock
return client_mock
def test_get_hostname_with_no_ssh_client(self):
hostname = sh.get_hostname(ssh_client=False)
self.assertEqual(socket.gethostname(), hostname)
def test_get_hostname_with_ssh_client(self):
ssh_client = self.mock_ssh_client()
hostname = sh.get_hostname(ssh_client=ssh_client)
self.assertEqual('mocked-hostname', hostname)
self.assertIs(hostname,
sh.get_hostname(ssh_client=ssh_client))
def test_get_hostname_with_no_cached(self):
ssh_client = self.mock_ssh_client()
hostname = sh.get_hostname(ssh_client=ssh_client,
cached=False)
self.assertEqual('mocked-hostname', hostname)
self.assertIsNot(hostname,
sh.get_hostname(ssh_client=ssh_client,
cached=False))
def test_get_hostname_with_ssh_proxy(self):
ssh_client = self.mock_ssh_client()
self.patch(ssh, 'ssh_client_fixture', return_value=ssh_client)
hostname = sh.get_hostname(ssh_client=None)
self.assertEqual('mocked-hostname', hostname)
def test_get_hostname_with_ssh_client_no_output(self):
ssh_client = self.mock_ssh_client(stdout='\n')
ex = self.assertRaises(sh.HostNameError,
sh.get_hostname,
ssh_client=ssh_client)
self.assertIn('Invalid result', str(ex))
def test_get_hostname_with_ssh_client_and_failure(self):
ssh_client = self.mock_ssh_client(exit_status=1,
stderr='command not found')
ex = self.assertRaises(sh.HostNameError,
sh.get_hostname,
ssh_client=ssh_client)
self.assertIn('command not found', str(ex))