diff --git a/tricircle/common/constants.py b/tricircle/common/constants.py index e0ebad03..483cbd8a 100644 --- a/tricircle/common/constants.py +++ b/tricircle/common/constants.py @@ -31,6 +31,7 @@ RT_SD_PORT = 'shadow_port' RT_ROUTER = 'router' RT_NS_ROUTER = 'ns_router' RT_SG = 'security_group' +RT_FIP = 'floatingip' REAL_SHADOW_TYPE_MAP = { RT_NETWORK: RT_SD_NETWORK, diff --git a/tricircle/tests/unit/network/test_central_plugin.py b/tricircle/tests/unit/network/test_central_plugin.py index ba677e11..cf0bb5f9 100644 --- a/tricircle/tests/unit/network/test_central_plugin.py +++ b/tricircle/tests/unit/network/test_central_plugin.py @@ -22,15 +22,9 @@ import six from six.moves import xrange import unittest -from sqlalchemy.orm import attributes -from sqlalchemy.orm import exc -from sqlalchemy.sql import elements -from sqlalchemy.sql import selectable - from neutron_lib.api.definitions import portbindings from neutron_lib.api.definitions import provider_net import neutron_lib.constants as q_constants -import neutron_lib.context as q_context import neutron_lib.exceptions as q_lib_exc from neutron_lib.plugins import directory @@ -42,7 +36,6 @@ from neutron.db import ipam_pluggable_backend from neutron.db import l3_db from neutron.db import models_v2 from neutron.db import rbac_db_models as rbac_db -import neutron.objects.exceptions as q_obj_exceptions from neutron.extensions import availability_zone as az_ext @@ -74,68 +67,37 @@ from tricircle.network.drivers import type_vxlan from tricircle.network import helper from tricircle.network import managers from tricircle.tests.unit.network import test_security_groups +import tricircle.tests.unit.utils as test_utils from tricircle.xjob import xmanager -TOP_NETS = [] -TOP_SUBNETS = [] -TOP_PORTS = [] -TOP_ROUTERS = [] -TOP_ROUTERPORT = [] -TOP_SUBNETPOOLS = [] -TOP_SUBNETPOOLPREFIXES = [] -TOP_IPALLOCATIONS = [] -TOP_VLANALLOCATIONS = [] -TOP_VXLANALLOCATIONS = [] -TOP_FLATALLOCATIONS = [] -TOP_SEGMENTS = [] -TOP_EXTNETS = [] -TOP_FLOATINGIPS = [] -TOP_SGS = [] -TOP_SG_RULES = [] -TOP_NETWORK_RBAC = [] -TOP_SUBNETROUTES = [] -TOP_DNSNAMESERVERS = [] -BOTTOM1_NETS = [] -BOTTOM1_SUBNETS = [] -BOTTOM1_PORTS = [] -BOTTOM1_ROUTERS = [] -BOTTOM1_SGS = [] -BOTTOM1_FIPS = [] -BOTTOM2_NETS = [] -BOTTOM2_SUBNETS = [] -BOTTOM2_PORTS = [] -BOTTOM2_ROUTERS = [] -BOTTOM2_SGS = [] -BOTTOM2_FIPS = [] -RES_LIST = [TOP_NETS, TOP_SUBNETS, TOP_PORTS, TOP_ROUTERS, TOP_ROUTERPORT, - TOP_SUBNETPOOLS, TOP_SUBNETPOOLPREFIXES, TOP_IPALLOCATIONS, - TOP_VLANALLOCATIONS, TOP_VXLANALLOCATIONS, TOP_FLOATINGIPS, - TOP_SEGMENTS, TOP_EXTNETS, TOP_FLOATINGIPS, TOP_SGS, TOP_SG_RULES, - TOP_NETWORK_RBAC, TOP_SUBNETROUTES, TOP_DNSNAMESERVERS, - BOTTOM1_NETS, BOTTOM1_SUBNETS, BOTTOM1_PORTS, BOTTOM1_ROUTERS, - BOTTOM1_SGS, BOTTOM1_FIPS, - BOTTOM2_NETS, BOTTOM2_SUBNETS, BOTTOM2_PORTS, BOTTOM2_ROUTERS, - BOTTOM2_SGS, BOTTOM2_FIPS] -RES_MAP = {'networks': TOP_NETS, - 'subnets': TOP_SUBNETS, - 'ports': TOP_PORTS, - 'routers': TOP_ROUTERS, - 'routerports': TOP_ROUTERPORT, - 'ipallocations': TOP_IPALLOCATIONS, - 'subnetpools': TOP_SUBNETPOOLS, - 'subnetpoolprefixes': TOP_SUBNETPOOLPREFIXES, - 'ml2_vlan_allocations': TOP_VLANALLOCATIONS, - 'ml2_vxlan_allocations': TOP_VXLANALLOCATIONS, - 'ml2_flat_allocations': TOP_FLATALLOCATIONS, - 'networksegments': TOP_SEGMENTS, - 'externalnetworks': TOP_EXTNETS, - 'floatingips': TOP_FLOATINGIPS, - 'securitygroups': TOP_SGS, - 'securitygrouprules': TOP_SG_RULES, - 'networkrbacs': TOP_NETWORK_RBAC, - 'subnetroutes': TOP_SUBNETROUTES, - 'dnsnameservers': TOP_DNSNAMESERVERS} -TEST_TENANT_ID = 'test_tenant_id' +_resource_store = test_utils.get_resource_store() +TOP_NETS = _resource_store.TOP_NETWORKS +TOP_SUBNETS = _resource_store.TOP_SUBNETS +TOP_PORTS = _resource_store.TOP_PORTS +TOP_ROUTERS = _resource_store.TOP_ROUTERS +TOP_ROUTERPORTS = _resource_store.TOP_ROUTERPORTS +TOP_IPALLOCATIONS = _resource_store.TOP_IPALLOCATIONS +TOP_VLANALLOCATIONS = _resource_store.TOP_ML2_VLAN_ALLOCATIONS +TOP_VXLANALLOCATIONS = _resource_store.TOP_ML2_VXLAN_ALLOCATIONS +TOP_FLATALLOCATIONS = _resource_store.TOP_ML2_FLAT_ALLOCATIONS +TOP_SEGMENTS = _resource_store.TOP_NETWORKSEGMENTS +TOP_FLOATINGIPS = _resource_store.TOP_FLOATINGIPS +TOP_SGS = _resource_store.TOP_SECURITYGROUPS +TOP_SG_RULES = _resource_store.TOP_SECURITYGROUPRULES +BOTTOM1_NETS = _resource_store.BOTTOM1_NETWORKS +BOTTOM1_SUBNETS = _resource_store.BOTTOM1_SUBNETS +BOTTOM1_PORTS = _resource_store.BOTTOM1_PORTS +BOTTOM1_SGS = _resource_store.BOTTOM1_SECURITYGROUPS +BOTTOM1_FIPS = _resource_store.BOTTOM1_FLOATINGIPS +BOTTOM1_ROUTERS = _resource_store.BOTTOM1_ROUTERS +BOTTOM2_NETS = _resource_store.BOTTOM2_NETWORKS +BOTTOM2_SUBNETS = _resource_store.BOTTOM2_SUBNETS +BOTTOM2_PORTS = _resource_store.BOTTOM2_PORTS +BOTTOM2_SGS = _resource_store.BOTTOM2_SECURITYGROUPS +BOTTOM2_FIPS = _resource_store.BOTTOM2_FLOATINGIPS +BOTTOM2_ROUTERS = _resource_store.BOTTOM2_ROUTERS +TEST_TENANT_ID = test_utils.TEST_TENANT_ID +FakeNeutronContext = test_utils.FakeNeutronContext def _fill_external_gateway_info(router): @@ -155,7 +117,7 @@ def _fill_external_gateway_info(router): def _transform_az(network): az_hints_key = 'availability_zone_hints' if az_hints_key in network: - ret = DotDict(network) + ret = test_utils.DotDict(network) az_str = network[az_hints_key] ret[az_hints_key] = jsonutils.loads(az_str) if az_str else [] return ret @@ -246,104 +208,14 @@ class FakePool(driver.Pool): pass -class DotDict(dict): - def __init__(self, normal_dict=None): - if normal_dict: - for key, value in six.iteritems(normal_dict): - self[key] = value - - def __getattr__(self, item): - dummy_value_map = { - 'rbac_entries': [], - 'segment_host_mapping': [] - } - if item in dummy_value_map: - return dummy_value_map[item] - return self.get(item) - - def to_dict(self): - return self - - def __copy__(self): - return DotDict(self) - - def bump_revision(self): - pass +class FakeNeutronClient(test_utils.FakeNeutronClient): + _resource = 'port' + ports_path = '' -class DotList(list): - def all(self): - return self - - -class FakeNeutronClient(object): - - _res_map = {'top': {'port': TOP_PORTS}, - 'pod_1': {'port': BOTTOM1_PORTS}, - 'pod_2': {'port': BOTTOM2_PORTS}} - - def __init__(self, region_name): - self.region_name = region_name - self.ports_path = '' - - def _get(self, params=None): - port_list = self._res_map[self.region_name]['port'] - - if not params: - return {'ports': port_list} - if 'marker' in params: - sorted_list = sorted(port_list, key=lambda x: x['id']) - for i, port in enumerate(sorted_list): - if port['id'] == params['marker']: - return {'ports': sorted_list[i + 1:]} - if 'filters' in params: - return_list = [] - for port in port_list: - is_selected = True - for key, value in six.iteritems(params['filters']): - if key not in port or not port[key] or ( - port[key] not in value): - is_selected = False - break - if is_selected: - return_list.append(port) - return {'ports': return_list} - return {'ports': port_list} - - def get(self, path, params=None): - if self.region_name in ['pod_1', 'pod_2', 'top']: - res_list = self._get(params)['ports'] - return_list = [] - for res in res_list: - if self.region_name != 'top': - res = copy.copy(res) - return_list.append(res) - return {'ports': return_list} - else: - raise Exception() - - -class FakeClient(object): - - _res_map = {'top': RES_MAP, - 'pod_1': {'network': BOTTOM1_NETS, - 'subnet': BOTTOM1_SUBNETS, - 'port': BOTTOM1_PORTS, - 'router': BOTTOM1_ROUTERS, - 'security_group': BOTTOM1_SGS, - 'floatingip': BOTTOM1_FIPS}, - 'pod_2': {'network': BOTTOM2_NETS, - 'subnet': BOTTOM2_SUBNETS, - 'port': BOTTOM2_PORTS, - 'router': BOTTOM2_ROUTERS, - 'security_group': BOTTOM2_SGS, - 'floatingip': BOTTOM2_FIPS}} - +class FakeClient(test_utils.FakeClient): def __init__(self, region_name=None): - if not region_name: - self.region_name = 'top' - else: - self.region_name = region_name + super(FakeClient, self).__init__(region_name) self.client = FakeNeutronClient(self.region_name) def get_native_client(self, resource, ctx): @@ -401,53 +273,7 @@ class FakeClient(object): body[_type]['gateway_ip'] = cidr[:cidr.rindex('.')] + '.1' if 'id' not in body[_type]: body[_type]['id'] = uuidutils.generate_uuid() - res_list = self._res_map[self.region_name][_type] - res = dict(body[_type]) - res_list.append(res) - return res - - def list_resources(self, _type, ctx, filters=None): - if self.region_name == 'top': - res_list = self._res_map[self.region_name][_type + 's'] - else: - res_list = self._res_map[self.region_name][_type] - ret_list = [] - for res in res_list: - is_selected = True - for _filter in filters: - if _filter['key'] not in res: - is_selected = False - break - if _filter['value'] != res[_filter['key']]: - is_selected = False - break - if is_selected: - ret_list.append(res) - return ret_list - - def update_resources(self, _type, ctx, _id, body): - if self.region_name == 'top': - res_list = self._res_map[self.region_name][_type + 's'] - else: - res_list = self._res_map[self.region_name][_type] - updated = False - for res in res_list: - if res['id'] == _id: - updated = True - res.update(body[_type]) - return updated - - def delete_resources(self, _type, ctx, _id): - index = -1 - if self.region_name == 'top': - res_list = self._res_map[self.region_name][_type + 's'] - else: - res_list = self._res_map[self.region_name][_type] - for i, res in enumerate(res_list): - if res['id'] == _id: - index = i - if index != -1: - del res_list[index] + return super(FakeClient, self).create_resources(_type, ctx, body) def list_networks(self, ctx, filters=None): networks = self.list_resources('network', ctx, filters) @@ -459,9 +285,7 @@ class FakeClient(object): return ret_list def get_networks(self, ctx, net_id): - return self.list_networks(ctx, [{'key': 'id', - 'comparator': 'eq', - 'value': net_id}])[0] + return self.get_resource(constants.RT_NETWORK, ctx, net_id) def delete_networks(self, ctx, net_id): self.delete_resources('network', ctx, net_id) @@ -473,9 +297,7 @@ class FakeClient(object): return self.list_resources('subnet', ctx, filters) def get_subnets(self, ctx, subnet_id): - return self.list_resources('subnet', ctx, [{'key': 'id', - 'comparator': 'eq', - 'value': subnet_id}])[0] + return self.get_resource(constants.RT_SUBNET, ctx, subnet_id) def delete_subnets(self, ctx, subnet_id): self.delete_resources('subnet', ctx, subnet_id) @@ -509,11 +331,11 @@ class FakeClient(object): if key != 'fields': value = query_filter['value'] filter_dict[key] = value - return self.client.get('', {'filters': filter_dict})['ports'] + return self.client.get('', filter_dict)['ports'] def get_ports(self, ctx, port_id): return self.client.get( - '', params={'filters': {'id': [port_id]}})['ports'][0] + '', params={'id': [port_id]})['ports'][0] def delete_ports(self, ctx, port_id): self.delete_resources('port', ctx, port_id) @@ -582,9 +404,7 @@ class FakeClient(object): pass def get_routers(self, ctx, router_id): - router = self.list_resources('router', ctx, [{'key': 'id', - 'comparator': 'eq', - 'value': router_id}])[0] + router = self.get_resource(constants.RT_ROUTER, ctx, router_id) return _fill_external_gateway_info(router) def delete_routers(self, ctx, router_id): @@ -648,81 +468,16 @@ class FakeClient(object): return def get_security_groups(self, ctx, sg_id): - res_list = self._res_map[self.region_name]['security_group'] - for sg in res_list: - if sg['id'] == sg_id: - # need to do a deep copy because we will traverse the security - # group's 'security_group_rules' field and make change to the - # group - ret_sg = copy.deepcopy(sg) - return ret_sg + sg = self.get_resource(constants.RT_SG, ctx, sg_id) + # need to do a deep copy because we will traverse the security + # group's 'security_group_rules' field and make change to the + # group + return copy.deepcopy(sg) def get_security_group(self, context, _id, fields=None, tenant_id=None): pass -class FakeNeutronContext(q_context.Context): - def __init__(self): - self._session = None - self.is_admin = True - self.is_advsvc = False - self.tenant_id = TEST_TENANT_ID - - @property - def session(self): - if not self._session: - self._session = FakeSession() - return self._session - - def elevated(self): - return self - - -def delete_model(res_list, model_obj, key=None): - if not res_list: - return - if not key: - key = 'id' - if key not in res_list[0]: - return - index = -1 - for i, res in enumerate(res_list): - if res[key] == model_obj[key]: - index = i - break - if index != -1: - del res_list[index] - return - - -def link_models(model_obj, model_dict, foreign_table, foreign_key, table, key, - link_prop): - if model_obj.__tablename__ == foreign_table: - for instance in RES_MAP[table]: - if instance[key] == model_dict[foreign_key]: - if link_prop not in instance: - instance[link_prop] = [] - instance[link_prop].append(model_dict) - - -def unlink_models(res_list, model_dict, foreign_key, key, link_prop, - link_ele_foreign_key, link_ele_key): - if foreign_key not in model_dict: - return - for instance in res_list: - if instance[key] == model_dict[foreign_key]: - if link_prop not in instance: - return - index = -1 - for i, res in enumerate(instance[link_prop]): - if res[link_ele_foreign_key] == model_dict[link_ele_key]: - index = i - break - if index != -1: - del instance[link_prop][index] - return - - def update_floatingip(self, context, _id, floatingip): for fip in TOP_FLOATINGIPS: if fip['id'] != _id: @@ -740,7 +495,7 @@ def update_floatingip(self, context, _id, floatingip): update_dict['fixed_port_id'] = port['id'] update_dict[ 'fixed_ip_address'] = port['fixed_ips'][0]['ip_address'] - for router_port in TOP_ROUTERPORT: + for router_port in TOP_ROUTERPORTS: for _port in TOP_PORTS: if _port['id'] != router_port['port_id']: continue @@ -750,287 +505,6 @@ def update_floatingip(self, context, _id, floatingip): fip.update(update_dict) -class FakeQuery(object): - pk_map = {'ports': 'id'} - - def __init__(self, records, table): - self.records = records - self.table = table - self.index = 0 - - def _handle_pagination_by_id(self, record_id): - for i, record in enumerate(self.records): - if record['id'] == record_id: - if i + 1 < len(self.records): - return FakeQuery(self.records[i + 1:], self.table) - else: - return FakeQuery([], self.table) - return FakeQuery([], self.table) - - def _handle_filter(self, keys, values): - filtered_list = [] - for record in self.records: - selected = True - for i, key in enumerate(keys): - if key not in record or record[key] != values[i]: - selected = False - break - if selected: - filtered_list.append(record) - return FakeQuery(filtered_list, self.table) - - def filter(self, *criteria): - _filter = [] - keys = [] - values = [] - for e in criteria: - if not hasattr(e, 'right') and isinstance(e, elements.False_): - # filter is a single False value, set key to a 'INVALID_FIELD' - # then no records will be returned - keys.append('INVALID_FIELD') - values.append(False) - elif hasattr(e, 'right') and not isinstance(e.right, - elements.Null): - _filter.append(e) - elif isinstance(e, selectable.Exists): - # handle external network filter - expression = e.element.element._whereclause - if expression.right.name == 'network_id': - keys.append('router:external') - values.append(True) - if not _filter: - if not keys: - return FakeQuery(self.records, self.table) - else: - return self._handle_filter(keys, values) - if hasattr(_filter[0].right, 'value'): - keys.extend([e.left.name for e in _filter]) - values.extend([e.right.value for e in _filter]) - else: - keys.extend([e.expression.left.name for e in _filter]) - values.extend( - [e.expression.right.element.clauses[0].value for e in _filter]) - if _filter[0].expression.operator.__name__ == 'lt': - return self._handle_pagination_by_id(values[0]) - else: - return self._handle_filter(keys, values) - - def filter_by(self, **kwargs): - filtered_list = [] - for record in self.records: - selected = True - for key, value in six.iteritems(kwargs): - if key not in record or record[key] != value: - selected = False - break - if selected: - filtered_list.append(record) - return FakeQuery(filtered_list, self.table) - - def get(self, pk): - pk_field = self.pk_map[self.table] - for record in self.records: - if record.get(pk_field) == pk: - return record - - def delete(self, synchronize_session=False): - for model_obj in self.records: - unlink_models(RES_MAP['routers'], model_obj, 'router_id', - 'id', 'attached_ports', 'port_id', 'port_id') - delete_model(RES_MAP[self.table], model_obj, key='port_id') - - def outerjoin(self, *props, **kwargs): - return FakeQuery(self.records, self.table) - - def join(self, *props, **kwargs): - return FakeQuery(self.records, self.table) - - def order_by(self, func): - self.records.sort(key=lambda x: x['id']) - return FakeQuery(self.records, self.table) - - def enable_eagerloads(self, value): - return FakeQuery(self.records, self.table) - - def limit(self, limit): - return FakeQuery(self.records[:limit], self.table) - - def next(self): - if self.index >= len(self.records): - raise StopIteration - self.index += 1 - return self.records[self.index - 1] - - __next__ = next - - def one(self): - if len(self.records) == 0: - raise exc.NoResultFound() - return self.records[0] - - def first(self): - if len(self.records) == 0: - return None - else: - return self.records[0] - - def update(self, values): - for record in self.records: - for key, value in six.iteritems(values): - record[key] = value - return len(self.records) - - def all(self): - return self.records - - def count(self): - return len(self.records) - - def __iter__(self): - return self - - -class FakeSession(object): - class WithWrapper(object): - def __enter__(self): - pass - - def __exit__(self, type, value, traceback): - pass - - def __init__(self): - self.info = {} - - def __getattr__(self, field): - def dummy_method(*args, **kwargs): - pass - - return dummy_method - - def __contains__(self, item): - return False - - @property - def is_active(self): - return True - - def begin(self, subtransactions=False, nested=True): - return FakeSession.WithWrapper() - - def begin_nested(self): - return FakeSession.WithWrapper() - - def query(self, model): - if isinstance(model, attributes.InstrumentedAttribute): - model = model.class_ - if model.__tablename__ not in RES_MAP: - return FakeQuery([], model.__tablename__) - return FakeQuery(RES_MAP[model.__tablename__], - model.__tablename__) - - def _extend_standard_attr(self, model_dict): - if 'standard_attr' in model_dict: - for field in ('resource_type', 'description', 'revision_number', - 'created_at', 'updated_at'): - model_dict[field] = getattr(model_dict['standard_attr'], field) - - def add(self, model_obj): - if model_obj.__tablename__ not in RES_MAP: - return - model_dict = DotDict(model_obj._as_dict()) - if 'project_id' in model_dict: - model_dict['tenant_id'] = model_dict['project_id'] - - if model_obj.__tablename__ == 'networks': - model_dict['subnets'] = [] - if model_obj.__tablename__ == 'ports': - model_dict['dhcp_opts'] = [] - model_dict['security_groups'] = [] - model_dict['fixed_ips'] = [] - - link_models(model_obj, model_dict, - 'subnetpoolprefixes', 'subnetpool_id', - 'subnetpools', 'id', 'prefixes') - link_models(model_obj, model_dict, - 'ipallocations', 'port_id', - 'ports', 'id', 'fixed_ips') - link_models(model_obj, model_dict, - 'subnets', 'network_id', 'networks', 'id', 'subnets') - link_models(model_obj, model_dict, - 'securitygrouprules', 'security_group_id', - 'securitygroups', 'id', 'security_group_rules') - - if model_obj.__tablename__ == 'routerports': - for port in TOP_PORTS: - if port['id'] == model_dict['port_id']: - model_dict['port'] = port - port.update(model_dict) - break - if model_obj.__tablename__ == 'externalnetworks': - for net in TOP_NETS: - if net['id'] == model_dict['network_id']: - net['external'] = True - net['router:external'] = True - break - if model_obj.__tablename__ == 'networkrbacs': - if (model_dict['action'] == 'access_as_shared' and - model_dict['target_tenant'] == '*'): - for net in TOP_NETS: - if net['id'] == model_dict['object']: - net['shared'] = True - break - - link_models(model_obj, model_dict, - 'routerports', 'router_id', - 'routers', 'id', 'attached_ports') - - if model_obj.__tablename__ == 'subnetroutes': - for subnet in TOP_SUBNETS: - if subnet['id'] != model_dict['subnet_id']: - continue - host_route = {'nexthop': model_dict['nexthop'], - 'destination': model_dict['destination']} - subnet['host_routes'].append(host_route) - break - - if model_obj.__tablename__ == 'dnsnameservers': - for subnet in TOP_SUBNETS: - if subnet['id'] != model_dict['subnet_id']: - continue - dnsnameservers = model_dict['address'] - subnet['dns_nameservers'].append(dnsnameservers) - break - - if model_obj.__tablename__ == 'ml2_flat_allocations': - for alloc in TOP_FLATALLOCATIONS: - if alloc['physical_network'] == model_dict['physical_network']: - raise q_obj_exceptions.NeutronDbObjectDuplicateEntry( - model_obj.__class__, - DotDict({'columns': '', 'value': ''})) - - self._extend_standard_attr(model_dict) - - RES_MAP[model_obj.__tablename__].append(model_dict) - - def _cascade_delete(self, model_dict, foreign_key, table, key): - if key not in model_dict: - return - index = -1 - for i, instance in enumerate(RES_MAP[table]): - if instance[foreign_key] == model_dict[key]: - index = i - break - if index != -1: - del RES_MAP[table][index] - - def delete(self, model_obj): - unlink_models(RES_MAP['routers'], model_obj, 'router_id', 'id', - 'attached_ports', 'port_id', 'id') - self._cascade_delete(model_obj, 'port_id', 'ipallocations', 'id') - for res_list in RES_MAP.values(): - delete_model(res_list, model_obj) - - class FakeBaseXManager(xmanager.XManager): def __init__(self, fake_plugin): self.clients = {constants.TOP: client.Client()} @@ -1345,11 +819,11 @@ class PluginTest(unittest.TestCase, 'pod_1', group='tricircle') for vlan in (vlan_min, vlan_max): TOP_VLANALLOCATIONS.append( - DotDict({'physical_network': phynet, - 'vlan_id': vlan, 'allocated': False})) + test_utils.DotDict({'physical_network': phynet, + 'vlan_id': vlan, 'allocated': False})) for vxlan in (vxlan_min, vxlan_max): TOP_VXLANALLOCATIONS.append( - DotDict({'vxlan_vni': vxlan, 'allocated': False})) + test_utils.DotDict({'vxlan_vni': vxlan, 'allocated': False})) def fake_get_plugin(alias=q_constants.CORE): if alias == 'trunk': @@ -1827,7 +1301,7 @@ class PluginTest(unittest.TestCase, 'ethertype': 'IPv4'} ] } - TOP_PORTS.append(DotDict(t_sg)) + TOP_PORTS.append(test_utils.DotDict(t_sg)) b_sg = { 'id': b_sg_id, @@ -1848,9 +1322,9 @@ class PluginTest(unittest.TestCase, ] } if pod_name == 'pod_1': - BOTTOM1_PORTS.append(DotDict(b_sg)) + BOTTOM1_PORTS.append(test_utils.DotDict(b_sg)) else: - BOTTOM2_PORTS.append(DotDict(b_sg)) + BOTTOM2_PORTS.append(test_utils.DotDict(b_sg)) pod_id = 'pod_id_1' if pod_name == 'pod_1' else 'pod_id_2' core.create_resource(ctx, models.ResourceRouting, @@ -1897,7 +1371,7 @@ class PluginTest(unittest.TestCase, if add_ip: t_port.update({'fixed_ips': [{'subnet_id': t_subnet_id, 'ip_address': ip_address}]}) - TOP_PORTS.append(DotDict(t_port)) + TOP_PORTS.append(test_utils.DotDict(t_port)) b_port = { 'id': b_port_id, @@ -1923,9 +1397,9 @@ class PluginTest(unittest.TestCase, 'ip_address': ip_address}]}) if pod_name == 'pod_1': - BOTTOM1_PORTS.append(DotDict(b_port)) + BOTTOM1_PORTS.append(test_utils.DotDict(b_port)) else: - BOTTOM2_PORTS.append(DotDict(b_port)) + BOTTOM2_PORTS.append(test_utils.DotDict(b_port)) pod_id = 'pod_id_1' if pod_name == 'pod_1' else 'pod_id_2' core.create_resource(ctx, models.ResourceRouting, @@ -1977,8 +1451,8 @@ class PluginTest(unittest.TestCase, 'host_routes': [], 'dns_nameservers': [] } - TOP_NETS.append(DotDict(t_net)) - TOP_SUBNETS.append(DotDict(t_subnet)) + TOP_NETS.append(test_utils.DotDict(t_net)) + TOP_SUBNETS.append(test_utils.DotDict(t_subnet)) else: t_net_id = t_nets[0]['id'] t_subnet_name = 'top_subnet_%d' % index @@ -2017,11 +1491,11 @@ class PluginTest(unittest.TestCase, 'dns_nameservers': [] } if region_name == 'pod_1': - BOTTOM1_NETS.append(DotDict(b_net)) - BOTTOM1_SUBNETS.append(DotDict(b_subnet)) + BOTTOM1_NETS.append(test_utils.DotDict(b_net)) + BOTTOM1_SUBNETS.append(test_utils.DotDict(b_subnet)) else: - BOTTOM2_NETS.append(DotDict(b_net)) - BOTTOM2_SUBNETS.append(DotDict(b_subnet)) + BOTTOM2_NETS.append(test_utils.DotDict(b_net)) + BOTTOM2_SUBNETS.append(test_utils.DotDict(b_subnet)) pod_id = 'pod_id_1' if region_name == 'pod_1' else 'pod_id_2' core.create_resource(ctx, models.ResourceRouting, @@ -2080,11 +1554,11 @@ class PluginTest(unittest.TestCase, 'project_id': project_id } b_port.update(extra_attrs) - TOP_PORTS.append(DotDict(t_port)) + TOP_PORTS.append(test_utils.DotDict(t_port)) if region_name == 'pod_1': - BOTTOM1_PORTS.append(DotDict(b_port)) + BOTTOM1_PORTS.append(test_utils.DotDict(b_port)) else: - BOTTOM2_PORTS.append(DotDict(b_port)) + BOTTOM2_PORTS.append(test_utils.DotDict(b_port)) pod_id = 'pod_id_1' if region_name == 'pod_1' else 'pod_id_2' core.create_resource(ctx, models.ResourceRouting, @@ -2102,12 +1576,12 @@ class PluginTest(unittest.TestCase, 'name': 'top_router', 'distributed': False, 'tenant_id': project_id, - 'attached_ports': DotList(), + 'attached_ports': test_utils.DotList(), 'extra_attributes': { 'availability_zone_hints': router_az_hints } } - TOP_ROUTERS.append(DotDict(t_router)) + TOP_ROUTERS.append(test_utils.DotDict(t_router)) return t_router_id def _prepare_router_test(self, tenant_id, ctx, region_name, index, @@ -2709,7 +2183,7 @@ class PluginTest(unittest.TestCase, 'name': 'top_router', 'distributed': False, 'tenant_id': tenant_id, - 'attached_ports': DotList(), + 'attached_ports': test_utils.DotList(), 'availability_zone_hints': ['pod_1'] } @@ -2922,7 +2396,7 @@ class PluginTest(unittest.TestCase, 'security_groups': [], 'tenant_id': t_subnet['tenant_id'] } - TOP_PORTS.append(DotDict(t_port)) + TOP_PORTS.append(test_utils.DotDict(t_port)) return t_port_id @patch.object(directory, 'get_plugin', new=fake_get_plugin) @@ -3179,13 +2653,13 @@ class PluginTest(unittest.TestCase, 'name': 'router', 'distributed': False, 'tenant_id': tenant_id, - 'attached_ports': DotList(), + 'attached_ports': test_utils.DotList(), 'extra_attributes': { 'availability_zone_hints': router_az_hints } } - TOP_ROUTERS.append(DotDict(t_router)) + TOP_ROUTERS.append(test_utils.DotDict(t_router)) return t_net_id, t_subnet_id, t_router_id, @patch.object(directory, 'get_plugin', new=fake_get_plugin) @@ -3337,10 +2811,10 @@ class PluginTest(unittest.TestCase, 'name': 'router', 'distributed': False, 'tenant_id': tenant_id, - 'attached_ports': DotList() + 'attached_ports': test_utils.DotList() } - TOP_ROUTERS.append(DotDict(t_router)) + TOP_ROUTERS.append(test_utils.DotDict(t_router)) add_gw_body = { 'router': {'external_gateway_info': { 'network_id': t_net_id, @@ -3959,7 +3433,6 @@ class PluginTest(unittest.TestCase, def tearDown(self): core.ModelBase.metadata.drop_all(core.get_engine()) - for res in RES_LIST: - del res[:] + test_utils.get_resource_store().clean() cfg.CONF.unregister_opts(q_config.core_opts) xmanager.IN_TEST = False diff --git a/tricircle/tests/unit/network/test_central_trunk_plugin.py b/tricircle/tests/unit/network/test_central_trunk_plugin.py index 5d5dc655..f78e2fa0 100644 --- a/tricircle/tests/unit/network/test_central_trunk_plugin.py +++ b/tricircle/tests/unit/network/test_central_trunk_plugin.py @@ -14,22 +14,16 @@ # under the License. -import copy from mock import patch import six import unittest from six.moves import xrange -from sqlalchemy.orm import attributes -from sqlalchemy.orm import exc -from sqlalchemy.sql import elements -from sqlalchemy.sql import expression import neutron.conf.common as q_config from neutron.db import db_base_plugin_v2 from neutron.plugins.common import utils from neutron_lib.api.definitions import portbindings -import neutron_lib.context as q_context from neutron_lib.plugins import directory from oslo_config import cfg @@ -44,47 +38,21 @@ from tricircle.db import models from tricircle.network import central_plugin import tricircle.network.central_trunk_plugin as trunk_plugin from tricircle.network import helper +import tricircle.tests.unit.utils as test_utils from tricircle.xjob import xmanager -TOP_TRUNKS = [] -TOP_SUBPORTS = [] -TOP_PORTS = [] -BOTTOM1_TRUNKS = [] -BOTTOM2_TRUNKS = [] -BOTTOM1_SUBPORTS = [] -BOTTOM2_SUBPORTS = [] -BOTTOM1_PORTS = [] -BOTTOM2_PORTS = [] -RES_LIST = [TOP_TRUNKS, TOP_SUBPORTS, TOP_PORTS, - BOTTOM1_TRUNKS, BOTTOM2_TRUNKS, BOTTOM1_PORTS, BOTTOM2_PORTS, - BOTTOM1_SUBPORTS, BOTTOM2_SUBPORTS] -RES_MAP = {'trunks': TOP_TRUNKS, 'subports': TOP_SUBPORTS, 'ports': TOP_PORTS} -TEST_TENANT_ID = 'test_tenant_id' - - -class DotDict(dict): - def __init__(self, normal_dict=None): - if normal_dict: - for key, value in six.iteritems(normal_dict): - self[key] = value - - def __getattr__(self, item): - return self.get(item) - - def __copy__(self): - return DotDict(self) - - def bump_revision(self): - pass - - def save(self, session=None): - pass - - -class DotList(list): - def all(self): - return self +_resource_store = test_utils.get_resource_store() +TOP_TRUNKS = _resource_store.TOP_TRUNKS +TOP_SUBPORTS = _resource_store.TOP_SUBPORTS +TOP_PORTS = _resource_store.TOP_PORTS +BOTTOM1_TRUNKS = _resource_store.BOTTOM1_TRUNKS +BOTTOM2_TRUNKS = _resource_store.BOTTOM2_TRUNKS +BOTTOM1_SUBPORTS = _resource_store.BOTTOM1_SUBPORTS +BOTTOM2_SUBPORTS = _resource_store.BOTTOM2_SUBPORTS +BOTTOM1_PORTS = _resource_store.BOTTOM1_PORTS +BOTTOM2_PORTS = _resource_store.BOTTOM2_PORTS +TEST_TENANT_ID = test_utils.TEST_TENANT_ID class FakeBaseXManager(xmanager.XManager): @@ -120,155 +88,27 @@ class FakeRPCAPI(FakeBaseRPCAPI): self.xmanager = FakeXManager(fake_plugin) -class FakeNeutronClient(object): - - _res_map = {'top': {'trunk': TOP_TRUNKS}, - 'pod_1': {'trunk': BOTTOM1_TRUNKS}, - 'pod_2': {'trunk': BOTTOM2_TRUNKS}} - - def __init__(self, region_name): - self.region_name = region_name - self.trunks_path = '' - - def _get(self, params=None): - trunk_list = self._res_map[self.region_name]['trunk'] - return_list = [] - - if not params: - return {'trunks': trunk_list} - - params_copy = copy.deepcopy(params) - limit = params_copy.pop('limit', None) - marker = params_copy.pop('marker', None) - - if params_copy: - for trunk in trunk_list: - is_selected = True - for key, value in six.iteritems(params_copy): - if (key not in trunk - or not trunk[key] - or trunk[key] not in value): - is_selected = False - break - if is_selected: - return_list.append(trunk) - else: - return_list = trunk_list - - if marker: - sorted_list = sorted(return_list, key=lambda x: x['id']) - for i, trunk in enumerate(sorted_list): - if trunk['id'] == marker: - return_list = sorted_list[i + 1:] - - if limit: - sorted_list = sorted(return_list, key=lambda x: x['id']) - if limit > len(sorted_list): - last_index = len(sorted_list) - else: - last_index = limit - return_list = sorted_list[0: last_index] - - return {'trunks': return_list} - - def get(self, path, params=None): - if self.region_name in ['pod_1', 'pod_2', 'top']: - res_list = self._get(params)['trunks'] - return_list = [] - for res in res_list: - if self.region_name != 'top': - res = copy.copy(res) - return_list.append(res) - return {'trunks': return_list} - else: - raise Exception() +class FakeNeutronClient(test_utils.FakeNeutronClient): + _resource = 'trunk' + trunks_path = '' -class FakeClient(object): - - _res_map = {'top': RES_MAP, - 'pod_1': {'trunk': BOTTOM1_TRUNKS, 'port': BOTTOM1_PORTS}, - 'pod_2': {'trunk': BOTTOM2_TRUNKS}} - +class FakeClient(test_utils.FakeClient): def __init__(self, region_name=None): - if not region_name: - self.region_name = 'top' - else: - self.region_name = region_name + super(FakeClient, self).__init__(region_name) self.client = FakeNeutronClient(self.region_name) def get_native_client(self, resource, ctx): return self.client - def create_resources(self, _type, ctx, body): - res_list = self._res_map[self.region_name][_type] - res = dict(body[_type]) - res_list.append(res) - return res - - def list_resources(self, _type, ctx, filters=None): - if self.region_name == 'top': - res_list = self._res_map[self.region_name][_type + 's'] - else: - res_list = self._res_map[self.region_name][_type] - ret_list = [] - for res in res_list: - is_selected = True - for _filter in filters: - if _filter['key'] not in res: - is_selected = False - break - if _filter['value'] != res[_filter['key']]: - is_selected = False - break - if is_selected: - ret_list.append(res) - return ret_list - - def delete_resources(self, _type, ctx, _id): - index = -1 - if self.region_name == 'top': - res_list = self._res_map[self.region_name][_type + 's'] - else: - res_list = self._res_map[self.region_name][_type] - for i, res in enumerate(res_list): - if res['id'] == _id: - index = i - if index != -1: - del res_list[index] - def get_trunks(self, ctx, trunk_id): - res = self.list_resources('trunk', ctx, - [{'key': 'id', - 'comparator': 'eq', - 'value': trunk_id}]) - if res: - return res[0] - return res + return self.get_resource(constants.RT_TRUNK, ctx, trunk_id) def update_trunks(self, context, trunk_id, trunk): - trunk_data = trunk[constants.RT_TRUNK] - if self.region_name == 'pod_1': - btm_trunks = BOTTOM1_TRUNKS - else: - btm_trunks = BOTTOM2_TRUNKS - - for trunk in btm_trunks: - if trunk['id'] == trunk_id: - for key in trunk_data: - trunk[key] = trunk_data[key] - return + self.update_resources(constants.RT_TRUNK, context, trunk_id, trunk) def delete_trunks(self, context, trunk_id): - if self.region_name == 'pod_1': - btm_trunks = BOTTOM1_TRUNKS - else: - btm_trunks = BOTTOM2_TRUNKS - - for trunk in btm_trunks: - if trunk['id'] == trunk_id: - btm_trunks.remove(trunk) - return + self.delete_resources(constants.RT_TRUNK, context, trunk_id) def action_trunks(self, ctx, action, resource_id, body): if self.region_name == 'pod_1': @@ -327,211 +167,20 @@ class FakeClient(object): return self.create_resources('port', ctx, body) -class FakeNeutronContext(q_context.Context): - def __init__(self): - self._session = None - self.is_admin = True - self.is_advsvc = False - self.tenant_id = TEST_TENANT_ID - - @property - def session(self): - if not self._session: - self._session = FakeSession() - return self._session - - def elevated(self): - return self +class FakeNeutronContext(test_utils.FakeNeutronContext): + def session_class(self): + return FakeSession -def delete_model(res_list, model_obj, key=None): - if not res_list: - return - if not key: - key = 'id' - if key not in res_list[0]: - return - index = -1 - for i, res in enumerate(res_list): - if res[key] == model_obj[key]: - index = i - break - if index != -1: - del res_list[index] - return - - -class FakeQuery(object): - def __init__(self, records, table): - self.records = records - self.table = table - self.index = 0 - - def _handle_pagination_by_id(self, record_id): - for i, record in enumerate(self.records): - if record['id'] == record_id: - if i + 1 < len(self.records): - return FakeQuery(self.records[i + 1:], self.table) - else: - return FakeQuery([], self.table) - return FakeQuery([], self.table) - - def _handle_filter(self, keys, values): - filtered_list = [] - for record in self.records: - selected = True - for i, key in enumerate(keys): - if key not in record or record[key] != values[i]: - selected = False - break - if selected: - filtered_list.append(record) - return FakeQuery(filtered_list, self.table) - - def filter(self, *criteria): - _filter = [] - keys = [] - values = [] - for e in criteria: - if isinstance(e, expression.BooleanClauseList): - e = e.clauses[0] - if not hasattr(e, 'right') and isinstance(e, elements.False_): - # filter is a single False value, set key to a 'INVALID_FIELD' - # then no records will be returned - keys.append('INVALID_FIELD') - values.append(False) - elif hasattr(e, 'right') and not isinstance(e.right, - elements.Null): - _filter.append(e) - if not _filter: - if not keys: - return FakeQuery(self.records, self.table) - else: - return self._handle_filter(keys, values) - if hasattr(_filter[0].right, 'value'): - keys.extend([f.left.name for f in _filter]) - values.extend([f.right.value for f in _filter]) - else: - keys.extend([f.expression.left.name for f in _filter]) - values.extend( - [f.expression.right.element.clauses[0].value for f in _filter]) - if _filter[0].expression.operator.__name__ == 'lt': - return self._handle_pagination_by_id(values[0]) - else: - return self._handle_filter(keys, values) - - def filter_by(self, **kwargs): - filtered_list = [] - for record in self.records: - selected = True - for key, value in six.iteritems(kwargs): - if key not in record or record[key] != value: - selected = False - break - if selected: - filtered_list.append(record) - return FakeQuery(filtered_list, self.table) - - def delete(self, synchronize_session=False): - for model_obj in self.records: - delete_model(RES_MAP[self.table], model_obj) - - def order_by(self, func): - self.records.sort(key=lambda x: x['id']) - return FakeQuery(self.records, self.table) - - def enable_eagerloads(self, value): - return FakeQuery(self.records, self.table) - - def limit(self, limit): - return FakeQuery(self.records[:limit], self.table) - - def next(self): - if self.index >= len(self.records): - raise StopIteration - self.index += 1 - return self.records[self.index - 1] - - __next__ = next - - def one(self): - if len(self.records) == 0: - raise exc.NoResultFound() - return self.records[0] - - def first(self): - if len(self.records) == 0: - return None - else: - return self.records[0] - - def update(self, values): - for record in self.records: - for key, value in six.iteritems(values): - record[key] = value - return len(self.records) - - def all(self): - return self.records - - def count(self): - return len(self.records) - - def __iter__(self): - return self - - -class FakeSession(object): - class WithWrapper(object): - def __enter__(self): - pass - - def __exit__(self, type, value, traceback): - pass - - def __init__(self): - self.info = {} - - def __getattr__(self, field): - def dummy_method(*args, **kwargs): - pass - - return dummy_method - - @property - def is_active(self): - return True - - def begin(self, subtransactions=False, nested=True): - return FakeSession.WithWrapper() - - def begin_nested(self): - return FakeSession.WithWrapper() - - def query(self, model): - if isinstance(model, attributes.InstrumentedAttribute): - model = model.class_ - if model.__tablename__ not in RES_MAP: - return FakeQuery([], model.__tablename__) - return FakeQuery(RES_MAP[model.__tablename__], - model.__tablename__) - - def add(self, model_obj): - if model_obj.__tablename__ not in RES_MAP: - return - model_dict = DotDict(model_obj._as_dict()) - if 'project_id' in model_dict: - model_dict['tenant_id'] = model_dict['project_id'] - +class FakeSession(test_utils.FakeSession): + def add_hook(self, model_obj, model_dict): if model_obj.__tablename__ == 'subports': for top_trunk in TOP_TRUNKS: if top_trunk['id'] == model_dict['trunk_id']: top_trunk['sub_ports'].append(model_dict) - RES_MAP[model_obj.__tablename__].append(model_dict) - def delete_top_subport(self, port_id): - for res_list in RES_MAP.values(): + for res_list in self.resource_store.store_map.values(): for res in res_list: sub_ports = res.get('sub_ports') if sub_ports: @@ -539,13 +188,10 @@ class FakeSession(object): if sub_port['port_id'] == port_id: sub_ports.remove(sub_port) - def delete(self, model_obj): - key = None + def delete_hook(self, model_obj): if model_obj.get('segmentation_type'): - key = 'port_id' self.delete_top_subport(model_obj['port_id']) - for res_list in RES_MAP.values(): - delete_model(res_list, model_obj, key) + return 'port_id' class FakePlugin(trunk_plugin.TricircleTrunkPlugin): @@ -641,7 +287,7 @@ class PluginTest(unittest.TestCase): 'network_id': t_net_id, 'fixed_ips': [{'subnet_id': t_subnet_id}] } - TOP_PORTS.append(DotDict(t_port)) + TOP_PORTS.append(test_utils.DotDict(t_port)) if create_bottom: b_port = { @@ -665,9 +311,9 @@ class PluginTest(unittest.TestCase): 'fixed_ips': [{'subnet_id': t_subnet_id}] } if pod_name == 'pod_1': - BOTTOM1_PORTS.append(DotDict(b_port)) + BOTTOM1_PORTS.append(test_utils.DotDict(b_port)) else: - BOTTOM2_PORTS.append(DotDict(b_port)) + BOTTOM2_PORTS.append(test_utils.DotDict(b_port)) pod_id = 'pod_id_1' if pod_name == 'pod_1' else 'pod_id_2' core.create_resource(ctx, models.ResourceRouting, @@ -704,8 +350,8 @@ class PluginTest(unittest.TestCase): 'project_id': project_id, 'sub_ports': [t_subport] } - TOP_TRUNKS.append(DotDict(t_trunk)) - TOP_SUBPORTS.append(DotDict(t_subport)) + TOP_TRUNKS.append(test_utils.DotDict(t_trunk)) + TOP_SUBPORTS.append(test_utils.DotDict(t_subport)) b_trunk = None if is_create_bottom: @@ -728,11 +374,11 @@ class PluginTest(unittest.TestCase): } if pod_name == 'pod_1': - BOTTOM1_SUBPORTS.append(DotDict(t_subport)) - BOTTOM1_TRUNKS.append(DotDict(b_trunk)) + BOTTOM1_SUBPORTS.append(test_utils.DotDict(t_subport)) + BOTTOM1_TRUNKS.append(test_utils.DotDict(b_trunk)) else: - BOTTOM2_SUBPORTS.append(DotDict(t_subport)) - BOTTOM2_TRUNKS.append(DotDict(b_trunk)) + BOTTOM2_SUBPORTS.append(test_utils.DotDict(t_subport)) + BOTTOM2_TRUNKS.append(test_utils.DotDict(b_trunk)) pod_id = 'pod_id_1' if pod_name == 'pod_1' else 'pod_id_2' core.create_resource(ctx, models.ResourceRouting, @@ -1028,7 +674,6 @@ class PluginTest(unittest.TestCase): def tearDown(self): core.ModelBase.metadata.drop_all(core.get_engine()) - for res in RES_LIST: - del res[:] + test_utils.get_resource_store().clean() cfg.CONF.unregister_opts(q_config.core_opts) xmanager.IN_TEST = False diff --git a/tricircle/tests/unit/network/test_local_plugin.py b/tricircle/tests/unit/network/test_local_plugin.py index 6ec3709f..86d2072e 100644 --- a/tricircle/tests/unit/network/test_local_plugin.py +++ b/tricircle/tests/unit/network/test_local_plugin.py @@ -33,35 +33,33 @@ from tricircle.common import constants import tricircle.common.context as t_context from tricircle.network import helper import tricircle.network.local_plugin as plugin +import tricircle.tests.unit.utils as test_utils -TOP_NETS = [] -TOP_SUBNETS = [] -TOP_PORTS = [] -TOP_SGS = [] -TOP_TRUNKS = [] -BOTTOM_NETS = [] -BOTTOM_SUBNETS = [] -BOTTOM_PORTS = [] -BOTTOM_SGS = [] -BOTTOM_AGENTS = [] -RES_LIST = [TOP_NETS, TOP_SUBNETS, TOP_PORTS, TOP_SGS, TOP_TRUNKS, - BOTTOM_NETS, BOTTOM_SUBNETS, BOTTOM_PORTS, BOTTOM_SGS, - BOTTOM_AGENTS] -RES_MAP = {'network': {True: TOP_NETS, False: BOTTOM_NETS}, - 'subnet': {True: TOP_SUBNETS, False: BOTTOM_SUBNETS}, - 'port': {True: TOP_PORTS, False: BOTTOM_PORTS}, - 'security_group': {True: TOP_SGS, False: BOTTOM_SGS}, - 'agent': {True: [], False: BOTTOM_AGENTS}, - 'trunk': {True: TOP_TRUNKS, False: []}} +_resource_store = test_utils.get_resource_store() +TOP_NETS = _resource_store.TOP_NETWORKS +TOP_SUBNETS = _resource_store.TOP_SUBNETS +TOP_PORTS = _resource_store.TOP_PORTS +TOP_SGS = _resource_store.TOP_SECURITYGROUPS +TOP_TRUNKS = _resource_store.TOP_TRUNKS +BOTTOM_NETS = _resource_store.BOTTOM1_NETWORKS +BOTTOM_SUBNETS = _resource_store.BOTTOM1_SUBNETS +BOTTOM_PORTS = _resource_store.BOTTOM1_PORTS +BOTTOM_SGS = _resource_store.BOTTOM1_SECURITYGROUPS +BOTTOM_AGENTS = _resource_store.BOTTOM1_AGENTS + + +def get_resource_list(_type, is_top): + pod = 'top' if is_top else 'pod_1' + return _resource_store.pod_store_map[pod][_type] def create_resource(_type, is_top, body): - RES_MAP[_type][is_top].append(body) + get_resource_list(_type, is_top).append(body) def update_resource(_type, is_top, resource_id, body): - for resource in RES_MAP[_type][is_top]: + for resource in get_resource_list(_type, is_top): if resource['id'] == resource_id: resource.update(body) return copy.deepcopy(resource) @@ -69,17 +67,19 @@ def update_resource(_type, is_top, resource_id, body): def get_resource(_type, is_top, resource_id): - for resource in RES_MAP[_type][is_top]: + for resource in get_resource_list(_type, is_top): if resource['id'] == resource_id: return copy.deepcopy(resource) raise q_exceptions.NotFound() def list_resource(_type, is_top, filters=None): + resource_list = get_resource_list(_type, is_top) if not filters: - return [copy.deepcopy(resource) for resource in RES_MAP[_type][is_top]] + return [copy.deepcopy(resource) for resource in get_resource_list( + _type, is_top)] ret = [] - for resource in RES_MAP[_type][is_top]: + for resource in resource_list: pick = True for key, value in six.iteritems(filters): if resource.get(key) not in value: @@ -90,10 +90,6 @@ def list_resource(_type, is_top, filters=None): return ret -def delete_resource(_type, is_top, body): - RES_MAP[_type][is_top].append(body) - - class FakeCorePlugin(object): supported_extension_aliases = ['agent'] @@ -147,21 +143,9 @@ class FakeCorePlugin(object): pass -class FakeSession(object): - class WithWrapper(object): - def __enter__(self): - pass - - def __exit__(self, type, value, traceback): - pass - - def begin(self, subtransactions=True): - return FakeSession.WithWrapper() - - class FakeContext(object): def __init__(self): - self.session = FakeSession() + self.session = test_utils.FakeSession() self.auth_token = 'token' self.project_id = '' self.request_id = 'req-' + uuidutils.generate_uuid() @@ -694,5 +678,4 @@ class PluginTest(unittest.TestCase): self.assertEqual(b_port['device_owner'], 'network:dhcp') def tearDown(self): - for res in RES_LIST: - del res[:] + test_utils.get_resource_store().clean() diff --git a/tricircle/tests/unit/utils.py b/tricircle/tests/unit/utils.py new file mode 100644 index 00000000..04be2fc9 --- /dev/null +++ b/tricircle/tests/unit/utils.py @@ -0,0 +1,594 @@ +# Copyright 2017 Huawei Technologies Co., Ltd. +# All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy + +import six +from sqlalchemy.orm import attributes +from sqlalchemy.orm import exc +from sqlalchemy.sql import elements +import sqlalchemy.sql.expression as sql_expression +from sqlalchemy.sql import selectable + +import neutron.objects.exceptions as q_obj_exceptions +import neutron_lib.context as q_context + +from tricircle.common import constants + + +class ResourceStore(object): + _resource_list = [('networks', constants.RT_NETWORK), + ('subnets', constants.RT_SUBNET), + ('ports', constants.RT_PORT), + ('routers', constants.RT_ROUTER), + ('routerports', None), + ('ipallocations', None), + ('subnetpools', None), + ('subnetpoolprefixes', None), + ('ml2_vlan_allocations', None), + ('ml2_vxlan_allocations', None), + ('ml2_flat_allocations', None), + ('networksegments', None), + ('externalnetworks', None), + ('floatingips', constants.RT_FIP), + ('securitygroups', constants.RT_SG), + ('securitygrouprules', None), + ('networkrbacs', None), + ('subnetroutes', None), + ('dnsnameservers', None), + ('trunks', 'trunk'), + ('subports', None), + ('agents', 'agent')] + + def __init__(self): + self.store_list = [] + self.store_map = {} + self.pod_store_map = {'top': {}, 'pod_1': {}, 'pod_2': {}} + for prefix, pod in [('TOP', 'top'), ('BOTTOM1', 'pod_1'), + ('BOTTOM2', 'pod_2')]: + for table, resource in self._resource_list: + store_name = '%s_%s' % (prefix, table.upper()) + setattr(self, store_name, []) + store = getattr(self, store_name) + self.store_list.append(store) + if prefix == 'TOP': + self.store_map[table] = store + if resource: + self.pod_store_map[pod][resource] = store + + def clean(self): + for store in self.store_list: + del store[:] + + +TEST_TENANT_ID = 'test_tenant_id' +_RESOURCE_STORE = None + + +def get_resource_store(): + global _RESOURCE_STORE + if not _RESOURCE_STORE: + _RESOURCE_STORE = ResourceStore() + return _RESOURCE_STORE + + +class DotDict(dict): + def __init__(self, normal_dict=None): + if normal_dict: + for key, value in six.iteritems(normal_dict): + self[key] = value + + def __getattr__(self, item): + dummy_value_map = { + 'rbac_entries': [], + 'segment_host_mapping': [] + } + if item in dummy_value_map: + return dummy_value_map[item] + return self.get(item) + + def to_dict(self): + return self + + def __copy__(self): + return DotDict(self) + + def bump_revision(self): + pass + + def save(self, session=None): + pass + + +class DotList(list): + def all(self): + return self + + +class FakeQuery(object): + pk_map = {'ports': 'id'} + + def __init__(self, records, table): + self.records = records + self.table = table + self.index = 0 + + def _handle_pagination_by_id(self, record_id): + for i, record in enumerate(self.records): + if record['id'] == record_id: + if i + 1 < len(self.records): + return FakeQuery(self.records[i + 1:], self.table) + else: + return FakeQuery([], self.table) + return FakeQuery([], self.table) + + def _handle_filter(self, keys, values): + filtered_list = [] + for record in self.records: + selected = True + for i, key in enumerate(keys): + if key not in record or record[key] != values[i]: + selected = False + break + if selected: + filtered_list.append(record) + return FakeQuery(filtered_list, self.table) + + def filter(self, *criteria): + _filter = [] + keys = [] + values = [] + for e in criteria: + if isinstance(e, sql_expression.BooleanClauseList): + e = e.clauses[0] + if not hasattr(e, 'right') and isinstance(e, elements.False_): + # filter is a single False value, set key to a 'INVALID_FIELD' + # then no records will be returned + keys.append('INVALID_FIELD') + values.append(False) + elif hasattr(e, 'right') and not isinstance(e.right, + elements.Null): + _filter.append(e) + elif isinstance(e, selectable.Exists): + # handle external network filter + expression = e.element.element._whereclause + if hasattr(expression, 'right') and ( + expression.right.name == 'network_id'): + keys.append('router:external') + values.append(True) + if not _filter: + if not keys: + return FakeQuery(self.records, self.table) + else: + return self._handle_filter(keys, values) + if hasattr(_filter[0].right, 'value'): + keys.extend([f.left.name for f in _filter]) + values.extend([f.right.value for f in _filter]) + else: + keys.extend([f.expression.left.name for f in _filter]) + values.extend( + [f.expression.right.element.clauses[0].value for f in _filter]) + if _filter[0].expression.operator.__name__ == 'lt': + return self._handle_pagination_by_id(values[0]) + else: + return self._handle_filter(keys, values) + + def filter_by(self, **kwargs): + filtered_list = [] + for record in self.records: + selected = True + for key, value in six.iteritems(kwargs): + if key not in record or record[key] != value: + selected = False + break + if selected: + filtered_list.append(record) + return FakeQuery(filtered_list, self.table) + + def get(self, pk): + pk_field = self.pk_map[self.table] + for record in self.records: + if record.get(pk_field) == pk: + return record + + def delete(self, synchronize_session=False): + pass + + def outerjoin(self, *props, **kwargs): + return FakeQuery(self.records, self.table) + + def join(self, *props, **kwargs): + return FakeQuery(self.records, self.table) + + def order_by(self, func): + self.records.sort(key=lambda x: x['id']) + return FakeQuery(self.records, self.table) + + def enable_eagerloads(self, value): + return FakeQuery(self.records, self.table) + + def limit(self, limit): + return FakeQuery(self.records[:limit], self.table) + + def next(self): + if self.index >= len(self.records): + raise StopIteration + self.index += 1 + return self.records[self.index - 1] + + __next__ = next + + def one(self): + if len(self.records) == 0: + raise exc.NoResultFound() + return self.records[0] + + def first(self): + if len(self.records) == 0: + return None + else: + return self.records[0] + + def update(self, values): + for record in self.records: + for key, value in six.iteritems(values): + record[key] = value + return len(self.records) + + def all(self): + return self.records + + def count(self): + return len(self.records) + + def __iter__(self): + return self + + +def delete_model(res_list, model_obj, key=None): + if not res_list: + return + if not key: + key = 'id' + if key not in res_list[0]: + return + index = -1 + for i, res in enumerate(res_list): + if res[key] == model_obj[key]: + index = i + break + if index != -1: + del res_list[index] + return + + +def link_models(model_obj, model_dict, foreign_table, foreign_key, table, key, + link_prop): + if model_obj.__tablename__ == foreign_table: + for instance in get_resource_store().store_map[table]: + if instance[key] == model_dict[foreign_key]: + if link_prop not in instance: + instance[link_prop] = [] + instance[link_prop].append(model_dict) + + +def unlink_models(res_list, model_dict, foreign_key, key, link_prop, + link_ele_foreign_key, link_ele_key): + if foreign_key not in model_dict: + return + for instance in res_list: + if instance[key] == model_dict[foreign_key]: + if link_prop not in instance: + return + index = -1 + for i, res in enumerate(instance[link_prop]): + if res[link_ele_foreign_key] == model_dict[link_ele_key]: + index = i + break + if index != -1: + del instance[link_prop][index] + return + + +class FakeSession(object): + class WithWrapper(object): + def __enter__(self): + pass + + def __exit__(self, type, value, traceback): + pass + + def __init__(self): + self.info = {} + self.resource_store = get_resource_store() + + def __getattr__(self, field): + def dummy_method(*args, **kwargs): + pass + + return dummy_method + + def __contains__(self, item): + return False + + @property + def is_active(self): + return True + + def begin(self, subtransactions=False, nested=True): + return FakeSession.WithWrapper() + + def begin_nested(self): + return FakeSession.WithWrapper() + + def query(self, model): + if isinstance(model, attributes.InstrumentedAttribute): + model = model.class_ + if model.__tablename__ not in self.resource_store.store_map: + return FakeQuery([], model.__tablename__) + return FakeQuery(self.resource_store.store_map[model.__tablename__], + model.__tablename__) + + def _extend_standard_attr(self, model_dict): + if 'standard_attr' in model_dict: + for field in ('resource_type', 'description', 'revision_number', + 'created_at', 'updated_at'): + model_dict[field] = getattr(model_dict['standard_attr'], field) + + def add_hook(self, model_obj, model_dict): + # hook for operations before adding the model_obj to the resource store + pass + + def delete_hook(self, model_obj): + # hook for operations before deleting the model_obj from the resource + # store. the default key to find the target object is "id", return + # non-None value if you would like specify other key + return None + + def add(self, model_obj): + if model_obj.__tablename__ not in self.resource_store.store_map: + return + model_dict = DotDict(model_obj._as_dict()) + if 'project_id' in model_dict: + model_dict['tenant_id'] = model_dict['project_id'] + + if model_obj.__tablename__ == 'networks': + model_dict['subnets'] = [] + if model_obj.__tablename__ == 'ports': + model_dict['dhcp_opts'] = [] + model_dict['security_groups'] = [] + model_dict['fixed_ips'] = [] + + link_models(model_obj, model_dict, + 'subnetpoolprefixes', 'subnetpool_id', + 'subnetpools', 'id', 'prefixes') + link_models(model_obj, model_dict, + 'ipallocations', 'port_id', + 'ports', 'id', 'fixed_ips') + link_models(model_obj, model_dict, + 'subnets', 'network_id', 'networks', 'id', 'subnets') + link_models(model_obj, model_dict, + 'securitygrouprules', 'security_group_id', + 'securitygroups', 'id', 'security_group_rules') + + if model_obj.__tablename__ == 'routerports': + for port in self.resource_store.TOP_PORTS: + if port['id'] == model_dict['port_id']: + model_dict['port'] = port + port.update(model_dict) + break + if model_obj.__tablename__ == 'externalnetworks': + for net in self.resource_store.TOP_NETWORKS: + if net['id'] == model_dict['network_id']: + net['external'] = True + net['router:external'] = True + break + if model_obj.__tablename__ == 'networkrbacs': + if (model_dict['action'] == 'access_as_shared' and + model_dict['target_tenant'] == '*'): + for net in self.resource_store.TOP_NETWORKS: + if net['id'] == model_dict['object']: + net['shared'] = True + break + + link_models(model_obj, model_dict, + 'routerports', 'router_id', + 'routers', 'id', 'attached_ports') + + if model_obj.__tablename__ == 'subnetroutes': + for subnet in self.resource_store.TOP_SUBNETS: + if subnet['id'] != model_dict['subnet_id']: + continue + host_route = {'nexthop': model_dict['nexthop'], + 'destination': model_dict['destination']} + subnet['host_routes'].append(host_route) + break + + if model_obj.__tablename__ == 'dnsnameservers': + for subnet in self.resource_store.TOP_SUBNETS: + if subnet['id'] != model_dict['subnet_id']: + continue + dnsnameservers = model_dict['address'] + subnet['dns_nameservers'].append(dnsnameservers) + break + + if model_obj.__tablename__ == 'ml2_flat_allocations': + for alloc in self.resource_store.TOP_ML2_FLAT_ALLOCATIONS: + if alloc['physical_network'] == model_dict['physical_network']: + raise q_obj_exceptions.NeutronDbObjectDuplicateEntry( + model_obj.__class__, + DotDict({'columns': '', 'value': ''})) + + self._extend_standard_attr(model_dict) + + self.add_hook(model_obj, model_dict) + self.resource_store.store_map[ + model_obj.__tablename__].append(model_dict) + + def _cascade_delete(self, model_dict, foreign_key, table, key): + if key not in model_dict: + return + index = -1 + for i, instance in enumerate(self.resource_store.store_map[table]): + if instance[foreign_key] == model_dict[key]: + index = i + break + if index != -1: + del self.resource_store.store_map[table][index] + + def delete(self, model_obj): + unlink_models(self.resource_store.store_map['routers'], model_obj, + 'router_id', 'id', 'attached_ports', 'port_id', 'id') + self._cascade_delete(model_obj, 'port_id', 'ipallocations', 'id') + key = self.delete_hook(model_obj) + for res_list in self.resource_store.store_map.values(): + delete_model(res_list, model_obj, key) + + +class FakeNeutronContext(q_context.Context): + def __init__(self): + self._session = None + self.is_admin = True + self.is_advsvc = False + self.tenant_id = TEST_TENANT_ID + + def session_class(self): + return FakeSession + + @property + def session(self): + if not self._session: + self._session = self.session_class()() + return self._session + + def elevated(self): + return self + + +def filter_resource(resource_list, params): + if not params: + return resource_list + + params_copy = copy.deepcopy(params) + limit = params_copy.pop('limit', None) + marker = params_copy.pop('marker', None) + + if params_copy: + return_list = [] + for resource in resource_list: + is_selected = True + for key, value in six.iteritems(params_copy): + if (key not in resource + or not resource[key] + or resource[key] not in value): + is_selected = False + break + if is_selected: + return_list.append(resource) + else: + return_list = resource_list + + if marker: + sorted_list = sorted(return_list, key=lambda x: x['id']) + for i, resource in enumerate(sorted_list): + if resource['id'] == marker: + return_list = sorted_list[i + 1:] + + if limit: + sorted_list = sorted(return_list, key=lambda x: x['id']) + if limit > len(sorted_list): + last_index = len(sorted_list) + else: + last_index = limit + return_list = sorted_list[0: last_index] + return return_list + + +class FakeNeutronClient(object): + # override this + _resource = None + + def __init__(self, region_name): + self.region_name = region_name + self._res_map = get_resource_store().pod_store_map + + def get(self, path, params=None): + if self.region_name in ['pod_1', 'pod_2', 'top']: + res_list = self._res_map[self.region_name][self._resource] + filtered_res_list = filter_resource(res_list, params) + return_list = [] + for res in filtered_res_list: + if self.region_name != 'top': + res = copy.copy(res) + return_list.append(res) + return {self._resource + 's': return_list} + else: + raise Exception() + + +class FakeClient(object): + def __init__(self, region_name=None): + if not region_name: + self.region_name = 'top' + else: + self.region_name = region_name + self._res_map = get_resource_store().pod_store_map + + def create_resources(self, _type, ctx, body): + res_list = self._res_map[self.region_name][_type] + res = dict(body[_type]) + res_list.append(res) + return res + + def list_resources(self, _type, ctx, filters=None): + res_list = self._res_map[self.region_name][_type] + ret_list = [] + for res in res_list: + is_selected = True + for _filter in filters: + if _filter['key'] not in res: + is_selected = False + break + if _filter['value'] != res[_filter['key']]: + is_selected = False + break + if is_selected: + ret_list.append(res) + return ret_list + + def get_resource(self, _type, ctx, _id): + res = self.list_resources( + _type, ctx, [{'key': 'id', 'comparator': 'eq', 'value': _id}]) + if res: + return res[0] + return None + + def delete_resources(self, _type, ctx, _id): + index = -1 + res_list = self._res_map[self.region_name][_type] + for i, res in enumerate(res_list): + if res['id'] == _id: + index = i + if index != -1: + del res_list[index] + + def update_resources(self, _type, ctx, _id, body): + res_list = self._res_map[self.region_name][_type] + updated = False + for res in res_list: + if res['id'] == _id: + updated = True + res.update(body[_type]) + return updated