diff --git a/neutron/db/agents_db.py b/neutron/db/agents_db.py index a0ff1c1e0cc..7aa28c8faba 100644 --- a/neutron/db/agents_db.py +++ b/neutron/db/agents_db.py @@ -37,6 +37,7 @@ from neutron.callbacks import events from neutron.callbacks import registry from neutron.callbacks import resources from neutron.common import constants as n_const +from neutron.db import _model_query as model_query from neutron.db import _utils as db_utils from neutron.db import api as db_api from neutron.db.models import agent as agent_model @@ -82,8 +83,8 @@ class AgentAvailabilityZoneMixin(az_ext.AvailabilityZonePluginBase): def _list_availability_zones(self, context, filters=None): result = {} - query = self._get_collection_query(context, agent_model.Agent, - filters=filters) + query = model_query.get_collection_query(context, agent_model.Agent, + filters=filters) columns = (agent_model.Agent.admin_state_up, agent_model.Agent.availability_zone, agent_model.Agent.agent_type) @@ -142,7 +143,7 @@ class AgentDbMixin(ext_agent.AgentPluginBase, AgentAvailabilityZoneMixin): def _get_agent(self, context, id): try: - agent = self._get_by_id(context, agent_model.Agent, id) + agent = model_query.get_by_id(context, agent_model.Agent, id) except exc.NoResultFound: raise ext_agent.AgentNotFound(id=id) return agent @@ -242,16 +243,16 @@ class AgentDbMixin(ext_agent.AgentPluginBase, AgentAvailabilityZoneMixin): @db_api.retry_if_session_inactive() def get_agents_db(self, context, filters=None): - query = self._get_collection_query(context, - agent_model.Agent, - filters=filters) + query = model_query.get_collection_query(context, + agent_model.Agent, + filters=filters) return query.all() @db_api.retry_if_session_inactive() def get_agents(self, context, filters=None, fields=None): - agents = self._get_collection(context, agent_model.Agent, - self._make_agent_dict, - filters=filters, fields=fields) + agents = model_query.get_collection(context, agent_model.Agent, + self._make_agent_dict, + filters=filters, fields=fields) alive = filters and filters.get('alive', None) if alive: alive = converters.convert_to_boolean(alive[0]) @@ -280,7 +281,7 @@ class AgentDbMixin(ext_agent.AgentPluginBase, AgentAvailabilityZoneMixin): len(agents)) def _get_agent_by_type_and_host(self, context, agent_type, host): - query = self._model_query(context, agent_model.Agent) + query = model_query.query_with_hooks(context, agent_model.Agent) try: agent_db = query.filter(agent_model.Agent.agent_type == agent_type, agent_model.Agent.host == host).one() diff --git a/neutron/db/db_base_plugin_common.py b/neutron/db/db_base_plugin_common.py index e9d566ac276..9ea399d5c9d 100644 --- a/neutron/db/db_base_plugin_common.py +++ b/neutron/db/db_base_plugin_common.py @@ -26,6 +26,8 @@ from sqlalchemy.orm import exc from neutron.api.v2 import attributes from neutron.common import constants as n_const from neutron.common import exceptions +from neutron.db import _model_query as model_query +from neutron.db import _resource_extend as resource_extend from neutron.db import _utils as db_utils from neutron.db import api as db_api from neutron.db import common_db_mixin @@ -152,7 +154,7 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): # The shared attribute for a subnet is the same as its parent network res['shared'] = self._is_network_shared(context, subnet.rbac_entries) # Call auxiliary extend functions, if any - self._apply_dict_extend_functions(attributes.SUBNETS, res, subnet) + resource_extend.apply_funcs(attributes.SUBNETS, res, subnet) return db_utils.resource_fields(res, fields) def _make_subnetpool_dict(self, subnetpool, fields=None): @@ -171,8 +173,7 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): 'ip_version': subnetpool['ip_version'], 'default_quota': subnetpool['default_quota'], 'address_scope_id': subnetpool['address_scope_id']} - self._apply_dict_extend_functions(attributes.SUBNETPOOLS, res, - subnetpool) + resource_extend.apply_funcs(attributes.SUBNETPOOLS, res, subnetpool) return db_utils.resource_fields(res, fields) def _make_port_dict(self, port, fields=None, @@ -191,20 +192,19 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): "device_owner": port["device_owner"]} # Call auxiliary extend functions, if any if process_extensions: - self._apply_dict_extend_functions( - attributes.PORTS, res, port) + resource_extend.apply_funcs(attributes.PORTS, res, port) return db_utils.resource_fields(res, fields) def _get_network(self, context, id): try: - network = self._get_by_id(context, models_v2.Network, id) + network = model_query.get_by_id(context, models_v2.Network, id) except exc.NoResultFound: raise n_exc.NetworkNotFound(net_id=id) return network def _get_subnet(self, context, id): try: - subnet = self._get_by_id(context, models_v2.Subnet, id) + subnet = model_query.get_by_id(context, models_v2.Subnet, id) except exc.NoResultFound: raise n_exc.SubnetNotFound(subnet_id=id) return subnet @@ -218,7 +218,7 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): def _get_port(self, context, id): try: - port = self._get_by_id(context, models_v2.Port, id) + port = model_query.get_by_id(context, models_v2.Port, id) except exc.NoResultFound: raise n_exc.PortNotFound(port_id=id) return port @@ -254,13 +254,13 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): marker_obj = self._get_marker_obj(context, 'subnet', limit, marker) make_subnet_dict = functools.partial(self._make_subnet_dict, context=context) - return self._get_collection(context, models_v2.Subnet, - make_subnet_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, - marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, models_v2.Subnet, + make_subnet_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) def _make_network_dict(self, network, fields=None, process_extensions=True, context=None): @@ -275,8 +275,7 @@ class DbBasePluginCommon(common_db_mixin.CommonDbMixin): res['shared'] = self._is_network_shared(context, network.rbac_entries) # Call auxiliary extend functions, if any if process_extensions: - self._apply_dict_extend_functions( - attributes.NETWORKS, res, network) + resource_extend.apply_funcs(attributes.NETWORKS, res, network) return db_utils.resource_fields(res, fields) def _is_network_shared(self, context, rbac_entries): diff --git a/neutron/db/db_base_plugin_v2.py b/neutron/db/db_base_plugin_v2.py index 27a0a4c9a37..13b5da76084 100644 --- a/neutron/db/db_base_plugin_v2.py +++ b/neutron/db/db_base_plugin_v2.py @@ -42,6 +42,7 @@ from neutron.common import exceptions as n_exc from neutron.common import ipv6_utils from neutron.common import utils from neutron.db import _model_query as model_query +from neutron.db import _resource_extend as resource_extend from neutron.db import _utils as ndb_utils from neutron.db import api as db_api from neutron.db import db_base_plugin_common @@ -213,10 +214,11 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon, tenant_id): ctx_admin = ctx.get_admin_context() rb_model = rbac_db.NetworkRBAC - other_rbac_entries = self._model_query(ctx_admin, rb_model).filter( - and_(rb_model.object_id == network_id, - rb_model.action == 'access_as_shared')) - ports = self._model_query(ctx_admin, models_v2.Port).filter( + other_rbac_entries = model_query.query_with_hooks( + ctx_admin, rb_model).filter( + and_(rb_model.object_id == network_id, + rb_model.action == 'access_as_shared')) + ports = model_query.query_with_hooks(ctx_admin, models_v2.Port).filter( models_v2.Port.network_id == network_id) if tenant_id == '*': # for the wildcard we need to get all of the rbac entries to @@ -263,11 +265,11 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon, # goes from True to False if updated['shared'] == original.shared or updated['shared']: return - ports = self._model_query( + ports = model_query.query_with_hooks( context, models_v2.Port).filter(models_v2.Port.network_id == id) ports = ports.filter(not_(models_v2.Port.device_owner.startswith( constants.DEVICE_OWNER_NETWORK_PREFIX))) - subnets = self._model_query( + subnets = model_query.query_with_hooks( context, models_v2.Subnet).filter( models_v2.Subnet.network_id == id) tenant_ids = set([port['tenant_id'] for port in ports] + @@ -462,18 +464,18 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon, marker_obj = self._get_marker_obj(context, 'network', limit, marker) make_network_dict = functools.partial(self._make_network_dict, context=context) - return self._get_collection(context, models_v2.Network, - make_network_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, - marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, models_v2.Network, + make_network_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) @db_api.retry_if_session_inactive() def get_networks_count(self, context, filters=None): - return self._get_collection_count(context, models_v2.Network, - filters=filters) + return model_query.get_collection_count(context, models_v2.Network, + filters=filters) @db_api.retry_if_session_inactive() def create_subnet_bulk(self, context, subnets): @@ -1015,8 +1017,8 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon, @db_api.retry_if_session_inactive() def get_subnets_count(self, context, filters=None): - return self._get_collection_count(context, models_v2.Subnet, - filters=filters) + return model_query.get_collection_count(context, models_v2.Subnet, + filters=filters) @db_api.retry_if_session_inactive() def get_subnets_by_network(self, context, network_id): @@ -1155,8 +1157,8 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon, for key in ['min_prefixlen', 'max_prefixlen', 'default_prefixlen']: updated['key'] = str(updated[key]) - self._apply_dict_extend_functions(attributes.SUBNETPOOLS, - updated, orig_sp.db_obj) + resource_extend.apply_funcs(attributes.SUBNETPOOLS, + updated, orig_sp.db_obj) return updated @db_api.retry_if_session_inactive() @@ -1351,8 +1353,9 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon, filters = filters or {} fixed_ips = filters.pop('fixed_ips', {}) - query = self._get_collection_query(context, Port, filters=filters, - *args, **kwargs) + query = model_query.get_collection_query(context, Port, + filters=filters, + *args, **kwargs) ip_addresses = fixed_ips.get('ip_address') subnet_ids = fixed_ips.get('subnet_id') if ip_addresses: diff --git a/neutron/db/ipam_backend_mixin.py b/neutron/db/ipam_backend_mixin.py index 40191b3e2a2..66883ab9640 100644 --- a/neutron/db/ipam_backend_mixin.py +++ b/neutron/db/ipam_backend_mixin.py @@ -33,6 +33,7 @@ from neutron.common import constants from neutron.common import exceptions as n_exc from neutron.common import ipv6_utils from neutron.common import utils as common_utils +from neutron.db import _model_query as model_query from neutron.db import _utils as db_utils from neutron.db import api as db_api from neutron.db import db_base_plugin_common @@ -591,7 +592,7 @@ class IpamBackendMixin(db_base_plugin_common.DbBasePluginCommon): return fixed_ip_list def _query_subnets_on_network(self, context, network_id): - query = self._get_collection_query(context, models_v2.Subnet) + query = model_query.get_collection_query(context, models_v2.Subnet) return query.filter(models_v2.Subnet.network_id == network_id) def _query_filter_service_subnets(self, query, service_type): diff --git a/neutron/db/l3_db.py b/neutron/db/l3_db.py index be53d62cbd0..10f12b2b929 100644 --- a/neutron/db/l3_db.py +++ b/neutron/db/l3_db.py @@ -39,6 +39,8 @@ from neutron.common import constants as n_const from neutron.common import ipv6_utils from neutron.common import rpc as n_rpc from neutron.common import utils +from neutron.db import _model_query as model_query +from neutron.db import _resource_extend as resource_extend from neutron.db import _utils as db_utils from neutron.db import api as db_api from neutron.db import common_db_mixin @@ -161,7 +163,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, def _get_router(self, context, router_id): try: - router = self._get_by_id(context, l3_models.Router, router_id) + router = model_query.get_by_id( + context, l3_models.Router, router_id) except exc.NoResultFound: raise l3.RouterNotFound(router_id=router_id) return router @@ -184,7 +187,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, # class inheriting from CommonDbMixin, which is true for all existing # plugins. if process_extensions: - self._apply_dict_extend_functions(l3.ROUTERS, res, router) + resource_extend.apply_funcs(l3.ROUTERS, res, router) return db_utils.resource_fields(res, fields) def _create_router_db(self, context, router, tenant_id): @@ -559,18 +562,18 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, sorts=None, limit=None, marker=None, page_reverse=False): marker_obj = self._get_marker_obj(context, 'router', limit, marker) - return self._get_collection(context, l3_models.Router, - self._make_router_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, - marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, l3_models.Router, + self._make_router_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) @db_api.retry_if_session_inactive() def get_routers_count(self, context, filters=None): - return self._get_collection_count(context, l3_models.Router, - filters=filters) + return model_query.get_collection_count(context, l3_models.Router, + filters=filters) def _check_for_dup_router_subnets(self, context, router, network_id, new_subnets): @@ -1006,7 +1009,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, def _get_floatingip(self, context, id): try: - floatingip = self._get_by_id(context, l3_models.FloatingIP, id) + floatingip = model_query.get_by_id( + context, l3_models.FloatingIP, id) except exc.NoResultFound: raise l3.FloatingIPNotFound(floatingip_id=id) return floatingip @@ -1025,7 +1029,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, # class inheriting from CommonDbMixin, which is true for all existing # plugins. if process_extensions: - self._apply_dict_extend_functions(l3.FLOATINGIPS, res, floatingip) + resource_extend.apply_funcs(l3.FLOATINGIPS, res, floatingip) return db_utils.resource_fields(res, fields) def _get_router_for_floatingip(self, context, internal_port, @@ -1295,8 +1299,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, self._process_dns_floatingip_create_postcommit(context, floatingip_dict, dns_data) - self._apply_dict_extend_functions(l3.FLOATINGIPS, floatingip_dict, - floatingip_db) + resource_extend.apply_funcs(l3.FLOATINGIPS, floatingip_dict, + floatingip_db) return floatingip_dict @db_api.retry_if_session_inactive() @@ -1326,8 +1330,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, self._process_dns_floatingip_update_postcommit(context, floatingip_dict, dns_data) - self._apply_dict_extend_functions(l3.FLOATINGIPS, floatingip_dict, - floatingip_db) + resource_extend.apply_funcs(l3.FLOATINGIPS, floatingip_dict, + floatingip_db) return old_floatingip, floatingip_dict def _floatingips_to_router_ids(self, floatingips): @@ -1344,8 +1348,9 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, @db_api.retry_if_session_inactive() def update_floatingip_status(self, context, floatingip_id, status): """Update operational status for floating IP in neutron DB.""" - fip_query = self._model_query(context, l3_models.FloatingIP).filter( - l3_models.FloatingIP.id == floatingip_id) + fip_query = model_query.query_with_hooks( + context, l3_models.FloatingIP).filter( + l3_models.FloatingIP.id == floatingip_id) fip_query.update({'status': status}, synchronize_session=False) def _delete_floatingip(self, context, id): @@ -1382,17 +1387,17 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, if key in filters: filters[val] = filters.pop(key) - return self._get_collection(context, l3_models.FloatingIP, - self._make_floatingip_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, - marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, l3_models.FloatingIP, + self._make_floatingip_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) @db_api.retry_if_session_inactive() def delete_disassociated_floatingips(self, context, network_id): - query = self._model_query(context, l3_models.FloatingIP) + query = model_query.query_with_hooks(context, l3_models.FloatingIP) query = query.filter_by(floating_network_id=network_id, fixed_port_id=None, router_id=None) @@ -1401,8 +1406,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, @db_api.retry_if_session_inactive() def get_floatingips_count(self, context, filters=None): - return self._get_collection_count(context, l3_models.FloatingIP, - filters=filters) + return model_query.get_collection_count(context, l3_models.FloatingIP, + filters=filters) def _router_exists(self, context, router_id): try: @@ -1517,7 +1522,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase, filters = {'id': router_ids} if router_ids else {} if active is not None: filters['admin_state_up'] = [active] - router_dicts = self._get_collection( + router_dicts = model_query.get_collection( context, l3_models.Router, self._make_router_dict_with_gw_port, filters=filters) if not router_dicts: diff --git a/neutron/db/metering/metering_db.py b/neutron/db/metering/metering_db.py index 641d146958d..57a514a7640 100644 --- a/neutron/db/metering/metering_db.py +++ b/neutron/db/metering/metering_db.py @@ -19,6 +19,7 @@ from sqlalchemy import orm from neutron.api.rpc.agentnotifiers import metering_rpc_agent_api from neutron.common import constants +from neutron.db import _model_query as model_query from neutron.db import _utils as db_utils from neutron.db import api as db_api from neutron.db import common_db_mixin as base_db @@ -60,9 +61,9 @@ class MeteringDbMixin(metering.MeteringPluginBase, def delete_metering_label(self, context, label_id): with db_api.context_manager.writer.using(context): try: - label = self._get_by_id(context, - metering_models.MeteringLabel, - label_id) + label = model_query.get_by_id(context, + metering_models.MeteringLabel, + label_id) except orm.exc.NoResultFound: raise metering.MeteringLabelNotFound(label_id=label_id) @@ -70,9 +71,8 @@ class MeteringDbMixin(metering.MeteringPluginBase, def get_metering_label(self, context, label_id, fields=None): try: - metering_label = self._get_by_id(context, - metering_models.MeteringLabel, - label_id) + metering_label = model_query.get_by_id( + context, metering_models.MeteringLabel, label_id) except orm.exc.NoResultFound: raise metering.MeteringLabelNotFound(label_id=label_id) @@ -83,13 +83,14 @@ class MeteringDbMixin(metering.MeteringPluginBase, page_reverse=False): marker_obj = self._get_marker_obj(context, 'metering_labels', limit, marker) - return self._get_collection(context, metering_models.MeteringLabel, - self._make_metering_label_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, - marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, + metering_models.MeteringLabel, + self._make_metering_label_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) @staticmethod def _make_metering_label_rule_dict(metering_label_rule, fields=None): @@ -106,17 +107,18 @@ class MeteringDbMixin(metering.MeteringPluginBase, marker_obj = self._get_marker_obj(context, 'metering_label_rules', limit, marker) - return self._get_collection(context, metering_models.MeteringLabelRule, - self._make_metering_label_rule_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, - marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, + metering_models.MeteringLabelRule, + self._make_metering_label_rule_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, + marker_obj=marker_obj, + page_reverse=page_reverse) def get_metering_label_rule(self, context, rule_id, fields=None): try: - metering_label_rule = self._get_by_id( + metering_label_rule = model_query.get_by_id( context, metering_models.MeteringLabelRule, rule_id) except orm.exc.NoResultFound: raise metering.MeteringLabelRuleNotFound(rule_id=rule_id) @@ -167,9 +169,9 @@ class MeteringDbMixin(metering.MeteringPluginBase, def delete_metering_label_rule(self, context, rule_id): with db_api.context_manager.writer.using(context): try: - rule = self._get_by_id(context, - metering_models.MeteringLabelRule, - rule_id) + rule = model_query.get_by_id(context, + metering_models.MeteringLabelRule, + rule_id) except orm.exc.NoResultFound: raise metering.MeteringLabelRuleNotFound(rule_id=rule_id) context.session.delete(rule) @@ -204,8 +206,8 @@ class MeteringDbMixin(metering.MeteringPluginBase, for label in labels: if label.shared: if not all_routers: - all_routers = self._get_collection_query(context, - l3_models.Router) + all_routers = model_query.get_collection_query( + context, l3_models.Router) routers = all_routers else: routers = label.routers @@ -232,7 +234,8 @@ class MeteringDbMixin(metering.MeteringPluginBase, rule['metering_label_id']) if label.shared: - routers = self._get_collection_query(context, l3_models.Router) + routers = model_query.get_collection_query( + context, l3_models.Router) else: routers = label.routers diff --git a/neutron/db/rbac_db_mixin.py b/neutron/db/rbac_db_mixin.py index 594ca57c5d9..ea4a392b967 100644 --- a/neutron/db/rbac_db_mixin.py +++ b/neutron/db/rbac_db_mixin.py @@ -20,6 +20,7 @@ from sqlalchemy.orm import exc from neutron.callbacks import events from neutron.callbacks import exceptions as c_exc from neutron.callbacks import registry +from neutron.db import _model_query as model_query from neutron.db import _utils as db_utils from neutron.db import api as db_api from neutron.db import common_db_mixin @@ -99,8 +100,8 @@ class RbacPluginMixin(common_db_mixin.CommonDbMixin): object_type = self._get_object_type(context, id) dbmodel = models.get_type_model_map()[object_type] try: - return self._model_query(context, - dbmodel).filter(dbmodel.id == id).one() + return model_query.query_with_hooks( + context, dbmodel).filter(dbmodel.id == id).one() except exc.NoResultFound: raise ext_rbac.RbacPolicyNotFound(id=id, object_type=object_type) @@ -118,7 +119,7 @@ class RbacPluginMixin(common_db_mixin.CommonDbMixin): m for t, m in models.get_type_model_map().items() if object_type_filters is None or t in object_type_filters ] - collections = [self._get_collection( + collections = [model_query.get_collection( context, model, self._make_rbac_policy_dict, filters=filters, fields=fields, sorts=sorts, limit=limit, page_reverse=page_reverse) diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 95f679f0bcb..ae6d7bc88eb 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -28,6 +28,7 @@ from neutron.callbacks import registry from neutron.callbacks import resources from neutron.common import constants as n_const from neutron.common import utils +from neutron.db import _model_query as model_query from neutron.db import _resource_extend as resource_extend from neutron.db import _utils as db_utils from neutron.db import api as db_api @@ -144,18 +145,18 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): self._ensure_default_security_group(context, tenant_id) marker_obj = self._get_marker_obj(context, 'security_group', limit, marker) - return self._get_collection(context, - sg_models.SecurityGroup, - self._make_security_group_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, + sg_models.SecurityGroup, + self._make_security_group_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, marker_obj=marker_obj, + page_reverse=page_reverse) @db_api.retry_if_session_inactive() def get_security_groups_count(self, context, filters=None): - return self._get_collection_count(context, sg_models.SecurityGroup, - filters=filters) + return model_query.get_collection_count( + context, sg_models.SecurityGroup, filters=filters) @db_api.retry_if_session_inactive() def get_security_group(self, context, id, fields=None, tenant_id=None): @@ -180,7 +181,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): def _get_security_group(self, context, id): try: - query = self._model_query(context, sg_models.SecurityGroup) + query = model_query.query_with_hooks( + context, sg_models.SecurityGroup) sg = query.filter(sg_models.SecurityGroup.id == id).one() except exc.NoResultFound: @@ -263,8 +265,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): 'description': security_group['description']} res['security_group_rules'] = [self._make_security_group_rule_dict(r) for r in security_group.rules] - self._apply_dict_extend_functions(ext_sg.SECURITYGROUPS, res, - security_group) + resource_extend.apply_funcs(ext_sg.SECURITYGROUPS, res, + security_group) return db_utils.resource_fields(res, fields) @staticmethod @@ -283,16 +285,16 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): def _get_port_security_group_bindings(self, context, filters=None, fields=None): - return self._get_collection(context, - sg_models.SecurityGroupPortBinding, - self._make_security_group_binding_dict, - filters=filters, fields=fields) + return model_query.get_collection( + context, sg_models.SecurityGroupPortBinding, + self._make_security_group_binding_dict, + filters=filters, fields=fields) @db_api.retry_if_session_inactive() def _delete_port_security_group_bindings(self, context, port_id): with db_api.context_manager.writer.using(context): - query = self._model_query(context, - sg_models.SecurityGroupPortBinding) + query = model_query.query_with_hooks( + context, sg_models.SecurityGroupPortBinding) bindings = query.filter( sg_models.SecurityGroupPortBinding.port_id == port_id) for binding in bindings: @@ -491,8 +493,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): 'remote_ip_prefix': security_group_rule['remote_ip_prefix'], 'remote_group_id': security_group_rule['remote_group_id']} - self._apply_dict_extend_functions(ext_sg.SECURITYGROUPRULES, res, - security_group_rule) + resource_extend.apply_funcs(ext_sg.SECURITYGROUPRULES, res, + security_group_rule) return db_utils.resource_fields(res, fields) def _make_security_group_rule_filter_dict(self, security_group_rule): @@ -603,18 +605,18 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): page_reverse=False): marker_obj = self._get_marker_obj(context, 'security_group_rule', limit, marker) - return self._get_collection(context, - sg_models.SecurityGroupRule, - self._make_security_group_rule_dict, - filters=filters, fields=fields, - sorts=sorts, - limit=limit, marker_obj=marker_obj, - page_reverse=page_reverse) + return model_query.get_collection(context, + sg_models.SecurityGroupRule, + self._make_security_group_rule_dict, + filters=filters, fields=fields, + sorts=sorts, + limit=limit, marker_obj=marker_obj, + page_reverse=page_reverse) @db_api.retry_if_session_inactive() def get_security_group_rules_count(self, context, filters=None): - return self._get_collection_count(context, sg_models.SecurityGroupRule, - filters=filters) + return model_query.get_collection_count( + context, sg_models.SecurityGroupRule, filters=filters) @db_api.retry_if_session_inactive() def get_security_group_rule(self, context, id, fields=None): @@ -623,7 +625,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): def _get_security_group_rule(self, context, id): try: - query = self._model_query(context, sg_models.SecurityGroupRule) + query = model_query.query_with_hooks( + context, sg_models.SecurityGroupRule) sgr = query.filter(sg_models.SecurityGroupRule.id == id).one() except exc.NoResultFound: raise ext_sg.SecurityGroupRuleNotFound(id=id) @@ -640,9 +643,9 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): exc_cls=ext_sg.SecurityGroupRuleInUse, **kwargs) with db_api.context_manager.writer.using(context): - query = self._model_query(context, - sg_models.SecurityGroupRule).filter( - sg_models.SecurityGroupRule.id == id) + query = model_query.query_with_hooks( + context, sg_models.SecurityGroupRule).filter( + sg_models.SecurityGroupRule.id == id) self._registry_notify(resources.SECURITY_GROUP_RULE, events.PRECOMMIT_DELETE, @@ -687,7 +690,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): def _get_default_sg_id(self, context, tenant_id): try: - query = self._model_query(context, sg_models.DefaultSecurityGroup) + query = model_query.query_with_hooks( + context, sg_models.DefaultSecurityGroup) default_group = query.filter_by(tenant_id=tenant_id).one() return default_group['security_group_id'] except exc.NoResultFound: diff --git a/neutron/objects/db/api.py b/neutron/objects/db/api.py index 4018e605b6b..80a514eca8a 100644 --- a/neutron/objects/db/api.py +++ b/neutron/objects/db/api.py @@ -14,17 +14,16 @@ # backends from neutron_lib import exceptions as n_exc -from neutron_lib.plugins import directory from oslo_utils import uuidutils +from neutron.db import _model_query as model_query + # Common database operation implementations def _get_filter_query(context, model, **kwargs): - # TODO(jlibosva): decompose _get_collection_query from plugin instance - plugin = directory.get_plugin() with context.session.begin(subtransactions=True): filters = _kwargs_to_filters(**kwargs) - query = plugin._get_collection_query(context, model, filters) + query = model_query.get_collection_query(context, model, filters) return query @@ -44,9 +43,7 @@ def _kwargs_to_filters(**kwargs): def get_objects(context, model, _pager=None, **kwargs): with context.session.begin(subtransactions=True): filters = _kwargs_to_filters(**kwargs) - # TODO(ihrachys): decompose _get_collection from plugin instance - plugin = directory.get_plugin() - return plugin._get_collection( + return model_query.get_collection( context, model, dict_func=None, # return all the data filters=filters, diff --git a/neutron/plugins/ml2/plugin.py b/neutron/plugins/ml2/plugin.py index 9d64d924735..2258724dfb8 100644 --- a/neutron/plugins/ml2/plugin.py +++ b/neutron/plugins/ml2/plugin.py @@ -789,7 +789,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, registry.notify(resources.NETWORK, events.PRECOMMIT_CREATE, self, context=context, request=net_data, network=result) - self._apply_dict_extend_functions('networks', result, net_db) + resource_extend.apply_funcs('networks', result, net_db) mech_context = driver_context.NetworkContext(self, context, result) self.mechanism_manager.create_network_precommit(mech_context) @@ -1135,7 +1135,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, self.mechanism_manager.create_port_precommit(mech_context) self._setup_dhcp_agent_provisioning_component(context, result) - self._apply_dict_extend_functions('ports', result, port_db) + resource_extend.apply_funcs('ports', result, port_db) return result, mech_context @utils.transaction_guard diff --git a/neutron/services/tag/tag_plugin.py b/neutron/services/tag/tag_plugin.py index b81e1331e94..5e423f7bf2a 100644 --- a/neutron/services/tag/tag_plugin.py +++ b/neutron/services/tag/tag_plugin.py @@ -73,7 +73,7 @@ class TagPlugin(common_db_mixin.CommonDbMixin, tag_ext.TagPluginBase): def _get_resource(self, context, resource, resource_id): model = resource_model_map[resource] try: - return self._get_by_id(context, model, resource_id) + return model_query.get_by_id(context, model, resource_id) except exc.NoResultFound: raise tag_ext.TagResourceNotFound(resource=resource, resource_id=resource_id) diff --git a/neutron/tests/unit/objects/db/test_api.py b/neutron/tests/unit/objects/db/test_api.py index 16bd68214d2..56cf367f8f8 100644 --- a/neutron/tests/unit/objects/db/test_api.py +++ b/neutron/tests/unit/objects/db/test_api.py @@ -15,8 +15,8 @@ import copy import mock from neutron_lib import context from neutron_lib import exceptions as n_exc -from neutron_lib.plugins import directory +from neutron.db import _model_query as model_query from neutron.db import models_v2 from neutron.objects import base from neutron.objects.db import api @@ -42,8 +42,8 @@ class GetObjectsTestCase(test_base.BaseTestCase): limit = mock.sentinel.limit pager = base.Pager(marker=marker, limit=limit) - plugin = directory.get_plugin() - with mock.patch.object(plugin, '_get_collection') as get_collection: + with mock.patch.object( + model_query, 'get_collection') as get_collection: with mock.patch.object(api, 'get_object') as get_object: api.get_objects(ctxt, model, _pager=pager) get_object.assert_called_with(ctxt, model, id=marker)