From 2e2d19f7c73f4ccf7eae0f348b44e0ee809b3bda Mon Sep 17 00:00:00 2001 From: Alex Kavanagh Date: Wed, 19 Jun 2019 17:08:45 +0100 Subject: [PATCH] Refactor import_authorized_keys() function for performance The main change is to fetch all of the relation_data() at once, and then iterate through the python dictionary. This speeds up processing of potentially hundreds of hosts and authorized_keys. Change-Id: I095104f535c1eae1554f842502ae93ebb92e44fe Related-Bug: #1833420 --- hooks/nova_compute_utils.py | 73 +++++++++++++-------------- unit_tests/test_nova_compute_utils.py | 50 +++++++++--------- 2 files changed, 59 insertions(+), 64 deletions(-) diff --git a/hooks/nova_compute_utils.py b/hooks/nova_compute_utils.py index f524f0f9..24d6cd88 100644 --- a/hooks/nova_compute_utils.py +++ b/hooks/nova_compute_utils.py @@ -598,39 +598,33 @@ def import_authorized_keys(user='root', prefix=None): """Import SSH authorized_keys + known_hosts from a cloud-compute relation. Store known_hosts in user's $HOME/.ssh and authorized_keys in a path specified using authorized-keys-path config option. - """ - known_hosts = [] - authorized_keys = [] - if prefix: - known_hosts_index = relation_get( - '{}_known_hosts_max_index'.format(prefix)) - if known_hosts_index: - for index in range(0, int(known_hosts_index)): - known_hosts.append(relation_get( - '{}_known_hosts_{}'.format(prefix, index))) - authorized_keys_index = relation_get( - '{}_authorized_keys_max_index'.format(prefix)) - if authorized_keys_index: - for index in range(0, int(authorized_keys_index)): - authorized_keys.append(relation_get( - '{}_authorized_keys_{}'.format(prefix, index))) - else: - # XXX: Should this be managed via templates + contexts? - known_hosts_index = relation_get('known_hosts_max_index') - if known_hosts_index: - for index in range(0, int(known_hosts_index)): - known_hosts.append(relation_get( - 'known_hosts_{}'.format(index))) - authorized_keys_index = relation_get('authorized_keys_max_index') - if authorized_keys_index: - for index in range(0, int(authorized_keys_index)): - authorized_keys.append(relation_get( - 'authorized_keys_{}'.format(index))) - # XXX: Should partial return of known_hosts or authorized_keys - # be allowed ? - if not len(known_hosts) or not len(authorized_keys): + The relation_get data is a series of key values of the form: + + [prefix_]known_hosts_max_index: + [prefix_]authorized_keys_max_index: + + [prefix_]known_hosts_[n]: + [prefix_]authorized_keys_[n]: + + :param user: the user to write the known hosts and keys for (default 'root) + :type user: str + :param prefix: A prefix to add to the relation data keys (default None) + :type prefix: Option[str, None] + """ + _prefix = "{}_".format(prefix) if prefix else "" + + # get all the data at once with one relation_get call + rdata = relation_get() or {} + + known_hosts_index = int( + rdata.get('{}known_hosts_max_index'.format(_prefix), '0')) + authorized_keys_index = int( + rdata.get('{}authorized_keys_max_index'.format(_prefix), '0')) + + if known_hosts_index == 0 or authorized_keys_index == 0: return + homedir = pwd.getpwnam(user).pw_dir dest_auth_keys = config('authorized-keys-path').format( homedir=homedir, username=user) @@ -638,12 +632,17 @@ def import_authorized_keys(user='root', prefix=None): log('Saving new known_hosts file to %s and authorized_keys file to: %s.' % (dest_known_hosts, dest_auth_keys)) - with open(dest_known_hosts, 'wt') as _hosts: - for index in range(0, int(known_hosts_index)): - _hosts.write('{}\n'.format(known_hosts[index])) - with open(dest_auth_keys, 'wt') as _keys: - for index in range(0, int(authorized_keys_index)): - _keys.write('{}\n'.format(authorized_keys[index])) + # write known hosts using data from relation_get + with open(dest_known_hosts, 'wt') as f: + for index in range(known_hosts_index): + f.write("{}\n".format( + rdata.get("{}known_hosts_{}".format(_prefix, index)))) + + # write authorized keys using data from relation_get + with open(dest_auth_keys, 'wt') as f: + for index in range(authorized_keys_index): + f.write("{}\n".format( + rdata.get('{}authorized_keys_{}'.format(_prefix, index)))) def do_openstack_upgrade(configs): diff --git a/unit_tests/test_nova_compute_utils.py b/unit_tests/test_nova_compute_utils.py index ff6822a1..77ea3196 100644 --- a/unit_tests/test_nova_compute_utils.py +++ b/unit_tests/test_nova_compute_utils.py @@ -553,16 +553,27 @@ class NovaComputeUtilsTests(CharmTestCase): auth_key_path='/home/foo/.ssh/' 'authorized_keys'): getpwnam.return_value = self.fake_user('foo') - self.relation_get.side_effect = [ - 3, # relation_get('known_hosts_max_index') - 'k_h_0', # relation_get_('known_hosts_0') - 'k_h_1', # relation_get_('known_hosts_1') - 'k_h_2', # relation_get_('known_hosts_2') - 3, # relation_get('authorized_keys_max_index') - 'auth_0', # relation_get('authorized_keys_0') - 'auth_1', # relation_get('authorized_keys_1') - 'auth_2', # relation_get('authorized_keys_2') - ] + + d = { + 'known_hosts_max_index': 3, + 'known_hosts_0': 'k_h_0', + 'known_hosts_1': 'k_h_1', + 'known_hosts_2': 'k_h_2', + 'authorized_keys_max_index': 3, + 'authorized_keys_0': 'auth_0', + 'authorized_keys_1': 'auth_1', + 'authorized_keys_2': 'auth_2', + } + if prefix: + for k, v in d.copy().items(): + d["{}_{}".format(prefix, k)] = v + + def _relation_get(scope=None, *args, **kwargs): + if scope is not None: + return d.get(scope, None) + return d + + self.relation_get.side_effect = _relation_get ex_open = [ call('/home/foo/.ssh/known_hosts', 'wt'), @@ -577,27 +588,12 @@ class NovaComputeUtilsTests(CharmTestCase): call('auth_2\n') ] + # we only have to verify that the files are writen as expected as this + # implicitly checks that the relation_get calls have occurred. with patch_open() as (_open, _file): utils.import_authorized_keys(user='foo', prefix=prefix) self.assertEqual(ex_open, _open.call_args_list) self.assertEqual(ex_write, _file.write.call_args_list) - authkey_root = 'authorized_keys_' - known_hosts_root = 'known_hosts_' - if prefix: - authkey_root = prefix + '_authorized_keys_' - known_hosts_root = prefix + '_known_hosts_' - expected_relations = [ - call(known_hosts_root + 'max_index'), - call(known_hosts_root + '0'), - call(known_hosts_root + '1'), - call(known_hosts_root + '2'), - call(authkey_root + 'max_index'), - call(authkey_root + '0'), - call(authkey_root + '1'), - call(authkey_root + '2') - ] - self.assertEqual(sorted(self.relation_get.call_args_list), - sorted(expected_relations)) def test_import_authorized_keys_noprefix(self): self._test_import_authorized_keys_base()