diff --git a/octavia/amphorae/drivers/driver_base.py b/octavia/amphorae/drivers/driver_base.py index 0309a5694f..2baf852481 100644 --- a/octavia/amphorae/drivers/driver_base.py +++ b/octavia/amphorae/drivers/driver_base.py @@ -185,6 +185,19 @@ class AmphoraLoadBalancerDriver(object): """ pass + def upload_cert_amp(self, amphora, pem_file): + """upload cert info to amphora + + + :param amphora: amphora object, needs id and network ip(s) + :type amphora: object + :param pem_file: a certificate file + :type pem_file: file object + + upload cert file to amphora for Controller Communication + """ + pass + @six.add_metaclass(abc.ABCMeta) class HealthMixin(object): diff --git a/octavia/amphorae/drivers/haproxy/rest_api_driver.py b/octavia/amphorae/drivers/haproxy/rest_api_driver.py index f0008385a9..4c8a78fc51 100644 --- a/octavia/amphorae/drivers/haproxy/rest_api_driver.py +++ b/octavia/amphorae/drivers/haproxy/rest_api_driver.py @@ -79,6 +79,12 @@ class HaproxyAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver): else: self.client.start_listener(amp, listener.id) + def upload_cert_amp(self, amp, pem): + LOG.debug("Amphora %s updating cert in REST driver " + "with amphora id %s,", + self.__class__.__name__, amp.id) + self.client.update_cert_for_rotation(amp, pem) + def _apply(self, func, listener=None, *args): for amp in listener.load_balancer.amphorae: if amp.status != constants.DELETED: @@ -250,6 +256,10 @@ class AmphoraAPIClient(object): data=pem_file) return exc.check_exception(r) + def update_cert_for_rotation(self, amp, pem_file): + r = self.put(amp, 'certificate', data=pem_file) + return exc.check_exception(r) + def get_cert_md5sum(self, amp, listener_id, pem_filename): r = self.get(amp, 'listeners/{listener_id}/certificates/{filename}'.format( diff --git a/octavia/amphorae/drivers/noop_driver/driver.py b/octavia/amphorae/drivers/noop_driver/driver.py index 9824b98c74..77d65e1919 100644 --- a/octavia/amphorae/drivers/noop_driver/driver.py +++ b/octavia/amphorae/drivers/noop_driver/driver.py @@ -92,6 +92,12 @@ class NoopManager(object): self.amphoraconfig[(load_balancer.id, id(amphorae_network_config))] = ( load_balancer.id, amphorae_network_config, 'post_vip_plug') + def upload_cert_amp(self, amphora, pem_file): + LOG.debug("Amphora %s no-op, upload cert amphora %s,with pem fle %s", + self.__class__.__name__, amphora.id, pem_file) + self.amphoraconfig[amphora.id, pem_file] = (amphora.id, pem_file, + 'update_amp_cert_file') + class NoopAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver): def __init__(self): @@ -133,3 +139,7 @@ class NoopAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver): def post_vip_plug(self, load_balancer, amphorae_network_config): self.driver.post_vip_plug(load_balancer, amphorae_network_config) + + def upload_cert_amp(self, amphora, pem_file): + + self.driver.upload_cert_amp(amphora, pem_file) diff --git a/octavia/cmd/house_keeping.py b/octavia/cmd/house_keeping.py index 8039efc940..3e4e15a92e 100755 --- a/octavia/cmd/house_keeping.py +++ b/octavia/cmd/house_keeping.py @@ -34,6 +34,7 @@ CONF.import_group('house_keeping', 'octavia.common.config') spare_amp_thread_event = threading.Event() db_cleanup_thread_event = threading.Event() +cert_rotate_thread_event = threading.Event() def spare_amphora_check(): @@ -65,6 +66,18 @@ def db_cleanup(): time.sleep(interval) +def cert_rotation(): + """Perform certificate rotation.""" + interval = CONF.house_keeping.cert_interval + LOG.info( + _LI("Expiring certificate check interval is set to %d sec") % interval) + cert_rotate = house_keeping.CertRotation() + while cert_rotate_thread_event.is_set(): + LOG.debug("Initiating certification rotation ...") + cert_rotate.rotate() + time.sleep(interval) + + def main(): service.prepare_service(sys.argv) @@ -85,6 +98,12 @@ def main(): db_cleanup_thread_event.set() db_cleanup_thread.start() + # Thread to perform certificate rotation + cert_rotate_thread = threading.Thread(target=cert_rotation) + cert_rotate_thread.daemon = True + cert_rotate_thread_event.set() + cert_rotate_thread.start() + # Try-Exception block should be at the end to gracefully exit threads try: while True: @@ -93,6 +112,8 @@ def main(): LOG.info(_LI("Attempting to gracefully terminate House-Keeping")) spare_amp_thread_event.clear() db_cleanup_thread_event.clear() + cert_rotate_thread_event.clear() spare_amp_thread.join() db_cleanup_thread.join() + cert_rotate_thread.join() LOG.info(_LI("House-Keeping process terminated")) diff --git a/octavia/common/config.py b/octavia/common/config.py index f8d584b52c..f2755755f7 100644 --- a/octavia/common/config.py +++ b/octavia/common/config.py @@ -274,7 +274,18 @@ house_keeping_opts = [ help=_('DB cleanup interval in seconds')), cfg.IntOpt('amphora_expiry_age', default=604800, - help=_('Amphora expiry age in seconds')) + help=_('Amphora expiry age in seconds')), + cfg.IntOpt('cert_interval', + default=3600, + help=_('Certificate check interval in seconds')), + # 14 days for cert expiry buffer + cfg.IntOpt('cert_expiry_buffer', + default=1209600, + help=_('Seconds until certificate expiration')), + cfg.IntOpt('cert_rotate_threads', + default=10, + help=_('Number of threads performing amphora certificate' + ' rotation')) ] diff --git a/octavia/common/constants.py b/octavia/common/constants.py index b57abbf557..a902392a50 100644 --- a/octavia/common/constants.py +++ b/octavia/common/constants.py @@ -103,6 +103,7 @@ ADDED_PORTS = 'added_ports' PORTS = 'ports' MEMBER_PORTS = 'member_ports' +CERT_ROTATE_AMPHORA_FLOW = 'octavia-cert-rotate-amphora-flow' CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow' CREATE_AMPHORA_FOR_LB_FLOW = 'octavia-create-amp-for-lb-flow' CREATE_HEALTH_MONITOR_FLOW = 'octavia-create-health-monitor-flow' diff --git a/octavia/common/data_models.py b/octavia/common/data_models.py index 2d455cebda..6b5318ec1d 100644 --- a/octavia/common/data_models.py +++ b/octavia/common/data_models.py @@ -225,7 +225,8 @@ class Amphora(BaseDataModel): def __init__(self, id=None, load_balancer_id=None, compute_id=None, status=None, lb_network_ip=None, vrrp_ip=None, ha_ip=None, vrrp_port_id=None, ha_port_id=None, - load_balancer=None, role=None): + load_balancer=None, role=None, cert_expiration=None, + cert_busy=False): self.id = id self.load_balancer_id = load_balancer_id self.compute_id = compute_id @@ -237,6 +238,8 @@ class Amphora(BaseDataModel): self.ha_port_id = ha_port_id self.role = role self.load_balancer = load_balancer + self.cert_expiration = cert_expiration + self.cert_busy = cert_busy class AmphoraHealth(BaseDataModel): diff --git a/octavia/common/tls_utils/cert_parser.py b/octavia/common/tls_utils/cert_parser.py index a89d74629f..0f55bfad50 100644 --- a/octavia/common/tls_utils/cert_parser.py +++ b/octavia/common/tls_utils/cert_parser.py @@ -107,6 +107,7 @@ def get_host_names(certificate): """ try: certificate = certificate.encode('ascii') + cert = x509.load_pem_x509_certificate(certificate, backends.default_backend()) cn = cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)[0] @@ -130,6 +131,23 @@ def get_host_names(certificate): raise exceptions.UnreadableCert +def get_cert_expiration(certificate_pem): + """Extract the expiration date from the Pem encoded X509 certificate + + :param certificate_pem: Certificate in PEM format + :returns: Expiration date of certificate_pem + """ + try: + certificate = certificate_pem.encode('ascii') + + cert = x509.load_pem_x509_certificate(certificate, + backends.default_backend()) + return cert.not_valid_after + except Exception as e: + LOG.exception(e) + raise exceptions.UnreadableCert + + def _get_x509_from_pem_bytes(certificate_pem): """Parse X509 data from a PEM encoded certificate diff --git a/octavia/controller/housekeeping/house_keeping.py b/octavia/controller/housekeeping/house_keeping.py index ccb826df90..668afa3fd9 100644 --- a/octavia/controller/housekeeping/house_keeping.py +++ b/octavia/controller/housekeeping/house_keeping.py @@ -14,6 +14,7 @@ import datetime +from concurrent import futures from oslo_config import cfg from oslo_log import log as logging @@ -78,4 +79,31 @@ class DatabaseCleanup(object): exp_age): LOG.info(_LI('Attempting to delete Amphora id : %s'), amp.id) self.amp_repo.delete(session, id=amp.id) - LOG.info(_LI('Deleted Amphora id : %s'), amp.id) + LOG.info(_LI('Deleted Amphora id : %s') % amp.id) + + +class CertRotation(object): + def __init__(self): + self.threads = CONF.house_keeping.cert_rotate_threads + self.cw = cw.ControllerWorker() + + def rotate(self): + """Check the amphora db table for expiring auth certs.""" + amp_repo = repo.AmphoraRepository() + + with futures.ThreadPoolExecutor(max_workers=self.threads) as executor: + try: + session = db_api.get_session() + rotation_count = 0 + while True: + amp = amp_repo.get_cert_expiring_amphora(session) + if not amp: + break + rotation_count += 1 + LOG.debug("Cert expired amphora's id is: %s", amp.id) + executor.submit(self.cw.amphora_cert_rotation, amp.id) + if rotation_count > 0: + LOG.info(_LI("Rotated certificates for %s ampohra") % + rotation_count) + finally: + executor.shutdown(wait=True) diff --git a/octavia/controller/worker/controller_worker.py b/octavia/controller/worker/controller_worker.py index 52f80edd59..1c3f1e4a11 100644 --- a/octavia/controller/worker/controller_worker.py +++ b/octavia/controller/worker/controller_worker.py @@ -26,6 +26,7 @@ from octavia.controller.worker.flows import member_flows from octavia.controller.worker.flows import pool_flows from octavia.db import api as db_apis from octavia.db import repositories as repo +from octavia.i18n import _LI from taskflow.listeners import logging as tf_logging @@ -486,3 +487,25 @@ class ControllerWorker(base_taskflow.BaseTaskFlowEngine): with tf_logging.DynamicLoggingListener(failover_amphora_tf, log=LOG): failover_amphora_tf.run() + + def amphora_cert_rotation(self, amphora_id): + """Perform cert rotation for an amphora. + + :param amphora_id: ID for amphora to rotate + :returns: None + :raises AmphoraNotFound: The referenced amphora was not found + """ + + amp = self._amphora_repo.get(db_apis.get_session(), + id=amphora_id) + LOG.info(_LI("Start amphora cert rotation, amphora's id is: %s") + % amp.id) + + certrotation_amphora_tf = self._taskflow_load( + self._amphora_flows.cert_rotate_amphora_flow(), + store={constants.AMPHORA: amp, + constants.AMPHORA_ID: amp.id}) + + with tf_logging.DynamicLoggingListener(certrotation_amphora_tf, + log=LOG): + certrotation_amphora_tf.run() diff --git a/octavia/controller/worker/flows/amphora_flows.py b/octavia/controller/worker/flows/amphora_flows.py index 9c1277a078..d3636476a4 100644 --- a/octavia/controller/worker/flows/amphora_flows.py +++ b/octavia/controller/worker/flows/amphora_flows.py @@ -25,7 +25,6 @@ from octavia.controller.worker.tasks import compute_tasks from octavia.controller.worker.tasks import database_tasks from octavia.controller.worker.tasks import network_tasks - CONF = cfg.CONF CONF.import_group('controller_worker', 'octavia.common.config') @@ -52,6 +51,11 @@ class AmphoraFlows(object): if self.REST_AMPHORA_DRIVER: create_amphora_flow.add(cert_task.GenerateServerPEMTask( provides=constants.SERVER_PEM)) + + create_amphora_flow.add( + database_tasks.UpdateAmphoraDBCertExpiration( + requires=(constants.AMPHORA_ID, constants.SERVER_PEM))) + create_amphora_flow.add(compute_tasks.CertComputeCreate( requires=(constants.AMPHORA_ID, constants.SERVER_PEM), provides=constants.COMPUTE_ID)) @@ -98,6 +102,11 @@ class AmphoraFlows(object): if self.REST_AMPHORA_DRIVER: create_amp_for_lb_flow.add(cert_task.GenerateServerPEMTask( provides=constants.SERVER_PEM)) + + create_amp_for_lb_flow.add( + database_tasks.UpdateAmphoraDBCertExpiration( + requires=(constants.AMPHORA_ID, constants.SERVER_PEM))) + create_amp_for_lb_flow.add(compute_tasks.CertComputeCreate( requires=(constants.AMPHORA_ID, constants.SERVER_PEM), provides=constants.COMPUTE_ID)) @@ -249,3 +258,35 @@ class AmphoraFlows(object): requires=(constants.AMPHORA, constants.LOADBALANCER_ID))) return failover_amphora_flow + + def cert_rotate_amphora_flow(self): + """Implement rotation for amphora's cert. + + 1. Create a new certificate + 2. Upload the cert to amphora + 3. update the newly created certificate info to amphora + 4. update the cert_busy flag to be false after rotation + + :returns: The flow for updating an amphora + """ + rotated_amphora_flow = linear_flow.Flow( + constants.CERT_ROTATE_AMPHORA_FLOW) + + # create a new certificate, the returned value is the newly created + # certificate + rotated_amphora_flow.add(cert_task.GenerateServerPEMTask( + provides=constants.SERVER_PEM)) + + # update it in amphora task + rotated_amphora_flow.add(amphora_driver_tasks.AmphoraCertUpload( + requires=(constants.AMPHORA, constants.SERVER_PEM))) + + # update the newly created certificate info to amphora + rotated_amphora_flow.add(database_tasks.UpdateAmphoraDBCertExpiration( + requires=(constants.AMPHORA_ID, constants.SERVER_PEM))) + + # update the cert_busy flag to be false after rotation + rotated_amphora_flow.add(database_tasks.UpdateAmphoraCertBusyToFalse( + requires=constants.AMPHORA)) + + return rotated_amphora_flow diff --git a/octavia/controller/worker/tasks/amphora_driver_tasks.py b/octavia/controller/worker/tasks/amphora_driver_tasks.py index 27cbe3674c..625c96b84b 100644 --- a/octavia/controller/worker/tasks/amphora_driver_tasks.py +++ b/octavia/controller/worker/tasks/amphora_driver_tasks.py @@ -239,3 +239,12 @@ class AmphoraPostVIPPlug(BaseAmphoraTask): self.loadbalancer_repo.update(db_apis.get_session(), id=loadbalancer.id, status=constants.ERROR) + + +class AmphoraCertUpload(BaseAmphoraTask): + """Upload a certificate to the amphora.""" + + def execute(self, amphora, server_pem): + """Execute cert_update_amphora routine.""" + LOG.debug("Upload cert in amphora REST driver") + self.amphora_driver.upload_cert_amp(amphora, server_pem) diff --git a/octavia/controller/worker/tasks/database_tasks.py b/octavia/controller/worker/tasks/database_tasks.py index 7a059aa92f..79a48f867d 100644 --- a/octavia/controller/worker/tasks/database_tasks.py +++ b/octavia/controller/worker/tasks/database_tasks.py @@ -22,10 +22,12 @@ from taskflow.types import failure from octavia.common import constants from octavia.common import data_models from octavia.common import exceptions +import octavia.common.tls_utils.cert_parser as cert_parser from octavia.db import api as db_apis from octavia.db import repositories as repo from octavia.i18n import _LI, _LW + LOG = logging.getLogger(__name__) @@ -54,7 +56,8 @@ class CreateAmphoraInDB(BaseDatabaseTask): amphora = self.amphora_repo.create(db_apis.get_session(), id=uuidutils.generate_uuid(), - status=constants.PENDING_CREATE) + status=constants.PENDING_CREATE, + cert_busy=False) LOG.info(_LI("Created Amphora in DB with id %s"), amphora.id) return amphora.id @@ -474,6 +477,24 @@ class UpdateAmphoraInfo(BaseDatabaseTask): return self.amphora_repo.get(db_apis.get_session(), id=amphora_id) +class UpdateAmphoraDBCertExpiration(BaseDatabaseTask): + """Update the amphora expiration date with new cert file date.""" + + def execute(self, amphora_id, server_pem): + LOG.debug("Update DB cert expiry date of amphora id: %s", amphora_id) + cert_expiration = cert_parser.get_cert_expiration(server_pem) + LOG.debug("Certificate expiration date is %s ", cert_expiration) + self.amphora_repo.update(db_apis.get_session(), amphora_id, + cert_expiration=cert_expiration) + + +class UpdateAmphoraCertBusyToFalse(BaseDatabaseTask): + """Update the amphora cert_busy flag to be false.""" + def execute(self, amphora): + self.amphora_repo.update(db_apis.get_session(), amphora.id, + cert_busy=False) + + class MarkLBActiveInDB(BaseDatabaseTask): """Mark the load balancer active in the DB. diff --git a/octavia/db/migration/alembic_migrations/versions/5a3ee5472c31_add_cert_expiration__infor_in_amphora_table.py b/octavia/db/migration/alembic_migrations/versions/5a3ee5472c31_add_cert_expiration__infor_in_amphora_table.py new file mode 100644 index 0000000000..046021594c --- /dev/null +++ b/octavia/db/migration/alembic_migrations/versions/5a3ee5472c31_add_cert_expiration__infor_in_amphora_table.py @@ -0,0 +1,37 @@ +# Copyright 2015 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""add cert expiration infor in amphora table + +Revision ID: 5a3ee5472c31 +Revises: 3b199c848b96 +Create Date: 2015-08-20 10:15:19.561066 + +""" + +# revision identifiers, used by Alembic. +revision = '5a3ee5472c31' +down_revision = '3b199c848b96' + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.add_column(u'amphora', + sa.Column(u'cert_expiration', sa.DateTime(timezone=True), + nullable=True) + ) + + op.add_column(u'amphora', sa.Column(u'cert_busy', sa.Boolean(), + nullable=False, default=False)) diff --git a/octavia/db/models.py b/octavia/db/models.py index 1a3d9c9c12..8e80820177 100644 --- a/octavia/db/models.py +++ b/octavia/db/models.py @@ -340,6 +340,10 @@ class Amphora(base_models.BASE): ha_ip = sa.Column(sa.String(64), nullable=True) vrrp_port_id = sa.Column(sa.String(36), nullable=True) ha_port_id = sa.Column(sa.String(36), nullable=True) + cert_expiration = sa.Column(sa.DateTime(timezone=True), default=None, + nullable=True) + cert_busy = sa.Column(sa.Boolean(), default=False, nullable=False) + role = sa.Column( sa.String(36), sa.ForeignKey("amphora_roles.name", name="fk_amphora_roles_name"), diff --git a/octavia/db/repositories.py b/octavia/db/repositories.py index e05ddb58bd..7068db8737 100644 --- a/octavia/db/repositories.py +++ b/octavia/db/repositories.py @@ -27,9 +27,11 @@ from octavia.common import constants from octavia.common import exceptions from octavia.db import models + CONF = cfg.CONF LOG = logging.getLogger(__name__) CONF.import_group('health_manager', 'octavia.common.config') +CONF.import_group('house_keeping', 'octavia.common.config') class BaseRepository(object): @@ -402,6 +404,30 @@ class AmphoraRepository(BaseRepository): return count + def get_cert_expiring_amphora(self, session): + """Retrieves an amphora whose cert is close to expiring.. + + :param session: A Sql Alchemy database session. + :returns: one amphora with expiring certificate + """ + # get amphorae with certs that will expire within the + # configured buffer period, so we can rotate their certs ahead of time + expired_seconds = CONF.house_keeping.cert_expiry_buffer + expired_date = datetime.datetime.utcnow() + datetime.timedelta( + seconds=expired_seconds) + + with session.begin(subtransactions=True): + amp = session.query(self.model_class).with_for_update().filter_by( + cert_busy=False).filter( + self.model_class.cert_expiration < expired_date).first() + + if amp is None: + return None + + amp.cert_busy = True + + return amp.to_data_model() + class SNIRepository(BaseRepository): model_class = models.SNI diff --git a/octavia/tests/functional/db/test_models.py b/octavia/tests/functional/db/test_models.py index c5f05a1763..a7ea153b92 100644 --- a/octavia/tests/functional/db/test_models.py +++ b/octavia/tests/functional/db/test_models.py @@ -123,7 +123,9 @@ class ModelTestMixin(object): 'ha_ip': self.FAKE_IP, 'vrrp_port_id': self.FAKE_UUID_1, 'ha_port_id': self.FAKE_UUID_2, - 'lb_network_ip': self.FAKE_IP} + 'lb_network_ip': self.FAKE_IP, + 'cert_expiration': datetime.datetime.utcnow(), + 'cert_busy': False} kwargs.update(overrides) return self._insert(session, models.Amphora, kwargs) diff --git a/octavia/tests/functional/db/test_repositories.py b/octavia/tests/functional/db/test_repositories.py index 88b3cf8aa7..654c9cb22f 100644 --- a/octavia/tests/functional/db/test_repositories.py +++ b/octavia/tests/functional/db/test_repositories.py @@ -15,6 +15,8 @@ import datetime import random +from oslo_config import cfg +from oslo_log import log as logging from oslo_utils import uuidutils from octavia.common import constants @@ -22,6 +24,10 @@ from octavia.common import data_models as models from octavia.db import repositories as repo from octavia.tests.functional.db import base +LOG = logging.getLogger(__name__) +CONF = cfg.CONF +CONF.import_group('house_keeping', 'octavia.common.config') + class BaseRepositoryTest(base.OctaviaDBTestBase): @@ -1108,13 +1114,16 @@ class AmphoraRepositoryTest(BaseRepositoryTest): operating_status=constants.ONLINE, enabled=True) def create_amphora(self, amphora_id): + expiration = datetime.datetime.utcnow() amphora = self.amphora_repo.create(self.session, id=amphora_id, compute_id=self.FAKE_UUID_3, status=constants.ACTIVE, lb_network_ip=self.FAKE_IP, vrrp_ip=self.FAKE_IP, ha_ip=self.FAKE_IP, - role=constants.ROLE_MASTER) + role=constants.ROLE_MASTER, + cert_expiration=expiration, + cert_busy=False) return amphora def test_get(self): @@ -1208,6 +1217,37 @@ class AmphoraRepositoryTest(BaseRepositoryTest): count = self.amphora_repo.get_spare_amphora_count(self.session) self.assertEqual(2, count) + def test_get_none_cert_expired_amphora(self): + # test with no expired amphora + amp = self.amphora_repo.get_cert_expiring_amphora(self.session) + self.assertIsNone(amp) + + amphora = self.create_amphora(self.FAKE_UUID_1) + + expired_interval = CONF.house_keeping.cert_expiry_buffer + expiration = datetime.datetime.utcnow() + datetime.timedelta( + seconds=2 * expired_interval) + + self.amphora_repo.update(self.session, amphora.id, + cert_expiration=expiration) + amp = self.amphora_repo.get_cert_expiring_amphora(self.session) + self.assertIsNone(amp) + + def test_get_cert_expired_amphora(self): + # test with expired amphora + amphora2 = self.create_amphora(self.FAKE_UUID_2) + + expiration = datetime.datetime.utcnow() + datetime.timedelta( + seconds=1) + self.amphora_repo.update(self.session, amphora2.id, + cert_expiration=expiration) + + cert_expired_amphora = self.amphora_repo.get_cert_expiring_amphora( + self.session) + + self.assertEqual(cert_expired_amphora.cert_expiration, expiration) + self.assertEqual(cert_expired_amphora.id, amphora2.id) + class AmphoraHealthRepositoryTest(BaseRepositoryTest): def setUp(self): diff --git a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py index b601d56f0c..b0c452d90b 100644 --- a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py +++ b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py @@ -16,6 +16,7 @@ import mock from oslo_utils import uuidutils import requests_mock +import six from octavia.amphorae.drivers.haproxy import exceptions as exc from octavia.amphorae.drivers.haproxy import rest_api_driver as driver @@ -34,10 +35,10 @@ FAKE_SUBNET_INFO = {'subnet_cidr': FAKE_CIDR, FAKE_UUID_1 = uuidutils.generate_uuid() -class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase): +class TestHaproxyAmphoraLoadBalancerDriverTest(base.TestCase): def setUp(self): - super(HaproxyAmphoraLoadBalancerDriverTest, self).setUp() + super(TestHaproxyAmphoraLoadBalancerDriverTest, self).setUp() self.driver = driver.HaproxyAmphoraLoadBalancerDriver() self.driver.cert_manager = mock.MagicMock() @@ -110,6 +111,11 @@ class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase): self.driver.client.start_listener.assert_called_once_with( self.amp, self.sl.id) + def test_upload_cert_amp(self): + self.driver.upload_cert_amp(self.amp, six.b('test')) + self.driver.client.update_cert_for_rotation.assert_called_once_with( + self.amp, six.b('test')) + def test_stop(self): # Execute driver method self.driver.stop(self.sl, self.sv) @@ -152,10 +158,10 @@ class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase): self.amp, dict(mac_address='123')) -class AmphoraAPIClientTest(base.TestCase): +class TestAmphoraAPIClientTest(base.TestCase): def setUp(self): - super(AmphoraAPIClientTest, self).setUp() + super(TestAmphoraAPIClientTest, self).setUp() self.driver = driver.AmphoraAPIClient() self.base_url = "https://127.0.0.1:9443/0.5" self.amp = models.Amphora(lb_network_ip='127.0.0.1', compute_id='123') @@ -479,6 +485,41 @@ class AmphoraAPIClientTest(base.TestCase): self.amp, FAKE_UUID_1, FAKE_PEM_FILENAME, "some_file") + @requests_mock.mock() + def test_update_cert_for_rotation(self, m): + m.put("{base}/certificate".format(base=self.base_url)) + resp_body = self.driver.update_cert_for_rotation(self.amp, + "some_file") + self.assertEqual(200, resp_body.status_code) + + @requests_mock.mock() + def test_update_invalid_cert_for_rotation(self, m): + m.put("{base}/certificate".format(base=self.base_url), status_code=403) + self.assertRaises(exc.InvalidRequest, + self.driver.update_cert_for_rotation, self.amp, + "some_file") + + @requests_mock.mock() + def test_update_cert_for_rotation_unauthorized(self, m): + m.put("{base}/certificate".format(base=self.base_url), status_code=401) + self.assertRaises(exc.Unauthorized, + self.driver.update_cert_for_rotation, self.amp, + "some_file") + + @requests_mock.mock() + def test_update_cert_for_rotation_error(self, m): + m.put("{base}/certificate".format(base=self.base_url), status_code=500) + self.assertRaises(exc.InternalServerError, + self.driver.update_cert_for_rotation, self.amp, + "some_file") + + @requests_mock.mock() + def test_update_cert_for_rotation_unavailable(self, m): + m.put("{base}/certificate".format(base=self.base_url), status_code=503) + self.assertRaises(exc.ServiceUnavailable, + self.driver.update_cert_for_rotation, self.amp, + "some_file") + @requests_mock.mock() def test_get_cert_5sum(self, m): md5sum = {"md5sum": "some_real_sum"} diff --git a/octavia/tests/unit/amphorae/drivers/test_noop_amphoraloadbalancer_driver.py b/octavia/tests/unit/amphorae/drivers/test_noop_amphoraloadbalancer_driver.py index ec34999f80..758cfd4f39 100644 --- a/octavia/tests/unit/amphorae/drivers/test_noop_amphoraloadbalancer_driver.py +++ b/octavia/tests/unit/amphorae/drivers/test_noop_amphoraloadbalancer_driver.py @@ -62,6 +62,7 @@ class TestNoopAmphoraLoadBalancerDriver(base.TestCase): amphora=self.amphora, vip_subnet=network_models.Subnet(id=self.FAKE_UUID_1)) } + self.pem_file = 'test_pem_file' def test_update(self): self.driver.update(self.listener, self.vip) @@ -125,3 +126,10 @@ class TestNoopAmphoraLoadBalancerDriver(base.TestCase): self.load_balancer.id, id(self.amphorae_net_configs) )] self.assertEqual(expected_method_and_args, actual_method_and_args) + + def test_upload_cert_amp(self): + self.driver.upload_cert_amp(self.amphora, self.pem_file) + self.assertEqual( + (self.amphora.id, self.pem_file, 'update_amp_cert_file'), + self.driver.driver.amphoraconfig[( + self.amphora.id, self.pem_file)]) diff --git a/octavia/tests/unit/cmd/test_house_keeping.py b/octavia/tests/unit/cmd/test_house_keeping.py index e60130e2da..2ef1f35ab1 100644 --- a/octavia/tests/unit/cmd/test_house_keeping.py +++ b/octavia/tests/unit/cmd/test_house_keeping.py @@ -64,23 +64,79 @@ class TestHouseKeepingCMD(base.TestCase): mock_DatabaseCleanup.assert_called_once_with() self.assertEqual(1, db_cleanup.delete_old_amphorae.call_count) + @mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event') + @mock.patch('octavia.controller.housekeeping.' + 'house_keeping.CertRotation') @mock.patch('time.sleep') + def test_hk_cert_rotation_with_exception(self, sleep_mock, + mock_CertRotation, + cert_rotate_event_mock): + # mock cert_rotate object + cert_rotate_mock = mock.MagicMock() + # mock rotate() + rotate_mock = mock.MagicMock() + + cert_rotate_mock.rotate = rotate_mock + + mock_CertRotation.return_value = cert_rotate_mock + + # mock cert_rotate_thread_event.is_set() in the while loop + cert_rotate_event_mock.is_set = mock.MagicMock() + cert_rotate_event_mock.is_set.side_effect = [True, Exception('break')] + + self.assertRaisesRegexp(Exception, 'break', + house_keeping.cert_rotation) + + mock_CertRotation.assert_called_once_with() + self.assertEqual(1, cert_rotate_mock.rotate.call_count) + + @mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event') + @mock.patch('octavia.controller.housekeeping.' + 'house_keeping.CertRotation') + @mock.patch('time.sleep') + def test_hk_cert_rotation_without_exception(self, sleep_mock, + mock_CertRotation, + cert_rotate_event_mock): + # mock cert_rotate object + cert_rotate_mock = mock.MagicMock() + # mock rotate() + rotate_mock = mock.MagicMock() + + cert_rotate_mock.rotate = rotate_mock + + mock_CertRotation.return_value = cert_rotate_mock + + # mock cert_rotate_thread_event.is_set() in the while loop + cert_rotate_event_mock.is_set = mock.MagicMock() + cert_rotate_event_mock.is_set.side_effect = [True, None] + + self.assertEqual(None, house_keeping.cert_rotation()) + + mock_CertRotation.assert_called_once_with() + self.assertEqual(1, cert_rotate_mock.rotate.call_count) + + @mock.patch('time.sleep') + @mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event') @mock.patch('octavia.cmd.house_keeping.db_cleanup_thread_event') @mock.patch('octavia.cmd.house_keeping.spare_amp_thread_event') @mock.patch('threading.Thread') @mock.patch('octavia.common.service.prepare_service') def test_main(self, mock_service, mock_thread, spare_amp_thread_event_mock, - db_cleanup_thread_event_mock, sleep_time): + db_cleanup_thread_event_mock, + cert_rotate_thread_event_mock, sleep_time): spare_amp_thread_mock = mock.MagicMock() db_cleanup_thread_mock = mock.MagicMock() + cert_rotate_thread_mock = mock.MagicMock() mock_thread.side_effect = [spare_amp_thread_mock, - db_cleanup_thread_mock] + db_cleanup_thread_mock, + cert_rotate_thread_mock] spare_amp_thread_mock.daemon.return_value = True db_cleanup_thread_mock.daemon.return_value = True + cert_rotate_thread_mock.daemon.return_value = True # mock the time.sleep() in the while loop sleep_time.side_effect = [True, Exception('break')] @@ -88,14 +144,18 @@ class TestHouseKeepingCMD(base.TestCase): spare_amp_thread_event_mock.set.assert_called_once_with() db_cleanup_thread_event_mock.set.assert_called_once_with() + cert_rotate_thread_event_mock.set.assert_called_once_with() spare_amp_thread_mock.start.assert_called_once_with() db_cleanup_thread_mock.start.assert_called_once_with() + cert_rotate_thread_mock.start.assert_called_once_with() self.assertTrue(spare_amp_thread_mock.daemon) self.assertTrue(db_cleanup_thread_mock.daemon) + self.assertTrue(cert_rotate_thread_mock.daemon) @mock.patch('time.sleep') + @mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event') @mock.patch('octavia.cmd.house_keeping.db_cleanup_thread_event') @mock.patch('octavia.cmd.house_keeping.spare_amp_thread_event') @mock.patch('threading.Thread') @@ -103,15 +163,19 @@ class TestHouseKeepingCMD(base.TestCase): def test_main_keyboard_interrupt(self, mock_service, mock_thread, spare_amp_thread_event_mock, db_cleanup_thread_event_mock, + cert_rotate_thread_event_mock, sleep_time): spare_amp_thread_mock = mock.MagicMock() db_cleanup_thread_mock = mock.MagicMock() + cert_rotate_thread_mock = mock.MagicMock() mock_thread.side_effect = [spare_amp_thread_mock, - db_cleanup_thread_mock] + db_cleanup_thread_mock, + cert_rotate_thread_mock] spare_amp_thread_mock.daemon.return_value = True db_cleanup_thread_mock.daemon.return_value = True + cert_rotate_thread_mock.daemon.return_value = True # mock the time.sleep() in the while loop sleep_time.side_effect = [True, KeyboardInterrupt] @@ -123,11 +187,17 @@ class TestHouseKeepingCMD(base.TestCase): db_cleanup_thread_event_mock.set.assert_called_once_with() db_cleanup_thread_event_mock.clear.assert_called_once_with() + cert_rotate_thread_event_mock.set.assert_called_once_with() + cert_rotate_thread_event_mock.clear.assert_called_once_with() + spare_amp_thread_mock.start.assert_called_once_with() db_cleanup_thread_mock.start.assert_called_once_with() + cert_rotate_thread_mock.start.assert_called_once_with() self.assertTrue(spare_amp_thread_mock.daemon) self.assertTrue(db_cleanup_thread_mock.daemon) + self.assertTrue(cert_rotate_thread_mock.daemon) spare_amp_thread_mock.join.assert_called_once_with() db_cleanup_thread_mock.join.assert_called_once_with() + cert_rotate_thread_mock.join.assert_called_once_with() diff --git a/octavia/tests/unit/common/tls_utils/test_cert_parser.py b/octavia/tests/unit/common/tls_utils/test_cert_parser.py index 101be6e744..e015b4a290 100644 --- a/octavia/tests/unit/common/tls_utils/test_cert_parser.py +++ b/octavia/tests/unit/common/tls_utils/test_cert_parser.py @@ -12,6 +12,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. +import datetime import mock import six @@ -316,3 +317,11 @@ class TestTLSParseUtils(base.TestCase): cp.return_value = {'cn': 'fakeCN'} cn = cert_parser.get_primary_cn(cert) self.assertEqual('fakeCN', cn) + + def test_get_cert_expiration(self): + exp_date = cert_parser.get_cert_expiration(ALT_EXT_CRT) + self.assertEqual(datetime.datetime(2025, 5, 18, 20, 33, 23), exp_date) + + # test the exception + self.assertRaises(exceptions.UnreadableCert, + cert_parser.get_cert_expiration, 'bad-cert-file') \ No newline at end of file diff --git a/octavia/tests/unit/controller/housekeeping/test_house_keeping.py b/octavia/tests/unit/controller/housekeeping/test_house_keeping.py index 6702fc4712..c8370f38ce 100644 --- a/octavia/tests/unit/controller/housekeeping/test_house_keeping.py +++ b/octavia/tests/unit/controller/housekeeping/test_house_keeping.py @@ -24,6 +24,16 @@ import octavia.tests.unit.base as base CONF = cfg.CONF CONF.import_group('house_keeping', 'octavia.common.config') +AMPHORA_ID = uuidutils.generate_uuid() + + +class TestException(Exception): + + def __init__(self, value): + self.value = value + + def __str__(self): + return repr(self.value) class TestSpareCheck(base.TestCase): @@ -122,3 +132,63 @@ class TestDatabaseCleanup(base.TestCase): self.assertTrue(self.amp_repo.get_all.called) self.assertTrue(self.amp_health_repo.check_amphora_expired.called) self.assertFalse(self.amp_repo.delete.called) + + +class TestCertRotation(base.TestCase): + def setUp(self): + super(TestCertRotation, self).setUp() + + @mock.patch('octavia.controller.worker.controller_worker.' + 'ControllerWorker.amphora_cert_rotation') + @mock.patch('octavia.db.repositories.AmphoraRepository.' + 'get_cert_expiring_amphora') + @mock.patch('octavia.db.api.get_session') + def test_cert_rotation_expired_amphora_with_exception(self, session, + cert_exp_amp_mock, + amp_cert_mock + ): + amphora = mock.MagicMock() + amphora.id = AMPHORA_ID + + session.return_value = session + cert_exp_amp_mock.side_effect = [amphora, TestException( + 'break_while')] + + cr = house_keeping.CertRotation() + self.assertRaises(TestException, cr.rotate) + amp_cert_mock.assert_called_once_with(AMPHORA_ID) + + @mock.patch('octavia.controller.worker.controller_worker.' + 'ControllerWorker.amphora_cert_rotation') + @mock.patch('octavia.db.repositories.AmphoraRepository.' + 'get_cert_expiring_amphora') + @mock.patch('octavia.db.api.get_session') + def test_cert_rotation_expired_amphora_without_exception(self, session, + cert_exp_amp_mock, + amp_cert_mock + ): + amphora = mock.MagicMock() + amphora.id = AMPHORA_ID + + session.return_value = session + cert_exp_amp_mock.side_effect = [amphora, None] + + cr = house_keeping.CertRotation() + + self.assertEqual(None, cr.rotate()) + amp_cert_mock.assert_called_once_with(AMPHORA_ID) + + @mock.patch('octavia.controller.worker.controller_worker.' + 'ControllerWorker.amphora_cert_rotation') + @mock.patch('octavia.db.repositories.AmphoraRepository.' + 'get_cert_expiring_amphora') + @mock.patch('octavia.db.api.get_session') + def test_cert_rotation_non_expired_amphora(self, session, + cert_exp_amp_mock, + amp_cert_mock): + + session.return_value = session + cert_exp_amp_mock.return_value = None + cr = house_keeping.CertRotation() + cr.rotate() + self.assertFalse(amp_cert_mock.called) diff --git a/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py b/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py index 36fc3cc070..1d827775bf 100644 --- a/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py +++ b/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py @@ -133,3 +133,17 @@ class TestAmphoraFlows(base.TestCase): self.assertEqual(2, len(amp_flow.requires)) self.assertEqual(12, len(amp_flow.provides)) + + def test_cert_rotate_amphora_flow(self): + cfg.CONF.set_override('amphora_driver', 'amphora_haproxy_rest_driver', + group='controller_worker') + self.AmpFlow = amphora_flows.AmphoraFlows() + + amp_rotate_flow = self.AmpFlow.cert_rotate_amphora_flow() + self.assertIsInstance(amp_rotate_flow, flow.Flow) + + self.assertIn(constants.SERVER_PEM, amp_rotate_flow.provides) + self.assertIn(constants.AMPHORA, amp_rotate_flow.requires) + + self.assertEqual(1, len(amp_rotate_flow.provides)) + self.assertEqual(2, len(amp_rotate_flow.requires)) diff --git a/octavia/tests/unit/controller/worker/tasks/test_amphora_driver_tasks.py b/octavia/tests/unit/controller/worker/tasks/test_amphora_driver_tasks.py index 023d8eff80..5558c2a3d6 100644 --- a/octavia/tests/unit/controller/worker/tasks/test_amphora_driver_tasks.py +++ b/octavia/tests/unit/controller/worker/tasks/test_amphora_driver_tasks.py @@ -270,3 +270,17 @@ class TestAmphoraDriverTasks(base.TestCase): status=constants.ERROR) self.assertIsNone(amp) + + def test_amphora_cert_upload(self, + mock_driver, + mock_generate_uuid, + mock_log, + mock_get_session, + mock_listener_repo_update, + mock_amphora_repo_update): + pem_file_mock = 'test-perm-file' + amphora_cert_upload_mock = amphora_driver_tasks.AmphoraCertUpload() + amphora_cert_upload_mock.execute(_amphora_mock, pem_file_mock) + + mock_driver.upload_cert_amp.assert_called_once_with( + _amphora_mock, pem_file_mock) diff --git a/octavia/tests/unit/controller/worker/tasks/test_database_tasks.py b/octavia/tests/unit/controller/worker/tasks/test_database_tasks.py index 8ee94868b4..91f857644b 100644 --- a/octavia/tests/unit/controller/worker/tasks/test_database_tasks.py +++ b/octavia/tests/unit/controller/worker/tasks/test_database_tasks.py @@ -44,6 +44,43 @@ _pool_mock.id = POOL_ID _listener_mock = mock.MagicMock() _listener_mock.id = LISTENER_ID _tf_failure_mock = mock.Mock(spec=failure.Failure) +_cert_mock = mock.MagicMock() +_pem_mock = """Junk +-----BEGIN CERTIFICATE----- +MIIBhDCCAS6gAwIBAgIGAUo7hO/eMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNVBAMT +BElNRDIwHhcNMTQxMjExMjI0MjU1WhcNMjUxMTIzMjI0MjU1WjAPMQ0wCwYDVQQD +EwRJTUQzMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAKHIPXo2pfD5dpnpVDVz4n43 +zn3VYsjz/mgOZU0WIWjPA97mvulb7mwb4/LB4ijOMzHj9XfwP75GiOFxYFs8O80C +AwEAAaNwMG4wDwYDVR0TAQH/BAUwAwEB/zA8BgNVHSMENTAzgBS6rfnABCO3oHEz +NUUtov2hfXzfVaETpBEwDzENMAsGA1UEAxMESU1EMYIGAUo7hO/DMB0GA1UdDgQW +BBRiLW10LVJiFO/JOLsQFev0ToAcpzANBgkqhkiG9w0BAQsFAANBABtdF+89WuDi +TC0FqCocb7PWdTucaItD9Zn55G8KMd93eXrOE/FQDf1ScC+7j0jIHXjhnyu6k3NV +8el/x5gUHlc= +-----END CERTIFICATE----- +Junk should be ignored by x509 splitter +-----BEGIN CERTIFICATE----- +MIIBhDCCAS6gAwIBAgIGAUo7hO/DMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNVBAMT +BElNRDEwHhcNMTQxMjExMjI0MjU1WhcNMjUxMTIzMjI0MjU1WjAPMQ0wCwYDVQQD +EwRJTUQyMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAJYHqnsisVKTlwVaCSa2wdrv +CeJJzqpEVV0RVgAAF6FXjX2Tioii+HkXMR9zFgpE1w4yD7iu9JDb8yTdNh+NxysC +AwEAAaNwMG4wDwYDVR0TAQH/BAUwAwEB/zA8BgNVHSMENTAzgBQt3KvN8ncGj4/s +if1+wdvIMCoiE6ETpBEwDzENMAsGA1UEAxMEcm9vdIIGAUo7hO+mMB0GA1UdDgQW +BBS6rfnABCO3oHEzNUUtov2hfXzfVTANBgkqhkiG9w0BAQsFAANBAIlJODvtmpok +eoRPOb81MFwPTTGaIqafebVWfBlR0lmW8IwLhsOUdsQqSzoeypS3SJUBpYT1Uu2v +zEDOmgdMsBY= +-----END CERTIFICATE----- +Junk should be thrown out like junk +-----BEGIN CERTIFICATE----- +MIIBfzCCASmgAwIBAgIGAUo7hO+mMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNVBAMT +BHJvb3QwHhcNMTQxMjExMjI0MjU1WhcNMjUxMTIzMjI0MjU1WjAPMQ0wCwYDVQQD +EwRJTUQxMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAI+tSJxr60ogwXFmgqbLMW7K +3fkQnh9sZBi7Qo6AzUnfe/AhXoisib651fOxKXCbp57IgzLTv7O9ygq3I+5fQqsC +AwEAAaNrMGkwDwYDVR0TAQH/BAUwAwEB/zA3BgNVHSMEMDAugBR73ZKSpjbsz9tZ +URkvFwpIO7gB4KETpBEwDzENMAsGA1UEAxMEcm9vdIIBATAdBgNVHQ4EFgQULdyr +zfJ3Bo+P7In9fsHbyDAqIhMwDQYJKoZIhvcNAQELBQADQQBenkZ2k7RgZqgj+dxA +D7BF8MN1oUAOpyYqAjkGddSEuMyNmwtHKZI1dyQ0gBIQdiU9yAG2oTbUIK4msbBV +uJIQ +-----END CERTIFICATE-----""" @mock.patch('octavia.db.repositories.AmphoraRepository.delete') @@ -92,7 +129,8 @@ class TestDatabaseTasks(base.TestCase): repo.AmphoraRepository.create.assert_called_once_with( 'TEST', id=AMP_ID, - status=constants.PENDING_CREATE) + status=constants.PENDING_CREATE, + cert_busy=False) assert(amp_id == _amphora_mock.id) @@ -616,6 +654,41 @@ class TestDatabaseTasks(base.TestCase): LB_ID, provisioning_status=constants.ERROR) + @mock.patch('octavia.common.tls_utils.cert_parser.get_cert_expiration', + return_value=_cert_mock) + def test_update_amphora_db_cert_exp(self, + mock_generate_uuid, + mock_LOG, + mock_get_session, + mock_loadbalancer_repo_update, + mock_listener_repo_update, + mock_amphora_repo_update, + mock_amphora_repo_delete, + mock_get_cert_exp): + + update_amp_cert = database_tasks.UpdateAmphoraDBCertExpiration() + update_amp_cert.execute(_amphora_mock.id, _pem_mock) + + repo.AmphoraRepository.update.assert_called_once_with( + 'TEST', + AMP_ID, + cert_expiration=_cert_mock) + + def test_update_amphora_cert_busy_to_false(self, + mock_generate_uuid, + mock_LOG, + mock_get_session, + mock_loadbalancer_repo_update, + mock_listener_repo_update, + mock_amphora_repo_update, + mock_amphora_repo_delete): + amp_cert_busy_to_F = database_tasks.UpdateAmphoraCertBusyToFalse() + amp_cert_busy_to_F.execute(_amphora_mock) + repo.AmphoraRepository.update.assert_called_once_with( + 'TEST', + AMP_ID, + cert_busy=False) + def test_mark_LB_active_in_db(self, mock_generate_uuid, mock_LOG, diff --git a/octavia/tests/unit/controller/worker/test_controller_worker.py b/octavia/tests/unit/controller/worker/test_controller_worker.py index d0d357c19b..4f023d6097 100644 --- a/octavia/tests/unit/controller/worker/test_controller_worker.py +++ b/octavia/tests/unit/controller/worker/test_controller_worker.py @@ -44,6 +44,7 @@ _member_mock = mock.MagicMock() _pool_mock = mock.MagicMock() _create_map_flow_mock = mock.MagicMock() _amphora_mock.load_balancer_id = LB_ID +_amphora_mock.id = AMP_ID @mock.patch('octavia.db.repositories.AmphoraRepository.get', @@ -681,3 +682,27 @@ class TestControllerWorker(base.TestCase): _amphora_mock.load_balancer_id})) _flow_mock.run.assert_called_once_with() + + @mock.patch('octavia.controller.worker.flows.' + 'amphora_flows.AmphoraFlows.cert_rotate_amphora_flow', + return_value=_flow_mock) + def test_amphora_cert_rotation(self, + mock_get_update_listener_flow, + mock_api_get_session, + mock_dyn_log_listener, + mock_taskflow_load, + mock_pool_repo_get, + mock_member_repo_get, + mock_listener_repo_get, + mock_lb_repo_get, + mock_health_mon_repo_get, + mock_amp_repo_get): + _flow_mock.reset_mock() + cw = controller_worker.ControllerWorker() + cw.amphora_cert_rotation(AMP_ID) + (base_taskflow.BaseTaskFlowEngine._taskflow_load. + assert_called_once_with(_flow_mock, + store={constants.AMPHORA: _amphora_mock, + constants.AMPHORA_ID: + _amphora_mock.id})) + _flow_mock.run.assert_called_once_with()