Refactor ssh_compute_remove() in nova

This removes one loop, but also changes the unit test so that it is not
dependent on the implementation of the function, specifically w.r.t. the
number of writes to the output file (which can now be changed in the
future).

Change-Id: Ieb0a373ed55971af0c357fa89c199fb781e772ab
This commit is contained in:
Alex Kavanagh 2019-06-24 10:50:49 +01:00
parent 7d33c617b2
commit 978e29012f
2 changed files with 25 additions and 11 deletions

View File

@ -1121,23 +1121,30 @@ def ssh_authorized_keys_lines(unit=None, user=None):
def ssh_compute_remove(public_key, unit=None, user=None):
"""Remove a key from the authorized_keys file for the unit/user
:param public_key: the key to remove
:type public_key: str
:param unit: The unit (as identified by Juju) to reference (default None)
:type unit: Union[str, None]
:param user: The username to reference (default None)
:type user: Union[str, None]
"""
if not (os.path.isfile(authorized_keys(unit, user)) or
os.path.isfile(known_hosts(unit, user))):
return
with open(authorized_keys(unit, user), 'rt') as _keys:
keys = [k.strip() for k in _keys.readlines()]
with open(authorized_keys(unit, user), 'rt') as f:
keys = [k.strip() for k in f.readlines()]
if public_key not in keys:
return
[keys.remove(key) for key in keys if key == public_key]
with open(authorized_keys(unit, user), 'wt') as _keys:
keys = '\n'.join(keys)
if not keys.endswith('\n'):
keys += '\n'
_keys.write(keys)
with open(authorized_keys(unit, user), 'wt') as f:
out = "\n".join([key for key in keys if key != public_key])
if not out.endswith('\n'):
out += '\n'
f.write(out)
def determine_endpoints(public_url, internal_url, admin_url):

View File

@ -731,12 +731,19 @@ class NovaCCUtilsTests(CharmTestCase):
)
isfile.return_value = True
self.remote_unit.return_value = 'nova-compute/2'
_written = ""
def _writer(s):
nonlocal _written
_written += s
with patch_open() as (_open, _file):
_file.readlines = MagicMock()
_file.write = MagicMock()
_file.write.side_effect = _writer
_file.readlines.return_value = AUTHORIZED_KEYS.split('\n')
utils.ssh_compute_remove(removed_key)
_file.write.assert_called_with(keys_removed)
self.assertEqual(_written, keys_removed)
def test_determine_endpoints_base(self):
self.relation_ids.return_value = []