diff --git a/octavia/api/v1/controllers/base.py b/octavia/api/v1/controllers/base.py index 11e002d77a..4cbaa68b55 100644 --- a/octavia/api/v1/controllers/base.py +++ b/octavia/api/v1/controllers/base.py @@ -16,6 +16,7 @@ from oslo_config import cfg from pecan import rest from stevedore import driver as stevedore_driver +from octavia.api.v1.types import listener as listener_types from octavia.api.v1.types import load_balancer as lb_types from octavia.api.v1.types import pool as pool_types from octavia.db import repositories @@ -52,9 +53,12 @@ class BaseController(rest.RestController): api_type.session_persistence = ( pool_types.SessionPersistenceResponse.from_data_model( db_obj.session_persistence)) + elif to_type == listener_types.ListenerResponse: + api_type.sni_containers = [sni_c.tls_container_id + for sni_c in db_obj.sni_containers] return api_type if isinstance(db_entity, list): converted = [_convert(db_obj) for db_obj in db_entity] else: converted = _convert(db_entity) - return converted \ No newline at end of file + return converted diff --git a/octavia/api/v1/controllers/listener.py b/octavia/api/v1/controllers/listener.py index 4cfa38059d..686c4132fc 100644 --- a/octavia/api/v1/controllers/listener.py +++ b/octavia/api/v1/controllers/listener.py @@ -83,8 +83,16 @@ class ListenersController(base.BaseController): Update the load balancer db when provisioning status changes. """ try: + sni_container_ids = listener_dict.pop('sni_containers') db_listener = self.repositories.listener.create( session, **listener_dict) + if sni_container_ids is not None: + for container_id in sni_container_ids: + sni_dict = {'listener_id': db_listener.id, + 'tls_container_id': container_id} + self.repositories.sni.create(session, **sni_dict) + db_listener = self.repositories.listener.get(session, + id=db_listener.id) except odb_exceptions.DBDuplicateEntry as de: # Setting LB back to active because this is just a validation # failure @@ -134,7 +142,6 @@ class ListenersController(base.BaseController): del listener_dict['tls_termination'] # This is the extra validation layer for wrong protocol or duplicate # listeners on the same load balancer. - return self._validate_listeners(session, lb_repo, listener_dict) def _test_lb_status_put(self, session, id): diff --git a/octavia/api/v1/types/listener.py b/octavia/api/v1/types/listener.py index 5e7d199357..0adf4c58fc 100644 --- a/octavia/api/v1/types/listener.py +++ b/octavia/api/v1/types/listener.py @@ -36,6 +36,7 @@ class ListenerResponse(base.BaseType): protocol_port = wtypes.wsattr(wtypes.IntegerType()) connection_limit = wtypes.wsattr(wtypes.IntegerType()) tls_certificate_id = wtypes.wsattr(wtypes.StringType(max_length=255)) + sni_containers = [wtypes.StringType(max_length=255)] class ListenerPOST(base.BaseType): @@ -49,6 +50,7 @@ class ListenerPOST(base.BaseType): connection_limit = wtypes.wsattr(wtypes.IntegerType()) tls_certificate_id = wtypes.wsattr(wtypes.StringType(max_length=255)) tls_termination = wtypes.wsattr(TLSTermination) + sni_containers = [wtypes.StringType(max_length=255)] class ListenerPUT(base.BaseType): @@ -61,3 +63,4 @@ class ListenerPUT(base.BaseType): connection_limit = wtypes.wsattr(wtypes.IntegerType()) tls_certificate_id = wtypes.wsattr(wtypes.StringType(max_length=255)) tls_termination = wtypes.wsattr(TLSTermination) + sni_containers = [wtypes.StringType(max_length=255)] diff --git a/octavia/tests/functional/api/v1/test_listener.py b/octavia/tests/functional/api/v1/test_listener.py index d2d220e8fb..a16f26af17 100644 --- a/octavia/tests/functional/api/v1/test_listener.py +++ b/octavia/tests/functional/api/v1/test_listener.py @@ -74,10 +74,13 @@ class TestListener(base.BaseAPITest): self.get(listener_path, status=404) def test_create(self, **optionals): + sni1 = uuidutils.generate_uuid() + sni2 = uuidutils.generate_uuid() lb_listener = {'name': 'listener1', 'description': 'desc1', 'enabled': False, 'protocol': constants.PROTOCOL_HTTP, 'protocol_port': 80, 'connection_limit': 10, - 'tls_certificate_id': uuidutils.generate_uuid()} + 'tls_certificate_id': uuidutils.generate_uuid(), + 'sni_containers': [sni1, sni2]} lb_listener.update(optionals) response = self.post(self.listeners_path, lb_listener) listener_api = response.json @@ -88,6 +91,12 @@ class TestListener(base.BaseAPITest): for key, value in optionals.items(): self.assertEqual(value, lb_listener.get(key)) lb_listener['id'] = listener_api.get('id') + lb_listener.pop('sni_containers') + sni_ex = [sni1, sni2] + sni_resp = listener_api.pop('sni_containers') + self.assertEqual(2, len(sni_resp)) + for sni in sni_resp: + self.assertTrue(sni in sni_ex) self.assertEqual(lb_listener, listener_api) self.assert_correct_lb_status(self.lb.get('id'), constants.PENDING_UPDATE, @@ -111,7 +120,8 @@ class TestListener(base.BaseAPITest): def test_create_defaults(self): defaults = {'name': None, 'description': None, 'enabled': True, - 'connection_limit': None, 'tls_certificate_id': None} + 'connection_limit': None, 'tls_certificate_id': None, + 'sni_containers': []} lb_listener = {'protocol': constants.PROTOCOL_HTTP, 'protocol_port': 80} response = self.post(self.listeners_path, lb_listener)