Merge "Implement cert cache for vault units (v3)"

This commit is contained in:
Zuul 2023-01-23 11:13:22 +00:00 committed by Gerrit Code Review
commit 15ab73ea72
5 changed files with 915 additions and 21 deletions

View File

@ -9,6 +9,7 @@ includes:
- interface:hacluster
- interface:vault-kv
- interface:tls-certificates
- interface:vault-ha
options:
basic:
use_venv: True

View File

@ -1,7 +1,11 @@
import hvac
from subprocess import check_output, CalledProcessError
from tempfile import NamedTemporaryFile
import charmhelpers.contrib.network.ip as ch_ip
import charmhelpers.core.hookenv as hookenv
from charms.reactive.relations import endpoint_from_name
from . import vault
@ -9,6 +13,9 @@ CHARM_PKI_MP = "charm-pki-local"
CHARM_PKI_ROLE = "local"
CHARM_PKI_ROLE_CLIENT = "local-client"
PKI_CACHE_KEY = "pki"
TOP_LEVEL_CERT_KEY = "top_level"
def configure_pki_backend(client, name, ttl=None, max_ttl=None):
"""Ensure a pki backend is enabled
@ -370,3 +377,197 @@ def update_roles(**kwargs):
local.update(**kwargs)
del local['server_flag']
write_roles(client, **local)
def verify_cert(ca_cert, untrusted_cert):
"""Verify that the 'untrusted_cert' is signed by the 'ca_cert'.
:param ca_cert: CA certificate that should sign the untrusted cert.
:param untrusted_cert: Certificate that is verified by the CA cert.
:return: True if CA cert can verify the untrusted cert
:rtype: bool
"""
with NamedTemporaryFile() as ca_file, NamedTemporaryFile() as cert_file:
ca_file.write(ca_cert.encode("UTF-8"))
ca_file.flush()
cert_file.write(untrusted_cert.encode("UTF-8"))
cert_file.flush()
try:
verify_cmd = ['openssl', 'verify', '-CAfile',
ca_file.name, cert_file.name]
check_output(verify_cmd)
except CalledProcessError as exc:
hookenv.log(
"Certificate verification failed: {}".format(exc.output),
hookenv.WARNING
)
return False
else:
return True
def get_pki_cache(unit_name):
"""Fetch and parse PKI from the leader storage.
Returned dictionary contains certificates and keys issued by the vault
leader unit as a response to requests from other charms. The structure
loosely matches the format in which the certificates are shared via data
in the `tls-certificates` relation.
See `tls_certificates_common.CertificateRequest.set_cert()` for more info
on the structure.
:return: Dictionary containing certs and keys generated by the leader unit
:rtype: dict
"""
unit_pki_cache_key = "{}_{}".format(PKI_CACHE_KEY, unit_name)
cluster = endpoint_from_name('cluster')
return cluster.get_unit_pki(unit_pki_cache_key)
def find_cert_in_cache(request):
"""Return certificate and key from cache that match the request.
Returned certificate is validated against the current CA cert. If CA cert
is missing, or certificate fails validation or it's simply not found,
returned value is None, None
:param request: Request for certificate from "client" unit.
:type request: tls_certificates_common.CertificateRequest
:return: Certificate and private key from cache
:rtype: Union[(str, str), (None, None)]
"""
try:
ca_chain = get_chain()
except (hvac.exceptions.VaultDown, hvac.exceptions.InvalidPath):
# Fetching CA chain may fail
ca_chain = None
ca_cert = ca_chain or get_ca()
if not ca_cert:
hookenv.log('CA cert not found. Skipping certificate cache lookup.',
hookenv.DEBUG)
return None, None
unit_data = get_pki_cache(request.unit_name)
try:
if request._is_top_level_server_cert:
cert = unit_data[TOP_LEVEL_CERT_KEY][request._server_cert_key]
key = unit_data[TOP_LEVEL_CERT_KEY][request._server_key_key]
else:
cert = unit_data[request._publish_key][request.common_name]['cert']
key = unit_data[request._publish_key][request.common_name]['key']
except (KeyError, TypeError):
hookenv.log('Certificate for "{}" (cn: "{}") not found in '
'cache.'.format(request.unit_name, request.common_name),
hookenv.DEBUG)
return None, None
if verify_cert(ca_cert, cert):
return cert, key
else:
hookenv.log('Certificate from cache for "{}" (cn: "{}") is no longer'
'valid and wont be reused.'.format(request.unit_name,
request.common_name))
return None, None
def update_cert_cache(request, cert, key):
"""Store certificate and key in the cache.
Stored values are associated with the request from "client" unit,
so it can be later retrieved when the request is handled again.
:param request: Request for certificate from "client" unit.
:type request: tls_certificates_common.CertificateRequest
:param cert: Issued certificate for the "client" request (in PEM format)
:type cert: str
:param key: Issued private key from the "client" request (in PEM format)
:type key: str
:return: None
"""
unit_cache = get_pki_cache(request.unit_name)
if request._is_top_level_server_cert:
unit_cache[TOP_LEVEL_CERT_KEY] = {
request._server_cert_key: cert,
request._server_key_key: key,
}
else:
structured_certs = unit_cache.get(request._publish_key, {})
structured_certs[request.common_name] = {
'cert': cert,
'key': key,
}
unit_cache[request._publish_key] = structured_certs
hookenv.log('Saving certificate for "{}" '
'(cn: "{}") into cache.'.format(request.unit_name,
request.common_name),
hookenv.DEBUG)
unit_pki_cache_key = "{}_{}".format(PKI_CACHE_KEY, request.unit_name)
cluster = endpoint_from_name('cluster')
cluster.set_unit_pki(unit_pki_cache_key, unit_cache)
def remove_unit_from_cache(unit_name):
"""Clear certificates and keys related to the unit from the cache.
:param unit_name: Name of the unit to be removed from the cache.
:type unit_name: str
:return: None
"""
hookenv.log('Removing certificates for unit "{}" from '
'cache.'.format(unit_name), hookenv.DEBUG)
unit_pki_cache_key = "{}_{}".format(PKI_CACHE_KEY, unit_name)
cluster = endpoint_from_name('cluster')
cluster.set_unit_pki(unit_pki_cache_key, None)
def populate_cert_cache(tls_endpoint):
"""Store previously issued certificates in the cache.
This function is used when vault charm is upgraded from older version
that may not have a certificate cache to a version that has it. It
goes through all previously issued certificates and stores them in
cache.
:param tls_endpoint: Endpoint of "certificates" relation
:type tls_endpoint: interface_tls_certificates.provides.TlsProvides
:return: None
"""
hookenv.log(
"Populating certificate cache with data from relations", hookenv.INFO
)
for request in tls_endpoint.all_requests:
try:
if request._is_top_level_server_cert:
relation_data = request._unit.relation.to_publish_raw
cert = relation_data[request._server_cert_key]
key = relation_data[request._server_key_key]
else:
relation_data = request._unit.relation.to_publish
cert = relation_data[request._publish_key][
request.common_name
]['cert']
key = relation_data[request._publish_key][
request.common_name
]['key']
except (KeyError, TypeError):
if request._is_top_level_server_cert:
cert_id = request._server_cert_key
else:
cert_id = request.common_name
hookenv.log(
'Certificate "{}" (or associated key) issued for unit "{}" '
'not found in relation data.'.format(
cert_id, request._unit.unit_name
),
hookenv.WARNING
)
continue
update_cert_cache(request, cert, key)

View File

@ -38,6 +38,7 @@ from charmhelpers.core.hookenv import (
log,
network_get_primary_address,
open_port,
remote_unit,
status_set,
unit_private_ip,
)
@ -342,6 +343,11 @@ def upgrade_charm():
remove_state('vault.nrpe.configured')
remove_state('vault.ssl.configured')
remove_state('vault.requested-lb')
# mkalcok: When upgrading from version of a charm that did not have a
# certificate cache, we need to populate the cache with already issued
# certificates. Otherwise the non-leader units would not be able to sync
# their certificate data via cache.
set_flag('needs-cert-cache-repopulation')
@when_not("is-update-status-hook")
@ -1083,23 +1089,104 @@ def publish_global_client_cert():
log("Vault not authorized: Skipping publish_global_client_cert",
"WARNING")
return
cert_created = is_flag_set('charm.vault.global-client-cert.created')
reissue_requested = is_flag_set('certificates.reissue.global.requested')
tls = endpoint_from_flag('certificates.available')
if not cert_created or reissue_requested:
cluster = endpoint_from_name('cluster')
bundle = cluster.get_global_client_cert()
certificate_present = "certificate" in bundle and "private_key" in bundle
if not certificate_present or reissue_requested:
ttl = config()['default-ttl']
max_ttl = config()['max-ttl']
bundle = vault_pki.generate_certificate('client',
'global-client',
[], ttl, max_ttl)
unitdata.kv().set('charm.vault.global-client-cert', bundle)
cluster.set_global_client_cert(bundle)
set_flag('charm.vault.global-client-cert.created')
clear_flag('certificates.reissue.global.requested')
else:
bundle = unitdata.kv().get('charm.vault.global-client-cert')
tls.set_client_cert(bundle['certificate'], bundle['private_key'])
@when('certificates.available')
@when('charm.vault.ca.ready')
@when('leadership.is_leader')
@when('needs-cert-cache-repopulation')
def repopulate_cert_cache():
"""Force repopulation of cert cache on the leader.
Certain circumstances such as 'upgrade-charm' hook should force the leader
to populate the cert cache, so then non-leaders will follow in the
'sync_cert_from_cache' method."""
tls = endpoint_from_flag('certificates.available')
if tls:
vault_pki.populate_cert_cache(tls)
clear_flag('needs-cert-cache-repopulation')
@when("certificates.available")
@when_not('leadership.is_leader')
def sync_cert_from_cache():
"""Sync cert and key data in the tls-certificate relation.
Non-leader units should keep the relation data up-to-date according
to the data from PKI cache that's maintained by the leader. This ensures
that "client" units can use data from any of the related vault units to
receive valid keys and certificates.
"""
tls = endpoint_from_flag('certificates.available')
cert_requests = tls.all_requests
ca = vault_pki.get_ca()
if not ca:
# Don't bother syncing now if we are in a state
# which we don't have a CA, defer syncing to later
return
# propagate CA cert
tls.set_ca(ca)
try:
# this might fail if we were restarted and need to be unsealed
chain = vault_pki.get_chain()
except (
vault.hvac.exceptions.VaultDown,
vault.hvac.exceptions.InvalidPath,
):
pass
except vault.VaultNotReady:
# With Vault not being ready, there's no sense in continuing
return
except vault.hvac.exceptions.InternalServerError:
# We either cannot communicate with Vault or
# parse a CA/Chain in this state, defer syncing to later
return
else:
tls.set_chain(chain)
# propagate global client cert from cache
cluster = endpoint_from_name('cluster')
bundle = cluster.get_global_client_cert()
if bundle.get('certificate') and bundle.get('private_key'):
tls.set_client_cert(bundle['certificate'], bundle['private_key'])
# update certificate data in relations
for request in cert_requests:
cache_cert, cache_key = vault_pki.find_cert_in_cache(request)
if cache_cert and cache_key:
request.set_cert(cache_cert, cache_key)
@hook('certificates-relation-departed')
def cert_client_leaving(relation):
"""Remove certs and keys of the departing unit from cache."""
if is_flag_set('leadership.is_leader'):
# mkalcok: Due to certificates requests replacing "/" in the unit
# name with "_" (see: tls_certificates_common.CertificateRequest),
# we must emulate the same behavior when removing unit certs from
# cache.
departing_unit = remote_unit()
log("Removing certificates for {} from cache.".format(departing_unit))
unit_name = departing_unit.replace('/', '_')
vault_pki.remove_unit_from_cache(unit_name)
@when_not("is-update-status-hook")
@when('leadership.is_leader',
'charm.vault.ca.ready',
@ -1127,6 +1214,16 @@ def create_certs():
processed_applications.append(request.application_name)
else:
cert_type = request.cert_type
cache_cert, cache_key = vault_pki.find_cert_in_cache(request)
if not reissue_requested and cache_cert and cache_key:
# If valid certificates are in cache, and re-issue was not
# requested, reuse them.
log("Reusing certificate for unit '{}' and CN '{}' from "
"cache.".format(request.unit_name, request.common_name))
request.set_cert(cache_cert, cache_key)
continue
try:
ttl = config()['default-ttl']
max_ttl = config()['max-ttl']
@ -1134,6 +1231,9 @@ def create_certs():
request.common_name,
request.sans, ttl, max_ttl)
request.set_cert(bundle['certificate'], bundle['private_key'])
vault_pki.update_cert_cache(request,
bundle["certificate"],
bundle["private_key"])
except vault.VaultInvalidRequest as e:
log(str(e), level=ERROR)
continue # TODO: report failure back to client

View File

@ -1,5 +1,5 @@
from unittest import mock
from unittest.mock import patch
from unittest.mock import call, patch, MagicMock
import hvac
@ -12,7 +12,9 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase):
def setUp(self):
super(TestLibCharmVaultPKI, self).setUp()
self.obj = vault_pki
self.patches = []
self.patches = [
'endpoint_from_name',
]
self.patch_all()
@patch.object(vault_pki.vault, 'is_backend_mounted')
@ -459,3 +461,339 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase):
client_flag=True)
),
])
def test_get_pki_cache(self):
"""Test retrieving PKI from cache."""
expected_pki = {
vault_pki.TOP_LEVEL_CERT_KEY: {
"client_unit_0.server.cert": "cert_data",
"client_unit_0.server.key": "key_data",
}
}
cluster_relation = MagicMock()
self.endpoint_from_name.return_value = cluster_relation
cluster_relation.get_unit_pki.return_value = expected_pki
pki = vault_pki.get_pki_cache('client_unit_0')
cluster_relation.get_unit_pki.assert_called_once_with(
'pki_client_unit_0')
self.assertEqual(pki, expected_pki)
# test retrieval if the PKI is not set
cluster_relation.get_unit_pki.return_value = {}
cluster_relation.get_unit_pki.reset_mock()
pki = vault_pki.get_pki_cache('client_unit_0')
cluster_relation.get_unit_pki.assert_called_once_with(
'pki_client_unit_0')
self.assertEqual(pki, {})
@patch.object(vault_pki, 'get_pki_cache')
@patch.object(vault_pki, 'get_chain')
@patch.object(vault_pki, 'get_ca')
def test_find_cert_in_cache_no_ca(self, get_ca, get_chain, get_pki_cache):
"""Test getting cert from cache when CA is missing."""
get_ca.return_value = None
get_chain.return_value = None
cert, key = vault_pki.find_cert_in_cache(MagicMock())
# assert that CA cert or chain was retrieved
get_ca.assert_called_once_with()
get_chain.assert_called_once_with()
# assert that function does not proceed due to the missing CA
get_pki_cache.assert_not_called()
self.assertIsNone(cert)
self.assertIsNone(key)
@patch.object(vault_pki, 'verify_cert')
@patch.object(vault_pki, 'get_pki_cache')
@patch.object(vault_pki, 'get_chain')
@patch.object(vault_pki, 'get_ca')
def test_find_cert_in_cache_missing(self, get_ca, get_chain,
get_pki_cache, verify_cache):
"""Test use case when searched certificate is not in cache."""
request = MagicMock()
request.unit_name = "client_unit_0"
request._is_top_level_server_cert = True
get_ca.return_value = MagicMock()
get_pki_cache.return_value = {}
cert, key = vault_pki.find_cert_in_cache(request)
# assert that verification of cert is not attempted when
# cert is not found
verify_cache.assert_not_called()
self.assertIsNone(cert)
self.assertIsNone(key)
# Same scenario, but with non-top-level certificate
request._is_top_level_server_cert = False
cert, key = vault_pki.find_cert_in_cache(request)
verify_cache.assert_not_called()
self.assertIsNone(cert)
self.assertIsNone(key)
@patch.object(vault_pki, 'get_pki_cache')
@patch.object(vault_pki, 'get_chain')
@patch.object(vault_pki, 'get_ca')
def test_find_cert_in_cache_err(self, get_ca, get_chain, get_pki_cache):
"""Test getting cert from cache when CA is missing."""
get_ca.return_value = None
get_chain.side_effect = hvac.exceptions.InvalidPath
cert, key = vault_pki.find_cert_in_cache(MagicMock())
# assert that CA cert or chain was retrieved
get_ca.assert_called_once_with()
get_chain.assert_called_once_with()
# assert that function does not proceed due to the missing CA
get_pki_cache.assert_not_called()
self.assertIsNone(cert)
self.assertIsNone(key)
@patch.object(vault_pki, 'verify_cert')
@patch.object(vault_pki, 'get_pki_cache')
@patch.object(vault_pki, 'get_chain')
@patch.object(vault_pki, 'get_ca')
def test_find_cert_in_cache_top_level(self, get_ca, get_chain,
get_pki_cache, verify_cache):
"""Test fetching top level cert from cache.
Additional test scenario: Test that nothing is returned if cert fails
CA verification.
"""
ca_cert = "CA cert data"
expected_cert = "cert data"
expected_key = "key data"
cert_name = "server.cert"
key_name = "server.key"
client_name = "client_unit_0"
# setup cert request
request = MagicMock()
request.unit_name = client_name
request._is_top_level_server_cert = True
request._server_cert_key = cert_name
request._server_key_key = key_name
# PKI cache content
pki = {
vault_pki.TOP_LEVEL_CERT_KEY: {
cert_name: expected_cert,
key_name: expected_key
}
}
get_ca.return_value = ca_cert
get_chain.return_value = ca_cert
get_pki_cache.return_value = pki
verify_cache.return_value = True
cert, key = vault_pki.find_cert_in_cache(request)
verify_cache.assert_called_once_with(ca_cert, expected_cert)
self.assertEqual(cert, expected_cert)
self.assertEqual(key, expected_key)
# Additional test: Nothing should be returned if cert failed
# CA verification.
verify_cache.reset_mock()
verify_cache.return_value = False
cert, key = vault_pki.find_cert_in_cache(request)
verify_cache.assert_called_once_with(ca_cert, expected_cert)
self.assertIsNone(cert)
self.assertIsNone(key)
@patch.object(vault_pki, 'verify_cert')
@patch.object(vault_pki, 'get_pki_cache')
@patch.object(vault_pki, 'get_chain')
@patch.object(vault_pki, 'get_ca')
def test_find_cert_in_cache_not_top_level(self, get_ca, get_chain,
get_pki_cache, verify_cache):
"""Test fetching non-top level cert from cache.
Additional test scenario: Test that nothing is returned if cert fails
CA verification.
"""
ca_cert = "CA cert data"
expected_cert = "cert data"
expected_key = "key data"
client_name = "client_unit_0"
publish_key = client_name + ".processed_client_requests"
common_name = "client.0"
# setup cert request
request = MagicMock()
request.unit_name = client_name
request._is_top_level_server_cert = False
request._publish_key = publish_key
request.common_name = common_name
# PKI cache content
pki = {
publish_key: {
common_name: {
"cert": expected_cert,
"key": expected_key,
}
}
}
get_ca.return_value = ca_cert
get_chain.return_value = ca_cert
get_pki_cache.return_value = pki
verify_cache.return_value = True
cert, key = vault_pki.find_cert_in_cache(request)
verify_cache.assert_called_once_with(ca_cert, expected_cert)
self.assertEqual(cert, expected_cert)
self.assertEqual(key, expected_key)
# Additional test: Nothing should be returned if cert failed
# CA verification.
verify_cache.reset_mock()
verify_cache.return_value = False
cert, key = vault_pki.find_cert_in_cache(request)
verify_cache.assert_called_once_with(ca_cert, expected_cert)
self.assertIsNone(cert)
self.assertIsNone(key)
@patch.object(vault_pki, 'get_pki_cache')
def test_update_cert_cache_top_level_cert(self, get_pki_cache):
"""Test storing top-level cert in cache."""
cert_data = "cert data"
key_data = "key data"
cert_name = "server.cert"
key_name = "server.key"
client_name = "client_unit_0"
# setup cert request
request = MagicMock()
request.unit_name = client_name
request.common_name = client_name
request._is_top_level_server_cert = True
request._server_cert_key = cert_name
request._server_key_key = key_name
cluster_relation = MagicMock()
self.endpoint_from_name.return_value = cluster_relation
# PKI structure
initial_pki = {}
expected_pki = {
vault_pki.TOP_LEVEL_CERT_KEY: {
cert_name: cert_data,
key_name: key_data
}
}
get_pki_cache.return_value = initial_pki
vault_pki.update_cert_cache(request, cert_data, key_data)
key = "{}_{}".format(vault_pki.PKI_CACHE_KEY, client_name)
cluster_relation.set_unit_pki.assert_called_once_with(
key, expected_pki)
@patch.object(vault_pki, 'get_pki_cache')
def test_update_cert_cache_non_top_level_cert(self, get_pki_cache):
"""Test storing non-top-level cert in cache."""
cert_data = "cert data"
key_data = "key data"
client_name = "client_unit_0"
publish_key = client_name + ".processed_client_requests"
common_name = "client.0"
cluster_relation = MagicMock()
self.endpoint_from_name.return_value = cluster_relation
# setup cert request
request = MagicMock()
request.unit_name = client_name
request._is_top_level_server_cert = False
request._publish_key = publish_key
request.common_name = common_name
# PKI structure
initial_pki = {}
expected_pki = {
publish_key: {
common_name: {
"cert": cert_data,
"key": key_data,
}
}
}
get_pki_cache.return_value = initial_pki
vault_pki.update_cert_cache(request, cert_data, key_data)
key = "{}_{}".format(vault_pki.PKI_CACHE_KEY, client_name)
cluster_relation.set_unit_pki.assert_called_once_with(
key, expected_pki)
def test_remove_unit_from_cache(self):
"""Test removing unit certificates from cache."""
cluster_relation = MagicMock()
self.endpoint_from_name.return_value = cluster_relation
vault_pki.remove_unit_from_cache('client_0')
key = "{}_{}".format(vault_pki.PKI_CACHE_KEY, 'client_0')
cluster_relation.set_unit_pki.assert_called_once_with(key, None)
@patch.object(vault_pki, 'update_cert_cache')
def test_populate_cert_cache(self, update_cert_cache):
# Define data for top level certificate and key
top_level_cert_name = "server.crt"
top_level_key_name = "server.key"
top_level_cert_data = "top level cert"
top_level_key_data = "top level key"
# Define data for non-top level certificate
processed_request_cn = "juju_unit_service.crt"
processed_request_publish_key = "juju_unit_service.processed"
processed_cert_data = "processed cert"
processed_key_data = "processed key"
# Mock request for top level certificate
top_level_request = MagicMock()
top_level_request._is_top_level_server_cert = True
top_level_request._server_cert_key = top_level_cert_name
top_level_request._server_key_key = top_level_key_name
top_level_request._unit.relation.to_publish_raw = {
top_level_cert_name: top_level_cert_data,
top_level_key_name: top_level_key_data,
}
# Mock request for non-top level certificate
processed_request = MagicMock()
processed_request._is_top_level_server_cert = False
processed_request.common_name = processed_request_cn
processed_request._publish_key = processed_request_publish_key
processed_request._unit.relation.to_publish = {
processed_request_publish_key: {processed_request_cn: {
"cert": processed_cert_data,
"key": processed_key_data
}}
}
tls_endpoint = MagicMock()
tls_endpoint.all_requests = [top_level_request, processed_request]
vault_pki.populate_cert_cache(tls_endpoint)
expected_update_calls = [
call(top_level_request, top_level_cert_data, top_level_key_data),
call(processed_request, processed_cert_data, processed_key_data),
]
update_cert_cache.assert_has_calls(expected_update_calls)

View File

@ -2,6 +2,7 @@ from unittest import mock
from unittest.mock import patch, call
import charms.reactive
import hvac
# Mock out reactive decorators prior to importing reactive.vault
dec_mock = mock.MagicMock()
@ -247,6 +248,16 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
mock.call('vault.ssl.configured')]
handlers.upgrade_charm()
self.remove_state.assert_has_calls(calls)
self.set_flag.assert_called_once_with(
'needs-cert-cache-repopulation')
@mock.patch.object(handlers, 'vault_pki')
def test_repopulate_cert_cache(self, mock_vault_pki):
handlers.repopulate_cert_cache()
mock_vault_pki.populate_cert_cache.assert_called_once_with(
self.endpoint_from_flag.return_value)
self.clear_flag.assert_called_once_with(
'needs-cert-cache-repopulation')
def test_request_db(self):
psql = mock.MagicMock()
@ -979,14 +990,18 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
self, vault_pki, _client_approle_authorized):
_client_approle_authorized.return_value = True
tls = self.endpoint_from_flag.return_value
self.is_flag_set.side_effect = [True, False]
self.unitdata.kv().get.return_value = {'certificate': 'crt',
'private_key': 'key'}
self.is_flag_set.return_value = False
cluster_relation = mock.MagicMock()
self.endpoint_from_name.return_value = cluster_relation
cluster_relation.get_global_client_cert.return_value = {
'certificate': 'crt',
'private_key': 'key'
}
handlers.publish_global_client_cert()
assert not vault_pki.generate_certificate.called
assert not self.set_flag.called
self.unitdata.kv().get.assert_called_with('charm.vault.'
'global-client-cert')
cluster_relation.get_global_client_cert.assert_called_with()
tls.set_client_cert.assert_called_with('crt', 'key')
@mock.patch.object(handlers, 'client_approle_authorized')
@ -999,9 +1014,16 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
'max-ttl': '3456h',
}
tls = self.endpoint_from_flag.return_value
cluster_relation = mock.MagicMock()
self.endpoint_from_name.return_value = cluster_relation
self.is_flag_set.side_effect = [True, True]
cluster_relation.get_global_client_cert.return_value = {
'certificate': 'stale_cert',
'private_key': 'stale_key'
}
tls = self.endpoint_from_flag.return_value
self.is_flag_set.return_value = True
bundle = {'certificate': 'crt',
'private_key': 'key'}
vault_pki.generate_certificate.return_value = bundle
@ -1011,9 +1033,7 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
[],
'3456h',
'3456h')
self.unitdata.kv().set.assert_called_with('charm.vault.'
'global-client-cert',
bundle)
cluster_relation.set_global_client_cert.assert_called_with(bundle)
self.set_flag.assert_called_with('charm.vault.'
'global-client-cert.created')
tls.set_client_cert.assert_called_with('crt', 'key')
@ -1028,8 +1048,13 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
'max-ttl': '3456h',
}
cluster_relation = mock.MagicMock()
self.endpoint_from_name.return_value = cluster_relation
cluster_relation.get_global_client_cert.return_value = {}
tls = self.endpoint_from_flag.return_value
self.is_flag_set.side_effect = [False, False]
self.is_flag_set.return_value = False
bundle = {'certificate': 'crt',
'private_key': 'key'}
vault_pki.generate_certificate.return_value = bundle
@ -1039,15 +1064,15 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
[],
'3456h',
'3456h')
self.unitdata.kv().set.assert_called_with('charm.vault.'
'global-client-cert',
bundle)
cluster_relation.set_global_client_cert.assert_called_with(
bundle)
self.set_flag.assert_called_with('charm.vault.'
'global-client-cert.created')
tls.set_client_cert.assert_called_with('crt', 'key')
@mock.patch.object(handlers, 'vault_pki')
def test_create_certs(self, vault_pki):
vault_pki.find_cert_in_cache.return_value = (None, None)
self.config.return_value = {
'default-ttl': '3456h',
'max-ttl': '3456h',
@ -1064,12 +1089,18 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
mock.Mock(cert_type='cert_type2',
common_name='common_name2',
sans='sans2')]
expected_cache_calls = [call(request) for request in tls.new_requests]
vault_pki.generate_certificate.side_effect = [
{'certificate': 'crt1', 'private_key': 'key1'},
handlers.vault.VaultInvalidRequest,
{'certificate': 'crt2', 'private_key': 'key2'},
]
expected_cache_update_calls = [
call(tls.new_requests[0], "crt1", "key1"),
call(tls.new_requests[2], "crt2", "key2"),
]
handlers.create_certs()
vault_pki.find_cert_in_cache.assert_has_calls(expected_cache_calls)
vault_pki.generate_certificate.assert_has_calls([
mock.call('cert_type1', 'common_name1', 'sans1',
'3456h', '3456h'),
@ -1085,6 +1116,229 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase):
tls.new_requests[2].set_cert.assert_has_calls([
mock.call('crt2', 'key2'),
])
vault_pki.update_cert_cache.assert_has_calls(
expected_cache_update_calls
)
@mock.patch.object(handlers, 'vault_pki')
def test_create_certs_from_cache(self, vault_pki):
"""Serve certificates from cache if they are available."""
cert_cache = (
("common_name1_cert", "common_name1_key"),
("common_name2_cert", "common_name2_key"),
)
vault_pki.find_cert_in_cache.side_effect = cert_cache
tls = self.endpoint_from_flag.return_value
self.is_flag_set.return_value = False
tls.new_requests = [mock.Mock(cert_type='cert_type1',
common_name='common_name1',
sans='sans1'),
mock.Mock(cert_type='cert_type2',
common_name='common_name2',
sans='sans2'),
]
handlers.create_certs()
vault_pki.generate_certificate.assert_not_called()
for index, request in enumerate(tls.new_requests):
request.set_cert.assert_called_once_with(*cert_cache[index])
@mock.patch.object(handlers, 'vault_pki')
def test_create_certs_reissue(self, vault_pki):
"""Test that certificates are not served from cache on reissue.
Even when certificates are available from cache, they should not
be reused if reissue was requested.
"""
self.config.return_value = {
'default-ttl': '3456h',
'max-ttl': '3456h',
}
cert_cache = (
("common_name1_cert", "common_name1_key"),
("common_name2_cert", "common_name2_key"),
)
new_certs = (
{"certificate": "cn1_new_cert", "private_key": "cn1_new_key"},
{"certificate": "cn2_new_cert", "private_key": "cn2_new_key"},
)
vault_pki.find_cert_in_cache.side_effect = cert_cache
vault_pki.generate_certificate.side_effect = new_certs
tls = self.endpoint_from_flag.return_value
self.is_flag_set.return_value = True
tls.all_requests = [mock.Mock(cert_type='cert_type1',
common_name='common_name1',
sans='sans1'),
mock.Mock(cert_type='cert_type2',
common_name='common_name2',
sans='sans2'),
]
expected_cache_update_calls = (
call(tls.all_requests[0],
new_certs[0]["certificate"],
new_certs[0]["private_key"]),
call(tls.all_requests[1],
new_certs[1]["certificate"],
new_certs[1]["private_key"]),
)
handlers.create_certs()
vault_pki.generate_certificate.assert_has_calls([
mock.call('cert_type1', 'common_name1', 'sans1',
'3456h', '3456h'),
mock.call('cert_type2', 'common_name2', 'sans2',
'3456h', '3456h')
])
for index, request in enumerate(tls.new_requests):
request.set_cert.assert_called_once_with(
new_certs[index]["certificate"],
new_certs[index]["private_key"],
)
vault_pki.update_cert_cache.assert_has_calls(
expected_cache_update_calls
)
@mock.patch.object(handlers, 'vault_pki')
@mock.patch.object(handlers, 'remote_unit')
def test_cert_client_leaving(self, remote_unit, vault_pki):
"""Test that certificates are removed from cache on unit departure."""
# This should be performed only on leader unit
self.is_flag_set.return_value = True
unit_name = "client/0"
cache_unit_id = "client_0"
remote_unit.return_value = unit_name
handlers.cert_client_leaving(mock.MagicMock())
vault_pki.remove_unit_from_cache.assert_called_once_with(cache_unit_id)
# non-leaders should not perform this action
vault_pki.remove_unit_from_cache.reset_mock()
self.is_flag_set.return_value = False
handlers.cert_client_leaving(mock.MagicMock())
vault_pki.remove_unit_from_cache.assert_not_called()
@mock.patch.object(handlers, 'vault_pki')
def test_sync_cert_from_cache(self, vault_pki):
"""Test that non-leaders copy data from cache to relations."""
global_client_bundle = {
"certificate": "Global client cert",
"private_key": "Global client key",
}
cluster_relation = mock.MagicMock()
self.endpoint_from_name.return_value = cluster_relation
cluster_relation.get_global_client_cert.return_value = (
global_client_bundle
)
certs_in_cache = (
("cn1_cert", "cn1_key"),
("cn2_cert", "cn2_key"),
)
vault_pki.find_cert_in_cache.side_effect = certs_in_cache
self.is_flag_set.return_value = False
tls = self.endpoint_from_flag.return_value
self.is_flag_set.return_value = True
tls.all_requests = [mock.Mock(cert_type='cert_type1',
common_name='common_name1',
sans='sans1'),
mock.Mock(cert_type='cert_type2',
common_name='common_name2',
sans='sans2'),
]
handlers.sync_cert_from_cache()
tls.set_client_cert.assert_called_once_with(
global_client_bundle["certificate"],
global_client_bundle["private_key"],
)
for index, request in enumerate(tls.all_requests):
request.set_cert.assert_called_once_with(
certs_in_cache[index][0],
certs_in_cache[index][1],
)
@mock.patch.object(handlers, 'vault_pki')
def test_sync_cert_from_cache_no_ca(self, vault_pki):
"""Test that non-leaders copy data from cache to relations."""
vault_pki.get_ca.return_value = None
handlers.sync_cert_from_cache()
vault_pki.get_ca.assert_called_once_with()
tls = self.endpoint_from_flag.return_value
tls.set_ca.assert_not_called()
@mock.patch.object(handlers, 'vault_pki')
def test_sync_cert_from_cache_no_chain_err(self, vault_pki):
"""Test that non-leaders copy data from cache to relations."""
vault_pki.get_chain.side_effect = hvac.exceptions.InternalServerError
handlers.sync_cert_from_cache()
vault_pki.get_ca.assert_called_once_with()
tls = self.endpoint_from_flag.return_value
tls.set_ca.assert_called_once_with(vault_pki.get_ca.return_value)
vault_pki.get_chain.assert_called_once_with()
tls.set_chain.assert_not_called()
@mock.patch.object(handlers, 'vault_pki')
@mock.patch.object(handlers, 'leader_get')
def test_sync_cert_from_cache_err(self, leader_get, vault_pki):
"""Test that it gracefully fails if get_chain doesn't succeed."""
global_client_bundle = {
"certificate": "Global client cert",
"private_key": "Global client key",
}
cluster_relation = mock.MagicMock()
self.endpoint_from_name.return_value = cluster_relation
cluster_relation.get_global_client_cert.return_value = (
global_client_bundle
)
certs_in_cache = (
("cn1_cert", "cn1_key"),
("cn2_cert", "cn2_key"),
)
vault_pki.find_cert_in_cache.side_effect = certs_in_cache
vault_pki.get_chain.side_effect = hvac.exceptions.InvalidPath
self.is_flag_set.return_value = False
tls = self.endpoint_from_flag.return_value
self.is_flag_set.return_value = True
tls.set_chain.assert_not_called()
tls.all_requests = [mock.Mock(cert_type='cert_type1',
common_name='common_name1',
sans='sans1'),
mock.Mock(cert_type='cert_type2',
common_name='common_name2',
sans='sans2'),
]
handlers.sync_cert_from_cache()
tls.set_client_cert.assert_called_once_with(
global_client_bundle["certificate"],
global_client_bundle["private_key"],
)
for index, request in enumerate(tls.all_requests):
request.set_cert.assert_called_once_with(
certs_in_cache[index][0],
certs_in_cache[index][1],
)
@mock.patch.object(handlers, 'vault_pki')
def test_tune_pki_backend(self, vault_pki):