Add get method support to the driver-agent

This patch adds support for the octavia-lib to get objects by ID.

Change-Id: I98b399891488e5972ea4d332c06b55b34f20fb11
Story: 2005870
Task: 33680
Co-Authored-By: Adam Harwell <flux.adam@gmail.com>
This commit is contained in:
Michael Johnson 2019-06-12 15:32:03 -07:00
parent fa70e9759d
commit 09efc2a423
39 changed files with 849 additions and 66 deletions

View File

@ -1931,11 +1931,13 @@ resources. See the `Octavia API Reference <https://docs.openstack.org/api-ref/lo
API Exception Model
-------------------
The driver support API will include two Exceptions, one for each of the
The driver support API will include exceptions:
two API groups:
* UpdateStatusError
* UpdateStatisticsError
* DriverAgentNotFound
* DriverAgentTimeout
Each exception class will include a message field that describes the error and
references to the failed record if available.
@ -1955,7 +1957,8 @@ references to the failed record if available.
self.status_object_id = kwargs.pop('status_object_id', None)
self.status_record = kwargs.pop('status_record', None)
super(UpdateStatusError, self).__init__(*args, **kwargs)
super(UpdateStatusError, self).__init__(self.fault_string,
*args, **kwargs)
class UpdateStatisticsError(Exception):
fault_string = _("The statistics update had an unknown error.")
@ -1970,7 +1973,24 @@ references to the failed record if available.
self.stats_object_id = kwargs.pop('stats_object_id', None)
self.stats_record = kwargs.pop('stats_record', None)
super(UpdateStatisticsError, self).__init__(*args, **kwargs)
super(UpdateStatisticsError, self).__init__(self.fault_string,
*args, **kwargs)
class DriverAgentNotFound(Exception):
fault_string = _("The driver-agent process was not found or not ready.")
def __init__(self, *args, **kwargs):
self.fault_string = kwargs.pop('fault_string', self.fault_string)
super(DriverAgentNotFound, self).__init__(self.fault_string,
*args, **kwargs)
class DriverAgentTimeout(Exception):
fault_string = _("The driver-agent timeout.")
def __init__(self, *args, **kwargs):
self.fault_string = kwargs.pop('fault_string', self.fault_string)
super(DriverAgentTimeout, self).__init__(self.fault_string,
*args, **kwargs)
Documenting the Driver
======================

View File

@ -504,6 +504,7 @@
[driver_agent]
# status_socket_path = /var/run/octavia/status.sock
# stats_socket_path = /var/run/octavia/stats.sock
# get_socket_path = /var/run/octavia/get.sock
# Maximum time to wait for a status message before checking for shutdown
# status_request_timeout = 5

View File

@ -0,0 +1,82 @@
# Copyright 2019 Red Hat, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from octavia_lib.common import constants as lib_consts
from octavia.api.drivers import utils as driver_utils
from octavia.common import constants
from octavia.db import api as db_api
from octavia.db import repositories
def process_get(get_data):
session = db_api.get_session()
if get_data[constants.OBJECT] == lib_consts.LOADBALANCERS:
lb_repo = repositories.LoadBalancerRepository()
db_lb = lb_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_lb:
provider_lb = (
driver_utils.db_loadbalancer_to_provider_loadbalancer(db_lb))
return provider_lb.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.LISTENERS:
listener_repo = repositories.ListenerRepository()
db_listener = listener_repo.get(
session, id=get_data[lib_consts.ID], show_deleted=False)
if db_listener:
provider_listener = (
driver_utils.db_listener_to_provider_listener(db_listener))
return provider_listener.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.POOLS:
pool_repo = repositories.PoolRepository()
db_pool = pool_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_pool:
provider_pool = (
driver_utils.db_pool_to_provider_pool(db_pool))
return provider_pool.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.MEMBERS:
member_repo = repositories.MemberRepository()
db_member = member_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_member:
provider_member = (
driver_utils.db_member_to_provider_member(db_member))
return provider_member.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.HEALTHMONITORS:
hm_repo = repositories.HealthMonitorRepository()
db_hm = hm_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_hm:
provider_hm = (
driver_utils.db_HM_to_provider_HM(db_hm))
return provider_hm.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.L7POLICIES:
l7policy_repo = repositories.L7PolicyRepository()
db_l7policy = l7policy_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_l7policy:
provider_l7policy = (
driver_utils.db_l7policy_to_provider_l7policy(db_l7policy))
return provider_l7policy.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.L7RULES:
l7rule_repo = repositories.L7RuleRepository()
db_l7rule = l7rule_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_l7rule:
provider_l7rule = (
driver_utils.db_l7rule_to_provider_l7rule(db_l7rule))
return provider_l7rule.to_dict(recurse=True, render_unsets=True)
return {}

View File

@ -1,4 +1,5 @@
# Copyright 2018 Rackspace, US Inc.
# Copyright 2019 Red Hat, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -23,6 +24,7 @@ from oslo_config import cfg
from oslo_log import log as logging
from oslo_serialization import jsonutils
from octavia.api.drivers.driver_agent import driver_get
from octavia.api.drivers.driver_agent import driver_updater
@ -80,6 +82,22 @@ class StatsRequestHandler(socketserver.BaseRequestHandler):
self.request.sendall(json_data)
class GetRequestHandler(socketserver.BaseRequestHandler):
def handle(self):
# Get the data request
get_data = _recv(self.request)
# Process the get
response = driver_get.process_get(get_data)
# Send the response
json_data = jsonutils.dump_as_bytes(response)
len_str = '{}\n'.format(len(json_data)).encode('utf-8')
self.request.send(len_str)
self.request.sendall(json_data)
class ForkingUDSServer(socketserver.ForkingMixIn,
socketserver.UnixStreamServer):
pass
@ -142,3 +160,26 @@ def stats_listener(exit_event):
LOG.info('Driver statistics listener shutdown finished.')
server.server_close()
_cleanup_socket_file(CONF.driver_agent.stats_socket_path)
def get_listener(exit_event):
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGHUP, _mutate_config)
_cleanup_socket_file(CONF.driver_agent.get_socket_path)
server = ForkingUDSServer(CONF.driver_agent.get_socket_path,
GetRequestHandler)
server.timeout = CONF.driver_agent.get_request_timeout
server.max_children = CONF.driver_agent.get_max_processes
while not exit_event.is_set():
server.handle_request()
LOG.info('Waiting for driver get listener to shutdown...')
# Can't shut ourselves down as we would deadlock, spawn a thread
threading.Thread(target=server.shutdown).start()
LOG.info('Driver get listener shutdown finished.')
server.server_close()
_cleanup_socket_file(CONF.driver_agent.get_socket_path)

View File

@ -250,11 +250,11 @@ def listener_dict_to_provider_dict(listener_dict):
listener_obj)
if 'tls_cert' in cert_dict and cert_dict['tls_cert']:
new_listener_dict['default_tls_container_data'] = (
cert_dict['tls_cert'].to_dict())
cert_dict['tls_cert'].to_dict(recurse=True))
if 'sni_certs' in cert_dict and cert_dict['sni_certs']:
sni_data_list = []
for sni in cert_dict['sni_certs']:
sni_data_list.append(sni.to_dict())
sni_data_list.append(sni.to_dict(recurse=True))
new_listener_dict['sni_container_data'] = sni_data_list
if listener_obj.client_ca_tls_certificate_id:
@ -344,7 +344,7 @@ def pool_dict_to_provider_dict(pool_dict):
pool_obj)
if 'tls_cert' in cert_dict and cert_dict['tls_cert']:
new_pool_dict['tls_container_data'] = (
cert_dict['tls_cert'].to_dict())
cert_dict['tls_cert'].to_dict(recurse=True))
if pool_obj.ca_tls_certificate_id:
cert = _get_secret_data(cert_manager, pool_obj.project_id,

View File

@ -18,10 +18,12 @@ import uuid
from oslo_config import cfg
from oslo_log import log as logging
import six
from octavia.certificates.common import local as local_common
from octavia.certificates.manager import cert_mgr
from octavia.common import exceptions
from octavia.common.tls_utils import cert_parser
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
@ -49,6 +51,10 @@ class LocalCertManager(cert_mgr.CertManager):
"""
cert_ref = str(uuid.uuid4())
filename_base = os.path.join(CONF.certificates.storage_path, cert_ref)
if type(certificate) == six.binary_type:
certificate = certificate.decode('utf-8')
if type(private_key) == six.binary_type:
private_key = private_key.decode('utf-8')
LOG.info("Storing certificate data on the local filesystem.")
try:
@ -66,12 +72,17 @@ class LocalCertManager(cert_mgr.CertManager):
if intermediates:
filename_intermediates = "{0}.int".format(filename_base)
if type(intermediates) == six.binary_type:
intermediates = intermediates.decode('utf-8')
with os.fdopen(os.open(
filename_intermediates, flags, mode), 'w') as int_file:
int_file.write(intermediates)
if private_key_passphrase:
filename_pkp = "{0}.pass".format(filename_base)
if type(private_key_passphrase) == six.binary_type:
private_key_passphrase = private_key_passphrase.decode(
'utf-8')
with os.fdopen(os.open(
filename_pkp, flags, mode), 'w') as pass_file:
pass_file.write(private_key_passphrase)
@ -122,6 +133,8 @@ class LocalCertManager(cert_mgr.CertManager):
try:
with os.fdopen(os.open(filename_intermediates, flags)) as int_file:
cert_data['intermediates'] = int_file.read()
cert_data['intermediates'] = list(
cert_parser.get_intermediates_pems(cert_data['intermediates']))
except IOError:
pass
@ -184,7 +197,7 @@ class LocalCertManager(cert_mgr.CertManager):
filename_base = os.path.join(CONF.certificates.storage_path,
secret_ref)
filename_secret = "{0}.pem".format(filename_base)
filename_secret = "{0}.crt".format(filename_base)
secret_data = None

View File

@ -66,16 +66,25 @@ def main():
LOG.info("Driver agent statistics listener process starts:")
stats_listener_proc.start()
get_listener_proc = multiprocessing.Process(
name='get_listener', target=driver_listener.get_listener,
args=(exit_event,))
processes.append(get_listener_proc)
LOG.info("Driver agent get listener process starts:")
get_listener_proc.start()
def process_cleanup(*args, **kwargs):
LOG.info("Driver agent exiting due to signal")
exit_event.set()
status_listener_proc.join()
stats_listener_proc.join()
get_listener_proc.join()
signal.signal(signal.SIGTERM, process_cleanup)
signal.signal(signal.SIGHUP, partial(
_handle_mutate_config, status_listener_proc.pid,
stats_listener_proc.pid))
stats_listener_proc.pid, get_listener_proc.pid))
try:
for process in processes:

View File

@ -637,6 +637,9 @@ driver_agent_opts = [
default='/var/run/octavia/stats.sock',
help=_('Path to the driver statistics unix domain socket '
'file.')),
cfg.StrOpt('get_socket_path',
default='/var/run/octavia/get.sock',
help=_('Path to the driver get unix domain socket file.')),
cfg.IntOpt('status_request_timeout',
default=5,
help=_('Time, in seconds, to wait for a status update '
@ -653,6 +656,13 @@ driver_agent_opts = [
default=50,
help=_('Maximum number of concurrent processes to use '
'servicing statistics updates.')),
cfg.IntOpt('get_request_timeout',
default=5,
help=_('Time, in seconds, to wait for a get request.')),
cfg.IntOpt('get_max_processes',
default=50,
help=_('Maximum number of concurrent processes to use '
'servicing get requests.')),
cfg.FloatOpt('max_process_warning_percent',
default=0.75, min=0.01, max=0.99,
help=_('Percentage of max_processes (both status and stats) '

View File

@ -351,6 +351,9 @@ TOTAL_CONNECTIONS = 'total_connections'
NAME = 'name'
PROVIDER_NAME = 'provider_name'
# Database fields
SNI_CONTAINERS = 'sni_containers'
CERT_ROTATE_AMPHORA_FLOW = 'octavia-cert-rotate-amphora-flow'
CREATE_AMPHORA_FLOW = 'octavia-create-amphora-flow'
CREATE_AMPHORA_FOR_LB_FLOW = 'octavia-create-amp-for-lb-flow'

View File

@ -47,9 +47,15 @@ class BaseDataModel(object):
calling_classes + [type(self)]),
recurse=recurse))
else:
# TODO(rm_work): Is the idea that if this list
# contains ANY BaseDataModel, that all of them
# are data models, and we may as well quit?
# Or, were we supposed to append a `None` for
# each one? I assume the former?
ret[attr] = None
break
else:
ret[attr] = item
ret[attr].append(item)
elif isinstance(getattr(self, attr), BaseDataModel):
if type(self) not in calling_classes:
ret[attr] = value.to_dict(
@ -62,8 +68,10 @@ class BaseDataModel(object):
else:
ret[attr] = value
else:
if isinstance(getattr(self, attr), (BaseDataModel, list)):
if isinstance(getattr(self, attr), BaseDataModel):
ret[attr] = None
elif isinstance(getattr(self, attr), list):
ret[attr] = []
else:
ret[attr] = value

View File

@ -106,6 +106,12 @@ def get_intermediates_pems(intermediates=None):
X509 pem block surrounded by BEGIN CERTIFICATE,
END CERTIFICATE block tags
"""
if isinstance(intermediates, six.string_types):
try:
intermediates = intermediates.encode("utf-8")
except UnicodeDecodeError:
LOG.debug("Couldn't encode intermediates string, it was probably "
"in binary DER format.")
if X509_BEG in intermediates:
for x509Pem in _split_x509s(intermediates):
yield _prepare_x509_cert(_get_x509_from_pem_bytes(x509Pem))
@ -249,6 +255,8 @@ def get_host_names(certificate):
certificate, and 'dns_names' is a list of dNSNames
(possibly empty) from the SubjectAltNames of the certificate.
"""
if isinstance(certificate, six.string_types):
certificate = certificate.encode('utf-8')
try:
cert = x509.load_pem_x509_certificate(certificate,
backends.default_backend())
@ -362,19 +370,34 @@ def load_certificates_data(cert_mngr, obj, context=None):
def _map_cert_tls_container(cert):
certificate = cert.get_certificate()
private_key = cert.get_private_key()
private_key_passphrase = cert.get_private_key_passphrase()
intermediates = cert.get_intermediates()
if isinstance(certificate, six.string_types):
certificate = certificate.encode('utf-8')
if isinstance(private_key, six.string_types):
private_key = private_key.encode('utf-8')
if isinstance(private_key_passphrase, six.string_types):
private_key_passphrase = private_key_passphrase.encode('utf-8')
if intermediates:
intermediates = [
(imd.encode('utf-8') if isinstance(imd, six.string_types) else imd)
for imd in intermediates
]
else:
intermediates = []
return data_models.TLSContainer(
# TODO(rm_work): applying nosec here because this is not intended to be
# secure, it's just a way to get a consistent ID. Changing this would
# break backwards compatibility with existing loadbalancers.
id=hashlib.sha1(cert.get_certificate()).hexdigest(), # nosec
primary_cn=get_primary_cn(cert),
private_key=prepare_private_key(
cert.get_private_key(),
cert.get_private_key_passphrase()),
certificate=cert.get_certificate(),
intermediates=cert.get_intermediates())
id=hashlib.sha1(certificate).hexdigest(), # nosec
primary_cn=get_primary_cn(certificate),
private_key=prepare_private_key(private_key, private_key_passphrase),
certificate=certificate,
intermediates=intermediates)
def get_primary_cn(tls_cert):
"""Returns primary CN for Certificate."""
return get_host_names(tls_cert.get_certificate())['cn']
return get_host_names(tls_cert)['cn']

View File

@ -2607,7 +2607,8 @@ class TestServerTestCase(base.TestCase):
self.assertEqual(500, rv.status_code)
def test_version_discovery(self):
self.test_client = server.Server().app.test_client()
with mock.patch('distro.id', return_value='ubuntu'):
self.test_client = server.Server().app.test_client()
expected_dict = {'api_version': api_server.VERSION}
rv = self.test_client.get('/')
self.assertEqual(200, rv.status_code)

View File

@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,334 @@
# Copyright 2019 Red Hat, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
import multiprocessing
from octavia_lib.api.drivers import driver_lib as octavia_driver_lib
from octavia_lib.common import constants as lib_consts
from oslo_config import cfg
from oslo_config import fixture as oslo_fixture
from oslo_utils import uuidutils
from stevedore import driver as stevedore_driver
from octavia.api.drivers.driver_agent import driver_listener
from octavia.common import config
from octavia.common import constants
from octavia.db import repositories
from octavia.tests.common import sample_certs
from octavia.tests.common import sample_data_models
from octavia.tests.functional.db import base
CONF = cfg.CONF
class DriverAgentTest(base.OctaviaDBTestBase):
def _process_cleanup(self):
self.exit_event.set()
self.status_listener_proc.join(5)
self.stats_listener_proc.join(5)
self.get_listener_proc.join(5)
def setUp(self):
status_socket_file = '/tmp/octavia-{}.status.sock'.format(
uuidutils.generate_uuid())
stats_socket_file = '/tmp/octavia-{}.stats.sock'.format(
uuidutils.generate_uuid())
get_socket_file = '/tmp/octavia-{}.get.sock'.format(
uuidutils.generate_uuid())
sqlite_db_file = '/tmp/octavia-{}.sqlite.db'.format(
uuidutils.generate_uuid())
sqlite_db_connection = 'sqlite:///{}'.format(sqlite_db_file)
# Note that because the driver agent is a multi-process
# agent we must use a sqlite file rather than an
# in-memory instance.
super(DriverAgentTest, self).setUp(
connection_string=sqlite_db_connection)
conf = self.useFixture(oslo_fixture.Config(config.cfg.CONF))
conf.config(group="driver_agent",
status_socket_path=status_socket_file)
conf.config(group="driver_agent",
stats_socket_path=stats_socket_file)
conf.config(group="driver_agent", status_request_timeout=1)
conf.config(group="driver_agent", get_socket_path=get_socket_file)
conf.config(group="certificates", cert_manager='local_cert_manager')
conf.config(group="certificates", storage_path='/tmp')
# Set up the certificate
cert_manager = stevedore_driver.DriverManager(
namespace='octavia.cert_manager',
name=CONF.certificates.cert_manager,
invoke_on_load=True,
).driver
self.cert_ref = cert_manager.store_cert(
None,
sample_certs.X509_CERT,
sample_certs.X509_CERT_KEY_ENCRYPTED,
sample_certs.X509_IMDS,
private_key_passphrase=sample_certs.X509_CERT_KEY_PASSPHRASE)
self.addCleanup(cert_manager.delete_cert, None, self.cert_ref)
self.exit_event = multiprocessing.Event()
self.status_listener_proc = multiprocessing.Process(
name='status_listener', target=driver_listener.status_listener,
args=(self.exit_event,))
# TODO(johnsom) Remove once https://bugs.python.org/issue6721
# is resolved.
self.status_listener_proc.daemon = True
self.status_listener_proc.start()
self.stats_listener_proc = multiprocessing.Process(
name='stats_listener', target=driver_listener.stats_listener,
args=(self.exit_event,))
# TODO(johnsom) Remove once https://bugs.python.org/issue6721
# is resolved.
self.stats_listener_proc.daemon = True
self.stats_listener_proc.start()
self.get_listener_proc = multiprocessing.Process(
name='get_listener', target=driver_listener.get_listener,
args=(self.exit_event,))
# TODO(johnsom) Remove once https://bugs.python.org/issue6721
# is resolved.
self.get_listener_proc.daemon = True
self.get_listener_proc.start()
self.addCleanup(self._process_cleanup)
self.driver_lib = octavia_driver_lib.DriverLibrary(
status_socket=status_socket_file,
stats_socket=stats_socket_file,
get_socket=get_socket_file)
self.sample_data = sample_data_models.SampleDriverDataModels()
self.repos = repositories.Repositories()
# Create the full load balancer in the database
self.tls_container_dict = {
'certificate': sample_certs.X509_CERT.decode('utf-8'),
'id': sample_certs.X509_CERT_SHA1,
'intermediates': [
i.decode('utf-8') for i in sample_certs.X509_IMDS_LIST],
'passphrase': None,
'primary_cn': sample_certs.X509_CERT_CN,
'private_key': sample_certs.X509_CERT_KEY.decode('utf-8')}
# ### Create load balancer
self.repos.flavor_profile.create(
self.session, id=self.sample_data.flavor_profile_id,
provider_name='amphora', flavor_data='{"something": "else"}')
self.repos.flavor.create(
self.session, id=self.sample_data.flavor_id,
enabled=True, flavor_profile_id=self.sample_data.flavor_profile_id)
self.repos.create_load_balancer_and_vip(
self.session, self.sample_data.test_loadbalancer1_dict,
self.sample_data.test_vip_dict)
# ### Create Pool
pool_dict = copy.deepcopy(self.sample_data.test_pool1_dict)
pool_dict['load_balancer_id'] = self.sample_data.lb_id
# Use a live certificate
pool_dict['tls_certificate_id'] = self.cert_ref
pool_dict['ca_tls_certificate_id'] = self.cert_ref
pool_dict['crl_container_id'] = self.cert_ref
# Remove items that are linked in the DB
del pool_dict[lib_consts.MEMBERS]
del pool_dict[constants.HEALTH_MONITOR]
del pool_dict['session_persistence']
del pool_dict[lib_consts.LISTENERS]
del pool_dict[lib_consts.L7POLICIES]
self.repos.pool.create(self.session, **pool_dict)
self.repos.session_persistence.create(
self.session, pool_id=self.sample_data.pool1_id,
type=lib_consts.SESSION_PERSISTENCE_SOURCE_IP)
self.provider_pool_dict = copy.deepcopy(
self.sample_data.provider_pool1_dict)
self.provider_pool_dict[
constants.LISTENER_ID] = self.sample_data.listener1_id
# Fix for render_unsets = True
self.provider_pool_dict['session_persistence']['cookie_name'] = None
self.provider_pool_dict['session_persistence'][
'persistence_granularity'] = None
self.provider_pool_dict['session_persistence'][
'persistence_timeout'] = None
# Use a live certificate
self.provider_pool_dict['tls_container_data'] = self.tls_container_dict
self.provider_pool_dict['tls_container_ref'] = self.cert_ref
self.provider_pool_dict[
'ca_tls_container_data'] = sample_certs.X509_CERT.decode('utf-8')
self.provider_pool_dict['ca_tls_container_ref'] = self.cert_ref
self.provider_pool_dict[
'crl_container_data'] = sample_certs.X509_CERT.decode('utf-8')
self.provider_pool_dict['crl_container_ref'] = self.cert_ref
# ### Create Member
member_dict = copy.deepcopy(self.sample_data.test_member1_dict)
self.repos.member.create(self.session, **member_dict)
self.provider_pool_dict[lib_consts.MEMBERS] = [
self.sample_data.provider_member1_dict]
# ### Create Health Monitor
hm_dict = copy.deepcopy(self.sample_data.test_hm1_dict)
self.repos.health_monitor.create(self.session, **hm_dict)
self.provider_pool_dict[
'healthmonitor'] = self.sample_data.provider_hm1_dict
# ### Create Listener
listener_dict = copy.deepcopy(self.sample_data.test_listener1_dict)
listener_dict['default_pool_id'] = self.sample_data.pool1_id
# Remove items that are linked in the DB
del listener_dict[lib_consts.L7POLICIES]
del listener_dict['default_pool']
del listener_dict[constants.SNI_CONTAINERS]
# Use a live certificate
listener_dict['tls_certificate_id'] = self.cert_ref
listener_dict['client_ca_tls_certificate_id'] = self.cert_ref
listener_dict['client_crl_container_id'] = self.cert_ref
self.repos.listener.create(self.session,
**listener_dict)
self.repos.sni.create(self.session,
listener_id=self.sample_data.listener1_id,
tls_container_id=self.cert_ref, position=1)
# Add our live certs in that differ from the fake certs in sample_data
self.provider_listener_dict = copy.deepcopy(
self.sample_data.provider_listener1_dict)
self.provider_listener_dict['allowed_cidrs'] = None
self.provider_listener_dict[
'default_tls_container_ref'] = self.cert_ref
self.provider_listener_dict[
'default_tls_container_data'] = self.tls_container_dict
self.provider_listener_dict[
'client_ca_tls_container_ref'] = self.cert_ref
self.provider_listener_dict['client_ca_tls_container_data'] = (
sample_certs.X509_CERT.decode('utf-8'))
self.provider_listener_dict['client_crl_container_ref'] = self.cert_ref
self.provider_listener_dict['client_crl_container_data'] = (
sample_certs.X509_CERT.decode('utf-8'))
self.provider_listener_dict[
'sni_container_data'] = [self.tls_container_dict]
self.provider_listener_dict['sni_container_refs'] = [self.cert_ref]
self.provider_listener_dict['default_pool'] = self.provider_pool_dict
self.provider_listener_dict[
'default_pool_id'] = self.sample_data.pool1_id
self.provider_listener_dict[lib_consts.L7POLICIES] = [
self.sample_data.provider_l7policy1_dict]
# ### Create L7 Policy
l7policy_dict = copy.deepcopy(self.sample_data.test_l7policy1_dict)
del l7policy_dict['l7rules']
self.repos.l7policy.create(self.session, **l7policy_dict)
# ### Create L7 Rules
l7rule_dict = copy.deepcopy(self.sample_data.test_l7rule1_dict)
self.repos.l7rule.create(self.session, **l7rule_dict)
l7rule2_dict = copy.deepcopy(self.sample_data.test_l7rule2_dict)
self.repos.l7rule.create(self.session, **l7rule2_dict)
self.provider_lb_dict = copy.deepcopy(
self.sample_data.provider_loadbalancer_tree_dict)
self.provider_lb_dict[lib_consts.POOLS] = [self.provider_pool_dict]
self.provider_lb_dict[
lib_consts.LISTENERS] = [self.provider_listener_dict]
def test_get_loadbalancer(self):
result = self.driver_lib.get_loadbalancer(self.sample_data.lb_id)
self.assertEqual(self.provider_lb_dict,
result.to_dict(render_unsets=True, recurse=True))
# Test non-existent load balancer
result = self.driver_lib.get_loadbalancer('bogus')
self.assertIsNone(result)
def test_get_listener(self):
result = self.driver_lib.get_listener(self.sample_data.listener1_id)
# We need to recurse here to pick up the SNI data
self.assertEqual(self.provider_listener_dict,
result.to_dict(render_unsets=True, recurse=True))
# Test non-existent listener
result = self.driver_lib.get_listener('bogus')
self.assertIsNone(result)
def test_get_pool(self):
result = self.driver_lib.get_pool(self.sample_data.pool1_id)
self.assertEqual(self.provider_pool_dict,
result.to_dict(render_unsets=True, recurse=True))
# Test non-existent pool
result = self.driver_lib.get_pool('bogus')
self.assertIsNone(result)
def test_get_member(self):
result = self.driver_lib.get_member(self.sample_data.member1_id)
self.assertEqual(self.sample_data.provider_member1_dict,
result.to_dict(render_unsets=True))
# Test non-existent member
result = self.driver_lib.get_member('bogus')
self.assertIsNone(result)
def test_get_healthmonitor(self):
result = self.driver_lib.get_healthmonitor(self.sample_data.hm1_id)
self.assertEqual(self.sample_data.provider_hm1_dict,
result.to_dict(render_unsets=True))
# Test non-existent health monitor
result = self.driver_lib.get_healthmonitor('bogus')
self.assertIsNone(result)
def test_get_l7policy(self):
result = self.driver_lib.get_l7policy(self.sample_data.l7policy1_id)
self.assertEqual(self.sample_data.provider_l7policy1_dict,
result.to_dict(render_unsets=True, recurse=True))
# Test non-existent L7 policy
result = self.driver_lib.get_l7policy('bogus')
self.assertIsNone(result)
def test_get_l7rule(self):
result = self.driver_lib.get_l7rule(self.sample_data.l7rule1_id)
self.assertEqual(self.sample_data.provider_l7rule1_dict,
result.to_dict(render_unsets=True))
# Test non-existent L7 rule
result = self.driver_lib.get_l7rule('bogus')
self.assertIsNone(result)

View File

@ -25,8 +25,8 @@ from octavia.common import constants
import octavia.common.context
from octavia.common import data_models
from octavia.common import exceptions
from octavia.tests.common import sample_certs
from octavia.tests.functional.api.v2 import base
from octavia.tests.unit.common.sample_configs import sample_certs
class TestListener(base.BaseAPITest):

View File

@ -22,8 +22,8 @@ from octavia.common import constants
import octavia.common.context
from octavia.common import data_models
from octavia.common import exceptions
from octavia.tests.common import sample_certs
from octavia.tests.functional.api.v2 import base
from octavia.tests.unit.common.sample_configs import sample_certs
class TestPool(base.BaseAPITest):

View File

@ -12,7 +12,11 @@
# License for the specific language governing permissions and limitations
# under the License.
import os
from oslo_config import cfg
from oslo_config import fixture as oslo_fixture
from oslo_db.sqlalchemy import session as db_session
from oslo_db.sqlalchemy import test_base
from octavia.common import config
@ -24,29 +28,40 @@ from octavia.db import models
class OctaviaDBTestBase(test_base.DbTestCase):
def setUp(self):
def setUp(self, connection_string='sqlite://'):
super(OctaviaDBTestBase, self).setUp()
# NOTE(blogan): doing this for now because using the engine and
# session set up in the fixture for test_base.DbTestCase does not work
# with the API functional tests. Need to investigate more if this
# becomes a problem
conf = self.useFixture(oslo_fixture.Config(config.cfg.CONF))
conf.config(group="database", connection='sqlite://')
conf.config(group="database", connection=connection_string)
# We need to get our own Facade so that the file backed sqlite tests
# don't use the _FACADE singleton. Some tests will use in-memory
# sqlite, some will use a file backed sqlite.
if 'sqlite:///' in connection_string:
facade = db_session.EngineFacade.from_config(cfg.CONF,
sqlite_fk=True)
engine = facade.get_engine()
self.session = facade.get_session(expire_on_commit=True,
autocommit=True)
else:
engine = db_api.get_engine()
self.session = db_api.get_session()
# needed for closure
engine = db_api.get_engine()
session = db_api.get_session()
base_models.BASE.metadata.create_all(engine)
self._seed_lookup_tables(session)
self._seed_lookup_tables(self.session)
def clear_tables():
"""Unregister all data models."""
base_models.BASE.metadata.drop_all(engine)
# If we created a file, clean it up too
if 'sqlite:///' in connection_string:
os.remove(connection_string.replace('sqlite:///', ''))
self.addCleanup(clear_tables)
self.session = session
def _seed_lookup_tables(self, session):
self._seed_lookup_table(
session, constants.SUPPORTED_PROVISIONING_STATUSES,

View File

@ -29,8 +29,8 @@ from octavia.amphorae.drivers.haproxy import rest_api_driver as driver
from octavia.common import constants
from octavia.db import models
from octavia.network import data_models as network_models
from octavia.tests.common import sample_certs
from octavia.tests.unit import base
from octavia.tests.unit.common.sample_configs import sample_certs
from octavia.tests.unit.common.sample_configs import sample_configs_split
API_VERSION = '0.5'

View File

@ -29,8 +29,8 @@ from octavia.amphorae.drivers.haproxy import rest_api_driver as driver
from octavia.common import constants
from octavia.db import models
from octavia.network import data_models as network_models
from octavia.tests.common import sample_certs
from octavia.tests.unit import base
from octavia.tests.unit.common.sample_configs import sample_certs
from octavia.tests.unit.common.sample_configs import sample_configs_combined
API_VERSION = '1.0'

View File

@ -160,7 +160,7 @@ class TestDataModelToDict(base.TestCase):
NO_RECURSE_RESULT = {'parent': None,
'text': 'parent_text',
'child': None,
'children': None}
'children': []}
def setUp(self):
super(TestDataModelToDict, self).setUp()

View File

@ -21,7 +21,7 @@ from octavia_lib.api.drivers import exceptions
from octavia.api.drivers.amphora_driver.v1 import driver
from octavia.common import constants as consts
from octavia.network import base as network_base
from octavia.tests.unit.api.drivers import sample_data_models
from octavia.tests.common import sample_data_models
from octavia.tests.unit import base
@ -585,8 +585,7 @@ class TestAmphoraDriver(base.TestRpc):
self.assertRaises(exceptions.DriverError,
self.amp_driver.get_supported_flavor_metadata)
@mock.patch('jsonschema.validators.requests')
def test_validate_flavor(self, mock_validate):
def test_validate_flavor(self):
ref_dict = {consts.LOADBALANCER_TOPOLOGY: consts.TOPOLOGY_SINGLE}
self.amp_driver.validate_flavor(ref_dict)

View File

@ -21,7 +21,7 @@ from octavia_lib.api.drivers import exceptions
from octavia.api.drivers.amphora_driver.v2 import driver
from octavia.common import constants as consts
from octavia.network import base as network_base
from octavia.tests.unit.api.drivers import sample_data_models
from octavia.tests.common import sample_data_models
from octavia.tests.unit import base
@ -585,8 +585,7 @@ class TestAmphoraDriver(base.TestRpc):
self.assertRaises(exceptions.DriverError,
self.amp_driver.get_supported_flavor_metadata)
@mock.patch('jsonschema.validators.requests')
def test_validate_flavor(self, mock_validate):
def test_validate_flavor(self):
ref_dict = {consts.LOADBALANCER_TOPOLOGY: consts.TOPOLOGY_SINGLE}
self.amp_driver.validate_flavor(ref_dict)

View File

@ -0,0 +1,121 @@
# Copyright 2019 Red Hat, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from octavia_lib.common import constants as lib_consts
from oslo_utils import uuidutils
from octavia.api.drivers.driver_agent import driver_get
from octavia.common import constants
import octavia.tests.unit.base as base
class TestDriverGet(base.TestCase):
@mock.patch('octavia.db.api.get_session')
def _test_process_get_object(self, object_name, mock_object_repo,
mock_object_to_provider, mock_get_session):
mock_get_session.return_value = 'bogus_session'
object_repo_mock = mock.MagicMock()
mock_object_repo.return_value = object_repo_mock
db_object_mock = mock.MagicMock()
object_repo_mock.get.return_value = db_object_mock
mock_prov_object = mock.MagicMock()
mock_object_to_provider.return_value = mock_prov_object
ref_prov_dict = mock_prov_object.to_dict(recurse=True,
render_unsets=True)
object_id = uuidutils.generate_uuid()
data = {constants.OBJECT: object_name, lib_consts.ID: object_id}
# Happy path
result = driver_get.process_get(data)
mock_object_repo.assert_called_once_with()
object_repo_mock.get.assert_called_once_with(
'bogus_session', id=object_id, show_deleted=False)
mock_object_to_provider.assert_called_once_with(db_object_mock)
self.assertEqual(ref_prov_dict, result)
# No matching listener
mock_object_repo.reset_mock()
mock_object_to_provider.reset_mock()
object_repo_mock.get.return_value = None
result = driver_get.process_get(data)
mock_object_repo.assert_called_once_with()
object_repo_mock.get.assert_called_once_with(
'bogus_session', id=object_id, show_deleted=False)
mock_object_to_provider.assert_not_called()
self.assertEqual({}, result)
@mock.patch('octavia.api.drivers.utils.'
'db_loadbalancer_to_provider_loadbalancer')
@mock.patch('octavia.db.repositories.LoadBalancerRepository')
def test_process_get_loadbalancer(self, mock_lb_repo, mock_lb_to_provider):
self._test_process_get_object(
lib_consts.LOADBALANCERS, mock_lb_repo, mock_lb_to_provider)
@mock.patch('octavia.api.drivers.utils.db_listener_to_provider_listener')
@mock.patch('octavia.db.repositories.ListenerRepository')
def test_process_get_listener(self, mock_listener_repo,
mock_listener_to_provider):
self._test_process_get_object(lib_consts.LISTENERS, mock_listener_repo,
mock_listener_to_provider)
@mock.patch('octavia.api.drivers.utils.db_pool_to_provider_pool')
@mock.patch('octavia.db.repositories.PoolRepository')
def test_process_get_pool(self, mock_pool_repo, mock_pool_to_provider):
self._test_process_get_object(lib_consts.POOLS, mock_pool_repo,
mock_pool_to_provider)
@mock.patch('octavia.api.drivers.utils.db_member_to_provider_member')
@mock.patch('octavia.db.repositories.MemberRepository')
def test_process_get_member(self, mock_member_repo,
mock_member_to_provider):
self._test_process_get_object(lib_consts.MEMBERS, mock_member_repo,
mock_member_to_provider)
@mock.patch('octavia.api.drivers.utils.db_HM_to_provider_HM')
@mock.patch('octavia.db.repositories.HealthMonitorRepository')
def test_process_get_healthmonitor(self, mock_hm_repo,
mock_hm_to_provider):
self._test_process_get_object(lib_consts.HEALTHMONITORS, mock_hm_repo,
mock_hm_to_provider)
@mock.patch('octavia.api.drivers.utils.db_l7policy_to_provider_l7policy')
@mock.patch('octavia.db.repositories.L7PolicyRepository')
def test_process_get_l7policy(self, mock_l7policy_repo,
mock_l7policy_to_provider):
self._test_process_get_object(lib_consts.L7POLICIES,
mock_l7policy_repo,
mock_l7policy_to_provider)
@mock.patch('octavia.api.drivers.utils.db_l7rule_to_provider_l7rule')
@mock.patch('octavia.db.repositories.L7RuleRepository')
def test_process_get_l7rule(self, mock_l7rule_repo,
mock_l7rule_to_provider):
self._test_process_get_object(lib_consts.L7RULES, mock_l7rule_repo,
mock_l7rule_to_provider)
@mock.patch('octavia.db.api.get_session')
def test_process_get_bogus_object(self, mock_get_session):
data = {constants.OBJECT: 'bogus', lib_consts.ID: 'bad ID'}
result = driver_get.process_get(data)
self.assertEqual({}, result)

View File

@ -110,6 +110,31 @@ class TestDriverListener(base.TestCase):
mock_send.assert_called_with(b'15\n')
mock_sendall.assert_called_with(jsonutils.dump_as_bytes(TEST_OBJECT))
@mock.patch('octavia.api.drivers.driver_agent.driver_get.'
'process_get')
@mock.patch('octavia.api.drivers.driver_agent.driver_listener._recv')
def test_GetRequestHandler_handle(self, mock_recv, mock_process_get):
TEST_OBJECT = {"test": "msg"}
mock_recv.return_value = 'bogus'
mock_process_get.return_value = TEST_OBJECT
mock_request = mock.MagicMock()
mock_send = mock.MagicMock()
mock_sendall = mock.MagicMock()
mock_request.send = mock_send
mock_request.sendall = mock_sendall
GetRequestHandler = driver_listener.GetRequestHandler(
mock_request, 'bogus', 'bogus')
GetRequestHandler.handle()
mock_recv.assert_called_with(mock_request)
mock_process_get.assert_called_with('bogus')
mock_send.assert_called_with(b'15\n')
mock_sendall.assert_called_with(jsonutils.dump_as_bytes(TEST_OBJECT))
@mock.patch('octavia.api.drivers.driver_agent.driver_listener.CONF')
def test_mutate_config(self, mock_conf):
driver_listener._mutate_config()
@ -169,3 +194,24 @@ class TestDriverListener(base.TestCase):
driver_listener.stats_listener(mock_exit_event)
mock_server.handle_request.assert_called()
mock_server.server_close.assert_called_once()
@mock.patch('octavia.api.drivers.driver_agent.driver_listener.'
'_cleanup_socket_file')
@mock.patch('octavia.api.drivers.driver_agent.driver_listener.signal')
@mock.patch('octavia.api.drivers.driver_agent.driver_listener.'
'ForkingUDSServer')
def test_get_listener(self, mock_forking_server,
mock_signal, mock_cleanup):
mock_server = mock.MagicMock()
mock_active_children = mock.PropertyMock(
side_effect=['a', 'a', 'a',
'a' * CONF.driver_agent.status_max_processes, 'a',
'a' * 1000, ''])
type(mock_server).active_children = mock_active_children
mock_forking_server.return_value = mock_server
mock_exit_event = mock.MagicMock()
mock_exit_event.is_set.side_effect = [False, False, False, False, True]
driver_listener.get_listener(mock_exit_event)
mock_server.handle_request.assert_called()
mock_server.server_close.assert_called_once()

View File

@ -24,7 +24,7 @@ from octavia.api.drivers import utils
from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
from octavia.tests.unit.api.drivers import sample_data_models
from octavia.tests.common import sample_data_models
from octavia.tests.unit import base
@ -140,12 +140,20 @@ class TestUtils(base.TestCase):
'operating_status': constants.OFFLINE,
'flavor_id': 'flavor_id',
'provider': 'noop_driver'}
ref_listeners = copy.deepcopy(self.sample_data.provider_listeners)
# TODO(johnsom) Remove this once the listener ACLs patch merges
# https://review.opendev.org/#/c/659626/
for listener in ref_listeners:
try:
del listener.allowed_cidrs
except AttributeError:
pass
ref_prov_lb_dict = {
'vip_address': self.sample_data.ip_address,
'admin_state_up': True,
'loadbalancer_id': self.sample_data.lb_id,
'vip_subnet_id': self.sample_data.subnet_id,
'listeners': self.sample_data.provider_listeners,
'listeners': ref_listeners,
'description': '',
'project_id': self.sample_data.project_id,
'vip_port_id': self.sample_data.port_id,
@ -211,8 +219,15 @@ class TestUtils(base.TestCase):
'sni_certs': [cert2, cert3]}
provider_listeners = utils.db_listeners_to_provider_listeners(
self.sample_data.test_db_listeners)
self.assertEqual(self.sample_data.provider_listeners,
provider_listeners)
ref_listeners = copy.deepcopy(self.sample_data.provider_listeners)
# TODO(johnsom) Remove this once the listener ACLs patch merges
# https://review.opendev.org/#/c/659626/
for listener in ref_listeners:
try:
del listener.allowed_cidrs
except AttributeError:
pass
self.assertEqual(ref_listeners, provider_listeners)
@mock.patch('octavia.api.drivers.utils._get_secret_data')
@mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data')

View File

@ -18,8 +18,8 @@ import mock
import six
import octavia.certificates.common.barbican as barbican_common
import octavia.tests.common.sample_certs as sample
import octavia.tests.unit.base as base
import octavia.tests.unit.common.sample_configs.sample_certs as sample
class TestBarbicanCert(base.TestCase):

View File

@ -21,8 +21,8 @@ import octavia.certificates.common.barbican as barbican_common
import octavia.certificates.common.cert as cert
import octavia.certificates.manager.barbican as barbican_cert_mgr
from octavia.common import exceptions
import octavia.tests.common.sample_certs as sample
import octavia.tests.unit.base as base
import octavia.tests.unit.common.sample_configs.sample_certs as sample
PROJECT_ID = "12345"

View File

@ -21,8 +21,8 @@ import six
import octavia.certificates.common.barbican as barbican_common
import octavia.certificates.common.cert as cert
import octavia.certificates.manager.barbican_legacy as barbican_cert_mgr
import octavia.tests.common.sample_certs as sample
import octavia.tests.unit.base as base
import octavia.tests.unit.common.sample_configs.sample_certs as sample
PROJECT_ID = "12345"

View File

@ -23,15 +23,16 @@ from oslo_utils import uuidutils
import octavia.certificates.common.cert as cert
import octavia.certificates.manager.local as local_cert_mgr
from octavia.common import exceptions
from octavia.tests.common import sample_certs
import octavia.tests.unit.base as base
class TestLocalManager(base.TestCase):
def setUp(self):
self.certificate = "My Certificate"
self.intermediates = "My Intermediates"
self.private_key = "My Private Key"
self.certificate = sample_certs.X509_CERT.decode('utf-8')
self.intermediates = sample_certs.X509_IMDS.decode('utf-8')
self.private_key = sample_certs.X509_CERT_KEY.decode('utf-8')
self.private_key_passphrase = "My Private Key Passphrase"
conf = oslo_fixture.Config(cfg.CONF)
@ -82,6 +83,12 @@ class TestLocalManager(base.TestCase):
def _get_cert(self, cert_id):
fd_mock = mock.mock_open()
fd_mock.side_effect = [
mock.mock_open(read_data=self.certificate).return_value,
mock.mock_open(read_data=self.private_key).return_value,
mock.mock_open(read_data=self.intermediates).return_value,
mock.mock_open(read_data=self.private_key_passphrase).return_value
]
open_mock = mock.Mock()
# Attempt to retrieve the cert
with mock.patch('os.open', open_mock), mock.patch.object(
@ -120,11 +127,8 @@ class TestLocalManager(base.TestCase):
self._store_cert()
def test_get_cert(self):
# Store a cert
cert_id = self._store_cert()
# Get the cert
self._get_cert(cert_id)
self._get_cert("cert1")
def test_delete_cert(self):
# Store a cert
@ -147,7 +151,7 @@ class TestLocalManager(base.TestCase):
# Verify the correct files were opened
flags = os.O_RDONLY
open_mock.assert_called_once_with('/tmp/{0}.pem'.format(secret_id),
open_mock.assert_called_once_with('/tmp/{0}.crt'.format(secret_id),
flags)
# Test failure path

View File

@ -45,15 +45,19 @@ class TestDriverAgentCMD(base.TestCase):
mock_multiprocessing.Event.return_value = mock_exit_event
mock_status_listener_proc = mock.MagicMock()
mock_stats_listener_proc = mock.MagicMock()
mock_get_listener_proc = mock.MagicMock()
mock_multiprocessing.Process.side_effect = [mock_status_listener_proc,
mock_stats_listener_proc,
mock_get_listener_proc,
mock_status_listener_proc,
mock_stats_listener_proc]
mock_stats_listener_proc,
mock_get_listener_proc]
driver_agent.main()
mock_prep_srvc.assert_called_once()
mock_gmr.assert_called_once()
mock_status_listener_proc.start.assert_called_once()
mock_stats_listener_proc.start.assert_called_once()
mock_get_listener_proc.start.assert_called_once()
process_calls = [mock.call(
args=mock_exit_event, name='status_listener',
target=(octavia.api.drivers.driver_agent.driver_listener.
@ -61,7 +65,11 @@ class TestDriverAgentCMD(base.TestCase):
mock.call(
args=mock_exit_event, name='stats_listener',
target=(octavia.api.drivers.driver_agent.driver_listener.
stats_listener))]
stats_listener)),
mock.call(
args=mock_exit_event, name='get_listener',
target=(octavia.api.drivers.driver_agent.driver_listener.
get_listener))]
mock_multiprocessing.Process.has_calls(process_calls, any_order=True)
# Test keyboard interrupt path

View File

@ -18,7 +18,7 @@ import collections
from oslo_config import cfg
from octavia.common import constants
from octavia.tests.unit.common.sample_configs import sample_certs
from octavia.tests.common import sample_certs
CONF = cfg.CONF

View File

@ -18,7 +18,7 @@ import collections
from oslo_config import cfg
from octavia.common import constants
from octavia.tests.unit.common.sample_configs import sample_certs
from octavia.tests.common import sample_certs
CONF = cfg.CONF

View File

@ -20,8 +20,8 @@ import mock
from octavia.common import data_models
import octavia.common.exceptions as exceptions
import octavia.common.tls_utils.cert_parser as cert_parser
from octavia.tests.common import sample_certs
from octavia.tests.unit import base
from octavia.tests.unit.common.sample_configs import sample_certs
from octavia.tests.unit.common.sample_configs import sample_configs_combined
@ -110,14 +110,12 @@ class TestTLSParseUtils(base.TestCase):
def test_get_intermediates_pem_chain(self):
self.assertEqual(
sample_certs.X509_IMDS_LIST,
[c for c in
cert_parser.get_intermediates_pems(sample_certs.X509_IMDS)])
list(cert_parser.get_intermediates_pems(sample_certs.X509_IMDS)))
def test_get_intermediates_pkcs7_pem(self):
self.assertEqual(
sample_certs.X509_IMDS_LIST,
[c for c in
cert_parser.get_intermediates_pems(sample_certs.PKCS7_PEM)])
list(cert_parser.get_intermediates_pems(sample_certs.PKCS7_PEM)))
def test_get_intermediates_pkcs7_pem_bad(self):
self.assertRaises(
@ -128,8 +126,7 @@ class TestTLSParseUtils(base.TestCase):
def test_get_intermediates_pkcs7_der(self):
self.assertEqual(
sample_certs.X509_IMDS_LIST,
[c for c in
cert_parser.get_intermediates_pems(sample_certs.PKCS7_DER)])
list(cert_parser.get_intermediates_pems(sample_certs.PKCS7_DER)))
def test_get_intermediates_pkcs7_der_bad(self):
self.assertRaises(
@ -217,7 +214,7 @@ class TestTLSParseUtils(base.TestCase):
self.assertEqual(expected, cert_parser.build_pem(tls_tuple))
def test_get_primary_cn(self):
cert = mock.MagicMock()
cert = sample_certs.X509_CERT
with mock.patch.object(cert_parser, 'get_host_names') as cp:
cp.return_value = {'cn': 'fakeCN'}

View File

@ -47,7 +47,8 @@ class TestNeutronUtils(base.TestCase):
project_id=t_constants.MOCK_PROJECT_ID,
gateway_ip=t_constants.MOCK_GATEWAY_IP,
cidr=t_constants.MOCK_CIDR,
ip_version=t_constants.MOCK_IP_VERSION
ip_version=t_constants.MOCK_IP_VERSION,
host_routes=[],
)
self._compare_ignore_value_none(model_obj.to_dict(), assert_dict)
@ -64,6 +65,7 @@ class TestNeutronUtils(base.TestCase):
status=t_constants.MOCK_STATUS,
project_id=t_constants.MOCK_PROJECT_ID,
admin_state_up=t_constants.MOCK_ADMIN_STATE_UP,
fixed_ips=[],
)
self._compare_ignore_value_none(model_obj.to_dict(), assert_dict)
fixed_ips = t_constants.MOCK_NEUTRON_PORT['port']['fixed_ips']

View File

@ -0,0 +1,4 @@
---
features:
- |
Adds support for the driver agent to query for load balancer objects.

View File

@ -48,6 +48,12 @@ commands =
coverage report --fail-under=90 --skip-covered
[testenv:functional]
# This is set as py27 right now, though the name is ambiguous.
basepython = python2.7
setenv = OS_TEST_PATH={toxinidir}/octavia/tests/functional
[testenv:functional-py27]
basepython = python2.7
setenv = OS_TEST_PATH={toxinidir}/octavia/tests/functional
[testenv:functional-py36]