diff --git a/octavia/api/drivers/utils.py b/octavia/api/drivers/utils.py index b4a279976c..e7c717fd60 100644 --- a/octavia/api/drivers/utils.py +++ b/octavia/api/drivers/utils.py @@ -21,6 +21,7 @@ from octavia_lib.api.drivers import exceptions as lib_exceptions from oslo_config import cfg from oslo_context import context as oslo_context from oslo_log import log as logging +from oslo_utils import excutils from stevedore import driver as stevedore_driver from octavia.api.drivers import exceptions as driver_exceptions @@ -124,8 +125,8 @@ def _base_to_provider_dict(current_dict, include_project_id=False): # Note: The provider dict returned from this method will have provider # data model objects in it. -def lb_dict_to_provider_dict(lb_dict, vip=None, - db_pools=None, db_listeners=None): +def lb_dict_to_provider_dict(lb_dict, vip=None, db_pools=None, + db_listeners=None, for_delete=False): new_lb_dict = _base_to_provider_dict(lb_dict, include_project_id=True) new_lb_dict['loadbalancer_id'] = new_lb_dict.pop('id') if vip: @@ -139,19 +140,21 @@ def lb_dict_to_provider_dict(lb_dict, vip=None, new_lb_dict['flavor'] = flavor_repo.get_flavor_metadata_dict( db_api.get_session(), lb_dict['flavor_id']) if db_pools: - new_lb_dict['pools'] = db_pools_to_provider_pools(db_pools) + new_lb_dict['pools'] = db_pools_to_provider_pools( + db_pools, for_delete=for_delete) if db_listeners: new_lb_dict['listeners'] = db_listeners_to_provider_listeners( - db_listeners) + db_listeners, for_delete=for_delete) return new_lb_dict -def db_loadbalancer_to_provider_loadbalancer(db_loadbalancer): +def db_loadbalancer_to_provider_loadbalancer(db_loadbalancer, + for_delete=False): new_loadbalancer_dict = lb_dict_to_provider_dict( db_loadbalancer.to_dict(recurse=True), vip=db_loadbalancer.vip, db_pools=db_loadbalancer.pools, - db_listeners=db_loadbalancer.listeners) + db_listeners=db_loadbalancer.listeners, for_delete=for_delete) for unsupported_field in ['server_group_id', 'amphorae', 'vrrp_group', 'topology', 'vip']: if unsupported_field in new_loadbalancer_dict: @@ -161,20 +164,22 @@ def db_loadbalancer_to_provider_loadbalancer(db_loadbalancer): return provider_loadbalancer -def db_listeners_to_provider_listeners(db_listeners): +def db_listeners_to_provider_listeners(db_listeners, for_delete=False): provider_listeners = [] for listener in db_listeners: - provider_listener = db_listener_to_provider_listener(listener) + provider_listener = db_listener_to_provider_listener( + listener, for_delete=for_delete) provider_listeners.append(provider_listener) return provider_listeners -def db_listener_to_provider_listener(db_listener): +def db_listener_to_provider_listener(db_listener, for_delete=False): new_listener_dict = listener_dict_to_provider_dict( - db_listener.to_dict(recurse=True)) + db_listener.to_dict(recurse=True), for_delete=for_delete) if ('default_pool' in new_listener_dict and new_listener_dict['default_pool']): - provider_pool = db_pool_to_provider_pool(db_listener.default_pool) + provider_pool = db_pool_to_provider_pool(db_listener.default_pool, + for_delete=for_delete) new_listener_dict['default_pool_id'] = provider_pool.pool_id new_listener_dict['default_pool'] = provider_pool if new_listener_dict.get('l7policies', None): @@ -184,16 +189,25 @@ def db_listener_to_provider_listener(db_listener): return provider_listener -def _get_secret_data(cert_manager, project_id, secret_ref): +def _get_secret_data(cert_manager, project_id, secret_ref, for_delete=False): """Get the secret from the certificate manager and upload it to the amp. :returns: The secret data. """ context = oslo_context.RequestContext(project_id=project_id) - return cert_manager.get_secret(context, secret_ref) + try: + secret_data = cert_manager.get_secret(context, secret_ref) + except Exception as e: + LOG.warning('Unable to retrieve certificate: %s due to %s.', + secret_ref, str(e)) + if for_delete: + secret_data = None + else: + raise exceptions.CertificateRetrievalException(ref=secret_ref) + return secret_data -def listener_dict_to_provider_dict(listener_dict): +def listener_dict_to_provider_dict(listener_dict, for_delete=False): new_listener_dict = _base_to_provider_dict(listener_dict, include_project_id=True) new_listener_dict['listener_id'] = new_listener_dict.pop('id') @@ -246,8 +260,16 @@ def listener_dict_to_provider_dict(listener_dict): name=CONF.certificates.cert_manager, invoke_on_load=True, ).driver - cert_dict = cert_parser.load_certificates_data(cert_manager, - listener_obj) + try: + cert_dict = cert_parser.load_certificates_data(cert_manager, + listener_obj) + except Exception as e: + with excutils.save_and_reraise_exception() as ctxt: + LOG.warning('Unable to retrieve certificate(s) due to %s.', + str(e)) + if for_delete: + ctxt.reraise = False + cert_dict = {} if 'tls_cert' in cert_dict and cert_dict['tls_cert']: new_listener_dict['default_tls_container_data'] = ( cert_dict['tls_cert'].to_dict(recurse=True)) @@ -287,7 +309,8 @@ def listener_dict_to_provider_dict(listener_dict): if ('default_pool' in new_listener_dict and new_listener_dict['default_pool']): pool = new_listener_dict.pop('default_pool') - new_listener_dict['default_pool'] = pool_dict_to_provider_dict(pool) + new_listener_dict['default_pool'] = pool_dict_to_provider_dict( + pool, for_delete=for_delete) provider_l7policies = [] if 'l7policies' in new_listener_dict: l7policies = new_listener_dict.pop('l7policies') or [] @@ -298,15 +321,17 @@ def listener_dict_to_provider_dict(listener_dict): return new_listener_dict -def db_pools_to_provider_pools(db_pools): +def db_pools_to_provider_pools(db_pools, for_delete=False): provider_pools = [] for pool in db_pools: - provider_pools.append(db_pool_to_provider_pool(pool)) + provider_pools.append(db_pool_to_provider_pool(pool, + for_delete=for_delete)) return provider_pools -def db_pool_to_provider_pool(db_pool): - new_pool_dict = pool_dict_to_provider_dict(db_pool.to_dict(recurse=True)) +def db_pool_to_provider_pool(db_pool, for_delete=False): + new_pool_dict = pool_dict_to_provider_dict(db_pool.to_dict(recurse=True), + for_delete=for_delete) # Replace the sub-dicts with objects if 'health_monitor' in new_pool_dict: del new_pool_dict['health_monitor'] @@ -325,7 +350,7 @@ def db_pool_to_provider_pool(db_pool): return driver_dm.Pool.from_dict(new_pool_dict) -def pool_dict_to_provider_dict(pool_dict): +def pool_dict_to_provider_dict(pool_dict, for_delete=False): new_pool_dict = _base_to_provider_dict(pool_dict, include_project_id=True) new_pool_dict['pool_id'] = new_pool_dict.pop('id') @@ -348,8 +373,16 @@ def pool_dict_to_provider_dict(pool_dict): name=CONF.certificates.cert_manager, invoke_on_load=True, ).driver - cert_dict = cert_parser.load_certificates_data(cert_manager, - pool_obj) + try: + cert_dict = cert_parser.load_certificates_data(cert_manager, + pool_obj) + except Exception as e: + with excutils.save_and_reraise_exception() as ctxt: + LOG.warning('Unable to retrieve certificate(s) due to %s.', + str(e)) + if for_delete: + ctxt.reraise = False + cert_dict = {} if 'tls_cert' in cert_dict and cert_dict['tls_cert']: new_pool_dict['tls_container_data'] = ( cert_dict['tls_cert'].to_dict(recurse=True)) diff --git a/octavia/api/v2/controllers/listener.py b/octavia/api/v2/controllers/listener.py index fc49793ce0..2b81825b7b 100644 --- a/octavia/api/v2/controllers/listener.py +++ b/octavia/api/v2/controllers/listener.py @@ -592,7 +592,8 @@ class ListenersController(base.BaseController): LOG.info("Sending delete Listener %s to provider %s", id, driver.name) provider_listener = ( - driver_utils.db_listener_to_provider_listener(db_listener)) + driver_utils.db_listener_to_provider_listener( + db_listener, for_delete=True)) driver_utils.call_provider(driver.name, driver.listener_delete, provider_listener) diff --git a/octavia/api/v2/controllers/load_balancer.py b/octavia/api/v2/controllers/load_balancer.py index 210240aefe..4c73170411 100644 --- a/octavia/api/v2/controllers/load_balancer.py +++ b/octavia/api/v2/controllers/load_balancer.py @@ -638,7 +638,8 @@ class LoadBalancersController(base.BaseController): LOG.info("Sending delete Load Balancer %s to provider %s", id, driver.name) provider_loadbalancer = ( - driver_utils.db_loadbalancer_to_provider_loadbalancer(db_lb)) + driver_utils.db_loadbalancer_to_provider_loadbalancer( + db_lb, for_delete=True)) driver_utils.call_provider(driver.name, driver.loadbalancer_delete, provider_loadbalancer, cascade) diff --git a/octavia/api/v2/controllers/pool.py b/octavia/api/v2/controllers/pool.py index 701fcf705b..92aee47c25 100644 --- a/octavia/api/v2/controllers/pool.py +++ b/octavia/api/v2/controllers/pool.py @@ -452,7 +452,8 @@ class PoolsController(base.BaseController): LOG.info("Sending delete Pool %s to provider %s", id, driver.name) provider_pool = ( - driver_utils.db_pool_to_provider_pool(db_pool)) + driver_utils.db_pool_to_provider_pool(db_pool, + for_delete=True)) driver_utils.call_provider(driver.name, driver.pool_delete, provider_pool) diff --git a/octavia/common/tls_utils/cert_parser.py b/octavia/common/tls_utils/cert_parser.py index b28c3b22e0..1ca9c1ca60 100644 --- a/octavia/common/tls_utils/cert_parser.py +++ b/octavia/common/tls_utils/cert_parser.py @@ -27,7 +27,7 @@ from pyasn1_modules import rfc2315 import six from octavia.common import data_models -import octavia.common.exceptions as exceptions +from octavia.common import exceptions X509_BEG = b'-----BEGIN CERTIFICATE-----' X509_END = b'-----END CERTIFICATE-----' @@ -354,16 +354,29 @@ def load_certificates_data(cert_mngr, obj, context=None): context = oslo_context.RequestContext(project_id=obj.project_id) if obj.tls_certificate_id: - tls_cert = _map_cert_tls_container( - cert_mngr.get_cert(context, - obj.tls_certificate_id, - check_only=True)) + try: + tls_cert = _map_cert_tls_container( + cert_mngr.get_cert(context, + obj.tls_certificate_id, + check_only=True)) + except Exception as e: + LOG.warning('Unable to retrieve certificate: %s due to %s.', + obj.tls_certificate_id, str(e)) + raise exceptions.CertificateRetrievalException( + ref=obj.tls_certificate_id) + if hasattr(obj, 'sni_containers') and obj.sni_containers: for sni_cont in obj.sni_containers: - cert_container = _map_cert_tls_container( - cert_mngr.get_cert(context, - sni_cont.tls_container_id, - check_only=True)) + try: + cert_container = _map_cert_tls_container( + cert_mngr.get_cert(context, + sni_cont.tls_container_id, + check_only=True)) + except Exception as e: + LOG.warning('Unable to retrieve certificate: %s due to %s.', + sni_cont.tls_container_id, str(e)) + raise exceptions.CertificateRetrievalException( + ref=sni_cont.tls_container_id) sni_certs.append(cert_container) return {'tls_cert': tls_cert, 'sni_certs': sni_certs} diff --git a/octavia/tests/functional/api/v2/test_listener.py b/octavia/tests/functional/api/v2/test_listener.py index ccb0d90780..c130cc1958 100644 --- a/octavia/tests/functional/api/v2/test_listener.py +++ b/octavia/tests/functional/api/v2/test_listener.py @@ -25,6 +25,7 @@ from octavia.common import constants import octavia.common.context from octavia.common import data_models from octavia.common import exceptions +from octavia.db import api as db_api from octavia.tests.common import sample_certs from octavia.tests.functional.api.v2 import base @@ -1849,6 +1850,37 @@ class TestListener(base.BaseAPITest): self.assert_final_listener_statuses(self.lb_id, api_listener['id'], delete=True) + # Problems with TLS certs should not block a delete + def test_delete_with_bad_tls_ref(self): + listener = self.create_listener(constants.PROTOCOL_TCP, + 443, self.lb_id) + tls_uuid = uuidutils.generate_uuid() + self.set_lb_status(self.lb_id) + self.listener_repo.update(db_api.get_session(), + listener['listener']['id'], + tls_certificate_id=tls_uuid, + protocol=constants.PROTOCOL_TERMINATED_HTTPS) + + listener_path = self.LISTENER_PATH.format( + listener_id=listener['listener']['id']) + self.delete(listener_path) + response = self.get(listener_path) + api_listener = response.json['listener'] + expected = {'name': None, 'default_pool_id': None, + 'description': None, 'admin_state_up': True, + 'operating_status': constants.ONLINE, + 'provisioning_status': constants.PENDING_DELETE, + 'connection_limit': None} + listener['listener'].update(expected) + + self.assertIsNone(listener['listener'].pop('updated_at')) + self.assertIsNotNone(api_listener.pop('updated_at')) + self.assertNotEqual(listener, api_listener) + self.assert_correct_lb_status(self.lb_id, constants.ONLINE, + constants.PENDING_UPDATE) + self.assert_final_listener_statuses(self.lb_id, api_listener['id'], + delete=True) + def test_delete_authorized(self): listener = self.create_listener(constants.PROTOCOL_HTTP, 80, self.lb_id) diff --git a/octavia/tests/functional/api/v2/test_pool.py b/octavia/tests/functional/api/v2/test_pool.py index f79a8e0000..f72e01c36d 100644 --- a/octavia/tests/functional/api/v2/test_pool.py +++ b/octavia/tests/functional/api/v2/test_pool.py @@ -22,6 +22,7 @@ from octavia.common import constants import octavia.common.context from octavia.common import data_models from octavia.common import exceptions +from octavia.db import api as db_api from octavia.tests.common import sample_certs from octavia.tests.functional.api.v2 import base @@ -1754,6 +1755,37 @@ class TestPool(base.BaseAPITest): listener_prov_status=constants.PENDING_UPDATE, pool_prov_status=constants.PENDING_DELETE) + # Problems with TLS certs should not block a delete + def test_delete_with_bad_tls_ref(self): + api_pool = self.create_pool( + self.lb_id, + constants.PROTOCOL_HTTP, + constants.LB_ALGORITHM_ROUND_ROBIN, + listener_id=self.listener_id).get(self.root_tag) + self.set_lb_status(lb_id=self.lb_id) + # Set status to ACTIVE/ONLINE because set_lb_status did it in the db + api_pool['provisioning_status'] = constants.ACTIVE + api_pool['operating_status'] = constants.ONLINE + api_pool.pop('updated_at') + + response = self.get(self.POOL_PATH.format( + pool_id=api_pool.get('id'))).json.get(self.root_tag) + response.pop('updated_at') + self.assertEqual(api_pool, response) + + tls_uuid = uuidutils.generate_uuid() + self.pool_repo.update(db_api.get_session(), + api_pool.get('id'), + tls_certificate_id=tls_uuid) + + self.delete(self.POOL_PATH.format(pool_id=api_pool.get('id'))) + self.assert_correct_status( + lb_id=self.lb_id, listener_id=self.listener_id, + pool_id=api_pool.get('id'), + lb_prov_status=constants.PENDING_UPDATE, + listener_prov_status=constants.PENDING_UPDATE, + pool_prov_status=constants.PENDING_DELETE) + def test_delete_authorize(self): api_pool = self.create_pool( self.lb_id, diff --git a/octavia/tests/unit/api/drivers/test_utils.py b/octavia/tests/unit/api/drivers/test_utils.py index a9c78a12e7..87265ac8f8 100644 --- a/octavia/tests/unit/api/drivers/test_utils.py +++ b/octavia/tests/unit/api/drivers/test_utils.py @@ -215,6 +215,22 @@ class TestUtils(base.TestCase): ref_listeners = copy.deepcopy(self.sample_data.provider_listeners) self.assertEqual(ref_listeners, provider_listeners) + @mock.patch('oslo_context.context.RequestContext', return_value=None) + def test_get_secret_data_errors(self, mock_context): + mock_cert_mngr = mock.MagicMock() + + mock_cert_mngr.get_secret.side_effect = [Exception, Exception] + + # Test for_delete == False path + self.assertRaises(exceptions.CertificateRetrievalException, + utils._get_secret_data, mock_cert_mngr, + 'fake_project_id', 1) + + # Test for_delete == True path + self.assertIsNone( + utils._get_secret_data(mock_cert_mngr, 'fake_project_id', + 2, for_delete=True)) + @mock.patch('octavia.api.drivers.utils._get_secret_data') @mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data') def test_listener_dict_to_provider_dict(self, mock_load_cert, mock_secret): @@ -241,6 +257,41 @@ class TestUtils(base.TestCase): self.sample_data.test_listener1_dict) self.assertEqual(expect_prov, provider_listener) + @mock.patch('octavia.api.drivers.utils._get_secret_data') + @mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data') + def test_listener_dict_to_provider_dict_load_cert_error( + self, mock_load_cert, mock_secret): + mock_secret.side_effect = ['ca cert', 'X509 CRL FILE', + 'X509 POOL CA CERT FILE', + 'X509 POOL CRL FILE'] + mock_load_cert.side_effect = [exceptions.OctaviaException, + Exception] + + # Test load_cert exception for_delete == False path + self.assertRaises(exceptions.OctaviaException, + utils.listener_dict_to_provider_dict, + self.sample_data.test_listener1_dict) + + @mock.patch('octavia.api.drivers.utils._get_secret_data') + @mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data') + def test_listener_dict_to_provider_dict_load_cert_error_for_delete( + self, mock_load_cert, mock_secret): + mock_secret.side_effect = ['ca cert', 'X509 CRL FILE', + 'X509 POOL CA CERT FILE', + 'X509 POOL CRL FILE'] + mock_load_cert.side_effect = [Exception] + + # Test load_cert exception for_delete == True path + expect_prov = copy.deepcopy(self.sample_data.provider_listener1_dict) + expect_pool_prov = copy.deepcopy(self.sample_data.provider_pool1_dict) + del expect_pool_prov['tls_container_data'] + expect_prov['default_pool'] = expect_pool_prov + del expect_prov['default_tls_container_data'] + del expect_prov['sni_container_data'] + provider_listener = utils.listener_dict_to_provider_dict( + self.sample_data.test_listener1_dict, for_delete=True) + self.assertEqual(expect_prov, provider_listener) + @mock.patch('octavia.api.drivers.utils._get_secret_data') @mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data') def test_listener_dict_to_provider_dict_SNI(self, mock_load_cert, @@ -315,6 +366,37 @@ class TestUtils(base.TestCase): provider_pool_dict.pop('crl_container_ref') self.assertEqual(expect_prov, provider_pool_dict) + @mock.patch('octavia.api.drivers.utils._get_secret_data') + @mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data') + def test_pool_dict_to_provider_dict_load_cert_error( + self, mock_load_cert, mock_secret): + + mock_load_cert.side_effect = [exceptions.OctaviaException, + Exception] + + # Test load_cert exception for_delete == False path + self.assertRaises(exceptions.OctaviaException, + utils.pool_dict_to_provider_dict, + self.sample_data.test_pool1_dict) + + @mock.patch('octavia.api.drivers.utils._get_secret_data') + @mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data') + def test_pool_dict_to_provider_dict_load_cert_error_for_delete( + self, mock_load_cert, mock_secret): + + mock_load_cert.side_effect = [Exception] + + # Test load_cert exception for_delete == True path + mock_secret.side_effect = ['X509 POOL CA CERT FILE', + 'X509 POOL CRL FILE'] + expect_prov = copy.deepcopy(self.sample_data.provider_pool1_dict) + expect_prov.pop('crl_container_ref') + del expect_prov['tls_container_data'] + provider_pool_dict = utils.pool_dict_to_provider_dict( + self.sample_data.test_pool1_dict, for_delete=True) + provider_pool_dict.pop('crl_container_ref') + self.assertEqual(expect_prov, provider_pool_dict) + def test_db_HM_to_provider_HM(self): provider_hm = utils.db_HM_to_provider_HM(self.sample_data.db_hm1) self.assertEqual(self.sample_data.provider_hm1, provider_hm) diff --git a/octavia/tests/unit/common/tls_utils/test_cert_parser.py b/octavia/tests/unit/common/tls_utils/test_cert_parser.py index 1d5dfe7a74..f70d968be6 100644 --- a/octavia/tests/unit/common/tls_utils/test_cert_parser.py +++ b/octavia/tests/unit/common/tls_utils/test_cert_parser.py @@ -174,6 +174,29 @@ class TestTLSParseUtils(base.TestCase): self.assertEqual(ref_empty_dict, result) mock_oslo.assert_called() + def test_load_certificates_get_cert_errors(self): + mock_cert_mngr = mock.MagicMock() + mock_obj = mock.MagicMock() + mock_sni_container = mock.MagicMock() + mock_sni_container.tls_container_id = 2 + + mock_cert_mngr.get_cert.side_effect = [Exception, Exception] + + # Test tls_certificate_id error + mock_obj.tls_certificate_id = 1 + + self.assertRaises(exceptions.CertificateRetrievalException, + cert_parser.load_certificates_data, + mock_cert_mngr, mock_obj) + + # Test sni_containers error + mock_obj.tls_certificate_id = None + mock_obj.sni_containers = [mock_sni_container] + + self.assertRaises(exceptions.CertificateRetrievalException, + cert_parser.load_certificates_data, + mock_cert_mngr, mock_obj) + @mock.patch('octavia.certificates.common.cert.Cert') def test_map_cert_tls_container(self, cert_mock): tls = data_models.TLSContainer(