Add the corresponding DB context to all SQL transactions

The goal of this patch is to make the Neutron code compliant
with SQLAlchemy 2.0.

All SQL transactions must be executed inside an explicit
writer/reader context. SQLAlchemy no longer will create an
implicit transaction if the session has no active transaction.

A warning message, only available in debug mode, is added. When
an ORM session calls "do_orm_execute", if there is no active
transaction, a warning message with a traceback will be logged
to help to debug the regression introduced.

Related-Bug: #1964575

Change-Id: I3da37fee205b8d67d10673075b9130147d9eab5f
This commit is contained in:
Rodolfo Alonso Hernandez 2022-03-05 14:34:40 +00:00
parent 452a3093f6
commit eeb918e1b9
32 changed files with 275 additions and 158 deletions

View File

@ -77,7 +77,7 @@ def get_availability_zones_by_agent_type(context, agent_type,
availability_zones):
"""Get list of availability zones based on agent type"""
agents = agent_obj.Agent._get_agents_by_availability_zones_and_agent_type(
agents = agent_obj.Agent.get_agents_by_availability_zones_and_agent_type(
context, agent_type=agent_type, availability_zones=availability_zones)
return set(agent.availability_zone for agent in agents)

View File

@ -259,6 +259,7 @@ class DbBasePluginCommon(object):
res.pop('bulk')
return db_utils.resource_fields(res, fields)
@db_api.CONTEXT_READER
def _get_network(self, context, id):
try:
network = model_query.get_by_id(context, models_v2.Network, id)
@ -266,6 +267,7 @@ class DbBasePluginCommon(object):
raise exceptions.NetworkNotFound(net_id=id)
return network
@db_api.CONTEXT_READER
def _network_exists(self, context, network_id):
query = model_query.query_with_hooks(
context, models_v2.Network, field='id')
@ -284,6 +286,7 @@ class DbBasePluginCommon(object):
raise exceptions.SubnetPoolNotFound(subnetpool_id=id)
return subnetpool
@db_api.CONTEXT_READER
def _get_port(self, context, id, lazy_fields=None):
try:
port = model_query.get_by_id(context, models_v2.Port, id,

View File

@ -230,37 +230,39 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
tenant_to_check = policy['target_project']
if tenant_to_check:
self.ensure_no_tenant_ports_on_network(net['id'], net['tenant_id'],
tenant_to_check)
self.ensure_no_tenant_ports_on_network(
context, net['id'], net['tenant_id'], tenant_to_check)
def ensure_no_tenant_ports_on_network(self, network_id, net_tenant_id,
tenant_id):
ctx_admin = ctx.get_admin_context()
ports = model_query.query_with_hooks(ctx_admin, models_v2.Port).filter(
models_v2.Port.network_id == network_id)
if tenant_id == '*':
# for the wildcard we need to get all of the rbac entries to
# see if any allow the remaining ports on the network.
# any port with another RBAC entry covering it or one belonging to
# the same tenant as the network owner is ok
other_rbac_objs = network_obj.NetworkRBAC.get_objects(
ctx_admin, object_id=network_id, action='access_as_shared')
allowed_tenants = [rbac['target_project'] for rbac
in other_rbac_objs
if rbac.target_project != tenant_id]
allowed_tenants.append(net_tenant_id)
ports = ports.filter(
~models_v2.Port.tenant_id.in_(allowed_tenants))
else:
# if there is a wildcard rule, we can return early because it
# allows any ports
if network_obj.NetworkRBAC.get_object(
ctx_admin, object_id=network_id, action='access_as_shared',
target_project='*'):
return
ports = ports.filter(models_v2.Port.project_id == tenant_id)
if ports.count():
raise exc.InvalidSharedSetting(network=network_id)
def ensure_no_tenant_ports_on_network(self, context, network_id,
net_tenant_id, tenant_id):
elevated = context.elevated()
with db_api.CONTEXT_READER.using(elevated):
ports = model_query.query_with_hooks(
elevated, models_v2.Port).filter(
models_v2.Port.network_id == network_id)
if tenant_id == '*':
# for the wildcard we need to get all of the rbac entries to
# see if any allow the remaining ports on the network.
# any port with another RBAC entry covering it or one belonging
# to the same tenant as the network owner is ok
other_rbac_objs = network_obj.NetworkRBAC.get_objects(
elevated, object_id=network_id, action='access_as_shared')
allowed_tenants = [rbac['target_project'] for rbac
in other_rbac_objs
if rbac.target_project != tenant_id]
allowed_tenants.append(net_tenant_id)
ports = ports.filter(
~models_v2.Port.tenant_id.in_(allowed_tenants))
else:
# if there is a wildcard rule, we can return early because it
# allows any ports
if network_obj.NetworkRBAC.get_object(
elevated, object_id=network_id,
action='access_as_shared', target_project='*'):
return
ports = ports.filter(models_v2.Port.project_id == tenant_id)
if ports.count():
raise exc.InvalidSharedSetting(network=network_id)
def set_ipam_backend(self):
self.ipam = ipam_pluggable_backend.IpamPluggableBackend()
@ -487,8 +489,8 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
registry.publish(resources.NETWORK, events.BEFORE_DELETE, self,
payload=events.DBEventPayload(
context, resource_id=id))
self._ensure_network_not_in_use(context, id)
with db_api.CONTEXT_READER.using(context):
self._ensure_network_not_in_use(context, id)
auto_delete_port_ids = [p.id for p in context.session.query(
models_v2.Port.id).filter_by(network_id=id).filter(
models_v2.Port.device_owner.in_(
@ -647,10 +649,9 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
s_gateway_ip != cur_subnet['gateway_ip'] and
not ipv6_utils.is_ipv6_pd_enabled(s)):
gateway_ip = str(cur_subnet['gateway_ip'])
with db_api.CONTEXT_READER.using(context):
alloc = port_obj.IPAllocation.get_alloc_routerports(
context, cur_subnet['id'], gateway_ip=gateway_ip,
first=True)
alloc = port_obj.IPAllocation.get_alloc_routerports(
context, cur_subnet['id'], gateway_ip=gateway_ip,
first=True)
if alloc and alloc.port_id:
raise exc.GatewayIpInUse(
@ -1600,6 +1601,7 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
return query
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_ports(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
@ -1619,6 +1621,7 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
return items
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_ports_count(self, context, filters=None):
return self._get_ports_query(context, filters).count()

View File

@ -33,6 +33,7 @@ from neutron._i18n import _
from neutron.db import models_v2
from neutron.extensions import rbac as rbac_ext
from neutron.objects import network as net_obj
from neutron.objects import ports as port_obj
from neutron.objects import router as l3_obj
@ -127,9 +128,9 @@ class External_net_db_mixin(object):
# must make sure we do not have any external gateway ports
# (and thus, possible floating IPs) on this network before
# allow it to be update to external=False
if context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=net_data['id']).first():
if port_obj.Port.count(
context, network_id=net_data['id'],
device_owner=constants.DEVICE_OWNER_ROUTER_GW):
raise extnet_exc.ExternalNetworkInUse(net_id=net_id)
net_obj.ExternalNetwork.delete_objects(
@ -200,10 +201,9 @@ class External_net_db_mixin(object):
if new_project == policy['target_project']:
# nothing to validate if the tenant didn't change
return
gw_ports = context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=policy['object_id'])
gw_ports = [gw_port[0] for gw_port in gw_ports]
gw_ports = port_obj.Port.get_gateway_port_ids_by_network(
context, policy['object_id'])
if policy['target_project'] != '*':
filters = {
'gw_port_id': gw_ports,

View File

@ -391,10 +391,9 @@ class L3AgentSchedulerDbMixin(l3agentscheduler.L3AgentSchedulerPluginBase,
rb_obj.RouterL3AgentBinding.get_l3_agents_by_router_ids(
context, router_ids))
@db_api.CONTEXT_READER
def list_l3_agents_hosting_router(self, context, router_id):
with db_api.CONTEXT_READER.using(context):
agents = self._get_l3_agents_hosting_routers(
context, [router_id])
agents = self._get_l3_agents_hosting_routers(context, [router_id])
return {'agents': [self._make_agent_dict(agent)
for agent in agents]}

View File

@ -621,6 +621,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
return self._make_router_dict(router, fields)
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_routers(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
@ -635,6 +636,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
page_reverse=page_reverse)
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_routers_count(self, context, filters=None):
return model_query.get_collection_count(
context, l3_models.Router, filters=filters,
@ -1364,7 +1366,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
fip_id = uuidutils.generate_uuid()
f_net_id = fip['floating_network_id']
f_net_db = self._core_plugin._get_network(context, f_net_id)
with db_api.CONTEXT_READER.using(context):
f_net_db = self._core_plugin._get_network(context, f_net_id)
if not f_net_db.external:
msg = _("Network %s is not a valid external network") % f_net_id
raise n_exc.BadRequest(resource='floatingip', msg=msg)
@ -1834,6 +1837,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
continue
yield port
@db_api.CONTEXT_READER
def _get_subnets_by_network_list(self, context, network_ids):
if not network_ids:
return {}

View File

@ -561,7 +561,6 @@ class L3_HA_NAT_db_mixin(l3_dvr_db.L3_NAT_with_dvr_db_mixin,
for agent in self.get_l3_agents_hosting_routers(context, [router_id]):
self.remove_router_from_l3_agent(context, agent['id'], router_id)
@db_api.CONTEXT_READER
def get_ha_router_port_bindings(self, context, router_ids, host=None):
if not router_ids:
return []

View File

@ -60,6 +60,7 @@ class IpAvailabilityMixin(object):
total_ips_columns.append(mod.IPAllocationPool.last_ip)
@classmethod
@db_api.CONTEXT_READER
def get_network_ip_availabilities(cls, context, filters=None):
"""Get IP availability stats on a per subnet basis.

View File

@ -42,6 +42,7 @@ def add_model_for_resource(resource, model):
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_WRITER
def add_provisioning_component(context, object_id, object_type, entity):
"""Adds a provisioning block by an entity to a given object.
@ -77,6 +78,7 @@ def add_provisioning_component(context, object_id, object_type, entity):
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_WRITER
def remove_provisioning_component(context, object_id, object_type, entity,
standard_attr_id=None):
"""Remove a provisioning block for an object without triggering a callback.
@ -125,26 +127,30 @@ def provisioning_complete(context, object_id, object_type, entity):
# tricking us into thinking there are remaining provisioning components
if utils.is_session_active(context.session):
raise RuntimeError(_("Must not be called in a transaction"))
standard_attr_id = _get_standard_attr_id(context, object_id,
object_type)
if not standard_attr_id:
return
if remove_provisioning_component(context, object_id, object_type, entity,
standard_attr_id):
LOG.debug("Provisioning for %(otype)s %(oid)s completed by entity "
"%(entity)s.", log_dict)
# now with that committed, check if any records are left. if None, emit
# an event that provisioning is complete.
if not pb_obj.ProvisioningBlock.objects_exist(
context, standard_attr_id=standard_attr_id):
LOG.debug("Provisioning complete for %(otype)s %(oid)s triggered by "
"entity %(entity)s.", log_dict)
registry.publish(object_type, PROVISIONING_COMPLETE, entity,
payload=events.DBEventPayload(
context, resource_id=object_id))
with db_api.CONTEXT_WRITER.using(context):
standard_attr_id = _get_standard_attr_id(context, object_id,
object_type)
if not standard_attr_id:
return
if remove_provisioning_component(context, object_id, object_type,
entity, standard_attr_id):
LOG.debug("Provisioning for %(otype)s %(oid)s completed by entity "
"%(entity)s.", log_dict)
# now with that committed, check if any records are left. if None, emit
# an event that provisioning is complete.
if pb_obj.ProvisioningBlock.objects_exist(
context, standard_attr_id=standard_attr_id):
return
LOG.debug("Provisioning complete for %(otype)s %(oid)s triggered by "
"entity %(entity)s.", log_dict)
registry.publish(object_type, PROVISIONING_COMPLETE, entity,
payload=events.DBEventPayload(
context, resource_id=object_id))
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def is_object_blocked(context, object_id, object_type):
"""Return boolean indicating if object has a provisioning block.

View File

@ -44,6 +44,7 @@ class ReservationInfo(collections.namedtuple(
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_quota_usage_by_resource_and_project(context, resource, project_id):
"""Return usage info for a given resource and project.

View File

@ -441,6 +441,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
"""Server-side RPC mixin using DB for SG notifications and responses."""
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_sg_ids_for_ports(self, context, ports):
if not ports:
return []
@ -451,6 +452,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return query.all()
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_rules_for_ports(self, context, ports):
if not ports:
return []
@ -467,6 +469,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return query.all()
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_ips_for_remote_group(self, context, remote_group_ids):
ips_by_group = {}
if not remote_group_ids:
@ -507,6 +510,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return ips_by_group
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_ips_for_remote_address_group(self, context,
remote_address_group_ids):
ips_by_group = {}

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
@ -53,6 +54,7 @@ class AddressScope(rbac_db.NeutronRbacObject):
}
@classmethod
@db_api.CONTEXT_READER
def get_network_address_scope(cls, context, network_id, ip_version):
query = context.session.query(cls.db_model)
query = query.join(

View File

@ -13,6 +13,7 @@
# under the License.
from neutron_lib import constants as const
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron_lib.objects import utils as obj_utils
from oslo_utils import versionutils
@ -122,11 +123,10 @@ class Agent(base.NeutronDbObject):
group_by(agent_model.Agent).
filter(agent_model.Agent.id.in_(agent_ids)).
order_by('count'))
agents = [cls._load_object(context, record[0]) for record in query]
return agents
return [cls._load_object(context, record[0]) for record in query]
@classmethod
@db_api.CONTEXT_READER
def get_ha_agents(cls, context, network_id=None, router_id=None):
if not (network_id or router_id):
return []
@ -154,7 +154,8 @@ class Agent(base.NeutronDbObject):
return agents
@classmethod
def _get_agents_by_availability_zones_and_agent_type(
@db_api.CONTEXT_READER
def get_agents_by_availability_zones_and_agent_type(
cls, context, agent_type, availability_zones):
query = context.session.query(
agent_model.Agent).filter_by(

View File

@ -16,12 +16,15 @@ from collections import abc as collections_abc
import copy
import functools
import itertools
import sys
import traceback
from neutron_lib.db import api as db_api
from neutron_lib.db import standard_attr
from neutron_lib import exceptions as n_exc
from neutron_lib.objects import exceptions as o_exc
from neutron_lib.objects.extensions import standardattributes
from oslo_config import cfg
from oslo_db import exception as obj_exc
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import utils as db_utils
@ -39,10 +42,38 @@ from neutron.objects.db import api as obj_db_api
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
_NO_DB_MODEL = object()
# NOTE(ralonsoh): this is a method evaluated anytime an ORM session is
# executing a SQL transaction.
# If "autocommit" is disabled (the default value in SQLAlchemy 1.4 and the
# only value in SQLAlchemy 2.0) and there is not active transaction, that
# means the SQL transaction is being run on an "implicit transaction". Under
# autocommit, this transaction is created, executed and discarded immediately;
# under non-autocommit, a transaction must be explicitly created
# (writer/reader) and sticks open.
# This evaluation is done only in debug mode to monitor the Neutron code
# compliance to SQLAlchemy 2.0.
def do_orm_execute(orm_execute_state):
if not orm_execute_state.session.in_transaction():
trace_string = '\n'.join(traceback.format_stack(sys._getframe(1)))
LOG.warning('ORM session: SQL execution without transaction in '
'progress, traceback:\n%s', trace_string)
try:
_debug = cfg.CONF.debug
except cfg.NoSuchOptError:
_debug = False
if _debug:
db_api.sqla_listen(orm.Session, 'do_orm_execute', do_orm_execute)
def get_object_class_by_model(model):
for obj_class in NeutronObjectRegistry.obj_classes().values():
obj_class = obj_class[0]
@ -919,6 +950,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject):
self._captured_db_model = None
@classmethod
@db_api.CONTEXT_READER
def count(cls, context, validate_filters=True, **kwargs):
"""Count the number of objects matching filtering criteria.
@ -935,6 +967,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject):
)
@classmethod
@db_api.CONTEXT_READER
def objects_exist(cls, context, validate_filters=True, **kwargs):
"""Check if objects are present in DB.

View File

@ -13,6 +13,7 @@
# TODO(ihrachys): cover the module with functional tests targeting supported
# backends
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query
from neutron_lib import exceptions as n_exc
from neutron_lib.objects import utils as obj_utils
@ -34,6 +35,7 @@ def get_object(obj_cls, context, **kwargs):
return _get_filter_query(obj_cls, context, **kwargs).first()
@db_api.CONTEXT_READER
def count(obj_cls, context, query_field=None, query_limit=None, **kwargs):
if not query_field and obj_cls.primary_keys:
query_field = obj_cls.primary_keys[0]

View File

@ -13,6 +13,7 @@
# under the License.
from neutron_lib import constants
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
@ -40,6 +41,7 @@ class L3HARouterAgentPortBinding(base.NeutronDbObject):
fields_no_update = ['router_id', 'port_id']
@classmethod
@db_api.CONTEXT_READER
def get_l3ha_filter_host_router(cls, context, router_ids, host):
query = context.session.query(l3ha.L3HARouterAgentPortBinding)

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
import sqlalchemy as sa
@ -42,6 +43,7 @@ class RouterL3AgentBinding(base.NeutronDbObject):
# TODO(ihrachys) return OVO objects not models
# TODO(ihrachys) move under Agent object class
@classmethod
@db_api.CONTEXT_READER
def get_l3_agents_by_router_ids(cls, context, router_ids):
query = context.session.query(l3agent.RouterL3AgentBinding)
query = query.options(joinedload('l3_agent')).filter(
@ -49,6 +51,7 @@ class RouterL3AgentBinding(base.NeutronDbObject):
return [db_obj.l3_agent for db_obj in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_down_router_bindings(cls, context, cutoff):
query = (context.session.query(
l3agent.RouterL3AgentBinding).

View File

@ -15,6 +15,7 @@
import itertools
import netaddr
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron.db.models import l3
@ -128,21 +129,23 @@ class PortForwarding(base.NeutronDbObject):
return result
@classmethod
@db_api.CONTEXT_READER
def get_port_forwarding_obj_by_routers(cls, context, router_ids):
query = context.session.query(cls.db_model, l3.FloatingIP)
query = query.join(l3.FloatingIP,
cls.db_model.floatingip_id == l3.FloatingIP.id)
query = query.filter(l3.FloatingIP.router_id.in_(router_ids))
return cls._unique_port_forwarding_iterator(query)
return cls._unique_port_forwarding(query)
@classmethod
def _unique_port_forwarding_iterator(cls, query):
@staticmethod
def _unique_port_forwarding(query):
q = query.order_by(l3.FloatingIP.router_id)
keyfunc = lambda row: row[1]
group_iterator = itertools.groupby(q, keyfunc)
result = []
for key, value in group_iterator:
for row in value:
yield (row[1]['router_id'], row[1]['floating_ip_address'],
row[0]['id'], row[1]['id'])
result.extend([(row[1]['router_id'], row[1]['floating_ip_address'],
row[0]['id'], row[1]['id']) for row in value])
return result

View File

@ -243,6 +243,7 @@ class IPAllocation(base.NeutronDbObject):
alloc_obj.delete()
@classmethod
@db_api.CONTEXT_READER
def get_alloc_routerports(cls, context, subnet_id, gateway_ip=None,
first=False):
alloc_qry = context.session.query(cls.db_model.port_id)
@ -466,6 +467,7 @@ class Port(base.NeutronDbObject):
return port_array
@classmethod
@db_api.CONTEXT_READER
def get_auto_deletable_port_ids_and_proper_port_count_by_segment(
cls, context, segment_id):
@ -584,6 +586,7 @@ class Port(base.NeutronDbObject):
primitive.pop('device_profile', None)
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_router_and_network(cls, context, router_id, owner,
network_id):
"""Returns port objects filtering by router ID, owner and network ID"""
@ -593,6 +596,7 @@ class Port(base.NeutronDbObject):
rports_filter, router_filter)
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_router_and_port(cls, context, router_id, owner, port_id):
"""Returns port objects filtering by router ID, owner and port ID"""
rports_filter = (l3.RouterPort.port_id == port_id, )
@ -645,6 +649,7 @@ class Port(base.NeutronDbObject):
return ports_rports
@classmethod
@db_api.CONTEXT_READER
def get_ports_ids_by_security_groups(cls, context, security_group_ids,
excluded_device_owners=None):
query = context.session.query(sg_models.SecurityGroupPortBinding)
@ -658,6 +663,7 @@ class Port(base.NeutronDbObject):
return [port_binding['port_id'] for port_binding in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_host(cls, context, host):
query = context.session.query(models_v2.Port.id).join(
ml2_models.PortBinding)
@ -666,6 +672,7 @@ class Port(base.NeutronDbObject):
return [port_id[0] for port_id in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_binding_type_and_host(cls, context,
binding_type, host):
query = context.session.query(models_v2.Port).join(
@ -676,6 +683,7 @@ class Port(base.NeutronDbObject):
return [cls._load_object(context, db_obj) for db_obj in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_vnic_type_and_host(
cls, context, vnic_type, host):
query = context.session.query(models_v2.Port).join(
@ -686,6 +694,7 @@ class Port(base.NeutronDbObject):
return [cls._load_object(context, db_obj) for db_obj in query.all()]
@classmethod
@db_api.CONTEXT_READER
def check_network_ports_by_binding_types(
cls, context, network_id, binding_types, negative_search=False):
"""This method is to check whether networks have ports with given
@ -710,6 +719,7 @@ class Port(base.NeutronDbObject):
return bool(query.count())
@classmethod
@db_api.CONTEXT_READER
def get_ports_allocated_by_subnet_id(cls, context, subnet_id):
"""Return ports with fixed IPs in a subnet"""
return context.session.query(models_v2.Port).filter(
@ -731,3 +741,11 @@ class Port(base.NeutronDbObject):
for _binding in port.bindings:
if _binding.get('profile', {}).get('pci_slot') == pci_slot:
return port
@classmethod
@db_api.CONTEXT_READER
def get_gateway_port_ids_by_network(cls, context, network_id):
gw_ports = context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=network_id)
return [gw_port[0] for gw_port in gw_ports]

View File

@ -15,6 +15,7 @@
import abc
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from sqlalchemy import and_
from sqlalchemy import exists
@ -55,6 +56,7 @@ class QosPolicyPortBinding(base.NeutronDbObject, _QosPolicyBindingMixin):
_bound_model_id = db_model.port_id
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_network_id(cls, context, network_id, policy_id=None):
query = context.session.query(models_v2.Port).filter(
models_v2.Port.network_id == network_id)
@ -103,6 +105,7 @@ class QosPolicyFloatingIPBinding(base.NeutronDbObject, _QosPolicyBindingMixin):
_bound_model_id = db_model.fip_id
@classmethod
@db_api.CONTEXT_READER
def get_fips_by_network_id(cls, context, network_id, policy_id=None):
"""Return the FIP belonging to a network, filtered by a QoS policy

View File

@ -15,6 +15,7 @@
import abc
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_utils import versionutils
from oslo_versionedobjects import fields as obj_fields
@ -39,6 +40,7 @@ class RBACBaseObject(base.NeutronDbObject, metaclass=abc.ABCMeta):
fields_no_update = ['id', 'project_id', 'object_id']
@classmethod
@db_api.CONTEXT_READER
def get_projects(cls, context, object_id=None, action=None,
target_project=None):
clauses = []

View File

@ -18,6 +18,7 @@ import itertools
from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources
from neutron_lib.db import api as db_api
from neutron_lib import exceptions
from sqlalchemy import and_
@ -104,6 +105,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
rbac_db_model.target_project != '*'))))
@classmethod
@db_api.CONTEXT_READER
def _validate_rbac_policy_delete(cls, context, obj_id, target_project):
ctx_admin = context.elevated()
rb_model = cls.rbac_db_cls.db_model
@ -147,13 +149,14 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
if policy['action'] != models.ACCESS_SHARED:
return
target_project = policy['target_project']
db_obj = obj_db_api.get_object(
cls, context.elevated(), id=policy['object_id'])
elevated_context = context.elevated()
with db_api.CONTEXT_READER.using(elevated_context):
db_obj = obj_db_api.get_object(cls, elevated_context,
id=policy['object_id'])
if db_obj.project_id == target_project:
return
cls._validate_rbac_policy_delete(context=context,
obj_id=policy['object_id'],
target_project=target_project)
cls._validate_rbac_policy_delete(context, policy['object_id'],
target_project)
@classmethod
def validate_rbac_policy_create(cls, resource, event, trigger,
@ -199,8 +202,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
# (hopefully) melded with this one.
if object_type != cls.rbac_db_cls.db_model.object_type:
return
db_obj = obj_db_api.get_object(
cls, context.elevated(), id=policy['object_id'])
elevated_context = context.elevated()
with db_api.CONTEXT_READER.using(elevated_context):
db_obj = obj_db_api.get_object(cls, elevated_context,
id=policy['object_id'])
if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE):
if (not context.is_admin and
db_obj['project_id'] != context.project_id):
@ -225,23 +230,23 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
def update_shared(self, is_shared_new, obj_id):
admin_context = self.obj_context.elevated()
shared_prev = obj_db_api.get_object(self.rbac_db_cls, admin_context,
object_id=obj_id,
target_project='*',
action=models.ACCESS_SHARED)
is_shared_prev = bool(shared_prev)
if is_shared_prev == is_shared_new:
return
with db_api.CONTEXT_WRITER.using(admin_context):
shared_prev = obj_db_api.get_object(
self.rbac_db_cls, admin_context, object_id=obj_id,
target_project='*', action=models.ACCESS_SHARED)
is_shared_prev = bool(shared_prev)
if is_shared_prev == is_shared_new:
return
# 'shared' goes False -> True
if not is_shared_prev and is_shared_new:
self.attach_rbac(obj_id, self.obj_context.project_id)
return
# 'shared' goes False -> True
if not is_shared_prev and is_shared_new:
self.attach_rbac(obj_id, self.obj_context.project_id)
return
# 'shared' goes True -> False is actually an attempt to delete
# rbac rule for sharing obj_id with target_project = '*'
self._validate_rbac_policy_delete(self.obj_context, obj_id, '*')
return self.obj_context.session.delete(shared_prev)
# 'shared' goes True -> False is actually an attempt to delete
# rbac rule for sharing obj_id with target_project = '*'
self._validate_rbac_policy_delete(self.obj_context, obj_id, '*')
return self.obj_context.session.delete(shared_prev)
def from_db_object(self, db_obj):
self._load_shared(db_obj)

View File

@ -17,6 +17,7 @@ import netaddr
from neutron_lib.api.definitions import availability_zone as az_def
from neutron_lib.api.validators import availability_zone as az_validator
from neutron_lib import constants as n_const
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils
@ -108,6 +109,7 @@ class RouterExtraAttributes(base.NeutronDbObject):
return result
@classmethod
@db_api.CONTEXT_READER
def get_router_agents_count(cls, context):
# TODO(sshank): This is pulled out from l3_agentschedulers_db.py
# until a way to handle joins is figured out.
@ -146,6 +148,7 @@ class RouterPort(base.NeutronDbObject):
}
@classmethod
@db_api.CONTEXT_READER
def get_router_ids_by_subnetpool(cls, context, subnetpool_id):
query = context.session.query(l3.RouterPort.router_id)
query = query.join(models_v2.Port)
@ -216,6 +219,7 @@ class Router(base.NeutronDbObject):
fields_no_update = ['project_id']
@classmethod
@db_api.CONTEXT_READER
def check_routers_not_owned_by_projects(cls, context, gw_ports, projects):
"""This method is to check whether routers that aren't owned by
existing projects or not
@ -332,6 +336,7 @@ class FloatingIP(base.NeutronDbObject):
primitive.pop('qos_network_policy_id', None)
@classmethod
@db_api.CONTEXT_READER
def get_scoped_floating_ips(cls, context, router_ids):
query = context.session.query(l3.FloatingIP,
models_v2.SubnetPool.address_scope_id)
@ -366,6 +371,7 @@ class FloatingIP(base.NeutronDbObject):
yield (cls._load_object(context, row[0]), row[1])
@classmethod
@db_api.CONTEXT_READER
def get_disassociated_ids_for_net(cls, context, network_id):
query = context.session.query(cls.db_model.id)
query = query.filter_by(

View File

@ -11,6 +11,7 @@
# under the License.
from neutron_lib import context as context_lib
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils
@ -239,11 +240,14 @@ class SecurityGroupRule(base.NeutronDbObject):
- The rule belongs to a security group that belongs to the project_id
"""
context = context_lib.get_admin_context()
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,
cls.db_model.security_group_id == SecurityGroup.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]
# NOTE(ralonsoh): do no use a READER decorator in this method. Elevated
# permissions are needed here.
with db_api.CONTEXT_READER.using(context):
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,
cls.db_model.security_group_id == SecurityGroup.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]

View File

@ -14,6 +14,7 @@
# under the License.
import netaddr
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
@ -123,21 +124,22 @@ class SubnetPool(rbac_db.NeutronRbacObject):
# Nothing to validate
return
rbac_as_model = rbac_db_models.AddressScopeRBAC
with db_api.CONTEXT_READER.using(context):
rbac_as_model = rbac_db_models.AddressScopeRBAC
# Ensure that target project has access to AS
shared_to_target_project_or_to_all = (
sa.and_(
rbac_as_model.target_project.in_(
["*", policy['target_project']]
),
rbac_as_model.object_id == db_obj["address_scope_id"]
# Ensure that target project has access to AS
shared_to_target_project_or_to_all = (
sa.and_(
rbac_as_model.target_project.in_(
["*", policy['target_project']]
),
rbac_as_model.object_id == db_obj["address_scope_id"]
)
)
)
matching_policies = model_query.query_with_hooks(
context, rbac_db_models.AddressScopeRBAC
).filter(shared_to_target_project_or_to_all).count()
matching_policies = model_query.query_with_hooks(
context, rbac_db_models.AddressScopeRBAC
).filter(shared_to_target_project_or_to_all).count()
if matching_policies == 0:
raise ext_rbac.RbacPolicyInitError(

View File

@ -315,10 +315,9 @@ def _prevent_segment_delete_with_port_bound(resource, event, trigger,
# don't check for network deletes
return
with db_api.CONTEXT_READER.using(payload.context):
auto_delete_port_ids, proper_port_count = port_obj.Port.\
get_auto_deletable_port_ids_and_proper_port_count_by_segment(
payload.context, segment_id=payload.resource_id)
auto_delete_port_ids, proper_port_count = port_obj.Port.\
get_auto_deletable_port_ids_and_proper_port_count_by_segment(
payload.context, segment_id=payload.resource_id)
if proper_port_count:
reason = (_("The segment is still bound with %s port(s)") %

View File

@ -345,44 +345,54 @@ class EndpointTunnelTypeDriver(ML2TunnelTypeDriver):
def get_endpoint_by_host(self, host):
LOG.debug("get_endpoint_by_host() called for host %s", host)
session = db_api.get_reader_session()
return (session.query(self.endpoint_model).
filter_by(host=host).first())
ctx = context.get_admin_context()
with db_api.CONTEXT_READER.using(ctx):
return (ctx.session.query(self.endpoint_model).
filter_by(host=host).first())
def get_endpoint_by_ip(self, ip):
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
session = db_api.get_reader_session()
return (session.query(self.endpoint_model).
filter_by(ip_address=ip).first())
ctx = context.get_admin_context()
with db_api.CONTEXT_READER.using(ctx):
return (ctx.session.query(self.endpoint_model).
filter_by(ip_address=ip).first())
def delete_endpoint(self, ip):
LOG.debug("delete_endpoint() called for ip %s", ip)
session = db_api.get_writer_session()
session.query(self.endpoint_model).filter_by(ip_address=ip).delete()
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
ctx.session.query(self.endpoint_model).filter_by(
ip_address=ip).delete()
def delete_endpoint_by_host_or_ip(self, host, ip):
LOG.debug("delete_endpoint_by_host_or_ip() called for "
"host %(host)s or %(ip)s", {'host': host, 'ip': ip})
session = db_api.get_writer_session()
session.query(self.endpoint_model).filter(
or_(self.endpoint_model.host == host,
self.endpoint_model.ip_address == ip)).delete()
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
ctx.session.query(self.endpoint_model).filter(
or_(self.endpoint_model.host == host,
self.endpoint_model.ip_address == ip)).delete()
def _get_endpoints(self):
LOG.debug("_get_endpoints() called")
session = db_api.get_reader_session()
return session.query(self.endpoint_model)
ctx = context.get_admin_context()
with db_api.CONTEXT_READER.using(ctx):
return ctx.session.query(self.endpoint_model).all()
def _add_endpoint(self, ip, host, **kwargs):
LOG.debug("_add_endpoint() called for ip %s", ip)
session = db_api.get_writer_session()
ctx = context.get_admin_context()
try:
endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs)
endpoint.save(session)
with db_api.CONTEXT_WRITER.using(ctx):
endpoint = self.endpoint_model(ip_address=ip, host=host,
**kwargs)
endpoint.save(ctx.session)
except db_exc.DBDuplicateEntry:
endpoint = (session.query(self.endpoint_model).
filter_by(ip_address=ip).one())
LOG.warning("Endpoint with ip %s already exists", ip)
with db_api.CONTEXT_READER.using(ctx):
endpoint = (ctx.session.query(self.endpoint_model).
filter_by(ip_address=ip).one())
LOG.warning("Endpoint with ip %s already exists", ip)
return endpoint

View File

@ -1993,12 +1993,13 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
@utils.transaction_guard
@db_api.retry_if_session_inactive()
def delete_port(self, context, id, l3_port_check=True):
try:
port_db = self._get_port(context, id)
port = self._make_port_dict(port_db)
except exc.PortNotFound:
LOG.debug("The port '%s' was deleted", id)
return
with db_api.CONTEXT_READER.using(context):
try:
port_db = self._get_port(context, id)
port = self._make_port_dict(port_db)
except exc.PortNotFound:
LOG.debug("The port '%s' was deleted", id)
return
self._pre_delete_port(context, id, l3_port_check, port)
# TODO(armax): get rid of the l3 dependency in the with block

View File

@ -268,6 +268,7 @@ class TrackedResource(BaseResource):
# Update quota usage
return self._resync(context, project_id, in_use)
@db_api.CONTEXT_WRITER
def count_used(self, context, project_id, resync_usage=True):
"""Returns the current usage count for the resource.

View File

@ -50,6 +50,7 @@ class TagPlugin(tagging.TagPluginBase):
tags = [tag_db.tag for tag_db in db_data.standard_attr.tags]
response_data['tags'] = tags
@db_api.CONTEXT_READER
def _get_resource(self, context, resource, resource_id):
model = resource_model_map[resource]
try:

View File

@ -176,24 +176,24 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
def test_ensure_no_port_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, '*')
self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_no_port_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1)
self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_port_in_tenant_2(self):
self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_2)
self.cxt, self.network_id, self.tenant_1, self.tenant_2)
def test_ensure_port_tenant_1_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True)
self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_1, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, '*')
self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_port_tenant_2_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True)
@ -201,21 +201,21 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
self._create_port(self.tenant_2, self.network_id, self.port_id)
self.assertRaises(n_exc.InvalidSharedSetting,
self.plugin.ensure_no_tenant_ports_on_network,
self.network_id, self.tenant_1, '*')
self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_port_tenant_1_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, True)
self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_1, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1)
self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_share_port_tenant_2_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, False)
self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_2, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1)
self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_share_port_tenant_2_in_tenant_2(self):
self._create_network(self.tenant_1, self.network_id, False)
@ -223,4 +223,5 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
self._create_port(self.tenant_2, self.network_id, self.port_id)
self.assertRaises(n_exc.InvalidSharedSetting,
self.plugin.ensure_no_tenant_ports_on_network,
self.network_id, self.tenant_1, self.tenant_2)
self.cxt, self.network_id, self.tenant_1,
self.tenant_2)

View File

@ -249,9 +249,7 @@ class RbacNeutronDbObjectTestCase(test_rbac.RBACBaseObjectIfaceTestCase,
'_get_projects_with_shared_access_to_db_obj') as sh_tids:
get_rbac_entries_mock.filter.return_value.count.return_value = 0
self._test_class._validate_rbac_policy_delete(
context=context,
obj_id='fake_obj_id',
target_project='fake_tid1')
context, 'fake_obj_id', 'fake_tid1')
sh_tids.assert_not_called()
@mock.patch.object(_test_class, '_get_db_obj_rbac_entries')