From 0abcbc4f7d412de7b711352e81e61daada56c741 Mon Sep 17 00:00:00 2001 From: German Eichberger Date: Fri, 27 Feb 2015 14:45:54 -0800 Subject: [PATCH] haproxy reference amphora REST API client Adds rest driver methods Adds rest driver tests Add cert task for generating server certs Modified compute task/flow Fixed local certificate stuff Refactored to use requests-mock inetad of responses Added a "conditiobal flow" for REST Cleaned up and changed the code to work with https://review.openstack.org/#/c/160034/ Replaces: https://review.openstack.org/#/c/144348/ https://review.openstack.org/#/c/145637/14 Change-Id: Ibcbf0717b785aab4c604deef1061e8b2fa41006c Co-Authored-By: Phillip Toohill Co-Authored-By: German Eichberger Co-Authored-By: Stephen Balukoff Implements: bp/haproxy-amphora-driver --- devstack/plugin.sh | 1 + .../install.d/75-run_setup_install | 1 + etc/octavia.conf | 14 +- .../backends/agent/api_server/listener.py | 2 +- .../amphorae/drivers/haproxy/data_models.py | 110 ++++ .../amphorae/drivers/haproxy/exceptions.py | 72 +++ .../drivers/haproxy/rest_api_driver.py | 299 +++++++++ octavia/certificates/generator/local.py | 27 +- octavia/common/config.py | 20 +- octavia/common/constants.py | 1 + octavia/common/data_models.py | 4 + .../controller/worker/flows/amphora_flows.py | 32 +- octavia/controller/worker/tasks/cert_task.py | 52 ++ .../controller/worker/tasks/compute_tasks.py | 26 +- .../backend/agent/api_server/test_server.py | 8 +- .../drivers/haproxy/test_rest_api_driver.py | 601 ++++++++++++++++++ .../worker/flows/test_amphora_flows.py | 38 ++ .../controller/worker/tasks/test_cert_task.py | 33 + .../worker/tasks/test_compute_tasks.py | 60 +- setup.cfg | 9 +- test-requirements.txt | 3 +- 21 files changed, 1380 insertions(+), 33 deletions(-) create mode 100644 octavia/amphorae/drivers/haproxy/data_models.py create mode 100644 octavia/amphorae/drivers/haproxy/exceptions.py create mode 100644 octavia/amphorae/drivers/haproxy/rest_api_driver.py create mode 100644 octavia/controller/worker/tasks/cert_task.py create mode 100644 octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py create mode 100644 octavia/tests/unit/controller/worker/tasks/test_cert_task.py diff --git a/devstack/plugin.sh b/devstack/plugin.sh index 54ee64bb9e..d2c24d73ce 100644 --- a/devstack/plugin.sh +++ b/devstack/plugin.sh @@ -61,6 +61,7 @@ function octavia_configure { # Setting other required default options iniset $OCTAVIA_CONF controller_worker amphora_driver amphora_haproxy_ssh_driver + #iniset $OCTAVIA_CONF controller_worker amphora_driver amphora_haproxy_rest_driver iniset $OCTAVIA_CONF controller_worker compute_driver compute_nova_driver iniset $OCTAVIA_CONF controller_worker network_driver allowed_address_pairs_driver diff --git a/elements/amphora-agent/install.d/75-run_setup_install b/elements/amphora-agent/install.d/75-run_setup_install index df76810083..55f46d3c5c 100755 --- a/elements/amphora-agent/install.d/75-run_setup_install +++ b/elements/amphora-agent/install.d/75-run_setup_install @@ -3,6 +3,7 @@ set -eux install-packages libffi-dev libssl-dev cd /opt/amphora-agent/ +pip install -r requirements.txt python setup.py install cp etc/init/octavia-agent.conf /etc/init/ mkdir /etc/octavia diff --git a/etc/octavia.conf b/etc/octavia.conf index 7fb730e1fb..cfc3dc9c30 100644 --- a/etc/octavia.conf +++ b/etc/octavia.conf @@ -53,7 +53,6 @@ # Network to communicate with amphora # lb_network_name = - [haproxy_amphora] # base_path = /var/lib/octavia # base_cert_dir = /var/lib/octavia/certs @@ -62,8 +61,12 @@ # base_log_dir = /logs # connection_max_retries = 10 # connection_retry_interval = 5 +# Cert manager options are local_cert_manager, +# barbican_cert_manager, +# +# cert_manager = barbican_cert_manager -#SSH Driver specific +# SSH Driver specific # username = ubuntu # key_path = /opt/stack/.ssh/id_rsa @@ -86,6 +89,7 @@ # amp_ssh_key_name = # amp_network = # amp_secgroup_list = +# client_ca = /etc/octavia/certs/ca_01.pem # Amphora driver options are amphora_noop_driver, # amphora_haproxy_rest_driver, @@ -102,6 +106,10 @@ # allowed_address_pairs_driver # # network_driver = network_noop_driver +# +# Certificate Generator options are local_cert_generator +# barbican_cert_generator +# cert_generator = local_cert_generator [task_flow] # engine = serial @@ -121,4 +129,4 @@ # rpc_thread_pool_size = 2 # Topic (i.e. Queue) Name -# topic = octavia_prov \ No newline at end of file +# topic = octavia_prov diff --git a/octavia/amphorae/backends/agent/api_server/listener.py b/octavia/amphorae/backends/agent/api_server/listener.py index d1d0b74d92..bd6d6fe141 100644 --- a/octavia/amphorae/backends/agent/api_server/listener.py +++ b/octavia/amphorae/backends/agent/api_server/listener.py @@ -361,7 +361,7 @@ def _check_ssl_filename_format(filename): def _cert_dir(listener_id): - return os.path.join(util.CONF.haproxy_amphora.haproxy_cert_dir, + return os.path.join(util.CONF.haproxy_amphora.base_cert_dir, listener_id) diff --git a/octavia/amphorae/drivers/haproxy/data_models.py b/octavia/amphorae/drivers/haproxy/data_models.py new file mode 100644 index 0000000000..3dd4d27579 --- /dev/null +++ b/octavia/amphorae/drivers/haproxy/data_models.py @@ -0,0 +1,110 @@ +# Copyright 2014 Rackspace +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import octavia.common.data_models as models + + +class Topology(models.BaseDataModel): + + def __init__(self, hostname=None, uuid=None, topology=None, role=None, + ip=None, ha_ip=None): + self.hostname = hostname + self.uuid = uuid + self.topology = topology + self.role = role + self.ip = ip + self.ha_ip = ha_ip + + +class Info(models.BaseDataModel): + + def __init__(self, hostname=None, uuid=None, version=None, + api_version=None): + self.hostname = hostname + self.uuid = uuid + self.version = version + self.api_version = api_version + + +class Details(models.BaseDataModel): + + def __init__(self, hostname=None, uuid=None, version=None, + api_version=None, network_tx=None, network_rx=None, + active=None, haproxy_count=None, cpu=None, memory=None, + disk=None, load=None, listeners=None, packages=None): + self.hostname = hostname + self.uuid = uuid, + self.version = version + self.api_version = api_version + self.network_tx = network_tx + self.network_rx = network_rx + self.active = active + self.haproxy_count = haproxy_count + self.cpu = cpu + self.memory = memory + self.disk = disk + self.load = load or [] + self.listeners = listeners or [] + self.packages = packages or [] + + +class CPU(models.BaseDataModel): + + def __init__(self, total=None, user=None, system=None, soft_irq=None): + self.total = total + self.user = user + self.system = system + self.soft_irq = soft_irq + + +class Memory(models.BaseDataModel): + + def __init__(self, total=None, free=None, available=None, buffers=None, + cached=None, swap_used=None, shared=None, slab=None, + committed_as=None): + self.total = total + self.free = free + self.available = available + self.buffers = buffers + self.cached = cached + self.swap_used = swap_used + self.shared = shared + self.slab = slab + self.committed_as = committed_as + + +class Disk(models.BaseDataModel): + + def __init__(self, used=None, available=None): + self.used = used + self.available = available + + +class ListenerStatus(models.BaseDataModel): + + def __init__(self, status=None, uuid=None, provisioning_status=None, + type=None, pools=None): + self.status = status + self.uuid = uuid + self.provisioning_status = provisioning_status + self.type = type + self.pools = pools or [] + + +class Pool(models.BaseDataModel): + + def __init__(self, uuid=None, status=None, members=None): + self.uuid = uuid + self.status = status + self.members = members or [] diff --git a/octavia/amphorae/drivers/haproxy/exceptions.py b/octavia/amphorae/drivers/haproxy/exceptions.py new file mode 100644 index 0000000000..774c816d70 --- /dev/null +++ b/octavia/amphorae/drivers/haproxy/exceptions.py @@ -0,0 +1,72 @@ +# Copyright 2014 Rackspace +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from webob import exc + + +def check_exception(response): + status_code = response.status_code + responses = { + 401: Unauthorized, + 403: InvalidRequest, + 404: NotFound, + 405: InvalidRequest, + 409: Conflict, + 500: InternalServerError, + 503: ServiceUnavailable + } + if status_code in responses: + raise responses[status_code]() + + return response + + +class APIException(exc.HTTPClientError): + msg = "Something unknown went wrong" + code = 500 + + def __init__(self, **kwargs): + self.msg = self.msg % kwargs + super(APIException, self).__init__(detail=self.msg) + + +class Unauthorized(APIException): + msg = "Unauthorized" + code = 401 + + +class InvalidRequest(APIException): + msg = "Invalid request" + code = 403 + + +class NotFound(APIException): + msg = "Not Found" + code = 404 + + +class Conflict(APIException): + msg = "Conflict" + code = 409 + + +class InternalServerError(APIException): + msg = "Internal Server Error" + code = 500 + + +class ServiceUnavailable(APIException): + msg = "Service Unavailable" + code = 503 \ No newline at end of file diff --git a/octavia/amphorae/drivers/haproxy/rest_api_driver.py b/octavia/amphorae/drivers/haproxy/rest_api_driver.py new file mode 100644 index 0000000000..13878e534c --- /dev/null +++ b/octavia/amphorae/drivers/haproxy/rest_api_driver.py @@ -0,0 +1,299 @@ +# Copyright 2015 Hewlett-Packard Development Company, L.P. +# Copyright (c) 2015 Rackspace +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import hashlib +import time + +from oslo_log import log as logging +import requests +from stevedore import driver as stevedore_driver + +from octavia.amphorae.drivers import driver_base as driver_base +from octavia.amphorae.drivers.haproxy import exceptions as exc +from octavia.amphorae.drivers.haproxy.jinja import jinja_cfg +from octavia.common.config import cfg +from octavia.common import data_models as data_models +from octavia.common.tls_utils import cert_parser +from octavia.i18n import _LW + +LOG = logging.getLogger(__name__) +API_VERSION = '0.5' +OCTAVIA_API_CLIENT = ( + "Octavia HaProxy Rest Client/{version} " + "(https://wiki.openstack.org/wiki/Octavia)").format(version=API_VERSION) +CONF = cfg.CONF +CONF.import_group('haproxy_amphora', 'octavia.common.config') + + +class HaproxyAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver): + def __init__(self): + super(HaproxyAmphoraLoadBalancerDriver, self).__init__() + self.client = AmphoraAPIClient() + self.cert_manager = stevedore_driver.DriverManager( + namespace='octavia.cert_manager', + name=CONF.haproxy_amphora.cert_manager, + invoke_on_load=True, + ).driver + self.jinja = jinja_cfg.JinjaTemplater( + base_amp_path=CONF.haproxy_amphora.base_path, + base_crt_dir=CONF.haproxy_amphora.base_cert_dir, + haproxy_template=CONF.haproxy_amphora.haproxy_template) + + def update(self, listener, vip): + LOG.debug("Amphora %s haproxy, updating listener %s, vip %s", + self.__class__.__name__, listener.protocol_port, + vip.ip_address) + + # Process listener certificate info + certs = self._process_tls_certificates(listener) + + # Generate HaProxy configuration from listener object + config = self.jinja.build_config(listener, certs['tls_cert'], + certs['sni_certs']) + + for amp in listener.load_balancer.amphorae: + self.client.upload_config(amp, listener.id, config) + # todo (german): add a method to REST interface to reload or start + # without having to check + # Is that listener running? + r = self.client.get_listener_status(amp, + listener.id) + if r['status'] == 'ACTIVE': + self.client.reload_listener(amp, listener.id) + else: + self.client.start_listener(amp, listener.id) + + def _apply(self, func, listener=None, *args): + for amp in listener.load_balancer.amphorae: + func(amp, listener.id, *args) + + def stop(self, listener, vip): + self._apply(self.client.stop_listener, listener) + + def start(self, listener, vip): + self._apply(self.client.start_listener, listener) + + def delete(self, listener, vip): + self._apply(self.client.delete_listener, listener) + + def get_info(self, amphora): + self.driver.get_info(amphora.lb_network_ip) + + def get_diagnostics(self, amphora): + self.driver.get_diagnostics(amphora.lb_network_ip) + + def finalize_amphora(self, amphora): + pass + + def post_vip_plug(self, load_balancer): + for amp in load_balancer.amphorae: + self.client.plug_vip(amp, load_balancer.vip.ip_address) + + def post_network_plug(self, amphora): + self.client.plug_network(amphora) + + def _process_tls_certificates(self, listener): + """Processes TLS data from the listener. + + Converts and uploads PEM data to the Amphora API + + return TLS_CERT and SNI_CERTS + """ + tls_cert = None + sni_certs = [] + + certs = [] + + if listener.tls_certificate_id: + tls_cert = self._map_cert_tls_container( + self.cert_manager.get_cert(listener.tls_certificate_id)) + certs.append(tls_cert) + if listener.sni_containers: + for sni_cont in listener.sni_containers: + bbq_container = self._map_cert_tls_container( + self.cert_manager.get_cert(sni_cont.tls_container.id)) + sni_certs.append(bbq_container) + certs.append(bbq_container) + + for cert in certs: + pem = self._build_pem(cert) + md5 = hashlib.md5(pem).hexdigest() + name = '{cn}.pem'.format(cn=cert.primary_cn) + self._apply(self._upload_cert, listener, pem, md5, name) + + return {'tls_cert': tls_cert, 'sni_certs': sni_certs} + + def _upload_cert(self, amp, listener_id, pem, md5, name): + try: + if self.client.get_cert_md5sum(amp, listener_id, name) == md5: + return + except exc.NotFound: + pass + + self.client.upload_cert_pem( + amp, listener_id, name, pem) + + def _get_primary_cn(self, tls_cert): + """Returns primary CN for Certificate.""" + return cert_parser.get_host_names(tls_cert.get_certificate())['cn'] + + def _map_cert_tls_container(self, cert): + return data_models.TLSContainer( + primary_cn=self._get_primary_cn(cert), + private_key=cert.get_private_key(), + certificate=cert.get_certificate(), + intermediates=cert.get_intermediates()) + + def _build_pem(self, tls_cert): + """Concatenate TLS Certificate fields to create a PEM + + encoded certificate file + """ + # TODO(ptoohill): Maybe this should be part of utils or manager? + pem = tls_cert.intermediates[:] + pem.extend([tls_cert.certificate, tls_cert.private_key]) + + return "\n".join(pem) + + +# Check a custom hostname +class CustomHostNameCheckingAdapter(requests.adapters.HTTPAdapter): + def cert_verify(self, conn, url, verify, cert): + conn.assert_hostname = self.uuid + return super(CustomHostNameCheckingAdapter, + self).cert_verify(conn, url, verify, cert) + + +class AmphoraAPIClient(object): + def __init__(self): + super(AmphoraAPIClient, self).__init__() + self.secure = False + + self.get = functools.partial(self.request, 'get') + self.post = functools.partial(self.request, 'post') + self.put = functools.partial(self.request, 'put') + self.delete = functools.partial(self.request, 'delete') + self.head = functools.partial(self.request, 'head') + + self.start_listener = functools.partial(self._action, 'start') + self.stop_listener = functools.partial(self._action, 'stop') + self.reload_listener = functools.partial(self._action, 'reload') + + self.session = requests.Session() + self.session.cert = CONF.haproxy_amphora.client_cert + self.ssl_adapter = CustomHostNameCheckingAdapter() + self.session.mount('https://', self.ssl_adapter) + + def _base_url(self, ip): + return "https://{ip}:{port}/{version}/".format( + ip=ip, + port=CONF.haproxy_amphora.bind_port, + version=API_VERSION) + + def request(self, method, amp, path='/', **kwargs): + LOG.debug("request url " + path) + _request = getattr(self.session, method.lower()) + _url = self._base_url(amp.lb_network_ip) + path + + reqargs = { + 'verify': CONF.haproxy_amphora.server_ca, + 'url': _url, } + reqargs.update(kwargs) + headers = reqargs.setdefault('headers', {}) + + headers['User-Agent'] = OCTAVIA_API_CLIENT + self.ssl_adapter.uuid = amp.id + # Keep retrying + for attempts in xrange(CONF.haproxy_amphora.connection_max_retries): + try: + r = _request(**reqargs) + except requests.ConnectionError: + LOG.warn(_LW("Could not talk to instance")) + time.sleep(CONF.haproxy_amphora.connection_retry_interval) + if attempts >= CONF.haproxy_amphora.connection_max_retries: + raise exc.TimeOutException() + else: + return r + raise exc.UnavailableException() + + def upload_config(self, amp, listener_id, config): + r = self.put( + amp, + 'listeners/{listener_id}/haproxy'.format(listener_id=listener_id), + data=config) + return exc.check_exception(r) + + def get_listener_status(self, amp, listener_id): + r = self.get( + amp, + 'listeners/{listener_id}'.format(listener_id=listener_id)) + if exc.check_exception(r): + return r.json() + + def _action(self, action, amp, listener_id): + r = self.put(amp, 'listeners/{listener_id}/{action}'.format( + listener_id=listener_id, action=action)) + return exc.check_exception(r) + + def upload_cert_pem(self, amp, listener_id, pem_filename, pem_file): + r = self.put( + amp, + 'listeners/{listener_id}/certificates/{filename}'.format( + listener_id=listener_id, filename=pem_filename), + data=pem_file) + return exc.check_exception(r) + + def get_cert_md5sum(self, amp, listener_id, pem_filename): + r = self.get(amp, + 'listeners/{listener_id}/certificates/{filename}'.format( + listener_id=listener_id, filename=pem_filename)) + if exc.check_exception(r): + return r.json().get("md5sum") + + def delete_listener(self, amp, listener_id): + r = self.delete( + amp, 'listeners/{listener_id}'.format(listener_id=listener_id)) + return exc.check_exception(r) + + def get_info(self, amp): + r = self.get(amp, "info") + if exc.check_exception(r): + return r.json() + + def get_details(self, amp): + r = self.get(amp, "details") + if exc.check_exception(r): + return r.json() + + def get_all_listeners(self, amp): + r = self.get(amp, "listeners") + if exc.check_exception(r): + return r.json() + + def delete_cert_pem(self, amp, listener_id, pem_filename): + r = self.delete( + amp, + 'listeners/{listener_id}/certificates/{filename}'.format( + listener_id=listener_id, filename=pem_filename)) + return exc.check_exception(r) + + def plug_network(self, amp): + r = self.post(amp, 'plug/network') + return exc.check_exception(r) + + def plug_vip(self, amp, vip): + r = self.post(amp, 'plug/vip/{vip}'.format(vip=vip)) + return exc.check_exception(r) \ No newline at end of file diff --git a/octavia/certificates/generator/local.py b/octavia/certificates/generator/local.py index 931a7dac5e..2e8662dd8a 100644 --- a/octavia/certificates/generator/local.py +++ b/octavia/certificates/generator/local.py @@ -93,10 +93,19 @@ class LocalCertGenerator(cert_gen.CertGenerator): cls._validate_cert(ca_cert, ca_key, ca_key_pass) if not ca_digest: ca_digest = CONF.certificates.signing_digest + if not ca_cert: + with open(CONF.certificates.ca_certificate, 'r') as f: + ca_cert = f.read() + if not ca_key: + with open(CONF.certificates.ca_private_key, 'r') as f: + ca_key = f.read() + if not ca_key_pass: + ca_key_pass = CONF.certificates.ca_private_key_passphrase + try: lo_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_cert) lo_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key, - passphrase=ca_key_pass) + ca_key_pass) lo_req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr) new_cert = crypto.X509() @@ -147,11 +156,17 @@ class LocalCertGenerator(cert_gen.CertGenerator): @classmethod def _generate_csr(cls, cn, private_key, passphrase=None): - pk = crypto.load_privatekey( - crypto.FILETYPE_PEM, - private_key, - passphrase - ) + if passphrase: + pk = crypto.load_privatekey( + crypto.FILETYPE_PEM, + private_key, + passphrase + ) + else: + pk = crypto.load_privatekey( + crypto.FILETYPE_PEM, + private_key + ) csr = crypto.X509Req() csr.set_pubkey(pk) subject = csr.get_subject() diff --git a/octavia/common/config.py b/octavia/common/config.py index ebab4ac65f..6228b2697b 100644 --- a/octavia/common/config.py +++ b/octavia/common/config.py @@ -113,6 +113,9 @@ haproxy_amphora_opts = [ cfg.IntOpt('connection_retry_interval', default=5, help=_('Retry timeout between attempts in seconds.')), + cfg.StrOpt('cert_manager', + default='barbican_cert_manager', + help=_('Name of the cert manager to use')), # REST server cfg.StrOpt('bind_host', default='0.0.0.0', @@ -135,7 +138,12 @@ haproxy_amphora_opts = [ cfg.StrOpt('agent_server_network_dir', default='/etc/network/interfaces.d/', help=_("The directory where new network interfaces " - "are located")) + "are located")), + # REST client + cfg.StrOpt('client_cert', default='/etc/octavia/certs/client.pem', + help=_("The client certificate to talk to the agent")), + cfg.StrOpt('server_ca', default='/etc/octavia/certs/server_ca.pem', + help=_("The ca which signed the server certificates")), ] controller_worker_opts = [ @@ -160,6 +168,9 @@ controller_worker_opts = [ cfg.ListOpt('amp_secgroup_list', default='', help=_('List of security groups to attach to the Amphora')), + cfg.StrOpt('client_ca', + default='/etc/octavia/certs/ca_01.pem', + help=_('Client CA for the amphora agent to use')), cfg.StrOpt('amphora_driver', default='amphora_noop_driver', help=_('Name of the amphora driver to use')), @@ -168,7 +179,11 @@ controller_worker_opts = [ help=_('Name of the compute driver to use')), cfg.StrOpt('network_driver', default='network_noop_driver', - help=_('Name of the network driver to use')) + help=_('Name of the network driver to use')), + cfg.StrOpt('cert_generator', + default='local_cert_generator', + help=_('Name of the cert generator to use')) + ] task_flow_opts = [ @@ -194,6 +209,7 @@ cfg.CONF.register_cli_opts(core_cli_opts) cfg.CONF.import_group('keystone_authtoken', 'keystonemiddleware.auth_token') cfg.CONF.register_opts(keystone_authtoken_v3_opts, group='keystone_authtoken_v3') + # Ensure that the control exchange is set correctly messaging.set_transport_defaults(control_exchange='octavia') _SQL_CONNECTION_DEFAULT = 'sqlite://' diff --git a/octavia/common/constants.py b/octavia/common/constants.py index 680732ff18..b402f2b96d 100644 --- a/octavia/common/constants.py +++ b/octavia/common/constants.py @@ -94,6 +94,7 @@ VIP = 'vip' POOL = 'pool' POOL_ID = 'pool_id' OBJECT = 'object' +SERVER_PEM = 'server_pem' CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow' CREATE_AMPHORA_FOR_LB_FLOW = 'octavia-create-amp-for-lb-flow' diff --git a/octavia/common/data_models.py b/octavia/common/data_models.py index 9dab553852..b7bf9e3300 100644 --- a/octavia/common/data_models.py +++ b/octavia/common/data_models.py @@ -37,6 +37,10 @@ class BaseDataModel(object): return self.to_dict() == other.to_dict() return False + @classmethod + def from_dict(cls, dict): + return cls(**dict) + @classmethod def _name(cls): """Returns class name in a more human readable form.""" diff --git a/octavia/controller/worker/flows/amphora_flows.py b/octavia/controller/worker/flows/amphora_flows.py index ed4db91b2d..498df16a93 100644 --- a/octavia/controller/worker/flows/amphora_flows.py +++ b/octavia/controller/worker/flows/amphora_flows.py @@ -20,17 +20,21 @@ from taskflow import retry from octavia.common import constants from octavia.controller.worker.flows import load_balancer_flows from octavia.controller.worker.tasks import amphora_driver_tasks +from octavia.controller.worker.tasks import cert_task from octavia.controller.worker.tasks import compute_tasks from octavia.controller.worker.tasks import controller_tasks from octavia.controller.worker.tasks import database_tasks + CONF = cfg.CONF CONF.import_group('controller_worker', 'octavia.common.config') class AmphoraFlows(object): - def __init__(self): + # for some reason only this has the values from the config file + self.REST_AMPHORA_DRIVER = (CONF.controller_worker.amphora_driver == + 'amphora_haproxy_rest_driver') self._lb_flows = load_balancer_flows.LoadBalancerFlows() def get_create_amphora_flow(self): @@ -45,9 +49,16 @@ class AmphoraFlows(object): create_amphora_flow = linear_flow.Flow(constants.CREATE_AMPHORA_FLOW) create_amphora_flow.add(database_tasks.CreateAmphoraInDB( provides=constants.AMPHORA_ID)) - create_amphora_flow.add(compute_tasks.ComputeCreate( - requires=constants.AMPHORA_ID, - provides=constants.COMPUTE_ID)) + if self.REST_AMPHORA_DRIVER: + create_amphora_flow.add(cert_task.GenerateServerPEMTask( + provides=constants.SERVER_PEM)) + create_amphora_flow.add(compute_tasks.CertComputeCreate( + requires=(constants.AMPHORA_ID, constants.SERVER_PEM), + provides=constants.COMPUTE_ID)) + else: + create_amphora_flow.add(compute_tasks.ComputeCreate( + requires=constants.AMPHORA_ID, + provides=constants.COMPUTE_ID)) create_amphora_flow.add(database_tasks.MarkAmphoraBootingInDB( requires=(constants.AMPHORA_ID, constants.COMPUTE_ID))) wait_flow = linear_flow.Flow('wait_for_amphora', @@ -80,9 +91,16 @@ class AmphoraFlows(object): CREATE_AMPHORA_FOR_LB_FLOW) create_amp_for_lb_flow.add(database_tasks.CreateAmphoraInDB( provides=constants.AMPHORA_ID)) - create_amp_for_lb_flow.add(compute_tasks.ComputeCreate( - requires=constants.AMPHORA_ID, - provides=constants.COMPUTE_ID)) + if self.REST_AMPHORA_DRIVER: + create_amp_for_lb_flow.add(cert_task.GenerateServerPEMTask( + provides=constants.SERVER_PEM)) + create_amp_for_lb_flow.add(compute_tasks.CertComputeCreate( + requires=(constants.AMPHORA_ID, constants.SERVER_PEM), + provides=constants.COMPUTE_ID)) + else: + create_amp_for_lb_flow.add(compute_tasks.ComputeCreate( + requires=constants.AMPHORA_ID, + provides=constants.COMPUTE_ID)) create_amp_for_lb_flow.add(database_tasks.UpdateAmphoraComputeId( requires=(constants.AMPHORA_ID, constants.COMPUTE_ID))) create_amp_for_lb_flow.add(database_tasks.MarkAmphoraBootingInDB( diff --git a/octavia/controller/worker/tasks/cert_task.py b/octavia/controller/worker/tasks/cert_task.py new file mode 100644 index 0000000000..3592a79784 --- /dev/null +++ b/octavia/controller/worker/tasks/cert_task.py @@ -0,0 +1,52 @@ +# Copyright 2015 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + + +import logging + +from oslo.config import cfg +from stevedore import driver as stevedore_driver +from taskflow import task + +CONF = cfg.CONF +CONF.import_group('controller_worker', 'octavia.common.config') +LOG = logging.getLogger(__name__) +CERT_VALIDITY = 365 + + +class BaseCertTask(task.Task): + """Base task to load drivers common to the tasks.""" + + def __init__(self, **kwargs): + super(BaseCertTask, self).__init__(**kwargs) + self.cert_generator = stevedore_driver.DriverManager( + namespace='octavia.cert_generator', + name=CONF.controller_worker.cert_generator, + invoke_on_load=True, + ).driver + + +class GenerateServerPEMTask(BaseCertTask): + """Create the server certs for the agent comm + + Use the amphora_id for the CN + """ + + def execute(self, amphora_id): + cert = self.cert_generator.generate_cert_key_pair( + cn=amphora_id, + validity=CERT_VALIDITY) + + return cert.certificate + cert.private_key \ No newline at end of file diff --git a/octavia/controller/worker/tasks/compute_tasks.py b/octavia/controller/worker/tasks/compute_tasks.py index 61629fe96d..61610cf059 100644 --- a/octavia/controller/worker/tasks/compute_tasks.py +++ b/octavia/controller/worker/tasks/compute_tasks.py @@ -46,12 +46,14 @@ class BaseComputeTask(task.Task): class ComputeCreate(BaseComputeTask): """Create the compute instance for a new amphora.""" - def execute(self, amphora_id): + def execute(self, amphora_id, config_drive_files=None): """Create an amphora :returns: an amphora """ - LOG.debug("Nova Create execute for amphora with id %s" % amphora_id) + + LOG.debug("Nova Create execute for amphora with id %s" + % amphora_id) try: compute_id = self.compute.build( @@ -60,7 +62,8 @@ class ComputeCreate(BaseComputeTask): image_id=CONF.controller_worker.amp_image_id, key_name=CONF.controller_worker.amp_ssh_key_name, sec_groups=CONF.controller_worker.amp_secgroup_list, - network_ids=[CONF.controller_worker.amp_network]) + network_ids=[CONF.controller_worker.amp_network], + config_drive_files=config_drive_files) LOG.debug("Server created with id: %s for amphora id: %s" % (compute_id, amphora_id)) @@ -90,6 +93,23 @@ class ComputeCreate(BaseComputeTask): " with exception %s"), e) +class CertComputeCreate(ComputeCreate): + def execute(self, amphora_id, server_pem): + """Create an amphora + + :returns: an amphora + """ + + # load client certificate + with open(CONF.controller_worker.client_ca, 'r') as client_ca: + config_drive_files = { + # '/etc/octavia/octavia.conf' + '/etc/octavia/certs/server.pem': server_pem, + '/etc/octavia/certs/client_ca.pem': client_ca} + return super(CertComputeCreate, self).execute(amphora_id, + config_drive_files) + + class DeleteAmphoraeOnLoadBalancer(BaseComputeTask): """Delete the amphorae on a load balancer. diff --git a/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py b/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py index 9551c6b7cf..75eff0ba9b 100644 --- a/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py +++ b/octavia/tests/functional/amphorae/backend/agent/api_server/test_server.py @@ -349,7 +349,8 @@ class ServerTestCase(base.TestCase): details='No certificate with filename: test.pem', message='Certificate Not Found'), json.loads(rv.data)) - mock_exists.assert_called_once_with('/tmp/123/test.pem') + mock_exists.assert_called_once_with( + '/var/lib/octavia/certs/123/test.pem') # wrong file name mock_exists.side_effect = [True] @@ -363,7 +364,8 @@ class ServerTestCase(base.TestCase): '/listeners/123/certificates/test.pem') self.assertEqual(200, rv.status_code) self.assertEqual(OK, json.loads(rv.data)) - mock_remove.assert_called_once_with('/tmp/123/test.pem') + mock_remove.assert_called_once_with( + '/var/lib/octavia/certs/123/test.pem') @mock.patch('os.path.exists') def test_get_certificate_md5(self, mock_exists): @@ -381,7 +383,7 @@ class ServerTestCase(base.TestCase): details='No certificate with filename: test.pem', message='Certificate Not Found'), json.loads(rv.data)) - mock_exists.assert_called_with('/tmp/123/test.pem') + mock_exists.assert_called_with('/var/lib/octavia/certs/123/test.pem') # wrong file name mock_exists.side_effect = [True] diff --git a/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py new file mode 100644 index 0000000000..4ff8f535fa --- /dev/null +++ b/octavia/tests/unit/amphorae/drivers/haproxy/test_rest_api_driver.py @@ -0,0 +1,601 @@ +# Copyright 2015 Hewlett-Packard Development Company, L.P. +# Copyright (c) 2015 Rackspace +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock +from oslo_utils import uuidutils +import requests_mock + +from octavia.amphorae.drivers.haproxy import exceptions as exc +from octavia.amphorae.drivers.haproxy import rest_api_driver as driver +from octavia.db import models +from octavia.tests.unit import base as base +from octavia.tests.unit.common.sample_configs import sample_configs + + +class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase): + FAKE_UUID_1 = uuidutils.generate_uuid() + + def setUp(self): + super(HaproxyAmphoraLoadBalancerDriverTest, self).setUp() + self.driver = driver.HaproxyAmphoraLoadBalancerDriver() + + self.driver.cert_manager = mock.MagicMock() + self.driver.client = mock.MagicMock() + self.driver.jinja = mock.MagicMock() + + # Build sample Listener and VIP configs + self.sl = sample_configs.sample_listener_tuple(tls=True, sni=True) + self.amp = self.sl.load_balancer.amphorae[0] + self.sv = sample_configs.sample_vip_tuple() + self.lb = self.sl.load_balancer + + @mock.patch('octavia.common.tls_utils.cert_parser.get_host_names') + def test_update(self, mock_cert): + mock_cert.return_value = {'cn': 'fakeCN'} + self.driver.client.get_cert_md5sum.side_effect = [ + exc.NotFound, 'Fake_MD5', 'd41d8cd98f00b204e9800998ecf8427e'] + self.driver.jinja.build_config.side_effect = ['fake_config'] + self.driver.client.get_listener_status.side_effect = [ + dict(status='ACTIVE')] + + # Execute driver method + self.driver.update(self.sl, self.sv) + + # verify result + # this is called 3 times + self.driver.client.get_cert_md5sum.assert_called_with( + self.amp, self.sl.id, 'fakeCN.pem') + # this is called twice (last MD5 matches) + self.driver.client.upload_cert_pem.assert_called_with( + self.amp, self.sl.id, 'fakeCN.pem', '') + self.assertEqual(2, self.driver.client.upload_cert_pem.call_count) + # upload only one config file + self.driver.client.upload_config.assert_called_once_with( + self.amp, self.sl.id, 'fake_config') + # start should be called once + self.driver.client.reload_listener.assert_called_once_with( + self.amp, self.sl.id) + + # listener down + self.driver.client.get_cert_md5sum.side_effect = [ + 'd41d8cd98f00b204e9800998ecf8427e'] * 3 + self.driver.jinja.build_config.side_effect = ['fake_config'] + self.driver.client.get_listener_status.side_effect = [ + dict(status='BLAH')] + + self.driver.update(self.sl, self.sv) + + self.driver.client.start_listener.assert_called_once_with( + self.amp, self.sl.id) + + def test_stop(self): + # Execute driver method + self.driver.stop(self.sl, self.sv) + self.driver.client.stop_listener.assert_called_once_with( + self.amp, self.sl.id) + + def test_start(self): + # Execute driver method + self.driver.start(self.sl, self.sv) + self.driver.client.start_listener.assert_called_once_with( + self.amp, self.sl.id) + + def test_delete(self): + # Execute driver method + self.driver.delete(self.sl, self.sv) + self.driver.client.delete_listener.assert_called_once_with( + self.amp, self.sl.id) + + def test_get_info(self): + pass + + def test_get_diagnostics(self): + pass + + def test_finalize_amphora(self): + pass + + def test_post_vip_plug(self): + self.driver.post_vip_plug(self.lb) + self.driver.client.plug_vip.assert_called_once_with( + self.amp, self.lb.vip.ip_address) + + def test_post_network_plug(self): + self.driver.post_network_plug(self.amp) + self.driver.client.plug_network.assert_called_once_with(self.amp) + + +class AmphoraAPIClientTest(base.TestCase): + FAKE_UUID_1 = uuidutils.generate_uuid() + FAKE_PEM_FILENAME = "file_name" + + def setUp(self): + super(AmphoraAPIClientTest, self).setUp() + self.driver = driver.AmphoraAPIClient() + self.base_url = "https://127.0.0.1:8443/0.5" + self.amp = models.Amphora(lb_network_ip='127.0.0.1', compute_id='123') + + @requests_mock.mock() + def test_get_info(self, m): + info = {"hostname": "some_hostname", "version": "some_version", + "api_version": "0.5", "uuid": self.FAKE_UUID_1} + m.get("{base}/info".format(base=self.base_url), + json=info) + information = self.driver.get_info(self.amp) + self.assertEqual(info, information) + + @requests_mock.mock() + def test_get_info_unauthorized(self, m): + m.get("{base}/info".format(base=self.base_url), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.get_info, self.amp) + + @requests_mock.mock() + def test_get_info_missing(self, m): + m.get("{base}/info".format(base=self.base_url), + status_code=404) + self.assertRaises(exc.NotFound, self.driver.get_info, self.amp) + + @requests_mock.mock() + def test_get_info_server_error(self, m): + m.get("{base}/info".format(base=self.base_url), + status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.get_info, + self.amp) + + @requests_mock.mock() + def test_get_info_service_unavailable(self, m): + m.get("{base}/info".format(base=self.base_url), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.get_info, + self.amp) + + @requests_mock.mock() + def test_get_details(self, m): + details = {"hostname": "some_hostname", "version": "some_version", + "api_version": "0.5", "uuid": self.FAKE_UUID_1, + "network_tx": "some_tx", "network_rx": "some_rx", + "active": True, "haproxy_count": 10} + m.get("{base}/details".format(base=self.base_url), + json=details) + amp_details = self.driver.get_details(self.amp) + self.assertEqual(details, amp_details) + + @requests_mock.mock() + def test_get_details_unauthorized(self, m): + m.get("{base}/details".format(base=self.base_url), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.get_details, self.amp) + + @requests_mock.mock() + def test_get_details_missing(self, m): + m.get("{base}/details".format(base=self.base_url), + status_code=404) + self.assertRaises(exc.NotFound, self.driver.get_details, self.amp) + + @requests_mock.mock() + def test_get_details_server_error(self, m): + m.get("{base}/details".format(base=self.base_url), + status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.get_details, + self.amp) + + @requests_mock.mock() + def test_get_details_service_unavailable(self, m): + m.get("{base}/details".format(base=self.base_url), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.get_details, + self.amp) + + @requests_mock.mock() + def test_get_all_listeners(self, m): + listeners = [{"status": "ONLINE", "provisioning_status": "ACTIVE", + "type": "PASSIVE", "uuid": self.FAKE_UUID_1}] + m.get("{base}/listeners".format(base=self.base_url), + json=listeners) + all_listeners = self.driver.get_all_listeners(self.amp) + self.assertEqual(listeners, all_listeners) + + @requests_mock.mock() + def test_get_all_listeners_unauthorized(self, m): + m.get("{base}/listeners".format(base=self.base_url), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.get_all_listeners, + self.amp) + + @requests_mock.mock() + def test_get_all_listeners_missing(self, m): + m.get("{base}/listeners".format(base=self.base_url), + status_code=404) + self.assertRaises(exc.NotFound, self.driver.get_all_listeners, + self.amp) + + @requests_mock.mock() + def test_get_all_listeners_server_error(self, m): + m.get("{base}/listeners".format(base=self.base_url), + status_code=500) + self.assertRaises(exc.InternalServerError, + self.driver.get_all_listeners, self.amp) + + @requests_mock.mock() + def test_get_all_listeners_service_unavailable(self, m): + m.get("{base}/listeners".format(base=self.base_url), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, + self.driver.get_all_listeners, self.amp) + + @requests_mock.mock() + def test_get_listener_status(self, m): + listener = {"status": "ONLINE", "provisioning_status": "ACTIVE", + "type": "PASSIVE", "uuid": self.FAKE_UUID_1} + m.get("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + json=listener) + status = self.driver.get_listener_status(self.amp, self.FAKE_UUID_1) + self.assertEqual(listener, status) + + @requests_mock.mock() + def test_get_listener_status_unauthorized(self, m): + m.get("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=401) + self.assertRaises(exc.Unauthorized, + self.driver.get_listener_status, self.amp, + self.FAKE_UUID_1) + + @requests_mock.mock() + def test_get_listener_status_missing(self, m): + m.get("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=404) + self.assertRaises(exc.NotFound, + self.driver.get_listener_status, self.amp, + self.FAKE_UUID_1) + + @requests_mock.mock() + def test_get_listener_status_server_error(self, m): + m.get("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=500) + self.assertRaises(exc.InternalServerError, + self.driver.get_listener_status, self.amp, + self.FAKE_UUID_1) + + @requests_mock.mock() + def test_get_listener_status_service_unavailable(self, m): + m.get("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, + self.driver.get_listener_status, self.amp, + self.FAKE_UUID_1) + + @requests_mock.mock() + def test_start_listener(self, m): + m.put("{base}/listeners/{listener_id}/start".format( + base=self.base_url, listener_id=self.FAKE_UUID_1)) + self.driver.start_listener(self.amp, self.FAKE_UUID_1) + self.assertTrue(m.called) + + @requests_mock.mock() + def test_start_listener_missing(self, m): + m.put("{base}/listeners/{listener_id}/start".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=404) + self.assertRaises(exc.NotFound, self.driver.start_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_start_listener_unauthorized(self, m): + m.put("{base}/listeners/{listener_id}/start".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.start_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_start_listener_server_error(self, m): + m.put("{base}/listeners/{listener_id}/start".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.start_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_start_listener_service_unavailable(self, m): + m.put("{base}/listeners/{listener_id}/start".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.start_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_stop_listener(self, m): + m.put("{base}/listeners/{listener_id}/stop".format( + base=self.base_url, listener_id=self.FAKE_UUID_1)) + self.driver.stop_listener(self.amp, self.FAKE_UUID_1) + self.assertTrue(m.called) + + @requests_mock.mock() + def test_stop_listener_missing(self, m): + m.put("{base}/listeners/{listener_id}/stop".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=404) + self.assertRaises(exc.NotFound, self.driver.stop_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_stop_listener_unauthorized(self, m): + m.put("{base}/listeners/{listener_id}/stop".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.stop_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_stop_listener_server_error(self, m): + m.put("{base}/listeners/{listener_id}/stop".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.stop_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_stop_listener_service_unavailable(self, m): + m.put("{base}/listeners/{listener_id}/stop".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.stop_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_delete_listener(self, m): + m.delete("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), json={}) + self.driver.delete_listener(self.amp, self.FAKE_UUID_1) + self.assertTrue(m.called) + + @requests_mock.mock() + def test_delete_listener_missing(self, m): + m.delete("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=404) + self.assertRaises(exc.NotFound, self.driver.delete_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_delete_listener_unauthorized(self, m): + m.delete("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.delete_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_delete_listener_server_error(self, m): + m.delete("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.delete_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_delete_listener_service_unavailable(self, m): + m.delete("{base}/listeners/{listener_id}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.delete_listener, + self.amp, self.FAKE_UUID_1) + + @requests_mock.mock() + def test_upload_cert_pem(self, m): + m.put("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME)) + self.driver.upload_cert_pem(self.amp, self.FAKE_UUID_1, + self.FAKE_PEM_FILENAME, + "some_file") + self.assertTrue(m.called) + + @requests_mock.mock() + def test_upload_invalid_cert_pem(self, m): + m.put("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=403) + self.assertRaises(exc.InvalidRequest, self.driver.upload_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME, + "some_file") + + @requests_mock.mock() + def test_upload_cert_pem_unauthorized(self, m): + m.put("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.upload_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME, + "some_file") + + @requests_mock.mock() + def test_upload_cert_pem_server_error(self, m): + m.put("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.upload_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME, + "some_file") + + @requests_mock.mock() + def test_upload_cert_pem_service_unavailable(self, m): + m.put("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.upload_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME, + "some_file") + + @requests_mock.mock() + def test_get_cert_5sum(self, m): + md5sum = {"md5sum": "some_real_sum"} + m.get("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), json=md5sum) + sum_test = self.driver.get_cert_md5sum(self.amp, self.FAKE_UUID_1, + self.FAKE_PEM_FILENAME) + self.assertIsNotNone(sum_test) + + @requests_mock.mock() + def test_get_cert_5sum_missing(self, m): + m.get("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=404) + self.assertRaises(exc.NotFound, self.driver.get_cert_md5sum, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_get_cert_5sum_unauthorized(self, m): + m.get("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.get_cert_md5sum, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_get_cert_5sum_server_error(self, m): + m.get("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.get_cert_md5sum, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_get_cert_5sum_service_unavailable(self, m): + m.get("{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.get_cert_md5sum, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_delete_cert_pem(self, m): + m.delete( + "{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME)) + self.driver.delete_cert_pem(self.amp, self.FAKE_UUID_1, + self.FAKE_PEM_FILENAME) + self.assertTrue(m.called) + + @requests_mock.mock() + def test_delete_cert_pem_missing(self, m): + m.delete( + "{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=404) + self.assertRaises(exc.NotFound, self.driver.delete_cert_pem, self.amp, + self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_delete_cert_pem_unauthorized(self, m): + m.delete( + "{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.delete_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_delete_cert_pem_server_error(self, m): + m.delete( + "{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.delete_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_delete_cert_pem_service_unavailable(self, m): + m.delete( + "{base}/listeners/{listener_id}/certificates/{filename}".format( + base=self.base_url, listener_id=self.FAKE_UUID_1, + filename=self.FAKE_PEM_FILENAME), status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.delete_cert_pem, + self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME) + + @requests_mock.mock() + def test_upload_config(self, m): + config = {"name": "fake_config"} + m.put( + "{base}/listeners/{listener_id}/haproxy".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + json=config) + self.driver.upload_config(self.amp, self.FAKE_UUID_1, config) + self.assertTrue(m.called) + + @requests_mock.mock() + def test_upload_invalid_config(self, m): + config = '{"name": "bad_config"}' + m.put( + "{base}/listeners/{listener_id}/haproxy".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=403) + self.assertRaises(exc.InvalidRequest, self.driver.upload_config, + self.amp, self.FAKE_UUID_1, config) + + @requests_mock.mock() + def test_upload_config_unauthorized(self, m): + config = '{"name": "bad_config"}' + m.put( + "{base}/listeners/{listener_id}/haproxy".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=401) + self.assertRaises(exc.Unauthorized, self.driver.upload_config, + self.amp, self.FAKE_UUID_1, config) + + @requests_mock.mock() + def test_upload_config_server_error(self, m): + config = '{"name": "bad_config"}' + m.put( + "{base}/listeners/{listener_id}/haproxy".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=500) + self.assertRaises(exc.InternalServerError, self.driver.upload_config, + self.amp, self.FAKE_UUID_1, config) + + @requests_mock.mock() + def test_upload_config_service_unavailable(self, m): + config = '{"name": "bad_config"}' + m.put( + "{base}/listeners/{listener_id}/haproxy".format( + base=self.base_url, listener_id=self.FAKE_UUID_1), + status_code=503) + self.assertRaises(exc.ServiceUnavailable, self.driver.upload_config, + self.amp, self.FAKE_UUID_1, config) + + @requests_mock.mock() + def test_plug_vip(self, m): + FAKE_IP = 'fake' + m.post("{base}/plug/vip/{vip}".format( + base=self.base_url, vip=FAKE_IP) + ) + self.driver.plug_vip(self.amp, FAKE_IP) + self.assertTrue(m.called) + + @requests_mock.mock() + def test_plug_network(self, m): + m.post("{base}/plug/network".format( + base=self.base_url) + ) + self.driver.plug_network(self.amp) + self.assertTrue(m.called) diff --git a/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py b/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py index b57421fea8..f060482e4b 100644 --- a/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py +++ b/octavia/tests/unit/controller/worker/flows/test_amphora_flows.py @@ -27,6 +27,8 @@ AUTH_VERSION = '2' class TestAmphoraFlows(base.TestCase): def setUp(self): + cfg.CONF.set_override('amphora_driver', 'amphora_haproxy_ssh_driver', + group='controller_worker') self.AmpFlow = amphora_flows.AmphoraFlows() conf = oslo_fixture.Config(cfg.CONF) conf.config(group="keystone_authtoken", auth_version=AUTH_VERSION) @@ -46,6 +48,22 @@ class TestAmphoraFlows(base.TestCase): self.assertEqual(len(amp_flow.provides), 3) self.assertEqual(len(amp_flow.requires), 0) + def test_get_create_amphora_flow_cert(self): + cfg.CONF.set_override('amphora_driver', 'amphora_haproxy_rest_driver', + group='controller_worker') + self.AmpFlow = amphora_flows.AmphoraFlows() + + amp_flow = self.AmpFlow.get_create_amphora_flow() + + self.assertIsInstance(amp_flow, flow.Flow) + + self.assertIn(constants.AMPHORA, amp_flow.provides) + self.assertIn(constants.AMPHORA_ID, amp_flow.provides) + self.assertIn(constants.COMPUTE_ID, amp_flow.provides) + + self.assertEqual(len(amp_flow.provides), 4) + self.assertEqual(len(amp_flow.requires), 0) + def test_get_create_amphora_for_lb_flow(self): amp_flow = self.AmpFlow.get_create_amphora_for_lb_flow() @@ -64,6 +82,26 @@ class TestAmphoraFlows(base.TestCase): self.assertEqual(len(amp_flow.provides), 7) self.assertEqual(len(amp_flow.requires), 1) + def test_get_cert_create_amphora_for_lb_flow(self): + cfg.CONF.set_override('amphora_driver', 'amphora_haproxy_rest_driver', + group='controller_worker') + self.AmpFlow = amphora_flows.AmphoraFlows() + amp_flow = self.AmpFlow.get_create_amphora_for_lb_flow() + + self.assertIsInstance(amp_flow, flow.Flow) + + self.assertIn(constants.LOADBALANCER_ID, amp_flow.requires) + self.assertIn(constants.AMPHORA, amp_flow.provides) + self.assertIn(constants.LOADBALANCER, amp_flow.provides) + self.assertIn(constants.VIP, amp_flow.provides) + self.assertIn(constants.AMPS_DATA, amp_flow.provides) + self.assertIn(constants.AMPHORA_ID, amp_flow.provides) + self.assertIn(constants.COMPUTE_ID, amp_flow.provides) + self.assertIn(constants.COMPUTE_OBJ, amp_flow.provides) + + self.assertEqual(len(amp_flow.provides), 8) + self.assertEqual(len(amp_flow.requires), 1) + def test_get_delete_amphora_flow(self): amp_flow = self.AmpFlow.get_delete_amphora_flow() diff --git a/octavia/tests/unit/controller/worker/tasks/test_cert_task.py b/octavia/tests/unit/controller/worker/tasks/test_cert_task.py new file mode 100644 index 0000000000..93d60912e8 --- /dev/null +++ b/octavia/tests/unit/controller/worker/tasks/test_cert_task.py @@ -0,0 +1,33 @@ +# Copyright 2015 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +import mock + +from octavia.certificates.common import local +from octavia.controller.worker.tasks import cert_task +import octavia.tests.unit.base as base + + +class TestCertTasks(base.TestCase): + @mock.patch('stevedore.driver.DriverManager.driver') + def test_execute(self, mock_driver): + dummy_cert = local.LocalCert('test_cert', 'test_key') + mock_driver.generate_cert_key_pair.side_effect = [dummy_cert] + c = cert_task.GenerateServerPEMTask() + pem = c.execute('123') + self.assertEqual( + pem, dummy_cert.get_certificate() + dummy_cert.get_private_key()) + mock_driver.generate_cert_key_pair.assert_called_once_with( + cn='123', validity=365) \ No newline at end of file diff --git a/octavia/tests/unit/controller/worker/tasks/test_compute_tasks.py b/octavia/tests/unit/controller/worker/tasks/test_compute_tasks.py index bcb42c3d37..bdc5066e52 100644 --- a/octavia/tests/unit/controller/worker/tasks/test_compute_tasks.py +++ b/octavia/tests/unit/controller/worker/tasks/test_compute_tasks.py @@ -73,10 +73,10 @@ class TestComputeTasks(base.TestCase): @mock.patch('stevedore.driver.DriverManager.driver') def test_compute_create(self, mock_driver): - mock_driver.build.side_effect = [COMPUTE_ID, TestException('test')] - - # Test execute() createcompute = compute_tasks.ComputeCreate() + + mock_driver.build.side_effect = [COMPUTE_ID, TestException('test')] + # Test execute() compute_id = createcompute.execute(_amphora_mock.id) # Validate that the build method was called properly @@ -86,17 +86,65 @@ class TestComputeTasks(base.TestCase): image_id=AMP_IMAGE_ID, key_name=AMP_SSH_KEY_NAME, sec_groups=AMP_SEC_GROUPS, - network_ids=[AMP_NET]) + network_ids=[AMP_NET], + config_drive_files=None) # Make sure it returns the expected compute_id assert(compute_id == COMPUTE_ID) # Test that a build exception is raised - createcompute = compute_tasks.ComputeCreate() self.assertRaises(TestException, createcompute.execute, - _amphora_mock) + _amphora_mock, 'test_cert') + + # Test revert() + + _amphora_mock.compute_id = COMPUTE_ID + + createcompute = compute_tasks.ComputeCreate() + createcompute.revert(compute_id, _amphora_mock.id) + + # Validate that the delete method was called properly + mock_driver.delete.assert_called_once_with( + COMPUTE_ID) + + # Test that a delete exception is not raised + + createcompute.revert(COMPUTE_ID, _amphora_mock.id) + + @mock.patch('stevedore.driver.DriverManager.driver') + def test_compute_create_cert(self, mock_driver): + + createcompute = compute_tasks.CertComputeCreate() + + mock_driver.build.side_effect = [COMPUTE_ID, TestException('test')] + m = mock.mock_open(read_data='test') + with mock.patch('__builtin__.open', m, create=True): + # Test execute() + compute_id = createcompute.execute(_amphora_mock.id, 'test_cert') + + # Validate that the build method was called properly + mock_driver.build.assert_called_once_with( + name="amphora-" + _amphora_mock.id, + amphora_flavor=AMP_FLAVOR_ID, + image_id=AMP_IMAGE_ID, + key_name=AMP_SSH_KEY_NAME, + sec_groups=AMP_SEC_GROUPS, + network_ids=[AMP_NET], + config_drive_files={ + '/etc/octavia/certs/server.pem': 'test_cert', + '/etc/octavia/certs/client_ca.pem': m.return_value}) + + # Make sure it returns the expected compute_id + assert(compute_id == COMPUTE_ID) + + # Test that a build exception is raised + with mock.patch('__builtin__.open', m, create=True): + createcompute = compute_tasks.ComputeCreate() + self.assertRaises(TestException, + createcompute.execute, + _amphora_mock, 'test_cert') # Test revert() diff --git a/setup.cfg b/setup.cfg index d0801e878f..e22aa84f2a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ octavia.api.handlers = queue_producer = octavia.api.v1.handlers.queue.producer:ProducerHandler octavia.amphora.drivers = amphora_noop_driver = octavia.amphorae.drivers.noop_driver.driver:NoopAmphoraLoadBalancerDriver - amphora_haproxy_rest_driver = octavia.amphorae.drivers.haproxy.rest_driver:NoopAmphoraLoadBalancerDriver + amphora_haproxy_rest_driver = octavia.amphorae.drivers.haproxy.rest_api_driver:HaproxyAmphoraLoadBalancerDriver amphora_haproxy_ssh_driver = octavia.amphorae.drivers.haproxy.ssh_driver:HaproxyManager octavia.compute.drivers = compute_noop_driver = octavia.compute.drivers.noop_driver.driver:NoopComputeDriver @@ -52,3 +52,10 @@ octavia.compute.drivers = octavia.network.drivers = network_noop_driver = octavia.network.drivers.noop_driver.driver:NoopNetworkDriver allowed_address_pairs_driver = octavia.network.drivers.neutron.allowed_address_pairs:AllowedAddressPairsDriver +octavia.cert_generator = + local_cert_generator = octavia.certificates.generator.local:LocalCertGenerator + barbican_cert_generator = octavia.certificates.generator.barbican:BarbicanCertGenerator +octavia.cert_manager = + local_cert_manager = octavia.certificates.manager.local:LocalCertManager + barbican_cert_manager = octavia.certificates.manager.barbican:BarbicanCertManager + diff --git a/test-requirements.txt b/test-requirements.txt index c0beae71b7..398ccf800c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,7 +2,7 @@ # of appearance. Changing the order has an impact on the overall integration # process, which may cause wedges in the gate later. hacking<0.10,>=0.9.1 - +requests-mock cliff>=1.13.0 # Apache-2.0 coverage>=3.6 discover @@ -15,3 +15,4 @@ testrepository>=0.0.18 testtools>=1.4.0 WebTest>=2.0 doc8 # Apache-2.0 +