haproxy reference amphora REST API client

Adds rest driver methods
Adds rest driver tests
Add cert task for generating server certs
Modified compute task/flow
Fixed local certificate stuff
Refactored to use requests-mock inetad of responses
Added a "conditiobal flow" for REST

Cleaned up and changed the code to work with
https://review.openstack.org/#/c/160034/

Replaces:
https://review.openstack.org/#/c/144348/
https://review.openstack.org/#/c/145637/14

Change-Id: Ibcbf0717b785aab4c604deef1061e8b2fa41006c
Co-Authored-By: Phillip Toohill <phillip.toohill@rackspace.com>
Co-Authored-By: German Eichberger <german.eichberger@hp.com>
Co-Authored-By: Stephen Balukoff <sbalukoff@bluebox.net>
Implements: bp/haproxy-amphora-driver
changes/72/171172/15
German Eichberger 8 years ago
parent 792a116523
commit 0abcbc4f7d

@ -61,6 +61,7 @@ function octavia_configure {
# Setting other required default options
iniset $OCTAVIA_CONF controller_worker amphora_driver amphora_haproxy_ssh_driver
#iniset $OCTAVIA_CONF controller_worker amphora_driver amphora_haproxy_rest_driver
iniset $OCTAVIA_CONF controller_worker compute_driver compute_nova_driver
iniset $OCTAVIA_CONF controller_worker network_driver allowed_address_pairs_driver

@ -3,6 +3,7 @@ set -eux
install-packages libffi-dev libssl-dev
cd /opt/amphora-agent/
pip install -r requirements.txt
python setup.py install
cp etc/init/octavia-agent.conf /etc/init/
mkdir /etc/octavia

@ -53,7 +53,6 @@
# Network to communicate with amphora
# lb_network_name =
[haproxy_amphora]
# base_path = /var/lib/octavia
# base_cert_dir = /var/lib/octavia/certs
@ -62,8 +61,12 @@
# base_log_dir = /logs
# connection_max_retries = 10
# connection_retry_interval = 5
# Cert manager options are local_cert_manager,
# barbican_cert_manager,
#
# cert_manager = barbican_cert_manager
#SSH Driver specific
# SSH Driver specific
# username = ubuntu
# key_path = /opt/stack/.ssh/id_rsa
@ -86,6 +89,7 @@
# amp_ssh_key_name =
# amp_network =
# amp_secgroup_list =
# client_ca = /etc/octavia/certs/ca_01.pem
# Amphora driver options are amphora_noop_driver,
# amphora_haproxy_rest_driver,
@ -102,6 +106,10 @@
# allowed_address_pairs_driver
#
# network_driver = network_noop_driver
#
# Certificate Generator options are local_cert_generator
# barbican_cert_generator
# cert_generator = local_cert_generator
[task_flow]
# engine = serial
@ -121,4 +129,4 @@
# rpc_thread_pool_size = 2
# Topic (i.e. Queue) Name
# topic = octavia_prov
# topic = octavia_prov

@ -361,7 +361,7 @@ def _check_ssl_filename_format(filename):
def _cert_dir(listener_id):
return os.path.join(util.CONF.haproxy_amphora.haproxy_cert_dir,
return os.path.join(util.CONF.haproxy_amphora.base_cert_dir,
listener_id)

@ -0,0 +1,110 @@
# Copyright 2014 Rackspace
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import octavia.common.data_models as models
class Topology(models.BaseDataModel):
def __init__(self, hostname=None, uuid=None, topology=None, role=None,
ip=None, ha_ip=None):
self.hostname = hostname
self.uuid = uuid
self.topology = topology
self.role = role
self.ip = ip
self.ha_ip = ha_ip
class Info(models.BaseDataModel):
def __init__(self, hostname=None, uuid=None, version=None,
api_version=None):
self.hostname = hostname
self.uuid = uuid
self.version = version
self.api_version = api_version
class Details(models.BaseDataModel):
def __init__(self, hostname=None, uuid=None, version=None,
api_version=None, network_tx=None, network_rx=None,
active=None, haproxy_count=None, cpu=None, memory=None,
disk=None, load=None, listeners=None, packages=None):
self.hostname = hostname
self.uuid = uuid,
self.version = version
self.api_version = api_version
self.network_tx = network_tx
self.network_rx = network_rx
self.active = active
self.haproxy_count = haproxy_count
self.cpu = cpu
self.memory = memory
self.disk = disk
self.load = load or []
self.listeners = listeners or []
self.packages = packages or []
class CPU(models.BaseDataModel):
def __init__(self, total=None, user=None, system=None, soft_irq=None):
self.total = total
self.user = user
self.system = system
self.soft_irq = soft_irq
class Memory(models.BaseDataModel):
def __init__(self, total=None, free=None, available=None, buffers=None,
cached=None, swap_used=None, shared=None, slab=None,
committed_as=None):
self.total = total
self.free = free
self.available = available
self.buffers = buffers
self.cached = cached
self.swap_used = swap_used
self.shared = shared
self.slab = slab
self.committed_as = committed_as
class Disk(models.BaseDataModel):
def __init__(self, used=None, available=None):
self.used = used
self.available = available
class ListenerStatus(models.BaseDataModel):
def __init__(self, status=None, uuid=None, provisioning_status=None,
type=None, pools=None):
self.status = status
self.uuid = uuid
self.provisioning_status = provisioning_status
self.type = type
self.pools = pools or []
class Pool(models.BaseDataModel):
def __init__(self, uuid=None, status=None, members=None):
self.uuid = uuid
self.status = status
self.members = members or []

@ -0,0 +1,72 @@
# Copyright 2014 Rackspace
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from webob import exc
def check_exception(response):
status_code = response.status_code
responses = {
401: Unauthorized,
403: InvalidRequest,
404: NotFound,
405: InvalidRequest,
409: Conflict,
500: InternalServerError,
503: ServiceUnavailable
}
if status_code in responses:
raise responses[status_code]()
return response
class APIException(exc.HTTPClientError):
msg = "Something unknown went wrong"
code = 500
def __init__(self, **kwargs):
self.msg = self.msg % kwargs
super(APIException, self).__init__(detail=self.msg)
class Unauthorized(APIException):
msg = "Unauthorized"
code = 401
class InvalidRequest(APIException):
msg = "Invalid request"
code = 403
class NotFound(APIException):
msg = "Not Found"
code = 404
class Conflict(APIException):
msg = "Conflict"
code = 409
class InternalServerError(APIException):
msg = "Internal Server Error"
code = 500
class ServiceUnavailable(APIException):
msg = "Service Unavailable"
code = 503

@ -0,0 +1,299 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
# Copyright (c) 2015 Rackspace
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import hashlib
import time
from oslo_log import log as logging
import requests
from stevedore import driver as stevedore_driver
from octavia.amphorae.drivers import driver_base as driver_base
from octavia.amphorae.drivers.haproxy import exceptions as exc
from octavia.amphorae.drivers.haproxy.jinja import jinja_cfg
from octavia.common.config import cfg
from octavia.common import data_models as data_models
from octavia.common.tls_utils import cert_parser
from octavia.i18n import _LW
LOG = logging.getLogger(__name__)
API_VERSION = '0.5'
OCTAVIA_API_CLIENT = (
"Octavia HaProxy Rest Client/{version} "
"(https://wiki.openstack.org/wiki/Octavia)").format(version=API_VERSION)
CONF = cfg.CONF
CONF.import_group('haproxy_amphora', 'octavia.common.config')
class HaproxyAmphoraLoadBalancerDriver(driver_base.AmphoraLoadBalancerDriver):
def __init__(self):
super(HaproxyAmphoraLoadBalancerDriver, self).__init__()
self.client = AmphoraAPIClient()
self.cert_manager = stevedore_driver.DriverManager(
namespace='octavia.cert_manager',
name=CONF.haproxy_amphora.cert_manager,
invoke_on_load=True,
).driver
self.jinja = jinja_cfg.JinjaTemplater(
base_amp_path=CONF.haproxy_amphora.base_path,
base_crt_dir=CONF.haproxy_amphora.base_cert_dir,
haproxy_template=CONF.haproxy_amphora.haproxy_template)
def update(self, listener, vip):
LOG.debug("Amphora %s haproxy, updating listener %s, vip %s",
self.__class__.__name__, listener.protocol_port,
vip.ip_address)
# Process listener certificate info
certs = self._process_tls_certificates(listener)
# Generate HaProxy configuration from listener object
config = self.jinja.build_config(listener, certs['tls_cert'],
certs['sni_certs'])
for amp in listener.load_balancer.amphorae:
self.client.upload_config(amp, listener.id, config)
# todo (german): add a method to REST interface to reload or start
# without having to check
# Is that listener running?
r = self.client.get_listener_status(amp,
listener.id)
if r['status'] == 'ACTIVE':
self.client.reload_listener(amp, listener.id)
else:
self.client.start_listener(amp, listener.id)
def _apply(self, func, listener=None, *args):
for amp in listener.load_balancer.amphorae:
func(amp, listener.id, *args)
def stop(self, listener, vip):
self._apply(self.client.stop_listener, listener)
def start(self, listener, vip):
self._apply(self.client.start_listener, listener)
def delete(self, listener, vip):
self._apply(self.client.delete_listener, listener)
def get_info(self, amphora):
self.driver.get_info(amphora.lb_network_ip)
def get_diagnostics(self, amphora):
self.driver.get_diagnostics(amphora.lb_network_ip)
def finalize_amphora(self, amphora):
pass
def post_vip_plug(self, load_balancer):
for amp in load_balancer.amphorae:
self.client.plug_vip(amp, load_balancer.vip.ip_address)
def post_network_plug(self, amphora):
self.client.plug_network(amphora)
def _process_tls_certificates(self, listener):
"""Processes TLS data from the listener.
Converts and uploads PEM data to the Amphora API
return TLS_CERT and SNI_CERTS
"""
tls_cert = None
sni_certs = []
certs = []
if listener.tls_certificate_id:
tls_cert = self._map_cert_tls_container(
self.cert_manager.get_cert(listener.tls_certificate_id))
certs.append(tls_cert)
if listener.sni_containers:
for sni_cont in listener.sni_containers:
bbq_container = self._map_cert_tls_container(
self.cert_manager.get_cert(sni_cont.tls_container.id))
sni_certs.append(bbq_container)
certs.append(bbq_container)
for cert in certs:
pem = self._build_pem(cert)
md5 = hashlib.md5(pem).hexdigest()
name = '{cn}.pem'.format(cn=cert.primary_cn)
self._apply(self._upload_cert, listener, pem, md5, name)
return {'tls_cert': tls_cert, 'sni_certs': sni_certs}
def _upload_cert(self, amp, listener_id, pem, md5, name):
try:
if self.client.get_cert_md5sum(amp, listener_id, name) == md5:
return
except exc.NotFound:
pass
self.client.upload_cert_pem(
amp, listener_id, name, pem)
def _get_primary_cn(self, tls_cert):
"""Returns primary CN for Certificate."""
return cert_parser.get_host_names(tls_cert.get_certificate())['cn']
def _map_cert_tls_container(self, cert):
return data_models.TLSContainer(
primary_cn=self._get_primary_cn(cert),
private_key=cert.get_private_key(),
certificate=cert.get_certificate(),
intermediates=cert.get_intermediates())
def _build_pem(self, tls_cert):
"""Concatenate TLS Certificate fields to create a PEM
encoded certificate file
"""
# TODO(ptoohill): Maybe this should be part of utils or manager?
pem = tls_cert.intermediates[:]
pem.extend([tls_cert.certificate, tls_cert.private_key])
return "\n".join(pem)
# Check a custom hostname
class CustomHostNameCheckingAdapter(requests.adapters.HTTPAdapter):
def cert_verify(self, conn, url, verify, cert):
conn.assert_hostname = self.uuid
return super(CustomHostNameCheckingAdapter,
self).cert_verify(conn, url, verify, cert)
class AmphoraAPIClient(object):
def __init__(self):
super(AmphoraAPIClient, self).__init__()
self.secure = False
self.get = functools.partial(self.request, 'get')
self.post = functools.partial(self.request, 'post')
self.put = functools.partial(self.request, 'put')
self.delete = functools.partial(self.request, 'delete')
self.head = functools.partial(self.request, 'head')
self.start_listener = functools.partial(self._action, 'start')
self.stop_listener = functools.partial(self._action, 'stop')
self.reload_listener = functools.partial(self._action, 'reload')
self.session = requests.Session()
self.session.cert = CONF.haproxy_amphora.client_cert
self.ssl_adapter = CustomHostNameCheckingAdapter()
self.session.mount('https://', self.ssl_adapter)
def _base_url(self, ip):
return "https://{ip}:{port}/{version}/".format(
ip=ip,
port=CONF.haproxy_amphora.bind_port,
version=API_VERSION)
def request(self, method, amp, path='/', **kwargs):
LOG.debug("request url " + path)
_request = getattr(self.session, method.lower())
_url = self._base_url(amp.lb_network_ip) + path
reqargs = {
'verify': CONF.haproxy_amphora.server_ca,
'url': _url, }
reqargs.update(kwargs)
headers = reqargs.setdefault('headers', {})
headers['User-Agent'] = OCTAVIA_API_CLIENT
self.ssl_adapter.uuid = amp.id
# Keep retrying
for attempts in xrange(CONF.haproxy_amphora.connection_max_retries):
try:
r = _request(**reqargs)
except requests.ConnectionError:
LOG.warn(_LW("Could not talk to instance"))
time.sleep(CONF.haproxy_amphora.connection_retry_interval)
if attempts >= CONF.haproxy_amphora.connection_max_retries:
raise exc.TimeOutException()
else:
return r
raise exc.UnavailableException()
def upload_config(self, amp, listener_id, config):
r = self.put(
amp,
'listeners/{listener_id}/haproxy'.format(listener_id=listener_id),
data=config)
return exc.check_exception(r)
def get_listener_status(self, amp, listener_id):
r = self.get(
amp,
'listeners/{listener_id}'.format(listener_id=listener_id))
if exc.check_exception(r):
return r.json()
def _action(self, action, amp, listener_id):
r = self.put(amp, 'listeners/{listener_id}/{action}'.format(
listener_id=listener_id, action=action))
return exc.check_exception(r)
def upload_cert_pem(self, amp, listener_id, pem_filename, pem_file):
r = self.put(
amp,
'listeners/{listener_id}/certificates/{filename}'.format(
listener_id=listener_id, filename=pem_filename),
data=pem_file)
return exc.check_exception(r)
def get_cert_md5sum(self, amp, listener_id, pem_filename):
r = self.get(amp,
'listeners/{listener_id}/certificates/{filename}'.format(
listener_id=listener_id, filename=pem_filename))
if exc.check_exception(r):
return r.json().get("md5sum")
def delete_listener(self, amp, listener_id):
r = self.delete(
amp, 'listeners/{listener_id}'.format(listener_id=listener_id))
return exc.check_exception(r)
def get_info(self, amp):
r = self.get(amp, "info")
if exc.check_exception(r):
return r.json()
def get_details(self, amp):
r = self.get(amp, "details")
if exc.check_exception(r):
return r.json()
def get_all_listeners(self, amp):
r = self.get(amp, "listeners")
if exc.check_exception(r):
return r.json()
def delete_cert_pem(self, amp, listener_id, pem_filename):
r = self.delete(
amp,
'listeners/{listener_id}/certificates/{filename}'.format(
listener_id=listener_id, filename=pem_filename))
return exc.check_exception(r)
def plug_network(self, amp):
r = self.post(amp, 'plug/network')
return exc.check_exception(r)
def plug_vip(self, amp, vip):
r = self.post(amp, 'plug/vip/{vip}'.format(vip=vip))
return exc.check_exception(r)

@ -93,10 +93,19 @@ class LocalCertGenerator(cert_gen.CertGenerator):
cls._validate_cert(ca_cert, ca_key, ca_key_pass)
if not ca_digest:
ca_digest = CONF.certificates.signing_digest
if not ca_cert:
with open(CONF.certificates.ca_certificate, 'r') as f:
ca_cert = f.read()
if not ca_key:
with open(CONF.certificates.ca_private_key, 'r') as f:
ca_key = f.read()
if not ca_key_pass:
ca_key_pass = CONF.certificates.ca_private_key_passphrase
try:
lo_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_cert)
lo_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key,
passphrase=ca_key_pass)
ca_key_pass)
lo_req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
new_cert = crypto.X509()
@ -147,11 +156,17 @@ class LocalCertGenerator(cert_gen.CertGenerator):
@classmethod
def _generate_csr(cls, cn, private_key, passphrase=None):
pk = crypto.load_privatekey(
crypto.FILETYPE_PEM,
private_key,
passphrase
)
if passphrase:
pk = crypto.load_privatekey(
crypto.FILETYPE_PEM,
private_key,
passphrase
)
else:
pk = crypto.load_privatekey(
crypto.FILETYPE_PEM,
private_key
)
csr = crypto.X509Req()
csr.set_pubkey(pk)
subject = csr.get_subject()

@ -113,6 +113,9 @@ haproxy_amphora_opts = [
cfg.IntOpt('connection_retry_interval',
default=5,
help=_('Retry timeout between attempts in seconds.')),
cfg.StrOpt('cert_manager',
default='barbican_cert_manager',
help=_('Name of the cert manager to use')),
# REST server
cfg.StrOpt('bind_host', default='0.0.0.0',
@ -135,7 +138,12 @@ haproxy_amphora_opts = [
cfg.StrOpt('agent_server_network_dir',
default='/etc/network/interfaces.d/',
help=_("The directory where new network interfaces "
"are located"))
"are located")),
# REST client
cfg.StrOpt('client_cert', default='/etc/octavia/certs/client.pem',
help=_("The client certificate to talk to the agent")),
cfg.StrOpt('server_ca', default='/etc/octavia/certs/server_ca.pem',
help=_("The ca which signed the server certificates")),
]
controller_worker_opts = [
@ -160,6 +168,9 @@ controller_worker_opts = [
cfg.ListOpt('amp_secgroup_list',
default='',
help=_('List of security groups to attach to the Amphora')),
cfg.StrOpt('client_ca',
default='/etc/octavia/certs/ca_01.pem',
help=_('Client CA for the amphora agent to use')),
cfg.StrOpt('amphora_driver',
default='amphora_noop_driver',
help=_('Name of the amphora driver to use')),
@ -168,7 +179,11 @@ controller_worker_opts = [
help=_('Name of the compute driver to use')),
cfg.StrOpt('network_driver',
default='network_noop_driver',
help=_('Name of the network driver to use'))
help=_('Name of the network driver to use')),
cfg.StrOpt('cert_generator',
default='local_cert_generator',
help=_('Name of the cert generator to use'))
]
task_flow_opts = [
@ -194,6 +209,7 @@ cfg.CONF.register_cli_opts(core_cli_opts)
cfg.CONF.import_group('keystone_authtoken', 'keystonemiddleware.auth_token')
cfg.CONF.register_opts(keystone_authtoken_v3_opts,
group='keystone_authtoken_v3')
# Ensure that the control exchange is set correctly
messaging.set_transport_defaults(control_exchange='octavia')
_SQL_CONNECTION_DEFAULT = 'sqlite://'

@ -94,6 +94,7 @@ VIP = 'vip'
POOL = 'pool'
POOL_ID = 'pool_id'
OBJECT = 'object'
SERVER_PEM = 'server_pem'
CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow'
CREATE_AMPHORA_FOR_LB_FLOW = 'octavia-create-amp-for-lb-flow'

@ -37,6 +37,10 @@ class BaseDataModel(object):
return self.to_dict() == other.to_dict()
return False
@classmethod
def from_dict(cls, dict):
return cls(**dict)
@classmethod
def _name(cls):
"""Returns class name in a more human readable form."""

@ -20,17 +20,21 @@ from taskflow import retry
from octavia.common import constants
from octavia.controller.worker.flows import load_balancer_flows
from octavia.controller.worker.tasks import amphora_driver_tasks
from octavia.controller.worker.tasks import cert_task
from octavia.controller.worker.tasks import compute_tasks
from octavia.controller.worker.tasks import controller_tasks
from octavia.controller.worker.tasks import database_tasks
CONF = cfg.CONF
CONF.import_group('controller_worker', 'octavia.common.config')
class AmphoraFlows(object):
def __init__(self):
# for some reason only this has the values from the config file
self.REST_AMPHORA_DRIVER = (CONF.controller_worker.amphora_driver ==
'amphora_haproxy_rest_driver')
self._lb_flows = load_balancer_flows.LoadBalancerFlows()
def get_create_amphora_flow(self):
@ -45,9 +49,16 @@ class AmphoraFlows(object):
create_amphora_flow = linear_flow.Flow(constants.CREATE_AMPHORA_FLOW)
create_amphora_flow.add(database_tasks.CreateAmphoraInDB(
provides=constants.AMPHORA_ID))
create_amphora_flow.add(compute_tasks.ComputeCreate(
requires=constants.AMPHORA_ID,
provides=constants.COMPUTE_ID))
if self.REST_AMPHORA_DRIVER:
create_amphora_flow.add(cert_task.GenerateServerPEMTask(
provides=constants.SERVER_PEM))
create_amphora_flow.add(compute_tasks.CertComputeCreate(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM),
provides=constants.COMPUTE_ID))
else:
create_amphora_flow.add(compute_tasks.ComputeCreate(
requires=constants.AMPHORA_ID,
provides=constants.COMPUTE_ID))
create_amphora_flow.add(database_tasks.MarkAmphoraBootingInDB(
requires=(constants.AMPHORA_ID, constants.COMPUTE_ID)))
wait_flow = linear_flow.Flow('wait_for_amphora',
@ -80,9 +91,16 @@ class AmphoraFlows(object):
CREATE_AMPHORA_FOR_LB_FLOW)
create_amp_for_lb_flow.add(database_tasks.CreateAmphoraInDB(
provides=constants.AMPHORA_ID))
create_amp_for_lb_flow.add(compute_tasks.ComputeCreate(
requires=constants.AMPHORA_ID,
provides=constants.COMPUTE_ID))
if self.REST_AMPHORA_DRIVER:
create_amp_for_lb_flow.add(cert_task.GenerateServerPEMTask(
provides=constants.SERVER_PEM))
create_amp_for_lb_flow.add(compute_tasks.CertComputeCreate(
requires=(constants.AMPHORA_ID, constants.SERVER_PEM),
provides=constants.COMPUTE_ID))
else:
create_amp_for_lb_flow.add(compute_tasks.ComputeCreate(
requires=constants.AMPHORA_ID,
provides=constants.COMPUTE_ID))
create_amp_for_lb_flow.add(database_tasks.UpdateAmphoraComputeId(
requires=(constants.AMPHORA_ID, constants.COMPUTE_ID)))
create_amp_for_lb_flow.add(database_tasks.MarkAmphoraBootingInDB(

@ -0,0 +1,52 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
import logging
from oslo.config import cfg
from stevedore import driver as stevedore_driver
from taskflow import task
CONF = cfg.CONF
CONF.import_group('controller_worker', 'octavia.common.config')
LOG = logging.getLogger(__name__)
CERT_VALIDITY = 365
class BaseCertTask(task.Task):
"""Base task to load drivers common to the tasks."""
def __init__(self, **kwargs):
super(BaseCertTask, self).__init__(**kwargs)
self.cert_generator = stevedore_driver.DriverManager(
namespace='octavia.cert_generator',
name=CONF.controller_worker.cert_generator,
invoke_on_load=True,
).driver
class GenerateServerPEMTask(BaseCertTask):
"""Create the server certs for the agent comm
Use the amphora_id for the CN
"""
def execute(self, amphora_id):
cert = self.cert_generator.generate_cert_key_pair(
cn=amphora_id,
validity=CERT_VALIDITY)
return cert.certificate + cert.private_key

@ -46,12 +46,14 @@ class BaseComputeTask(task.Task):
class ComputeCreate(BaseComputeTask):
"""Create the compute instance for a new amphora."""
def execute(self, amphora_id):
def execute(self, amphora_id, config_drive_files=None):
"""Create an amphora
:returns: an amphora
"""
LOG.debug("Nova Create execute for amphora with id %s" % amphora_id)
LOG.debug("Nova Create execute for amphora with id %s"
% amphora_id)
try:
compute_id = self.compute.build(
@ -60,7 +62,8 @@ class ComputeCreate(BaseComputeTask):
image_id=CONF.controller_worker.amp_image_id,
key_name=CONF.controller_worker.amp_ssh_key_name,
sec_groups=CONF.controller_worker.amp_secgroup_list,
network_ids=[CONF.controller_worker.amp_network])
network_ids=[CONF.controller_worker.amp_network],
config_drive_files=config_drive_files)
LOG.debug("Server created with id: %s for amphora id: %s" %
(compute_id, amphora_id))
@ -90,6 +93,23 @@ class ComputeCreate(BaseComputeTask):
" with exception %s"), e)
class CertComputeCreate(ComputeCreate):
def execute(self, amphora_id, server_pem):
"""Create an amphora
:returns: an amphora
"""
# load client certificate
with open(CONF.controller_worker.client_ca, 'r') as client_ca:
config_drive_files = {
# '/etc/octavia/octavia.conf'
'/etc/octavia/certs/server.pem': server_pem,
'/etc/octavia/certs/client_ca.pem': client_ca}
return super(CertComputeCreate, self).execute(amphora_id,
config_drive_files)
class DeleteAmphoraeOnLoadBalancer(BaseComputeTask):
"""Delete the amphorae on a load balancer.

@ -349,7 +349,8 @@ class ServerTestCase(base.TestCase):
details='No certificate with filename: test.pem',
message='Certificate Not Found'),
json.loads(rv.data))
mock_exists.assert_called_once_with('/tmp/123/test.pem')
mock_exists.assert_called_once_with(
'/var/lib/octavia/certs/123/test.pem')
# wrong file name
mock_exists.side_effect = [True]
@ -363,7 +364,8 @@ class ServerTestCase(base.TestCase):
'/listeners/123/certificates/test.pem')
self.assertEqual(200, rv.status_code)
self.assertEqual(OK, json.loads(rv.data))
mock_remove.assert_called_once_with('/tmp/123/test.pem')
mock_remove.assert_called_once_with(
'/var/lib/octavia/certs/123/test.pem')
@mock.patch('os.path.exists')
def test_get_certificate_md5(self, mock_exists):
@ -381,7 +383,7 @@ class ServerTestCase(base.TestCase):
details='No certificate with filename: test.pem',
message='Certificate Not Found'),
json.loads(rv.data))
mock_exists.assert_called_with('/tmp/123/test.pem')
mock_exists.assert_called_with('/var/lib/octavia/certs/123/test.pem')
# wrong file name
mock_exists.side_effect = [True]

@ -0,0 +1,601 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
# Copyright (c) 2015 Rackspace
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from oslo_utils import uuidutils
import requests_mock
from octavia.amphorae.drivers.haproxy import exceptions as exc
from octavia.amphorae.drivers.haproxy import rest_api_driver as driver
from octavia.db import models
from octavia.tests.unit import base as base
from octavia.tests.unit.common.sample_configs import sample_configs
class HaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
FAKE_UUID_1 = uuidutils.generate_uuid()
def setUp(self):
super(HaproxyAmphoraLoadBalancerDriverTest, self).setUp()
self.driver = driver.HaproxyAmphoraLoadBalancerDriver()
self.driver.cert_manager = mock.MagicMock()
self.driver.client = mock.MagicMock()
self.driver.jinja = mock.MagicMock()
# Build sample Listener and VIP configs
self.sl = sample_configs.sample_listener_tuple(tls=True, sni=True)
self.amp = self.sl.load_balancer.amphorae[0]
self.sv = sample_configs.sample_vip_tuple()
self.lb = self.sl.load_balancer
@mock.patch('octavia.common.tls_utils.cert_parser.get_host_names')
def test_update(self, mock_cert):
mock_cert.return_value = {'cn': 'fakeCN'}
self.driver.client.get_cert_md5sum.side_effect = [
exc.NotFound, 'Fake_MD5', 'd41d8cd98f00b204e9800998ecf8427e']
self.driver.jinja.build_config.side_effect = ['fake_config']
self.driver.client.get_listener_status.side_effect = [
dict(status='ACTIVE')]
# Execute driver method
self.driver.update(self.sl, self.sv)
# verify result
# this is called 3 times
self.driver.client.get_cert_md5sum.assert_called_with(
self.amp, self.sl.id, 'fakeCN.pem')
# this is called twice (last MD5 matches)
self.driver.client.upload_cert_pem.assert_called_with(
self.amp, self.sl.id, 'fakeCN.pem', '')
self.assertEqual(2, self.driver.client.upload_cert_pem.call_count)
# upload only one config file
self.driver.client.upload_config.assert_called_once_with(
self.amp, self.sl.id, 'fake_config')
# start should be called once
self.driver.client.reload_listener.assert_called_once_with(
self.amp, self.sl.id)
# listener down
self.driver.client.get_cert_md5sum.side_effect = [
'd41d8cd98f00b204e9800998ecf8427e'] * 3
self.driver.jinja.build_config.side_effect = ['fake_config']
self.driver.client.get_listener_status.side_effect = [
dict(status='BLAH')]
self.driver.update(self.sl, self.sv)
self.driver.client.start_listener.assert_called_once_with(
self.amp, self.sl.id)
def test_stop(self):
# Execute driver method
self.driver.stop(self.sl, self.sv)
self.driver.client.stop_listener.assert_called_once_with(
self.amp, self.sl.id)
def test_start(self):
# Execute driver method
self.driver.start(self.sl, self.sv)
self.driver.client.start_listener.assert_called_once_with(
self.amp, self.sl.id)
def test_delete(self):
# Execute driver method
self.driver.delete(self.sl, self.sv)
self.driver.client.delete_listener.assert_called_once_with(
self.amp, self.sl.id)
def test_get_info(self):
pass
def test_get_diagnostics(self):
pass
def test_finalize_amphora(self):
pass
def test_post_vip_plug(self):
self.driver.post_vip_plug(self.lb)
self.driver.client.plug_vip.assert_called_once_with(
self.amp, self.lb.vip.ip_address)
def test_post_network_plug(self):
self.driver.post_network_plug(self.amp)
self.driver.client.plug_network.assert_called_once_with(self.amp)
class AmphoraAPIClientTest(base.TestCase):
FAKE_UUID_1 = uuidutils.generate_uuid()
FAKE_PEM_FILENAME = "file_name"
def setUp(self):
super(AmphoraAPIClientTest, self).setUp()
self.driver = driver.AmphoraAPIClient()
self.base_url = "https://127.0.0.1:8443/0.5"
self.amp = models.Amphora(lb_network_ip='127.0.0.1', compute_id='123')
@requests_mock.mock()
def test_get_info(self, m):
info = {"hostname": "some_hostname", "version": "some_version",
"api_version": "0.5", "uuid": self.FAKE_UUID_1}
m.get("{base}/info".format(base=self.base_url),
json=info)
information = self.driver.get_info(self.amp)
self.assertEqual(info, information)
@requests_mock.mock()
def test_get_info_unauthorized(self, m):
m.get("{base}/info".format(base=self.base_url),
status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.get_info, self.amp)
@requests_mock.mock()
def test_get_info_missing(self, m):
m.get("{base}/info".format(base=self.base_url),
status_code=404)
self.assertRaises(exc.NotFound, self.driver.get_info, self.amp)
@requests_mock.mock()
def test_get_info_server_error(self, m):
m.get("{base}/info".format(base=self.base_url),
status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.get_info,
self.amp)
@requests_mock.mock()
def test_get_info_service_unavailable(self, m):
m.get("{base}/info".format(base=self.base_url),
status_code=503)
self.assertRaises(exc.ServiceUnavailable, self.driver.get_info,
self.amp)
@requests_mock.mock()
def test_get_details(self, m):
details = {"hostname": "some_hostname", "version": "some_version",
"api_version": "0.5", "uuid": self.FAKE_UUID_1,
"network_tx": "some_tx", "network_rx": "some_rx",
"active": True, "haproxy_count": 10}
m.get("{base}/details".format(base=self.base_url),
json=details)
amp_details = self.driver.get_details(self.amp)
self.assertEqual(details, amp_details)
@requests_mock.mock()
def test_get_details_unauthorized(self, m):
m.get("{base}/details".format(base=self.base_url),
status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.get_details, self.amp)
@requests_mock.mock()
def test_get_details_missing(self, m):
m.get("{base}/details".format(base=self.base_url),
status_code=404)
self.assertRaises(exc.NotFound, self.driver.get_details, self.amp)
@requests_mock.mock()
def test_get_details_server_error(self, m):
m.get("{base}/details".format(base=self.base_url),
status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.get_details,
self.amp)
@requests_mock.mock()
def test_get_details_service_unavailable(self, m):
m.get("{base}/details".format(base=self.base_url),
status_code=503)
self.assertRaises(exc.ServiceUnavailable, self.driver.get_details,
self.amp)
@requests_mock.mock()
def test_get_all_listeners(self, m):
listeners = [{"status": "ONLINE", "provisioning_status": "ACTIVE",
"type": "PASSIVE", "uuid": self.FAKE_UUID_1}]
m.get("{base}/listeners".format(base=self.base_url),
json=listeners)
all_listeners = self.driver.get_all_listeners(self.amp)
self.assertEqual(listeners, all_listeners)
@requests_mock.mock()
def test_get_all_listeners_unauthorized(self, m):
m.get("{base}/listeners".format(base=self.base_url),
status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.get_all_listeners,
self.amp)
@requests_mock.mock()
def test_get_all_listeners_missing(self, m):
m.get("{base}/listeners".format(base=self.base_url),
status_code=404)
self.assertRaises(exc.NotFound, self.driver.get_all_listeners,
self.amp)
@requests_mock.mock()
def test_get_all_listeners_server_error(self, m):
m.get("{base}/listeners".format(base=self.base_url),
status_code=500)
self.assertRaises(exc.InternalServerError,
self.driver.get_all_listeners, self.amp)
@requests_mock.mock()
def test_get_all_listeners_service_unavailable(self, m):
m.get("{base}/listeners".format(base=self.base_url),
status_code=503)
self.assertRaises(exc.ServiceUnavailable,
self.driver.get_all_listeners, self.amp)
@requests_mock.mock()
def test_get_listener_status(self, m):
listener = {"status": "ONLINE", "provisioning_status": "ACTIVE",
"type": "PASSIVE", "uuid": self.FAKE_UUID_1}
m.get("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
json=listener)
status = self.driver.get_listener_status(self.amp, self.FAKE_UUID_1)
self.assertEqual(listener, status)
@requests_mock.mock()
def test_get_listener_status_unauthorized(self, m):
m.get("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=401)
self.assertRaises(exc.Unauthorized,
self.driver.get_listener_status, self.amp,
self.FAKE_UUID_1)
@requests_mock.mock()
def test_get_listener_status_missing(self, m):
m.get("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=404)
self.assertRaises(exc.NotFound,
self.driver.get_listener_status, self.amp,
self.FAKE_UUID_1)
@requests_mock.mock()
def test_get_listener_status_server_error(self, m):
m.get("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=500)
self.assertRaises(exc.InternalServerError,
self.driver.get_listener_status, self.amp,
self.FAKE_UUID_1)
@requests_mock.mock()
def test_get_listener_status_service_unavailable(self, m):
m.get("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=503)
self.assertRaises(exc.ServiceUnavailable,
self.driver.get_listener_status, self.amp,
self.FAKE_UUID_1)
@requests_mock.mock()
def test_start_listener(self, m):
m.put("{base}/listeners/{listener_id}/start".format(
base=self.base_url, listener_id=self.FAKE_UUID_1))
self.driver.start_listener(self.amp, self.FAKE_UUID_1)
self.assertTrue(m.called)
@requests_mock.mock()
def test_start_listener_missing(self, m):
m.put("{base}/listeners/{listener_id}/start".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=404)
self.assertRaises(exc.NotFound, self.driver.start_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_start_listener_unauthorized(self, m):
m.put("{base}/listeners/{listener_id}/start".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.start_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_start_listener_server_error(self, m):
m.put("{base}/listeners/{listener_id}/start".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.start_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_start_listener_service_unavailable(self, m):
m.put("{base}/listeners/{listener_id}/start".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=503)
self.assertRaises(exc.ServiceUnavailable, self.driver.start_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_stop_listener(self, m):
m.put("{base}/listeners/{listener_id}/stop".format(
base=self.base_url, listener_id=self.FAKE_UUID_1))
self.driver.stop_listener(self.amp, self.FAKE_UUID_1)
self.assertTrue(m.called)
@requests_mock.mock()
def test_stop_listener_missing(self, m):
m.put("{base}/listeners/{listener_id}/stop".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=404)
self.assertRaises(exc.NotFound, self.driver.stop_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_stop_listener_unauthorized(self, m):
m.put("{base}/listeners/{listener_id}/stop".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.stop_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_stop_listener_server_error(self, m):
m.put("{base}/listeners/{listener_id}/stop".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.stop_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_stop_listener_service_unavailable(self, m):
m.put("{base}/listeners/{listener_id}/stop".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=503)
self.assertRaises(exc.ServiceUnavailable, self.driver.stop_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_delete_listener(self, m):
m.delete("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1), json={})
self.driver.delete_listener(self.amp, self.FAKE_UUID_1)
self.assertTrue(m.called)
@requests_mock.mock()
def test_delete_listener_missing(self, m):
m.delete("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=404)
self.assertRaises(exc.NotFound, self.driver.delete_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_delete_listener_unauthorized(self, m):
m.delete("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.delete_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_delete_listener_server_error(self, m):
m.delete("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.delete_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_delete_listener_service_unavailable(self, m):
m.delete("{base}/listeners/{listener_id}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1),
status_code=503)
self.assertRaises(exc.ServiceUnavailable, self.driver.delete_listener,
self.amp, self.FAKE_UUID_1)
@requests_mock.mock()
def test_upload_cert_pem(self, m):
m.put("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME))
self.driver.upload_cert_pem(self.amp, self.FAKE_UUID_1,
self.FAKE_PEM_FILENAME,
"some_file")
self.assertTrue(m.called)
@requests_mock.mock()
def test_upload_invalid_cert_pem(self, m):
m.put("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=403)
self.assertRaises(exc.InvalidRequest, self.driver.upload_cert_pem,
self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME,
"some_file")
@requests_mock.mock()
def test_upload_cert_pem_unauthorized(self, m):
m.put("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.upload_cert_pem,
self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME,
"some_file")
@requests_mock.mock()
def test_upload_cert_pem_server_error(self, m):
m.put("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.upload_cert_pem,
self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME,
"some_file")
@requests_mock.mock()
def test_upload_cert_pem_service_unavailable(self, m):
m.put("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=503)
self.assertRaises(exc.ServiceUnavailable, self.driver.upload_cert_pem,
self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME,
"some_file")
@requests_mock.mock()
def test_get_cert_5sum(self, m):
md5sum = {"md5sum": "some_real_sum"}
m.get("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), json=md5sum)
sum_test = self.driver.get_cert_md5sum(self.amp, self.FAKE_UUID_1,
self.FAKE_PEM_FILENAME)
self.assertIsNotNone(sum_test)
@requests_mock.mock()
def test_get_cert_5sum_missing(self, m):
m.get("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=404)
self.assertRaises(exc.NotFound, self.driver.get_cert_md5sum,
self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME)
@requests_mock.mock()
def test_get_cert_5sum_unauthorized(self, m):
m.get("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=401)
self.assertRaises(exc.Unauthorized, self.driver.get_cert_md5sum,
self.amp, self.FAKE_UUID_1, self.FAKE_PEM_FILENAME)
@requests_mock.mock()
def test_get_cert_5sum_server_error(self, m):
m.get("{base}/listeners/{listener_id}/certificates/{filename}".format(
base=self.base_url, listener_id=self.FAKE_UUID_1,
filename=self.FAKE_PEM_FILENAME), status_code=500)
self.assertRaises(exc.InternalServerError, self.driver.get_cert_md5sum,
self.amp, self.FAKE_UUID_1