diff --git a/neutron_lbaas/services/loadbalancer/data_models.py b/neutron_lbaas/services/loadbalancer/data_models.py index b127c9a05..983234103 100644 --- a/neutron_lbaas/services/loadbalancer/data_models.py +++ b/neutron_lbaas/services/loadbalancer/data_models.py @@ -37,6 +37,10 @@ from neutron_lbaas.services.loadbalancer import constants as l_const class BaseDataModel(object): + # NOTE(ihrachys): we could reuse the list to provide a default __init__ + # implementation. That would require handling custom default values though. + fields = [] + def to_dict(self, **kwargs): ret = {} for attr in self.__dict__: @@ -62,14 +66,16 @@ class BaseDataModel(object): @classmethod def from_dict(cls, model_dict): - return cls(**model_dict) + fields = {k: v for k, v in model_dict.items() + if k in cls.fields} + return cls(**fields) @classmethod def from_sqlalchemy_model(cls, sa_model, calling_classes=None): calling_classes = calling_classes or [] attr_mapping = vars(cls).get("attr_mapping") instance = cls() - for attr_name in vars(instance): + for attr_name in cls.fields: if attr_name.startswith('_'): continue if attr_mapping and attr_name in attr_mapping.keys(): @@ -137,6 +143,8 @@ class BaseDataModel(object): # instead of these. class AllocationPool(BaseDataModel): + fields = ['start', 'end'] + def __init__(self, start=None, end=None): self.start = start self.end = end @@ -144,6 +152,8 @@ class AllocationPool(BaseDataModel): class HostRoute(BaseDataModel): + fields = ['destination', 'nexthop'] + def __init__(self, destination=None, nexthop=None): self.destination = destination self.nexthop = nexthop @@ -151,6 +161,11 @@ class HostRoute(BaseDataModel): class Subnet(BaseDataModel): + fields = ['id', 'name', 'tenant_id', 'network_id', 'ip_version', 'cidr', + 'gateway_ip', 'enable_dhcp', 'ipv6_ra_mode', 'ipv6_address_mode', + 'shared', 'dns_nameservers', 'host_routes', 'allocation_pools', + 'subnetpool_id'] + def __init__(self, id=None, name=None, tenant_id=None, network_id=None, ip_version=None, cidr=None, gateway_ip=None, enable_dhcp=None, ipv6_ra_mode=None, ipv6_address_mode=None, shared=None, @@ -180,11 +195,13 @@ class Subnet(BaseDataModel): for route in host_routes] model_dict['allocation_pools'] = [AllocationPool.from_dict(ap) for ap in allocation_pools] - return Subnet(**model_dict) + return super(Subnet, cls).from_dict(model_dict) class IPAllocation(BaseDataModel): + fields = ['port_id', 'ip_address', 'subnet_id', 'network_id'] + def __init__(self, port_id=None, ip_address=None, subnet_id=None, network_id=None): self.port_id = port_id @@ -197,7 +214,7 @@ class IPAllocation(BaseDataModel): subnet = model_dict.pop('subnet', None) # TODO(blogan): add subnet to __init__. Can't do it yet because it # causes issues with converting SA models into data models. - instance = IPAllocation(**model_dict) + instance = super(IPAllocation, cls).from_dict(model_dict) setattr(instance, 'subnet', None) if subnet: setattr(instance, 'subnet', Subnet.from_dict(subnet)) @@ -206,6 +223,10 @@ class IPAllocation(BaseDataModel): class Port(BaseDataModel): + fields = ['id', 'tenant_id', 'name', 'network_id', 'mac_address', + 'admin_state_up', 'status', 'device_id', 'device_owner', + 'fixed_ips'] + def __init__(self, id=None, tenant_id=None, name=None, network_id=None, mac_address=None, admin_state_up=None, status=None, device_id=None, device_owner=None, fixed_ips=None): @@ -225,11 +246,13 @@ class Port(BaseDataModel): fixed_ips = model_dict.pop('fixed_ips', []) model_dict['fixed_ips'] = [IPAllocation.from_dict(fixed_ip) for fixed_ip in fixed_ips] - return Port(**model_dict) + return super(Port, cls).from_dict(model_dict) class ProviderResourceAssociation(BaseDataModel): + fields = ['provider_name', 'resource_id'] + def __init__(self, provider_name=None, resource_id=None): self.provider_name = provider_name self.resource_id = resource_id @@ -237,13 +260,16 @@ class ProviderResourceAssociation(BaseDataModel): @classmethod def from_dict(cls, model_dict): device_driver = model_dict.pop('device_driver', None) - instance = ProviderResourceAssociation(**model_dict) + instance = super(ProviderResourceAssociation, cls).from_dict( + model_dict) setattr(instance, 'device_driver', device_driver) return instance class SessionPersistence(BaseDataModel): + fields = ['pool_id', 'type', 'cookie_name', 'pool'] + def __init__(self, pool_id=None, type=None, cookie_name=None, pool=None): self.pool_id = pool_id @@ -261,11 +287,14 @@ class SessionPersistence(BaseDataModel): if pool: model_dict['pool'] = Pool.from_dict( pool) - return SessionPersistence(**model_dict) + return super(SessionPersistence, cls).from_dict(model_dict) class LoadBalancerStatistics(BaseDataModel): + fields = ['loadbalancer_id', 'bytes_in', 'bytes_out', 'active_connections', + 'total_connections', 'loadbalancer'] + def __init__(self, loadbalancer_id=None, bytes_in=None, bytes_out=None, active_connections=None, total_connections=None, loadbalancer=None): @@ -283,6 +312,10 @@ class LoadBalancerStatistics(BaseDataModel): class HealthMonitor(BaseDataModel): + fields = ['id', 'tenant_id', 'type', 'delay', 'timeout', 'max_retries', + 'http_method', 'url_path', 'expected_codes', + 'provisioning_status', 'admin_state_up', 'pool', 'name'] + def __init__(self, id=None, tenant_id=None, type=None, delay=None, timeout=None, max_retries=None, http_method=None, url_path=None, expected_codes=None, provisioning_status=None, @@ -323,11 +356,17 @@ class HealthMonitor(BaseDataModel): if pool: model_dict['pool'] = Pool.from_dict( pool) - return HealthMonitor(**model_dict) + return super(HealthMonitor, cls).from_dict(model_dict) class Pool(BaseDataModel): + fields = ['id', 'tenant_id', 'name', 'description', 'healthmonitor_id', + 'protocol', 'lb_algorithm', 'admin_state_up', 'operating_status', + 'provisioning_status', 'members', 'healthmonitor', + 'session_persistence', 'loadbalancer_id', 'loadbalancer', + 'listener', 'listeners', 'l7_policies'] + # Map deprecated attribute names to new ones. attr_mapping = {'sessionpersistence': 'session_persistence'} @@ -410,11 +449,15 @@ class Pool(BaseDataModel): session_persistence) if loadbalancer: model_dict['loadbalancer'] = LoadBalancer.from_dict(loadbalancer) - return Pool(**model_dict) + return super(Pool, cls).from_dict(model_dict) class Member(BaseDataModel): + fields = ['id', 'tenant_id', 'pool_id', 'address', 'protocol_port', + 'weight', 'admin_state_up', 'subnet_id', 'operating_status', + 'provisioning_status', 'pool', 'name'] + def __init__(self, id=None, tenant_id=None, pool_id=None, address=None, protocol_port=None, weight=None, admin_state_up=None, subnet_id=None, operating_status=None, @@ -445,10 +488,13 @@ class Member(BaseDataModel): if pool: model_dict['pool'] = Pool.from_dict( pool) - return Member(**model_dict) + return super(Member, cls).from_dict(model_dict) class SNI(BaseDataModel): + + fields = ['listener_id', 'tls_container_id', 'position', 'listener'] + def __init__(self, listener_id=None, tls_container_id=None, position=None, listener=None): self.listener_id = listener_id @@ -462,13 +508,12 @@ class SNI(BaseDataModel): def to_api_dict(self): return super(SNI, self).to_dict(listener=False) - @classmethod - def from_dict(cls, model_dict): - return SNI(**model_dict) - class TLSContainer(BaseDataModel): + fields = ['id', 'certificate', 'private_key', 'passphrase', + 'intermediates', 'primary_cn'] + def __init__(self, id=None, certificate=None, private_key=None, passphrase=None, intermediates=None, primary_cn=None): self.id = id @@ -481,6 +526,10 @@ class TLSContainer(BaseDataModel): class L7Rule(BaseDataModel): + fields = ['id', 'tenant_id', 'l7policy_id', 'type', 'compare_type', + 'invert', 'key', 'value', 'provisioning_status', + 'admin_state_up', 'policy'] + def __init__(self, id=None, tenant_id=None, l7policy_id=None, type=None, compare_type=None, invert=None, key=None, value=None, provisioning_status=None, @@ -514,11 +563,16 @@ class L7Rule(BaseDataModel): policy = model_dict.pop('policy', None) if policy: model_dict['policy'] = L7Policy.from_dict(policy) - return L7Rule(**model_dict) + return super(L7Rule, cls).from_dict(model_dict) class L7Policy(BaseDataModel): + fields = ['id', 'tenant_id', 'name', 'description', 'listener_id', + 'action', 'redirect_pool_id', 'redirect_url', 'position', + 'admin_state_up', 'provisioning_status', 'listener', 'rules', + 'redirect_pool'] + def __init__(self, id=None, tenant_id=None, name=None, description=None, listener_id=None, action=None, redirect_pool_id=None, redirect_url=None, position=None, @@ -563,11 +617,17 @@ class L7Policy(BaseDataModel): model_dict['redirect_pool'] = Pool.from_dict(redirect_pool) model_dict['rules'] = [L7Rule.from_dict(rule) for rule in rules] - return L7Policy(**model_dict) + return super(L7Policy, cls).from_dict(model_dict) class Listener(BaseDataModel): + fields = ['id', 'tenant_id', 'name', 'description', 'default_pool_id', + 'loadbalancer_id', 'protocol', 'default_tls_container_id', + 'sni_containers', 'protocol_port', 'connection_limit', + 'admin_state_up', 'provisioning_status', 'operating_status', + 'default_pool', 'loadbalancer', 'l7_policies'] + def __init__(self, id=None, tenant_id=None, name=None, description=None, default_pool_id=None, loadbalancer_id=None, protocol=None, default_tls_container_id=None, sni_containers=None, @@ -627,11 +687,16 @@ class Listener(BaseDataModel): model_dict['loadbalancer'] = LoadBalancer.from_dict(loadbalancer) model_dict['l7_policies'] = [L7Policy.from_dict(policy) for policy in l7_policies] - return Listener(**model_dict) + return super(Listener, cls).from_dict(model_dict) class LoadBalancer(BaseDataModel): + fields = ['id', 'tenant_id', 'name', 'description', 'vip_subnet_id', + 'vip_port_id', 'vip_address', 'provisioning_status', + 'operating_status', 'admin_state_up', 'vip_port', 'stats', + 'provider', 'listeners', 'pools', 'flavor_id'] + def __init__(self, id=None, tenant_id=None, name=None, description=None, vip_subnet_id=None, vip_port_id=None, vip_address=None, provisioning_status=None, operating_status=None, @@ -687,7 +752,7 @@ class LoadBalancer(BaseDataModel): if provider: model_dict['provider'] = ProviderResourceAssociation.from_dict( provider) - return LoadBalancer(**model_dict) + return super(LoadBalancer, cls).from_dict(model_dict) SA_MODEL_TO_DATA_MODEL_MAP = { diff --git a/neutron_lbaas/tests/tools.py b/neutron_lbaas/tests/tools.py new file mode 100644 index 000000000..b5880096e --- /dev/null +++ b/neutron_lbaas/tests/tools.py @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string + + +# NOTE(ihrachys): this function is copied from neutron tree +def get_random_string(n=10): + return ''.join(random.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/neutron_lbaas/tests/unit/drivers/common/test_agent_callbacks.py b/neutron_lbaas/tests/unit/drivers/common/test_agent_callbacks.py index 758d1bcb2..df119c873 100644 --- a/neutron_lbaas/tests/unit/drivers/common/test_agent_callbacks.py +++ b/neutron_lbaas/tests/unit/drivers/common/test_agent_callbacks.py @@ -153,6 +153,7 @@ class TestLoadBalancerCallbacks( expected_lb['provider']['device_driver'] = 'dummy' subnet = self.plugin_instance.db._core_plugin.get_subnet( ctx, expected_lb['vip_subnet_id']) + subnet = data_models.Subnet.from_dict(subnet).to_dict() expected_lb['vip_port']['fixed_ips'][0]['subnet'] = subnet del expected_lb['stats'] self.assertEqual(expected_lb, load_balancer) diff --git a/neutron_lbaas/tests/unit/services/loadbalancer/test_data_models.py b/neutron_lbaas/tests/unit/services/loadbalancer/test_data_models.py new file mode 100644 index 000000000..f5fdb60b7 --- /dev/null +++ b/neutron_lbaas/tests/unit/services/loadbalancer/test_data_models.py @@ -0,0 +1,97 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import mock +import testscenarios + +from neutron_lbaas.services.loadbalancer import data_models +from neutron_lbaas.tests import base +from neutron_lbaas.tests import tools + +load_tests = testscenarios.load_tests_apply_scenarios + + +class TestBaseDataModel(base.BaseTestCase): + + def _get_fake_model_cls(self, fields_): + class FakeModel(data_models.BaseDataModel): + fields = fields_ + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + return FakeModel + + def test_from_dict(self): + + fields_ = ['field1', 'field2'] + dict_ = {field: tools.get_random_string() + for field in fields_} + + model_cls = self._get_fake_model_cls(fields_) + model = model_cls.from_dict(dict_) + + for field in fields_: + self.assertEqual(dict_[field], getattr(model, field)) + + def test_from_dict_filters_by_fields(self): + + fields_ = ['field1', 'field2'] + dict_ = {field: tools.get_random_string() + for field in fields_} + dict_['foo'] = 'bar' + + model_cls = self._get_fake_model_cls(fields_) + model = model_cls.from_dict(dict_) + self.assertFalse(hasattr(model, 'foo')) + + +def _get_models(): + models = [] + for name, obj in inspect.getmembers(data_models): + if inspect.isclass(obj): + if issubclass(obj, data_models.BaseDataModel): + if type(obj) != data_models.BaseDataModel: + models.append(obj) + return models + + +class TestModels(base.BaseTestCase): + + scenarios = [ + (model.__name__, {'model': model}) + for model in _get_models() + ] + + @staticmethod + def _get_iterable_mock(*args, **kwargs): + m = mock.create_autospec(dict, spec_set=True) + + def _get_empty_iterator(*args, **kwargs): + return iter([]) + + m.__iter__ = _get_empty_iterator + m.pop = _get_empty_iterator + return m + + def test_from_dict_filters_by_fields(self): + + dict_ = {field: self._get_iterable_mock() + for field in self.model.fields} + dict_['foo'] = 'bar' + + model = self.model.from_dict(dict_) + self.assertFalse(hasattr(model, 'foo'))