Add cert tracking and rotating in Housekeeping

The goal of this patch is to add the function that once we detect an
amphora's cert will expire in 2 weeks from utcnow, we will update its
cert with a new one and update its db information at the same time.

In order to achieve this target, I did the following changes:

Add 2 new columns cert_busy and cert_expiration in amphora table
Add methods to get cert expiration date from PEM server_pem and
update db info
Use the new REST agent method to perform cycling
Add process in housekeeping to facilitate rotation
Add unit tests

Change-Id: I28578a3e560ee09ba300788a5423863c893b8638
This commit is contained in:
minwang 2015-08-20 15:02:34 -07:00
parent 900e8a5256
commit 19c7f93882
27 changed files with 657 additions and 15 deletions

View File

@ -185,6 +185,19 @@ class AmphoraLoadBalancerDriver(object):
"""
pass
def upload_cert_amp(self, amphora, pem_file):
"""upload cert info to amphora
:param amphora: amphora object, needs id and network ip(s)
:type amphora: object
:param pem_file: a certificate file
:type pem_file: file object
upload cert file to amphora for Controller Communication
"""
pass
@six.add_metaclass(abc.ABCMeta)
class HealthMixin(object):

View File

@ -79,6 +79,12 @@ class HaproxyAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver):
else:
self.client.start_listener(amp, listener.id)
def upload_cert_amp(self, amp, pem):
LOG.debug("Amphora %s updating cert in REST driver "
"with amphora id %s,",
self.__class__.__name__, amp.id)
self.client.update_cert_for_rotation(amp, pem)
def _apply(self, func, listener=None, *args):
for amp in listener.load_balancer.amphorae:
if amp.status != constants.DELETED:
@ -250,6 +256,10 @@ class AmphoraAPIClient(object):
data=pem_file)
return exc.check_exception(r)
def update_cert_for_rotation(self, amp, pem_file):
r = self.put(amp, 'certificate', data=pem_file)
return exc.check_exception(r)
def get_cert_md5sum(self, amp, listener_id, pem_filename):
r = self.get(amp,
'listeners/{listener_id}/certificates/{filename}'.format(

View File

@ -92,6 +92,12 @@ class NoopManager(object):
self.amphoraconfig[(load_balancer.id, id(amphorae_network_config))] = (
load_balancer.id, amphorae_network_config, 'post_vip_plug')
def upload_cert_amp(self, amphora, pem_file):
LOG.debug("Amphora %s no-op, upload cert amphora %s,with pem fle %s",
self.__class__.__name__, amphora.id, pem_file)
self.amphoraconfig[amphora.id, pem_file] = (amphora.id, pem_file,
'update_amp_cert_file')
class NoopAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver):
def __init__(self):
@ -133,3 +139,7 @@ class NoopAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver):
def post_vip_plug(self, load_balancer, amphorae_network_config):
self.driver.post_vip_plug(load_balancer, amphorae_network_config)
def upload_cert_amp(self, amphora, pem_file):
self.driver.upload_cert_amp(amphora, pem_file)

View File

@ -34,6 +34,7 @@ CONF.import_group('house_keeping', 'octavia.common.config')
spare_amp_thread_event = threading.Event()
db_cleanup_thread_event = threading.Event()
cert_rotate_thread_event = threading.Event()
def spare_amphora_check():
@ -65,6 +66,18 @@ def db_cleanup():
time.sleep(interval)
def cert_rotation():
"""Perform certificate rotation."""
interval = CONF.house_keeping.cert_interval
LOG.info(
_LI("Expiring certificate check interval is set to %d sec") % interval)
cert_rotate = house_keeping.CertRotation()
while cert_rotate_thread_event.is_set():
LOG.debug("Initiating certification rotation ...")
cert_rotate.rotate()
time.sleep(interval)
def main():
service.prepare_service(sys.argv)
@ -85,6 +98,12 @@ def main():
db_cleanup_thread_event.set()
db_cleanup_thread.start()
# Thread to perform certificate rotation
cert_rotate_thread = threading.Thread(target=cert_rotation)
cert_rotate_thread.daemon = True
cert_rotate_thread_event.set()
cert_rotate_thread.start()
# Try-Exception block should be at the end to gracefully exit threads
try:
while True:
@ -93,6 +112,8 @@ def main():
LOG.info(_LI("Attempting to gracefully terminate House-Keeping"))
spare_amp_thread_event.clear()
db_cleanup_thread_event.clear()
cert_rotate_thread_event.clear()
spare_amp_thread.join()
db_cleanup_thread.join()
cert_rotate_thread.join()
LOG.info(_LI("House-Keeping process terminated"))

View File

@ -274,7 +274,18 @@ house_keeping_opts = [
help=_('DB cleanup interval in seconds')),
cfg.IntOpt('amphora_expiry_age',
default=604800,
help=_('Amphora expiry age in seconds'))
help=_('Amphora expiry age in seconds')),
cfg.IntOpt('cert_interval',
default=3600,
help=_('Certificate check interval in seconds')),
# 14 days for cert expiry buffer
cfg.IntOpt('cert_expiry_buffer',
default=1209600,
help=_('Seconds until certificate expiration')),
cfg.IntOpt('cert_rotate_threads',
default=10,
help=_('Number of threads performing amphora certificate'
' rotation'))
]

View File

@ -103,6 +103,7 @@ ADDED_PORTS = 'added_ports'
PORTS = 'ports'
MEMBER_PORTS = 'member_ports'
CERT_ROTATE_AMPHORA_FLOW = 'octavia-cert-rotate-amphora-flow'
CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow'
CREATE_AMPHORA_FOR_LB_FLOW = 'octavia-create-amp-for-lb-flow'
CREATE_HEALTH_MONITOR_FLOW = 'octavia-create-health-monitor-flow'

View File

@ -225,7 +225,8 @@ class Amphora(BaseDataModel):
def __init__(self, id=None, load_balancer_id=None, compute_id=None,
status=None, lb_network_ip=None, vrrp_ip=None,
ha_ip=None, vrrp_port_id=None, ha_port_id=None,
load_balancer=None, role=None):
load_balancer=None, role=None, cert_expiration=None,
cert_busy=False):
self.id = id
self.load_balancer_id = load_balancer_id
self.compute_id = compute_id
@ -237,6 +238,8 @@ class Amphora(BaseDataModel):
self.ha_port_id = ha_port_id
self.role = role
self.load_balancer = load_balancer
self.cert_expiration = cert_expiration
self.cert_busy = cert_busy
class AmphoraHealth(BaseDataModel):

View File

@ -107,6 +107,7 @@ def get_host_names(certificate):
"""
try:
certificate = certificate.encode('ascii')
cert = x509.load_pem_x509_certificate(certificate,
backends.default_backend())
cn = cert.subject.get_attributes_for_oid(x509.OID_COMMON_NAME)[0]
@ -130,6 +131,23 @@ def get_host_names(certificate):
raise exceptions.UnreadableCert
def get_cert_expiration(certificate_pem):
"""Extract the expiration date from the Pem encoded X509 certificate
:param certificate_pem: Certificate in PEM format
:returns: Expiration date of certificate_pem
"""
try:
certificate = certificate_pem.encode('ascii')
cert = x509.load_pem_x509_certificate(certificate,
backends.default_backend())
return cert.not_valid_after
except Exception as e:
LOG.exception(e)
raise exceptions.UnreadableCert
def _get_x509_from_pem_bytes(certificate_pem):
"""Parse X509 data from a PEM encoded certificate

View File

@ -14,6 +14,7 @@
import datetime
from concurrent import futures
from oslo_config import cfg
from oslo_log import log as logging
@ -78,4 +79,31 @@ class DatabaseCleanup(object):
exp_age):
LOG.info(_LI('Attempting to delete Amphora id : %s'), amp.id)
self.amp_repo.delete(session, id=amp.id)
LOG.info(_LI('Deleted Amphora id : %s'), amp.id)
LOG.info(_LI('Deleted Amphora id : %s') % amp.id)
class CertRotation(object):
def __init__(self):
self.threads = CONF.house_keeping.cert_rotate_threads
self.cw = cw.ControllerWorker()
def rotate(self):
"""Check the amphora db table for expiring auth certs."""
amp_repo = repo.AmphoraRepository()
with futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
try:
session = db_api.get_session()
rotation_count = 0
while True:
amp = amp_repo.get_cert_expiring_amphora(session)
if not amp:
break
rotation_count += 1
LOG.debug("Cert expired amphora's id is: %s", amp.id)
executor.submit(self.cw.amphora_cert_rotation, amp.id)
if rotation_count > 0:
LOG.info(_LI("Rotated certificates for %s ampohra") %
rotation_count)
finally:
executor.shutdown(wait=True)

View File

@ -26,6 +26,7 @@ from octavia.controller.worker.flows import member_flows
from octavia.controller.worker.flows import pool_flows
from octavia.db import api as db_apis
from octavia.db import repositories as repo
from octavia.i18n import _LI
from taskflow.listeners import logging as tf_logging
@ -486,3 +487,25 @@ class ControllerWorker(base_taskflow.BaseTaskFlowEngine):
with tf_logging.DynamicLoggingListener(failover_amphora_tf,
log=LOG):
failover_amphora_tf.run()
def amphora_cert_rotation(self, amphora_id):
"""Perform cert rotation for an amphora.
:param amphora_id: ID for amphora to rotate
:returns: None
:raises AmphoraNotFound: The referenced amphora was not found
"""
amp = self._amphora_repo.get(db_apis.get_session(),
id=amphora_id)
LOG.info(_LI("Start amphora cert rotation, amphora's id is: %s")
% amp.id)
certrotation_amphora_tf = self._taskflow_load(
self._amphora_flows.cert_rotate_amphora_flow(),
store={constants.AMPHORA: amp,
constants.AMPHORA_ID: amp.id})
with tf_logging.DynamicLoggingListener(certrotation_amphora_tf,
log=LOG):
certrotation_amphora_tf.run()

View File

@ -25,7 +25,6 @@ from octavia.controller.worker.tasks import compute_tasks
from octavia.controller.worker.tasks import database_tasks
from octavia.controller.worker.tasks import network_tasks
CONF = cfg.CONF
CONF.import_group('controller_worker', 'octavia.common.config')
@ -52,6 +51,11 @@ class AmphoraFlows(object):
if self.REST_AMPHORA_DRIVER:
create_amphora_flow.add(cert_task.GenerateServerPEMTask(
provides=constants.SERVER_PEM))
create_amphora_flow.add(
database_tasks.UpdateAmphoraDBCertExpiration(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM)))
create_amphora_flow.add(compute_tasks.CertComputeCreate(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM),
provides=constants.COMPUTE_ID))
@ -98,6 +102,11 @@ class AmphoraFlows(object):
if self.REST_AMPHORA_DRIVER:
create_amp_for_lb_flow.add(cert_task.GenerateServerPEMTask(
provides=constants.SERVER_PEM))
create_amp_for_lb_flow.add(
database_tasks.UpdateAmphoraDBCertExpiration(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM)))
create_amp_for_lb_flow.add(compute_tasks.CertComputeCreate(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM),
provides=constants.COMPUTE_ID))
@ -249,3 +258,35 @@ class AmphoraFlows(object):
requires=(constants.AMPHORA, constants.LOADBALANCER_ID)))
return failover_amphora_flow
def cert_rotate_amphora_flow(self):
"""Implement rotation for amphora's cert.
1. Create a new certificate
2. Upload the cert to amphora
3. update the newly created certificate info to amphora
4. update the cert_busy flag to be false after rotation
:returns: The flow for updating an amphora
"""
rotated_amphora_flow = linear_flow.Flow(
constants.CERT_ROTATE_AMPHORA_FLOW)
# create a new certificate, the returned value is the newly created
# certificate
rotated_amphora_flow.add(cert_task.GenerateServerPEMTask(
provides=constants.SERVER_PEM))
# update it in amphora task
rotated_amphora_flow.add(amphora_driver_tasks.AmphoraCertUpload(
requires=(constants.AMPHORA, constants.SERVER_PEM)))
# update the newly created certificate info to amphora
rotated_amphora_flow.add(database_tasks.UpdateAmphoraDBCertExpiration(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM)))
# update the cert_busy flag to be false after rotation
rotated_amphora_flow.add(database_tasks.UpdateAmphoraCertBusyToFalse(
requires=constants.AMPHORA))
return rotated_amphora_flow

View File

@ -239,3 +239,12 @@ class AmphoraPostVIPPlug(BaseAmphoraTask):
self.loadbalancer_repo.update(db_apis.get_session(),
id=loadbalancer.id,
status=constants.ERROR)
class AmphoraCertUpload(BaseAmphoraTask):
"""Upload a certificate to the amphora."""
def execute(self, amphora, server_pem):
"""Execute cert_update_amphora routine."""
LOG.debug("Upload cert in amphora REST driver")
self.amphora_driver.upload_cert_amp(amphora, server_pem)

View File

@ -22,10 +22,12 @@ from taskflow.types import failure
from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
import octavia.common.tls_utils.cert_parser as cert_parser
from octavia.db import api as db_apis
from octavia.db import repositories as repo
from octavia.i18n import _LI, _LW
LOG = logging.getLogger(__name__)
@ -54,7 +56,8 @@ class CreateAmphoraInDB(BaseDatabaseTask):
amphora = self.amphora_repo.create(db_apis.get_session(),
id=uuidutils.generate_uuid(),
status=constants.PENDING_CREATE)
status=constants.PENDING_CREATE,
cert_busy=False)
LOG.info(_LI("Created Amphora in DB with id %s"), amphora.id)
return amphora.id
@ -474,6 +477,24 @@ class UpdateAmphoraInfo(BaseDatabaseTask):
return self.amphora_repo.get(db_apis.get_session(), id=amphora_id)
class UpdateAmphoraDBCertExpiration(BaseDatabaseTask):
"""Update the amphora expiration date with new cert file date."""
def execute(self, amphora_id, server_pem):
LOG.debug("Update DB cert expiry date of amphora id: %s", amphora_id)
cert_expiration = cert_parser.get_cert_expiration(server_pem)
LOG.debug("Certificate expiration date is %s ", cert_expiration)
self.amphora_repo.update(db_apis.get_session(), amphora_id,
cert_expiration=cert_expiration)
class UpdateAmphoraCertBusyToFalse(BaseDatabaseTask):
"""Update the amphora cert_busy flag to be false."""
def execute(self, amphora):
self.amphora_repo.update(db_apis.get_session(), amphora.id,
cert_busy=False)
class MarkLBActiveInDB(BaseDatabaseTask):
"""Mark the load balancer active in the DB.

View File

@ -0,0 +1,37 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""add cert expiration infor in amphora table
Revision ID: 5a3ee5472c31
Revises: 3b199c848b96
Create Date: 2015-08-20 10:15:19.561066
"""
# revision identifiers, used by Alembic.
revision = '5a3ee5472c31'
down_revision = '3b199c848b96'
from alembic import op
import sqlalchemy as sa
def upgrade():
op.add_column(u'amphora',
sa.Column(u'cert_expiration', sa.DateTime(timezone=True),
nullable=True)
)
op.add_column(u'amphora', sa.Column(u'cert_busy', sa.Boolean(),
nullable=False, default=False))

View File

@ -340,6 +340,10 @@ class Amphora(base_models.BASE):
ha_ip = sa.Column(sa.String(64), nullable=True)
vrrp_port_id = sa.Column(sa.String(36), nullable=True)
ha_port_id = sa.Column(sa.String(36), nullable=True)
cert_expiration = sa.Column(sa.DateTime(timezone=True), default=None,
nullable=True)
cert_busy = sa.Column(sa.Boolean(), default=False, nullable=False)
role = sa.Column(
sa.String(36),
sa.ForeignKey("amphora_roles.name", name="fk_amphora_roles_name"),

View File

@ -27,9 +27,11 @@ from octavia.common import constants
from octavia.common import exceptions
from octavia.db import models
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
CONF.import_group('health_manager', 'octavia.common.config')
CONF.import_group('house_keeping', 'octavia.common.config')
class BaseRepository(object):
@ -402,6 +404,30 @@ class AmphoraRepository(BaseRepository):
return count
def get_cert_expiring_amphora(self, session):
"""Retrieves an amphora whose cert is close to expiring..
:param session: A Sql Alchemy database session.
:returns: one amphora with expiring certificate
"""
# get amphorae with certs that will expire within the
# configured buffer period, so we can rotate their certs ahead of time
expired_seconds = CONF.house_keeping.cert_expiry_buffer
expired_date = datetime.datetime.utcnow() + datetime.timedelta(
seconds=expired_seconds)
with session.begin(subtransactions=True):
amp = session.query(self.model_class).with_for_update().filter_by(
cert_busy=False).filter(
self.model_class.cert_expiration < expired_date).first()
if amp is None:
return None
amp.cert_busy = True
return amp.to_data_model()
class SNIRepository(BaseRepository):
model_class = models.SNI

View File

@ -123,7 +123,9 @@ class ModelTestMixin(object):
'ha_ip': self.FAKE_IP,
'vrrp_port_id': self.FAKE_UUID_1,
'ha_port_id': self.FAKE_UUID_2,
'lb_network_ip': self.FAKE_IP}
'lb_network_ip': self.FAKE_IP,
'cert_expiration': datetime.datetime.utcnow(),
'cert_busy': False}
kwargs.update(overrides)
return self._insert(session, models.Amphora, kwargs)

View File

@ -15,6 +15,8 @@
import datetime
import random
from oslo_config import cfg
from oslo_log import log as logging
from oslo_utils import uuidutils
from octavia.common import constants
@ -22,6 +24,10 @@ from octavia.common import data_models as models
from octavia.db import repositories as repo
from octavia.tests.functional.db import base
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
CONF.import_group('house_keeping', 'octavia.common.config')
class BaseRepositoryTest(base.OctaviaDBTestBase):
@ -1108,13 +1114,16 @@ class AmphoraRepositoryTest(BaseRepositoryTest):
operating_status=constants.ONLINE, enabled=True)
def create_amphora(self, amphora_id):
expiration = datetime.datetime.utcnow()
amphora = self.amphora_repo.create(self.session, id=amphora_id,
compute_id=self.FAKE_UUID_3,
status=constants.ACTIVE,
lb_network_ip=self.FAKE_IP,
vrrp_ip=self.FAKE_IP,
ha_ip=self.FAKE_IP,
role=constants.ROLE_MASTER)
role=constants.ROLE_MASTER,
cert_expiration=expiration,
cert_busy=False)
return amphora
def test_get(self):
@ -1208,6 +1217,37 @@ class AmphoraRepositoryTest(BaseRepositoryTest):
count = self.amphora_repo.get_spare_amphora_count(self.session)
self.assertEqual(2, count)
def test_get_none_cert_expired_amphora(self):
# test with no expired amphora
amp = self.amphora_repo.get_cert_expiring_amphora(self.session)
self.assertIsNone(amp)
amphora = self.create_amphora(self.FAKE_UUID_1)
expired_interval = CONF.house_keeping.cert_expiry_buffer
expiration = datetime.datetime.utcnow() + datetime.timedelta(
seconds=2 * expired_interval)
self.amphora_repo.update(self.session, amphora.id,
cert_expiration=expiration)
amp = self.amphora_repo.get_cert_expiring_amphora(self.session)
self.assertIsNone(amp)
def test_get_cert_expired_amphora(self):
# test with expired amphora
amphora2 = self.create_amphora(self.FAKE_UUID_2)
expiration = datetime.datetime.utcnow() + datetime.timedelta(
seconds=1)
self.amphora_repo.update(self.session, amphora2.id,
cert_expiration=expiration)
cert_expired_amphora = self.amphora_repo.get_cert_expiring_amphora(
self.session)
self.assertEqual(cert_expired_amphora.cert_expiration, expiration)
self.assertEqual(cert_expired_amphora.id, amphora2.id)
class AmphoraHealthRepositoryTest(BaseRepositoryTest):
def setUp(self):

View File

@ -16,6 +16,7 @@
import mock
from oslo_utils import uuidutils
import requests_mock
import six
from octavia.amphorae.drivers.haproxy import exceptions as exc
from octavia.amphorae.drivers.haproxy import rest_api_driver as driver
@ -34,10 +35,10 @@ FAKE_SUBNET_INFO = {'subnet_cidr': FAKE_CIDR,
FAKE_UUID_1 = uuidutils.generate_uuid()
class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
class TestHaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
def setUp(self):
super(HaproxyAmphoraLoadBalancerDriverTest, self).setUp()
super(TestHaproxyAmphoraLoadBalancerDriverTest, self).setUp()
self.driver = driver.HaproxyAmphoraLoadBalancerDriver()
self.driver.cert_manager = mock.MagicMock()
@ -110,6 +111,11 @@ class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
self.driver.client.start_listener.assert_called_once_with(
self.amp, self.sl.id)
def test_upload_cert_amp(self):
self.driver.upload_cert_amp(self.amp, six.b('test'))
self.driver.client.update_cert_for_rotation.assert_called_once_with(
self.amp, six.b('test'))
def test_stop(self):
# Execute driver method
self.driver.stop(self.sl, self.sv)
@ -152,10 +158,10 @@ class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
self.amp, dict(mac_address='123'))
class AmphoraAPIClientTest(base.TestCase):
class TestAmphoraAPIClientTest(base.TestCase):
def setUp(self):
super(AmphoraAPIClientTest, self).setUp()
super(TestAmphoraAPIClientTest, self).setUp()
self.driver = driver.AmphoraAPIClient()
self.base_url = "https://127.0.0.1:9443/0.5"
self.amp = models.Amphora(lb_network_ip='127.0.0.1', compute_id='123')
@ -479,6 +485,41 @@ class AmphoraAPIClientTest(base.TestCase):
self.amp, FAKE_UUID_1, FAKE_PEM_FILENAME,
"some_file")
@requests_mock.mock()
def test_update_cert_for_rotation(self, m):
m.put("{base}/certificate".format(base=self.base_url))
resp_body = self.driver.update_cert_for_rotation(self.amp,
"some_file")
self.assertEqual(200, resp_body.status_code)
@requests_mock.mock()
def test_update_invalid_cert_for_rotation(self, m):
m.put("{base}/certificate".format(base=self.base_url), status_code=403)
self.assertRaises(exc.InvalidRequest,
self.driver.update_cert_for_rotation, self.amp,
"some_file")
@requests_mock.mock()
def test_update_cert_for_rotation_unauthorized(self, m):
m.put("{base}/certificate".format(base=self.base_url), status_code=401)
self.assertRaises(exc.Unauthorized,
self.driver.update_cert_for_rotation, self.amp,
"some_file")
@requests_mock.mock()
def test_update_cert_for_rotation_error(self, m):
m.put("{base}/certificate".format(base=self.base_url), status_code=500)
self.assertRaises(exc.InternalServerError,
self.driver.update_cert_for_rotation, self.amp,
"some_file")
@requests_mock.mock()
def test_update_cert_for_rotation_unavailable(self, m):
m.put("{base}/certificate".format(base=self.base_url), status_code=503)
self.assertRaises(exc.ServiceUnavailable,
self.driver.update_cert_for_rotation, self.amp,
"some_file")
@requests_mock.mock()
def test_get_cert_5sum(self, m):
md5sum = {"md5sum": "some_real_sum"}

View File

@ -62,6 +62,7 @@ class TestNoopAmphoraLoadBalancerDriver(base.TestCase):
amphora=self.amphora,
vip_subnet=network_models.Subnet(id=self.FAKE_UUID_1))
}
self.pem_file = 'test_pem_file'
def test_update(self):
self.driver.update(self.listener, self.vip)
@ -125,3 +126,10 @@ class TestNoopAmphoraLoadBalancerDriver(base.TestCase):
self.load_balancer.id, id(self.amphorae_net_configs)
)]
self.assertEqual(expected_method_and_args, actual_method_and_args)
def test_upload_cert_amp(self):
self.driver.upload_cert_amp(self.amphora, self.pem_file)
self.assertEqual(
(self.amphora.id, self.pem_file, 'update_amp_cert_file'),
self.driver.driver.amphoraconfig[(
self.amphora.id, self.pem_file)])

View File

@ -64,23 +64,79 @@ class TestHouseKeepingCMD(base.TestCase):
mock_DatabaseCleanup.assert_called_once_with()
self.assertEqual(1, db_cleanup.delete_old_amphorae.call_count)
@mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event')
@mock.patch('octavia.controller.housekeeping.'
'house_keeping.CertRotation')
@mock.patch('time.sleep')
def test_hk_cert_rotation_with_exception(self, sleep_mock,
mock_CertRotation,
cert_rotate_event_mock):
# mock cert_rotate object
cert_rotate_mock = mock.MagicMock()
# mock rotate()
rotate_mock = mock.MagicMock()
cert_rotate_mock.rotate = rotate_mock
mock_CertRotation.return_value = cert_rotate_mock
# mock cert_rotate_thread_event.is_set() in the while loop
cert_rotate_event_mock.is_set = mock.MagicMock()
cert_rotate_event_mock.is_set.side_effect = [True, Exception('break')]
self.assertRaisesRegexp(Exception, 'break',
house_keeping.cert_rotation)
mock_CertRotation.assert_called_once_with()
self.assertEqual(1, cert_rotate_mock.rotate.call_count)
@mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event')
@mock.patch('octavia.controller.housekeeping.'
'house_keeping.CertRotation')
@mock.patch('time.sleep')
def test_hk_cert_rotation_without_exception(self, sleep_mock,
mock_CertRotation,
cert_rotate_event_mock):
# mock cert_rotate object
cert_rotate_mock = mock.MagicMock()
# mock rotate()
rotate_mock = mock.MagicMock()
cert_rotate_mock.rotate = rotate_mock
mock_CertRotation.return_value = cert_rotate_mock
# mock cert_rotate_thread_event.is_set() in the while loop
cert_rotate_event_mock.is_set = mock.MagicMock()
cert_rotate_event_mock.is_set.side_effect = [True, None]
self.assertEqual(None, house_keeping.cert_rotation())
mock_CertRotation.assert_called_once_with()
self.assertEqual(1, cert_rotate_mock.rotate.call_count)
@mock.patch('time.sleep')
@mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event')
@mock.patch('octavia.cmd.house_keeping.db_cleanup_thread_event')
@mock.patch('octavia.cmd.house_keeping.spare_amp_thread_event')
@mock.patch('threading.Thread')
@mock.patch('octavia.common.service.prepare_service')
def test_main(self, mock_service, mock_thread,
spare_amp_thread_event_mock,
db_cleanup_thread_event_mock, sleep_time):
db_cleanup_thread_event_mock,
cert_rotate_thread_event_mock, sleep_time):
spare_amp_thread_mock = mock.MagicMock()
db_cleanup_thread_mock = mock.MagicMock()
cert_rotate_thread_mock = mock.MagicMock()
mock_thread.side_effect = [spare_amp_thread_mock,
db_cleanup_thread_mock]
db_cleanup_thread_mock,
cert_rotate_thread_mock]
spare_amp_thread_mock.daemon.return_value = True
db_cleanup_thread_mock.daemon.return_value = True
cert_rotate_thread_mock.daemon.return_value = True
# mock the time.sleep() in the while loop
sleep_time.side_effect = [True, Exception('break')]
@ -88,14 +144,18 @@ class TestHouseKeepingCMD(base.TestCase):
spare_amp_thread_event_mock.set.assert_called_once_with()
db_cleanup_thread_event_mock.set.assert_called_once_with()
cert_rotate_thread_event_mock.set.assert_called_once_with()
spare_amp_thread_mock.start.assert_called_once_with()
db_cleanup_thread_mock.start.assert_called_once_with()
cert_rotate_thread_mock.start.assert_called_once_with()
self.assertTrue(spare_amp_thread_mock.daemon)
self.assertTrue(db_cleanup_thread_mock.daemon)
self.assertTrue(cert_rotate_thread_mock.daemon)
@mock.patch('time.sleep')
@mock.patch('octavia.cmd.house_keeping.cert_rotate_thread_event')
@mock.patch('octavia.cmd.house_keeping.db_cleanup_thread_event')
@mock.patch('octavia.cmd.house_keeping.spare_amp_thread_event')
@mock.patch('threading.Thread')
@ -103,15 +163,19 @@ class TestHouseKeepingCMD(base.TestCase):
def test_main_keyboard_interrupt(self, mock_service, mock_thread,
spare_amp_thread_event_mock,
db_cleanup_thread_event_mock,
cert_rotate_thread_event_mock,
sleep_time):
spare_amp_thread_mock = mock.MagicMock()
db_cleanup_thread_mock = mock.MagicMock()
cert_rotate_thread_mock = mock.MagicMock()
mock_thread.side_effect = [spare_amp_thread_mock,
db_cleanup_thread_mock]
db_cleanup_thread_mock,
cert_rotate_thread_mock]
spare_amp_thread_mock.daemon.return_value = True
db_cleanup_thread_mock.daemon.return_value = True
cert_rotate_thread_mock.daemon.return_value = True
# mock the time.sleep() in the while loop
sleep_time.side_effect = [True, KeyboardInterrupt]
@ -123,11 +187,17 @@ class TestHouseKeepingCMD(base.TestCase):
db_cleanup_thread_event_mock.set.assert_called_once_with()
db_cleanup_thread_event_mock.clear.assert_called_once_with()
cert_rotate_thread_event_mock.set.assert_called_once_with()
cert_rotate_thread_event_mock.clear.assert_called_once_with()
spare_amp_thread_mock.start.assert_called_once_with()
db_cleanup_thread_mock.start.assert_called_once_with()
cert_rotate_thread_mock.start.assert_called_once_with()
self.assertTrue(spare_amp_thread_mock.daemon)
self.assertTrue(db_cleanup_thread_mock.daemon)
self.assertTrue(cert_rotate_thread_mock.daemon)
spare_amp_thread_mock.join.assert_called_once_with()
db_cleanup_thread_mock.join.assert_called_once_with()
cert_rotate_thread_mock.join.assert_called_once_with()

View File

@ -12,6 +12,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import mock
import six
@ -316,3 +317,11 @@ class TestTLSParseUtils(base.TestCase):
cp.return_value = {'cn': 'fakeCN'}
cn = cert_parser.get_primary_cn(cert)
self.assertEqual('fakeCN', cn)
def test_get_cert_expiration(self):
exp_date = cert_parser.get_cert_expiration(ALT_EXT_CRT)
self.assertEqual(datetime.datetime(2025, 5, 18, 20, 33, 23), exp_date)
# test the exception
self.assertRaises(exceptions.UnreadableCert,
cert_parser.get_cert_expiration, 'bad-cert-file')

View File

@ -24,6 +24,16 @@ import octavia.tests.unit.base as base
CONF = cfg.CONF
CONF.import_group('house_keeping', 'octavia.common.config')
AMPHORA_ID = uuidutils.generate_uuid()
class TestException(Exception):
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value)
class TestSpareCheck(base.TestCase):
@ -122,3 +132,63 @@ class TestDatabaseCleanup(base.TestCase):
self.assertTrue(self.amp_repo.get_all.called)
self.assertTrue(self.amp_health_repo.check_amphora_expired.called)
self.assertFalse(self.amp_repo.delete.called)
class TestCertRotation(base.TestCase):
def setUp(self):
super(TestCertRotation, self).setUp()
@mock.patch('octavia.controller.worker.controller_worker.'
'ControllerWorker.amphora_cert_rotation')
@mock.patch('octavia.db.repositories.AmphoraRepository.'
'get_cert_expiring_amphora')
@mock.patch('octavia.db.api.get_session')
def test_cert_rotation_expired_amphora_with_exception(self, session,
cert_exp_amp_mock,
amp_cert_mock
):
amphora = mock.MagicMock()
amphora.id = AMPHORA_ID
session.return_value = session
cert_exp_amp_mock.side_effect = [amphora, TestException(
'break_while')]
cr = house_keeping.CertRotation()
self.assertRaises(TestException, cr.rotate)
amp_cert_mock.assert_called_once_with(AMPHORA_ID)
@mock.patch('octavia.controller.worker.controller_worker.'
'ControllerWorker.amphora_cert_rotation')
@mock.patch('octavia.db.repositories.AmphoraRepository.'
'get_cert_expiring_amphora')
@mock.patch('octavia.db.api.get_session')
def test_cert_rotation_expired_amphora_without_exception(self, session,
cert_exp_amp_mock,
amp_cert_mock
):
amphora = mock.MagicMock()
amphora.id = AMPHORA_ID
session.return_value = session
cert_exp_amp_mock.side_effect = [amphora, None]
cr = house_keeping.CertRotation()
self.assertEqual(None, cr.rotate())
amp_cert_mock.assert_called_once_with(AMPHORA_ID)
@mock.patch('octavia.controller.worker.controller_worker.'
'ControllerWorker.amphora_cert_rotation')
@mock.patch('octavia.db.repositories.AmphoraRepository.'
'get_cert_expiring_amphora')
@mock.patch('octavia.db.api.get_session')
def test_cert_rotation_non_expired_amphora(self, session,
cert_exp_amp_mock,
amp_cert_mock):
session.return_value = session
cert_exp_amp_mock.return_value = None
cr = house_keeping.CertRotation()
cr.rotate()
self.assertFalse(amp_cert_mock.called)

View File

@ -133,3 +133,17 @@ class TestAmphoraFlows(base.TestCase):
self.assertEqual(2, len(amp_flow.requires))
self.assertEqual(12, len(amp_flow.provides))
def test_cert_rotate_amphora_flow(self):
cfg.CONF.set_override('amphora_driver', 'amphora_haproxy_rest_driver',
group='controller_worker')
self.AmpFlow = amphora_flows.AmphoraFlows()
amp_rotate_flow = self.AmpFlow.cert_rotate_amphora_flow()
self.assertIsInstance(amp_rotate_flow, flow.Flow)
self.assertIn(constants.SERVER_PEM, amp_rotate_flow.provides)
self.assertIn(constants.AMPHORA, amp_rotate_flow.requires)
self.assertEqual(1, len(amp_rotate_flow.provides))
self.assertEqual(2, len(amp_rotate_flow.requires))

View File

@ -270,3 +270,17 @@ class TestAmphoraDriverTasks(base.TestCase):
status=constants.ERROR)
self.assertIsNone(amp)
def test_amphora_cert_upload(self,
mock_driver,
mock_generate_uuid,
mock_log,
mock_get_session,
mock_listener_repo_update,
mock_amphora_repo_update):
pem_file_mock = 'test-perm-file'
amphora_cert_upload_mock = amphora_driver_tasks.AmphoraCertUpload()
amphora_cert_upload_mock.execute(_amphora_mock, pem_file_mock)
mock_driver.upload_cert_amp.assert_called_once_with(
_amphora_mock, pem_file_mock)

View File

@ -44,6 +44,43 @@ _pool_mock.id = POOL_ID
_listener_mock = mock.MagicMock()
_listener_mock.id = LISTENER_ID
_tf_failure_mock = mock.Mock(spec=failure.Failure)
_cert_mock = mock.MagicMock()
_pem_mock = """Junk
-----BEGIN CERTIFICATE-----
MIIBhDCCAS6gAwIBAgIGAUo7hO/eMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNVBAMT
BElNRDIwHhcNMTQxMjExMjI0MjU1WhcNMjUxMTIzMjI0MjU1WjAPMQ0wCwYDVQQD
EwRJTUQzMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAKHIPXo2pfD5dpnpVDVz4n43
zn3VYsjz/mgOZU0WIWjPA97mvulb7mwb4/LB4ijOMzHj9XfwP75GiOFxYFs8O80C
AwEAAaNwMG4wDwYDVR0TAQH/BAUwAwEB/zA8BgNVHSMENTAzgBS6rfnABCO3oHEz
NUUtov2hfXzfVaETpBEwDzENMAsGA1UEAxMESU1EMYIGAUo7hO/DMB0GA1UdDgQW
BBRiLW10LVJiFO/JOLsQFev0ToAcpzANBgkqhkiG9w0BAQsFAANBABtdF+89WuDi
TC0FqCocb7PWdTucaItD9Zn55G8KMd93eXrOE/FQDf1ScC+7j0jIHXjhnyu6k3NV
8el/x5gUHlc=
-----END CERTIFICATE-----
Junk should be ignored by x509 splitter
-----BEGIN CERTIFICATE-----
MIIBhDCCAS6gAwIBAgIGAUo7hO/DMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNVBAMT
BElNRDEwHhcNMTQxMjExMjI0MjU1WhcNMjUxMTIzMjI0MjU1WjAPMQ0wCwYDVQQD
EwRJTUQyMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAJYHqnsisVKTlwVaCSa2wdrv
CeJJzqpEVV0RVgAAF6FXjX2Tioii+HkXMR9zFgpE1w4yD7iu9JDb8yTdNh+NxysC
AwEAAaNwMG4wDwYDVR0TAQH/BAUwAwEB/zA8BgNVHSMENTAzgBQt3KvN8ncGj4/s
if1+wdvIMCoiE6ETpBEwDzENMAsGA1UEAxMEcm9vdIIGAUo7hO+mMB0GA1UdDgQW
BBS6rfnABCO3oHEzNUUtov2hfXzfVTANBgkqhkiG9w0BAQsFAANBAIlJODvtmpok
eoRPOb81MFwPTTGaIqafebVWfBlR0lmW8IwLhsOUdsQqSzoeypS3SJUBpYT1Uu2v
zEDOmgdMsBY=
-----END CERTIFICATE-----
Junk should be thrown out like junk
-----BEGIN CERTIFICATE-----
MIIBfzCCASmgAwIBAgIGAUo7hO+mMA0GCSqGSIb3DQEBCwUAMA8xDTALBgNVBAMT
BHJvb3QwHhcNMTQxMjExMjI0MjU1WhcNMjUxMTIzMjI0MjU1WjAPMQ0wCwYDVQQD
EwRJTUQxMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAI+tSJxr60ogwXFmgqbLMW7K
3fkQnh9sZBi7Qo6AzUnfe/AhXoisib651fOxKXCbp57IgzLTv7O9ygq3I+5fQqsC
AwEAAaNrMGkwDwYDVR0TAQH/BAUwAwEB/zA3BgNVHSMEMDAugBR73ZKSpjbsz9tZ
URkvFwpIO7gB4KETpBEwDzENMAsGA1UEAxMEcm9vdIIBATAdBgNVHQ4EFgQULdyr
zfJ3Bo+P7In9fsHbyDAqIhMwDQYJKoZIhvcNAQELBQADQQBenkZ2k7RgZqgj+dxA
D7BF8MN1oUAOpyYqAjkGddSEuMyNmwtHKZI1dyQ0gBIQdiU9yAG2oTbUIK4msbBV
uJIQ
-----END CERTIFICATE-----"""
@mock.patch('octavia.db.repositories.AmphoraRepository.delete')
@ -92,7 +129,8 @@ class TestDatabaseTasks(base.TestCase):
repo.AmphoraRepository.create.assert_called_once_with(
'TEST',
id=AMP_ID,
status=constants.PENDING_CREATE)
status=constants.PENDING_CREATE,
cert_busy=False)
assert(amp_id == _amphora_mock.id)
@ -616,6 +654,41 @@ class TestDatabaseTasks(base.TestCase):
LB_ID,
provisioning_status=constants.ERROR)
@mock.patch('octavia.common.tls_utils.cert_parser.get_cert_expiration',
return_value=_cert_mock)
def test_update_amphora_db_cert_exp(self,
mock_generate_uuid,
mock_LOG,
mock_get_session,
mock_loadbalancer_repo_update,
mock_listener_repo_update,
mock_amphora_repo_update,
mock_amphora_repo_delete,
mock_get_cert_exp):
update_amp_cert = database_tasks.UpdateAmphoraDBCertExpiration()
update_amp_cert.execute(_amphora_mock.id, _pem_mock)
repo.AmphoraRepository.update.assert_called_once_with(
'TEST',
AMP_ID,
cert_expiration=_cert_mock)
def test_update_amphora_cert_busy_to_false(self,
mock_generate_uuid,
mock_LOG,
mock_get_session,
mock_loadbalancer_repo_update,
mock_listener_repo_update,
mock_amphora_repo_update,
mock_amphora_repo_delete):
amp_cert_busy_to_F = database_tasks.UpdateAmphoraCertBusyToFalse()
amp_cert_busy_to_F.execute(_amphora_mock)
repo.AmphoraRepository.update.assert_called_once_with(
'TEST',
AMP_ID,
cert_busy=False)
def test_mark_LB_active_in_db(self,
mock_generate_uuid,
mock_LOG,

View File

@ -44,6 +44,7 @@ _member_mock = mock.MagicMock()
_pool_mock = mock.MagicMock()
_create_map_flow_mock = mock.MagicMock()
_amphora_mock.load_balancer_id = LB_ID
_amphora_mock.id = AMP_ID
@mock.patch('octavia.db.repositories.AmphoraRepository.get',
@ -681,3 +682,27 @@ class TestControllerWorker(base.TestCase):
_amphora_mock.load_balancer_id}))
_flow_mock.run.assert_called_once_with()
@mock.patch('octavia.controller.worker.flows.'
'amphora_flows.AmphoraFlows.cert_rotate_amphora_flow',
return_value=_flow_mock)
def test_amphora_cert_rotation(self,
mock_get_update_listener_flow,
mock_api_get_session,
mock_dyn_log_listener,
mock_taskflow_load,
mock_pool_repo_get,
mock_member_repo_get,
mock_listener_repo_get,
mock_lb_repo_get,
mock_health_mon_repo_get,
mock_amp_repo_get):
_flow_mock.reset_mock()
cw = controller_worker.ControllerWorker()
cw.amphora_cert_rotation(AMP_ID)
(base_taskflow.BaseTaskFlowEngine._taskflow_load.
assert_called_once_with(_flow_mock,
store={constants.AMPHORA: _amphora_mock,
constants.AMPHORA_ID:
_amphora_mock.id}))
_flow_mock.run.assert_called_once_with()