diff --git a/src/lib/charm/vault_pki.py b/src/lib/charm/vault_pki.py index 0403a8e..c6f0631 100644 --- a/src/lib/charm/vault_pki.py +++ b/src/lib/charm/vault_pki.py @@ -1,4 +1,8 @@ import hvac +import json + +from subprocess import check_output, CalledProcessError +from tempfile import NamedTemporaryFile import charmhelpers.contrib.network.ip as ch_ip import charmhelpers.core.hookenv as hookenv @@ -9,6 +13,9 @@ CHARM_PKI_MP = "charm-pki-local" CHARM_PKI_ROLE = "local" CHARM_PKI_ROLE_CLIENT = "local-client" +PKI_CACHE_KEY = "pki" +TOP_LEVEL_CERT_KEY = "top_level" + def configure_pki_backend(client, name, ttl=None, max_ttl=None): """Ensure a pki backend is enabled @@ -337,3 +344,197 @@ def update_roles(**kwargs): local.update(**kwargs) del local['server_flag'] write_roles(client, **local) + + +def verify_cert(ca_cert, untrusted_cert): + """Verify that the 'untrusted_cert' is signed by the 'ca_cert'. + + :param ca_cert: CA certificate that should sign the untrusted cert. + :param untrusted_cert: Certificate that is verified by the CA cert. + :return: True if CA cert can verify the untrusted cert + :rtype: bool + """ + with NamedTemporaryFile() as ca_file, NamedTemporaryFile() as cert_file: + ca_file.write(ca_cert.encode("UTF-8")) + ca_file.flush() + + cert_file.write(untrusted_cert.encode("UTF-8")) + cert_file.flush() + + try: + verify_cmd = ['openssl', 'verify', '-CAfile', + ca_file.name, cert_file.name] + check_output(verify_cmd) + except CalledProcessError as exc: + hookenv.log( + "Certificate verification failed: {}".format(exc.output), + hookenv.WARNING + ) + return False + else: + return True + + +def get_pki_cache(): + """Fetch and parse PKI from the leader storage. + + Returned dictionary contains certificates and keys issued by the vault + leader unit as a response to requests from other charms. The structure + loosely matches the format in which the certificates are shared via data + in the `tls-certificates` relation. + See `tls_certificates_common.CertificateRequest.set_cert()` for more info + on the structure. + + :return: Dictionary containing certs and keys generated by the leader unit + :rtype: dict + """ + raw_cache = hookenv.leader_get(PKI_CACHE_KEY) or '{}' + return json.loads(raw_cache) + + +def find_cert_in_cache(request): + """Return certificate and key from cache that match the request. + + Returned certificate is validated against the current CA cert. If CA cert + is missing, or certificate fails validation or it's simply not found, + returned value is None, None + + :param request: Request for certificate from "client" unit. + :type request: tls_certificates_common.CertificateRequest + :return: Certificate and private key from cache + :rtype: Union[(str, str), (None, None)] + """ + try: + ca_chain = get_chain() + except (hvac.exceptions.VaultDown, TypeError): + # Fetching CA chain may fail + ca_chain = None + + ca_cert = ca_chain or get_ca() + if not ca_cert: + hookenv.log('CA cert not found. Skipping certificate cache lookup.', + hookenv.DEBUG) + return None, None + + pki_cache = get_pki_cache() + unit_data = pki_cache.get(request.unit_name, {}) + + try: + if request._is_top_level_server_cert: + cert = unit_data[TOP_LEVEL_CERT_KEY][request._server_cert_key] + key = unit_data[TOP_LEVEL_CERT_KEY][request._server_key_key] + else: + cert = unit_data[request._publish_key][request.common_name]['cert'] + key = unit_data[request._publish_key][request.common_name]['key'] + except (KeyError, TypeError): + hookenv.log('Certificate for "{}" (cn: "{}") not found in ' + 'cache.'.format(request.unit_name, request.common_name), + hookenv.DEBUG) + return None, None + + if verify_cert(ca_cert, cert): + return cert, key + else: + hookenv.log('Certificate from cache for "{}" (cn: "{}") is no longer' + 'valid and wont be reused.'.format(request.unit_name, + request.common_name)) + return None, None + + +def update_cert_cache(request, cert, key): + """Store certificate and key in the cache. + + Stored values are associated with the request from "client" unit, + so it can be later retrieved when the request is handled again. + + :param request: Request for certificate from "client" unit. + :type request: tls_certificates_common.CertificateRequest + :param cert: Issued certificate for the "client" request (in PEM format) + :type cert: str + :param key: Issued private key from the "client" request (in PEM format) + :type key: str + :return: None + """ + pki_cache = get_pki_cache() + unit_cache = pki_cache.get(request.unit_name, {}) + + if request._is_top_level_server_cert: + unit_cache[TOP_LEVEL_CERT_KEY] = { + request._server_cert_key: cert, + request._server_key_key: key, + } + else: + structured_certs = unit_cache.get(request._publish_key, {}) + structured_certs[request.common_name] = { + 'cert': cert, + 'key': key, + } + unit_cache[request._publish_key] = structured_certs + + hookenv.log('Saving certificate for "{}" ' + '(cn: "{}") into cache.'.format(request.unit_name, + request.common_name), + hookenv.DEBUG) + pki_cache[request.unit_name] = unit_cache + hookenv.leader_set({PKI_CACHE_KEY: json.dumps(pki_cache)}) + + +def remove_unit_from_cache(unit_name): + """Clear certificates and keys related to the unit from the cache. + + :param unit_name: Name of the unit to be removed from the cache. + :type unit_name: str + :return: None + """ + hookenv.log('Removing certificates for unit "{}" from ' + 'cache.'.format(unit_name), hookenv.DEBUG) + pki_cache = get_pki_cache() + pki_cache.pop(unit_name, None) + hookenv.leader_set({PKI_CACHE_KEY: json.dumps(pki_cache)}) + + +def populate_cert_cache(tls_endpoint): + """Store previously issued certificates in the cache. + + This function is used when vault charm is upgraded from older version + that may not have a certificate cache to a version that has it. It + goes through all previously issued certificates and stores them in + cache. + + :param tls_endpoint: Endpoint of "certificates" relation + :type tls_endpoint: interface_tls_certificates.provides.TlsProvides + :return: None + """ + hookenv.log( + "Populating certificate cache with data from relations", hookenv.INFO + ) + + for request in tls_endpoint.all_requests: + try: + if request._is_top_level_server_cert: + relation_data = request._unit.relation.to_publish_raw + cert = relation_data[request._server_cert_key] + key = relation_data[request._server_key_key] + else: + relation_data = request._unit.relation.to_publish + cert = relation_data[request._publish_key][ + request.common_name + ]['cert'] + key = relation_data[request._publish_key][ + request.common_name + ]['key'] + except (KeyError, TypeError): + if request._is_top_level_server_cert: + cert_id = request._server_cert_key + else: + cert_id = request.common_name + hookenv.log( + 'Certificate "{}" (or associated key) issued for unit "{}" ' + 'not found in relation data.'.format( + cert_id, request._unit.unit_name + ), + hookenv.WARNING + ) + continue + + update_cert_cache(request, cert, key) diff --git a/src/reactive/vault_handlers.py b/src/reactive/vault_handlers.py index 5bfcc2c..6011da6 100644 --- a/src/reactive/vault_handlers.py +++ b/src/reactive/vault_handlers.py @@ -1,4 +1,5 @@ import base64 +import json import os import psycopg2 import subprocess @@ -35,6 +36,7 @@ from charmhelpers.core.hookenv import ( log, network_get_primary_address, open_port, + remote_unit, status_set, unit_private_ip, ) @@ -285,6 +287,13 @@ def upgrade_charm(): remove_state('vault.nrpe.configured') remove_state('vault.ssl.configured') remove_state('vault.requested-lb') + # mkalcok: When upgrading from version of a charm that did not have a + # certificate cache, we need to populate the cache with already issued + # certificates. Otherwise the non-leader units would not be able to sync + # their certificate data via cache. + tls = endpoint_from_flag('certificates.available') + if tls and is_flag_set('leadership.is_leader'): + vault_pki.populate_cert_cache(tls) @when_not("is-update-status-hook") @@ -980,23 +989,72 @@ def publish_global_client_cert(): log("Vault not authorized: Skipping publish_global_client_cert", "WARNING") return - cert_created = is_flag_set('charm.vault.global-client-cert.created') reissue_requested = is_flag_set('certificates.reissue.global.requested') tls = endpoint_from_flag('certificates.available') - if not cert_created or reissue_requested: + bundle = json.loads(leader_get('charm.vault.global-client-cert') or '{}') + certificate_present = "certificate" in bundle and "private_key" in bundle + if not certificate_present or reissue_requested: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] bundle = vault_pki.generate_certificate('client', 'global-client', [], ttl, max_ttl) - unitdata.kv().set('charm.vault.global-client-cert', bundle) + leader_set({'charm.vault.global-client-cert': json.dumps(bundle)}) set_flag('charm.vault.global-client-cert.created') clear_flag('certificates.reissue.global.requested') - else: - bundle = unitdata.kv().get('charm.vault.global-client-cert') + tls.set_client_cert(bundle['certificate'], bundle['private_key']) +@when("certificates.available") +@when_not('leadership.is_leader') +def sync_cert_from_cache(): + """Sync cert and key data in the tls-certificate relation. + + Non-leader units should keep the relation data up-to-date according + to the data from PKI cache that's maintained by the leader. This ensures + that "client" units can use data from any of the related vault units to + receive valid keys and certificates. + """ + tls = endpoint_from_flag('certificates.available') + cert_requests = tls.all_requests + + # propagate CA cert + tls.set_ca(vault_pki.get_ca()) + try: + # this might fail if we were restarted and need to be unsealed + chain = vault_pki.get_chain() + except (vault.hvac.exceptions.VaultDown, TypeError): + pass + else: + tls.set_chain(chain) + + # propagate global client cert from cache + bundle = json.loads(leader_get('charm.vault.global-client-cert') or '{}') + if bundle.get('certificate') and bundle.get('private_key'): + tls.set_client_cert(bundle['certificate'], bundle['private_key']) + + # update certificate data in relations + for request in cert_requests: + cache_cert, cache_key = vault_pki.find_cert_in_cache(request) + if cache_cert and cache_key: + request.set_cert(cache_cert, cache_key) + + +@hook('certificates-relation-departed') +def cert_client_leaving(relation): + """Remove certs and keys of the departing unit from cache.""" + if is_flag_set('leadership.is_leader'): + # mkalcok: Due to certificates requests replacing "/" in the unit + # name with "_" (see: tls_certificates_common.CertificateRequest), + # we must emulate the same behavior when removing unit certs from + # cache. + departing_unit = remote_unit() + log("Removing certificates for {} from cache.".format(departing_unit)) + unit_name = departing_unit.replace('/', '_') + vault_pki.remove_unit_from_cache(unit_name) + + @when_not("is-update-status-hook") @when('leadership.is_leader', 'charm.vault.ca.ready', @@ -1024,6 +1082,16 @@ def create_certs(): processed_applications.append(request.application_name) else: cert_type = request.cert_type + + cache_cert, cache_key = vault_pki.find_cert_in_cache(request) + if not reissue_requested and cache_cert and cache_key: + # If valid certificates are in cache, and re-issue was not + # requested, reuse them. + log("Reusing certificate for unit '{}' and CN '{}' from " + "cache.".format(request.unit_name, request.common_name)) + request.set_cert(cache_cert, cache_key) + continue + try: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] @@ -1031,6 +1099,9 @@ def create_certs(): request.common_name, request.sans, ttl, max_ttl) request.set_cert(bundle['certificate'], bundle['private_key']) + vault_pki.update_cert_cache(request, + bundle["certificate"], + bundle["private_key"]) except vault.VaultInvalidRequest as e: log(str(e), level=ERROR) continue # TODO: report failure back to client diff --git a/unit_tests/test_lib_charm_vault_pki.py b/unit_tests/test_lib_charm_vault_pki.py index 5b06e22..ea7a72a 100644 --- a/unit_tests/test_lib_charm_vault_pki.py +++ b/unit_tests/test_lib_charm_vault_pki.py @@ -1,7 +1,8 @@ from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch, MagicMock import hvac +import json import lib.charm.vault_pki as vault_pki import unit_tests.test_utils @@ -515,3 +516,337 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase): server_flag=False, client_flag=True), ]) + + @patch.object(vault_pki.hookenv, 'leader_get') + def test_get_pki_cache(self, leader_get): + """Test retrieving PKI from cache.""" + expected_pki = { + "client_unit_0": { + vault_pki.TOP_LEVEL_CERT_KEY: { + "client_unit_0.server.cert": "cert_data", + "client_unit_0.server.key": "key_data", + } + } + } + leader_get.return_value = json.dumps(expected_pki) + + pki = vault_pki.get_pki_cache() + self.assertEqual(pki, expected_pki) + + # test retrieval if the PKI is not set + leader_get.return_value = None + + pki = vault_pki.get_pki_cache() + self.assertEqual(pki, {}) + + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki, 'get_chain') + @patch.object(vault_pki, 'get_ca') + def test_find_cert_in_cache_no_ca(self, get_ca, get_chain, get_pki_cache): + """Test getting cert from cache when CA is missing.""" + get_ca.return_value = None + get_chain.return_value = None + + cert, key = vault_pki.find_cert_in_cache(MagicMock()) + + # assert that CA cert or chain was retrieved + get_ca.assert_called_once_with() + get_chain.assert_called_once_with() + # assert that function does not proceed due to the missing CA + get_pki_cache.assert_not_called() + + self.assertIsNone(cert) + self.assertIsNone(key) + + @patch.object(vault_pki, 'verify_cert') + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki, 'get_chain') + @patch.object(vault_pki, 'get_ca') + def test_find_cert_in_cache_missing(self, get_ca, get_chain, + get_pki_cache, verify_cache): + """Test use case when searched certificate is not in cache.""" + request = MagicMock() + request.unit_name = "client_unit_0" + request._is_top_level_server_cert = True + + get_ca.return_value = MagicMock() + get_pki_cache.return_value = {} + + cert, key = vault_pki.find_cert_in_cache(request) + + # assert that verification of cert is not attempted when + # cert is not found + verify_cache.assert_not_called() + + self.assertIsNone(cert) + self.assertIsNone(key) + + # Same scenario, but with non-top-level certificate + request._is_top_level_server_cert = False + + cert, key = vault_pki.find_cert_in_cache(request) + + verify_cache.assert_not_called() + self.assertIsNone(cert) + self.assertIsNone(key) + + @patch.object(vault_pki, 'verify_cert') + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki, 'get_chain') + @patch.object(vault_pki, 'get_ca') + def test_find_cert_in_cache_top_level(self, get_ca, get_chain, + get_pki_cache, verify_cache): + """Test fetching top level cert from cache. + + Additional test scenario: Test that nothing is returned if cert fails + CA verification. + """ + ca_cert = "CA cert data" + expected_cert = "cert data" + expected_key = "key data" + cert_name = "server.cert" + key_name = "server.key" + client_name = "client_unit_0" + + # setup cert request + request = MagicMock() + request.unit_name = client_name + request._is_top_level_server_cert = True + request._server_cert_key = cert_name + request._server_key_key = key_name + + # PKI cache content + pki = { + client_name: { + vault_pki.TOP_LEVEL_CERT_KEY: { + cert_name: expected_cert, + key_name: expected_key + } + } + } + + get_ca.return_value = ca_cert + get_chain.return_value = ca_cert + get_pki_cache.return_value = pki + verify_cache.return_value = True + + cert, key = vault_pki.find_cert_in_cache(request) + + verify_cache.assert_called_once_with(ca_cert, expected_cert) + self.assertEqual(cert, expected_cert) + self.assertEqual(key, expected_key) + + # Additional test: Nothing should be returned if cert failed + # CA verification. + verify_cache.reset_mock() + verify_cache.return_value = False + + cert, key = vault_pki.find_cert_in_cache(request) + + verify_cache.assert_called_once_with(ca_cert, expected_cert) + self.assertIsNone(cert) + self.assertIsNone(key) + + @patch.object(vault_pki, 'verify_cert') + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki, 'get_chain') + @patch.object(vault_pki, 'get_ca') + def test_find_cert_in_cache_not_top_level(self, get_ca, get_chain, + get_pki_cache, verify_cache): + """Test fetching non-top level cert from cache. + + Additional test scenario: Test that nothing is returned if cert fails + CA verification. + """ + ca_cert = "CA cert data" + expected_cert = "cert data" + expected_key = "key data" + client_name = "client_unit_0" + publish_key = client_name + ".processed_client_requests" + common_name = "client.0" + + # setup cert request + request = MagicMock() + request.unit_name = client_name + request._is_top_level_server_cert = False + request._publish_key = publish_key + request.common_name = common_name + + # PKI cache content + pki = { + client_name: { + publish_key: { + common_name: { + "cert": expected_cert, + "key": expected_key, + } + } + } + } + + get_ca.return_value = ca_cert + get_chain.return_value = ca_cert + get_pki_cache.return_value = pki + verify_cache.return_value = True + + cert, key = vault_pki.find_cert_in_cache(request) + + verify_cache.assert_called_once_with(ca_cert, expected_cert) + self.assertEqual(cert, expected_cert) + self.assertEqual(key, expected_key) + + # Additional test: Nothing should be returned if cert failed + # CA verification. + verify_cache.reset_mock() + verify_cache.return_value = False + + cert, key = vault_pki.find_cert_in_cache(request) + + verify_cache.assert_called_once_with(ca_cert, expected_cert) + self.assertIsNone(cert) + self.assertIsNone(key) + + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki.hookenv, 'leader_set') + def test_update_cert_cache_top_level_cert(self, leader_set, get_pki_cache): + """Test storing top-level cert in cache.""" + cert_data = "cert data" + key_data = "key data" + cert_name = "server.cert" + key_name = "server.key" + client_name = "client_unit_0" + + # setup cert request + request = MagicMock() + request.unit_name = client_name + request.common_name = client_name + request._is_top_level_server_cert = True + request._server_cert_key = cert_name + request._server_key_key = key_name + + # PKI structure + initial_pki = {} + expected_pki = { + client_name: { + vault_pki.TOP_LEVEL_CERT_KEY: { + cert_name: cert_data, + key_name: key_data + } + } + } + + get_pki_cache.return_value = initial_pki + + vault_pki.update_cert_cache(request, cert_data, key_data) + + leader_set.assert_called_once_with( + {vault_pki.PKI_CACHE_KEY: json.dumps(expected_pki)} + ) + + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki.hookenv, 'leader_set') + def test_update_cert_cache_non_top_level_cert(self, leader_set, + get_pki_cache): + """Test storing non-top-level cert in cache.""" + cert_data = "cert data" + key_data = "key data" + client_name = "client_unit_0" + publish_key = client_name + ".processed_client_requests" + common_name = "client.0" + + # setup cert request + request = MagicMock() + request.unit_name = client_name + request._is_top_level_server_cert = False + request._publish_key = publish_key + request.common_name = common_name + + # PKI structure + initial_pki = {} + expected_pki = { + client_name: { + publish_key: { + common_name: { + "cert": cert_data, + "key": key_data, + } + } + } + } + + get_pki_cache.return_value = initial_pki + + vault_pki.update_cert_cache(request, cert_data, key_data) + + leader_set.assert_called_once_with( + {vault_pki.PKI_CACHE_KEY: json.dumps(expected_pki)} + ) + + @patch.object(vault_pki, 'get_pki_cache') + @patch.object(vault_pki.hookenv, 'leader_set') + def test_remove_unit_from_cache(self, leader_set, get_pki_cache): + """Test removing unit certificates from cache.""" + remaining_unit = "client/0" + removed_unit = "client/1" + pki = { + remaining_unit: "Unit certificates", + removed_unit: "Unit certificates", + } + expected_pki = { + remaining_unit: "Unit certificates" + } + + get_pki_cache.return_value = pki + + vault_pki.remove_unit_from_cache(removed_unit) + + leader_set.assert_called_once_with( + {vault_pki.PKI_CACHE_KEY: json.dumps(expected_pki)} + ) + + @patch.object(vault_pki, 'update_cert_cache') + def test_populate_cert_cache(self, update_cert_cache): + # Define data for top level certificate and key + top_level_cert_name = "server.crt" + top_level_key_name = "server.key" + top_level_cert_data = "top level cert" + top_level_key_data = "top level key" + + # Define data for non-top level certificate + processed_request_cn = "juju_unit_service.crt" + processed_request_publish_key = "juju_unit_service.processed" + processed_cert_data = "processed cert" + processed_key_data = "processed key" + + # Mock request for top level certificate + top_level_request = MagicMock() + top_level_request._is_top_level_server_cert = True + top_level_request._server_cert_key = top_level_cert_name + top_level_request._server_key_key = top_level_key_name + top_level_request._unit.relation.to_publish_raw = { + top_level_cert_name: top_level_cert_data, + top_level_key_name: top_level_key_data, + } + + # Mock request for non-top level certificate + processed_request = MagicMock() + processed_request._is_top_level_server_cert = False + processed_request.common_name = processed_request_cn + processed_request._publish_key = processed_request_publish_key + processed_request._unit.relation.to_publish = { + processed_request_publish_key: {processed_request_cn: { + "cert": processed_cert_data, + "key": processed_key_data + }} + } + + tls_endpoint = MagicMock() + tls_endpoint.all_requests = [top_level_request, processed_request] + + vault_pki.populate_cert_cache(tls_endpoint) + + expected_update_calls = [ + call(top_level_request, top_level_cert_data, top_level_key_data), + call(processed_request, processed_cert_data, processed_key_data), + ] + update_cert_cache.assert_has_calls(expected_update_calls) diff --git a/unit_tests/test_reactive_vault_handlers.py b/unit_tests/test_reactive_vault_handlers.py index 1945b08..f7c966f 100644 --- a/unit_tests/test_reactive_vault_handlers.py +++ b/unit_tests/test_reactive_vault_handlers.py @@ -1,6 +1,8 @@ from unittest import mock from unittest.mock import patch, call +import json + import charms.reactive # Mock out reactive decorators prior to importing reactive.vault @@ -934,35 +936,41 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): self.set_flag.assert_called_with('failed.to.start') assert not _vault.get_client.called + @mock.patch.object(handlers, 'leader_get') @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') def test_publish_global_client_cert_already_gend( - self, vault_pki, _client_approle_authorized): + self, vault_pki, _client_approle_authorized, leader_get): _client_approle_authorized.return_value = True tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [True, False] - self.unitdata.kv().get.return_value = {'certificate': 'crt', - 'private_key': 'key'} + self.is_flag_set.return_value = False + leader_get.return_value = json.dumps({'certificate': 'crt', + 'private_key': 'key'}) handlers.publish_global_client_cert() assert not vault_pki.generate_certificate.called assert not self.set_flag.called - self.unitdata.kv().get.assert_called_with('charm.vault.' - 'global-client-cert') + leader_get.assert_called_with('charm.vault.global-client-cert') tls.set_client_cert.assert_called_with('crt', 'key') + @mock.patch.object(handlers, 'leader_get') + @mock.patch.object(handlers, 'leader_set') @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') def test_publish_global_client_cert_reissue( - self, vault_pki, _client_approle_authorized): + self, vault_pki, _client_approle_authorized, leader_set, + leader_get + ): _client_approle_authorized.return_value = True self.config.return_value = { 'default-ttl': '3456h', 'max-ttl': '3456h', } + leader_get.return_value = json.dumps({'certificate': 'stale_cert', + 'private_key': 'stale_key'}) tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [True, True] + self.is_flag_set.return_value = True bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle @@ -972,17 +980,21 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): [], '3456h', '3456h') - self.unitdata.kv().set.assert_called_with('charm.vault.' - 'global-client-cert', - bundle) + leader_set.assert_called_with({ + 'charm.vault.global-client-cert': json.dumps(bundle) + }) self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') + @mock.patch.object(handlers, 'leader_get') + @mock.patch.object(handlers, 'leader_set') @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') def test_publish_global_client_certe( - self, vault_pki, _client_approle_authorized): + self, vault_pki, _client_approle_authorized, leader_set, + leader_get + ): _client_approle_authorized.return_value = True self.config.return_value = { 'default-ttl': '3456h', @@ -990,7 +1002,8 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): } tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [False, False] + self.is_flag_set.return_value = False + leader_get.return_value = None bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle @@ -1000,15 +1013,16 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): [], '3456h', '3456h') - self.unitdata.kv().set.assert_called_with('charm.vault.' - 'global-client-cert', - bundle) + leader_set.assert_called_with( + {'charm.vault.global-client-cert': json.dumps(bundle)} + ) self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'vault_pki') def test_create_certs(self, vault_pki): + vault_pki.find_cert_in_cache.return_value = (None, None) self.config.return_value = { 'default-ttl': '3456h', 'max-ttl': '3456h', @@ -1025,12 +1039,18 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): mock.Mock(cert_type='cert_type2', common_name='common_name2', sans='sans2')] + expected_cache_calls = [call(request) for request in tls.new_requests] vault_pki.generate_certificate.side_effect = [ {'certificate': 'crt1', 'private_key': 'key1'}, handlers.vault.VaultInvalidRequest, {'certificate': 'crt2', 'private_key': 'key2'}, ] + expected_cache_update_calls = [ + call(tls.new_requests[0], "crt1", "key1"), + call(tls.new_requests[2], "crt2", "key2"), + ] handlers.create_certs() + vault_pki.find_cert_in_cache.assert_has_calls(expected_cache_calls) vault_pki.generate_certificate.assert_has_calls([ mock.call('cert_type1', 'common_name1', 'sans1', '3456h', '3456h'), @@ -1046,6 +1066,153 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): tls.new_requests[2].set_cert.assert_has_calls([ mock.call('crt2', 'key2'), ]) + vault_pki.update_cert_cache.assert_has_calls( + expected_cache_update_calls + ) + + @mock.patch.object(handlers, 'vault_pki') + def test_create_certs_from_cache(self, vault_pki): + """Serve certificates from cache if they are available.""" + cert_cache = ( + ("common_name1_cert", "common_name1_key"), + ("common_name2_cert", "common_name2_key"), + ) + vault_pki.find_cert_in_cache.side_effect = cert_cache + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = False + tls.new_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + + handlers.create_certs() + + vault_pki.generate_certificate.assert_not_called() + for index, request in enumerate(tls.new_requests): + request.set_cert.assert_called_once_with(*cert_cache[index]) + + @mock.patch.object(handlers, 'vault_pki') + def test_create_certs_reissue(self, vault_pki): + """Test that certificates are not served from cache on reissue. + + Even when certificates are available from cache, they should not + be reused if reissue was requested. + """ + self.config.return_value = { + 'default-ttl': '3456h', + 'max-ttl': '3456h', + } + cert_cache = ( + ("common_name1_cert", "common_name1_key"), + ("common_name2_cert", "common_name2_key"), + ) + new_certs = ( + {"certificate": "cn1_new_cert", "private_key": "cn1_new_key"}, + {"certificate": "cn2_new_cert", "private_key": "cn2_new_key"}, + ) + vault_pki.find_cert_in_cache.side_effect = cert_cache + vault_pki.generate_certificate.side_effect = new_certs + + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + expected_cache_update_calls = ( + call(tls.all_requests[0], + new_certs[0]["certificate"], + new_certs[0]["private_key"]), + call(tls.all_requests[1], + new_certs[1]["certificate"], + new_certs[1]["private_key"]), + ) + + handlers.create_certs() + + vault_pki.generate_certificate.assert_has_calls([ + mock.call('cert_type1', 'common_name1', 'sans1', + '3456h', '3456h'), + mock.call('cert_type2', 'common_name2', 'sans2', + '3456h', '3456h') + ]) + + for index, request in enumerate(tls.new_requests): + request.set_cert.assert_called_once_with( + new_certs[index]["certificate"], + new_certs[index]["private_key"], + ) + vault_pki.update_cert_cache.assert_has_calls( + expected_cache_update_calls + ) + + @mock.patch.object(handlers, 'vault_pki') + @mock.patch.object(handlers, 'remote_unit') + def test_cert_client_leaving(self, remote_unit, vault_pki): + """Test that certificates are removed from cache on unit departure.""" + # This should be performed only on leader unit + self.is_flag_set.return_value = True + unit_name = "client/0" + cache_unit_id = "client_0" + remote_unit.return_value = unit_name + + handlers.cert_client_leaving(mock.MagicMock()) + + vault_pki.remove_unit_from_cache.assert_called_once_with(cache_unit_id) + + # non-leaders should not perform this action + vault_pki.remove_unit_from_cache.reset_mock() + self.is_flag_set.return_value = False + + handlers.cert_client_leaving(mock.MagicMock()) + + vault_pki.remove_unit_from_cache.assert_not_called() + + @mock.patch.object(handlers, 'vault_pki') + @mock.patch.object(handlers, 'leader_get') + def test_sync_cert_from_cache(self, leader_get, vault_pki): + """Test that non-leaders copy data from cache to relations.""" + global_client_bundle = { + "certificate": "Global client cert", + "private_key": "Global client key", + } + leader_get.return_value = json.dumps(global_client_bundle) + + certs_in_cache = ( + ("cn1_cert", "cn1_key"), + ("cn2_cert", "cn2_key"), + ) + vault_pki.find_cert_in_cache.side_effect = certs_in_cache + + self.is_flag_set.return_value = False + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + + handlers.sync_cert_from_cache() + + tls.set_client_cert.assert_called_once_with( + global_client_bundle["certificate"], + global_client_bundle["private_key"], + ) + + for index, request in enumerate(tls.all_requests): + request.set_cert.assert_called_once_with( + certs_in_cache[index][0], + certs_in_cache[index][1], + ) @mock.patch.object(handlers, 'vault_pki') def test_tune_pki_backend(self, vault_pki):