# Copyright 2017 Huawei Technologies Co., Ltd. # All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import copy from oslo_utils import uuidutils import six from sqlalchemy.orm import attributes from sqlalchemy.orm import exc from sqlalchemy.sql import elements import sqlalchemy.sql.expression as sql_expression from sqlalchemy.sql import selectable import neutron_lib.context as q_context import neutron_lib.objects.exceptions as q_obj_exceptions from tricircle.common import constants from tricircle.network.drivers import type_flat from tricircle.network.drivers import type_local from tricircle.network.drivers import type_vlan from tricircle.network.drivers import type_vxlan from tricircle.network import managers class ResourceStore(object): _resource_list = [('networks', constants.RT_NETWORK), ('subnets', constants.RT_SUBNET), ('ports', constants.RT_PORT), ('routers', constants.RT_ROUTER), ('routerports', None), ('ipallocations', None), ('subnetpools', None), ('subnetpoolprefixes', None), ('ml2_vlan_allocations', None), ('ml2_vxlan_allocations', None), ('ml2_flat_allocations', None), ('networksegments', None), ('externalnetworks', None), ('floatingips', constants.RT_FIP), ('securitygroups', constants.RT_SG), ('securitygrouprules', None), ('networkrbacs', None), ('subnetroutes', None), ('dnsnameservers', None), ('trunks', 'trunk'), ('subports', None), ('agents', 'agent'), ('sfc_port_pairs', constants.RT_PORT_PAIR), ('sfc_port_pair_groups', constants.RT_PORT_PAIR_GROUP), ('sfc_port_chains', constants.RT_PORT_CHAIN), ('sfc_flow_classifiers', constants.RT_FLOW_CLASSIFIER), ('sfc_chain_group_associations', None), ('sfc_chain_classifier_associations', None), ('qos_policies', constants.RT_QOS), ('qos_bandwidth_limit_rules', 'qos_bandwidth_limit_rules')] def __init__(self): self.store_list = [] self.store_map = {} self.pod_store_map = {'top': {}, 'pod_1': {}, 'pod_2': {}} for prefix, pod in [('TOP', 'top'), ('BOTTOM1', 'pod_1'), ('BOTTOM2', 'pod_2')]: for table, resource in self._resource_list: store_name = '%s_%s' % (prefix, table.upper()) setattr(self, store_name, []) store = getattr(self, store_name) self.store_list.append(store) if prefix == 'TOP': self.store_map[table] = store if resource: self.pod_store_map[pod][resource] = store def clean(self): for store in self.store_list: del store[:] TEST_TENANT_ID = 'test_tenant_id' _RESOURCE_STORE = None def get_resource_store(): global _RESOURCE_STORE if not _RESOURCE_STORE: _RESOURCE_STORE = ResourceStore() return _RESOURCE_STORE class DotDict(dict): def __init__(self, normal_dict=None): if normal_dict: for key, value in six.iteritems(normal_dict): self[key] = value def __getattr__(self, item): dummy_value_map = { 'rbac_entries': [], 'segment_host_mapping': [] } if item in dummy_value_map: return dummy_value_map[item] return self.get(item) def __setattr__(self, name, value): self[name] = value def to_dict(self): return self def __copy__(self): return DotDict(self) def bump_revision(self): pass def save(self, session=None): pass def update_fields(self, obj_data): for k, v in obj_data.items(): if k in self: setattr(self, k, v) class DotList(list): def all(self): return self class FakeQuery(object): pk_map = {'ports': 'id'} def __init__(self, records, table): self.records = records self.table = table self.index = 0 def _handle_pagination_by_id(self, record_id): for i, record in enumerate(self.records): if record['id'] == record_id: if i + 1 < len(self.records): return FakeQuery(self.records[i + 1:], self.table) else: return FakeQuery([], self.table) return FakeQuery([], self.table) def _handle_filter(self, keys, values): filtered_list = [] for record in self.records: selected = True for i, key in enumerate(keys): if key not in record or record[key] != values[i]: selected = False break if selected: filtered_list.append(record) return FakeQuery(filtered_list, self.table) def filter(self, *criteria): _filter = [] keys = [] values = [] for e in criteria: if isinstance(e, sql_expression.BooleanClauseList): e = e.clauses[0] if not hasattr(e, 'right') and isinstance(e, elements.False_): # filter is a single False value, set key to a 'INVALID_FIELD' # then no records will be returned keys.append('INVALID_FIELD') values.append(False) elif hasattr(e, 'right') and not isinstance(e.right, elements.Null): _filter.append(e) elif isinstance(e, selectable.Exists): # handle external network filter expression = e.element.element._whereclause if hasattr(expression, 'right') and ( expression.right.name == 'network_id'): keys.append('router:external') values.append(True) if not _filter: if not keys: return FakeQuery(self.records, self.table) else: return self._handle_filter(keys, values) if hasattr(_filter[0].right, 'value'): keys.extend([f.left.name for f in _filter]) values.extend([f.right.value for f in _filter]) else: keys.extend([f.expression.left.name for f in _filter]) values.extend( [f.expression.right.element.clauses[0].value for f in _filter]) if _filter[0].expression.operator.__name__ == 'lt': return self._handle_pagination_by_id(values[0]) else: return self._handle_filter(keys, values) def filter_by(self, **kwargs): filtered_list = [] for record in self.records: selected = True for key, value in six.iteritems(kwargs): if key not in record or record[key] != value: selected = False break if selected: filtered_list.append(record) return FakeQuery(filtered_list, self.table) def get(self, pk): pk_field = self.pk_map[self.table] for record in self.records: if record.get(pk_field) == pk: return record def delete(self, synchronize_session=False): pass def outerjoin(self, *props, **kwargs): return FakeQuery(self.records, self.table) def join(self, *props, **kwargs): return FakeQuery(self.records, self.table) def order_by(self, func): self.records.sort(key=lambda x: x['id']) return FakeQuery(self.records, self.table) def enable_eagerloads(self, value): return FakeQuery(self.records, self.table) def limit(self, limit): return FakeQuery(self.records[:limit], self.table) def next(self): if self.index >= len(self.records): raise StopIteration self.index += 1 return self.records[self.index - 1] __next__ = next def one(self): if len(self.records) == 0: raise exc.NoResultFound() return self.records[0] def first(self): if len(self.records) == 0: return None else: return self.records[0] def update(self, values): for record in self.records: for key, value in six.iteritems(values): record[key] = value return len(self.records) def all(self): return self.records def count(self): return len(self.records) def __iter__(self): return self def delete_model(res_list, model_obj, key=None): if not res_list: return if not key: key = 'id' if key not in res_list[0]: return index = -1 for i, res in enumerate(res_list): if res[key] == model_obj[key]: index = i break if index != -1: del res_list[index] return def link_models(model_obj, model_dict, foreign_table, foreign_key, table, key, link_prop): if model_obj.__tablename__ == foreign_table: for instance in get_resource_store().store_map[table]: if instance[key] == model_dict[foreign_key]: if link_prop not in instance: instance[link_prop] = [] instance[link_prop].append(model_dict) def unlink_models(res_list, model_dict, foreign_key, key, link_prop, link_ele_foreign_key, link_ele_key): if foreign_key not in model_dict: return for instance in res_list: if instance[key] == model_dict[foreign_key]: if link_prop not in instance: return index = -1 for i, res in enumerate(instance[link_prop]): if res[link_ele_foreign_key] == model_dict[link_ele_key]: index = i break if index != -1: del instance[link_prop][index] return class FakeSession(object): class WithWrapper(object): def __enter__(self): pass def __exit__(self, type, value, traceback): pass def __init__(self): self.info = {} self.resource_store = get_resource_store() def __getattr__(self, field): def dummy_method(*args, **kwargs): pass return dummy_method def __contains__(self, item): return False @property def is_active(self): return True def begin(self, subtransactions=False, nested=True): return FakeSession.WithWrapper() def begin_nested(self): return FakeSession.WithWrapper() def query(self, model): if isinstance(model, attributes.InstrumentedAttribute): model = model.class_ if model.__tablename__ not in self.resource_store.store_map: return FakeQuery([], model.__tablename__) return FakeQuery(self.resource_store.store_map[model.__tablename__], model.__tablename__) def _extend_standard_attr(self, model_dict): if 'standard_attr' in model_dict: for field in ('resource_type', 'description', 'revision_number', 'created_at', 'updated_at'): model_dict[field] = getattr(model_dict['standard_attr'], field) def add_hook(self, model_obj, model_dict): # hook for operations before adding the model_obj to the resource store pass def delete_hook(self, model_obj): # hook for operations before deleting the model_obj from the resource # store. the default key to find the target object is "id", return # non-None value if you would like specify other key return None def add(self, model_obj): if model_obj.__tablename__ not in self.resource_store.store_map: return model_dict = DotDict(model_obj._as_dict()) if 'project_id' in model_dict: model_dict['tenant_id'] = model_dict['project_id'] if model_obj.__tablename__ == 'networks': model_dict['subnets'] = [] if model_obj.__tablename__ == 'ports': model_dict['dhcp_opts'] = [] model_dict['security_groups'] = [] model_dict['fixed_ips'] = [] link_models(model_obj, model_dict, 'subnetpoolprefixes', 'subnetpool_id', 'subnetpools', 'id', 'prefixes') link_models(model_obj, model_dict, 'ipallocations', 'port_id', 'ports', 'id', 'fixed_ips') link_models(model_obj, model_dict, 'subnets', 'network_id', 'networks', 'id', 'subnets') link_models(model_obj, model_dict, 'securitygrouprules', 'security_group_id', 'securitygroups', 'id', 'security_group_rules') if model_obj.__tablename__ == 'routerports': for port in self.resource_store.TOP_PORTS: if port['id'] == model_dict['port_id']: model_dict['port'] = port port.update(model_dict) break if model_obj.__tablename__ == 'externalnetworks': for net in self.resource_store.TOP_NETWORKS: if net['id'] == model_dict['network_id']: net['external'] = True net['router:external'] = True break if model_obj.__tablename__ == 'networkrbacs': if (model_dict['action'] == 'access_as_shared' and model_dict['target_tenant'] == '*'): for net in self.resource_store.TOP_NETWORKS: if net['id'] == model_dict['object']: net['shared'] = True break link_models(model_obj, model_dict, 'routerports', 'router_id', 'routers', 'id', 'attached_ports') if model_obj.__tablename__ == 'subnetroutes': for subnet in self.resource_store.TOP_SUBNETS: if subnet['id'] != model_dict['subnet_id']: continue host_route = {'nexthop': model_dict['nexthop'], 'destination': model_dict['destination']} subnet['host_routes'].append(host_route) break if model_obj.__tablename__ == 'dnsnameservers': for subnet in self.resource_store.TOP_SUBNETS: if subnet['id'] != model_dict['subnet_id']: continue dnsnameservers = model_dict['address'] subnet['dns_nameservers'].append(dnsnameservers) break if model_obj.__tablename__ == 'ml2_flat_allocations': for alloc in self.resource_store.TOP_ML2_FLAT_ALLOCATIONS: if alloc['physical_network'] == model_dict['physical_network']: raise q_obj_exceptions.NeutronDbObjectDuplicateEntry( model_obj.__class__, DotDict({'columns': '', 'value': ''})) self._extend_standard_attr(model_dict) self.add_hook(model_obj, model_dict) self.resource_store.store_map[ model_obj.__tablename__].append(model_dict) def _cascade_delete(self, model_dict, foreign_key, table, key): if key not in model_dict: return index = -1 for i, instance in enumerate(self.resource_store.store_map[table]): if instance[foreign_key] == model_dict[key]: index = i break if index != -1: del self.resource_store.store_map[table][index] def delete(self, model_obj): unlink_models(self.resource_store.store_map['routers'], model_obj, 'router_id', 'id', 'attached_ports', 'port_id', 'id') unlink_models(self.resource_store.store_map['securitygroups'], model_obj, 'security_group_id', 'id', 'security_group_rules', 'id', 'id') self._cascade_delete(model_obj, 'port_id', 'ipallocations', 'id') key = self.delete_hook(model_obj) for res_list in self.resource_store.store_map.values(): delete_model(res_list, model_obj, key) class FakeNeutronContext(q_context.Context): def __init__(self): self._session = None self.is_admin = True self.is_advsvc = False self.tenant_id = TEST_TENANT_ID def session_class(self): return FakeSession @property def session(self): if not self._session: self._session = self.session_class()() return self._session def elevated(self): return self def filter_resource(resource_list, params): if not params: return resource_list params_copy = copy.deepcopy(params) limit = params_copy.pop('limit', None) marker = params_copy.pop('marker', None) if params_copy: return_list = [] for resource in resource_list: is_selected = True for key, value in six.iteritems(params_copy): if (key not in resource or not resource[key] or resource[key] not in value): is_selected = False break if is_selected: return_list.append(resource) else: return_list = resource_list if marker: sorted_list = sorted(return_list, key=lambda x: x['id']) for i, resource in enumerate(sorted_list): if resource['id'] == marker: return_list = sorted_list[i + 1:] if limit: sorted_list = sorted(return_list, key=lambda x: x['id']) if limit > len(sorted_list): last_index = len(sorted_list) else: last_index = limit return_list = sorted_list[0: last_index] return return_list class FakeNeutronClient(object): # override this _resource = None def __init__(self, region_name): self.region_name = region_name self._res_map = get_resource_store().pod_store_map def get(self, path, params=None): if self.region_name in ['pod_1', 'pod_2', 'top']: res_list = self._res_map[self.region_name][self._resource] filtered_res_list = filter_resource(res_list, params) return_list = [] for res in filtered_res_list: if self.region_name != 'top': res = copy.copy(res) return_list.append(res) return {self._resource + 's': return_list} else: raise Exception() class FakeClient(object): def __init__(self, region_name=None): if not region_name: self.region_name = 'top' else: self.region_name = region_name self._res_map = get_resource_store().pod_store_map def create_resources(self, _type, ctx, body): res_list = self._res_map[self.region_name][_type] if _type == 'qos_policy': _type = 'policy' res = dict(body[_type]) if 'id' not in res: res['id'] = uuidutils.generate_uuid() if _type == 'policy' and 'rules' not in res: res['rules'] = [] res_list.append(res) return res def list_resources(self, _type, ctx, filters=None): res_list = self._res_map[self.region_name][_type] ret_list = [] for res in res_list: is_selected = True for _filter in filters: if _filter['key'] not in res: is_selected = False break if _filter['value'] != res[_filter['key']]: is_selected = False break if is_selected: ret_list.append(res) return ret_list def get_resource(self, _type, ctx, _id): res = self.list_resources( _type, ctx, [{'key': 'id', 'comparator': 'eq', 'value': _id}]) if res: return res[0] return None def delete_resources(self, _type, ctx, _id): if _type is 'policy': _type = 'qos_policy' index = -1 res_list = self._res_map[self.region_name][_type] for i, res in enumerate(res_list): if res['id'] == _id: index = i if index != -1: del res_list[index] def update_resources(self, _type, ctx, _id, body): if _type == 'policy': res_list = self._res_map[self.region_name]['qos_policy'] else: res_list = self._res_map[self.region_name][_type] updated = False for res in res_list: if res['id'] == _id: updated = True res.update(body[_type]) return updated class FakeTypeManager(managers.TricircleTypeManager): def _register_types(self): local_driver = type_local.LocalTypeDriver() self.drivers[constants.NT_LOCAL] = FakeExtension(local_driver) vlan_driver = type_vlan.VLANTypeDriver() self.drivers[constants.NT_VLAN] = FakeExtension(vlan_driver) vxlan_driver = type_vxlan.VxLANTypeDriver() self.drivers[constants.NT_VxLAN] = FakeExtension(vxlan_driver) local_driver = type_flat.FlatTypeDriver() self.drivers[constants.NT_FLAT] = FakeExtension(local_driver) def extend_network_dict_provider(self, cxt, net): target_net = None for t_net in get_resource_store().TOP_NETWORKS: if t_net['id'] == net['id']: target_net = t_net if not target_net: return for segment in get_resource_store().TOP_NETWORKSEGMENTS: if target_net['id'] == segment['network_id']: target_net['provider:network_type'] = segment['network_type'] target_net[ 'provider:physical_network'] = segment['physical_network'] target_net[ 'provider:segmentation_id'] = segment['segmentation_id'] break class FakeExtension(object): def __init__(self, ext_obj): self.obj = ext_obj