Allow objects to opt in new engine facade

New facade is enabled by setting new_facade = True for the object of
interest. With new_facade on, all OVO actions will use the new reader /
writer decorator to activate sessions.

There are two new facade decorators added to OVO: db_context_reader and
db_context_write that should be used instead of explicit
autonested_transaction / reader.using / writer.using in OVO context.

All neutron.objects.db.api helpers now receive OVO classes / objects
instead of model classes, since they need to know which type of engine
facade to use for which object. While it means we change signatures for
those helper functions, they are not used anywhere outside neutron tree
except vmware-nsx unit tests, and the latter pass anyway because the
tests completely mock out them disregarding their signatures.

This patch also adds several new OVO objects to be able to continue
using neutron.objects.db.api helpers to persist models that previously
didn't have corresponding OVO classes.

Finally, the patch adds registration for missing options in
neutron/tests/unit/extensions/test_qos_fip.py to be able to debug
failures in those unit tests. Strictly speaking, this change doesn't
belong to the patch, but I include it nevertheless to speed up merge in
time close to release.

There are several non-obvious changes included, specifically:

- in neutron.objects.base, decorator() that refreshes / expunges models
from the active session now opens a subtransaction for the whole span of
call / refresh / expunge, so that we can safely refresh model regardless
of whether caller opened another parent subtransaction (it was not the
case for create_subnetpool in base db plugin code).

- in neutron.db.l3_fip_qos, removed code that updates obj.db_model
relationship directly after corresponding insertions for child policy
binding model. This code is not needed because the only caller to the
_process_extra_fip_qos_update method refetches latest state of floating
ip OVO object anyway, and this code triggers several unit test failures.

- unit tests checking that a single commit happens for get_object and
get_objects are no longer valid for new facade objects that use reader
decorator that doesn't commit but close. This change is as intended, so
unit tests were tweaked to check close for new facade objects.

Change-Id: I15ec238c18a464f977f7d1079605b82965052311
Related-Bug: #1746996
This commit is contained in:
Ihar Hrachyshka 2018-02-06 18:12:38 -08:00 committed by Armando Migliaccio
parent a0d705c0e8
commit 6f83466307
25 changed files with 425 additions and 255 deletions

View File

@ -322,6 +322,43 @@ model, the nullable parameter is by default :code:`True`, while for OVO fields,
the nullable is set to :code:`False`. Make sure you correctly map database
column nullability properties to relevant object fields.
Database session activation
---------------------------
By default, all objects use old ``oslo.db`` engine facade. To enable the new
facade for a particular object, set ``new_facade`` class attribute to ``True``:
.. code-block:: Python
@obj_base.VersionedObjectRegistry.register
class ExampleObject(base.NeutronDbObject):
new_facade = True
It will make all OVO actions - ``get_object``, ``update``, ``count`` etc. - to
use new ``reader.using`` or ``writer.using`` decorators to manage database
transactions.
Whenever you need to open a new subtransaction in scope of OVO code, use the
following database session decorators:
.. code-block:: Python
@obj_base.VersionedObjectRegistry.register
class ExampleObject(base.NeutronDbObject):
@classmethod
def get_object(cls, context, **kwargs):
with cls.db_context_reader(context):
super(ExampleObject, cls).get_object(context, **kwargs)
# fetch more data in the same transaction
def create(self):
with self.db_context_writer(self.obj_context):
super(ExampleObject, self).create()
# apply more changes in the same transaction
``db_context_reader`` and ``db_context_writer`` decorators abstract the choice
of engine facade used for particular object from action implementation.
Synthetic fields
----------------

View File

@ -44,8 +44,7 @@ class FloatingQoSDbMixin(object):
def _create_fip_qos_db(self, context, fip_id, policy_id):
policy = self._get_policy_obj(context, policy_id)
policy.attach_floatingip(fip_id)
binding_db_obj = obj_db_api.get_object(
context, policy.fip_binding_model, fip_id=fip_id)
binding_db_obj = obj_db_api.get_object(policy, context, fip_id=fip_id)
return binding_db_obj
def _delete_fip_qos_db(self, context, fip_id, policy_id):
@ -73,14 +72,7 @@ class FloatingQoSDbMixin(object):
self._delete_fip_qos_db(context,
floatingip_obj['id'],
old_qos_policy_id)
if floatingip_obj.db_obj.qos_policy_binding:
floatingip_obj.db_obj.qos_policy_binding['policy_id'] = (
new_qos_policy_id)
if not new_qos_policy_id:
return
qos_policy_binding = self._create_fip_qos_db(
context,
floatingip_obj['id'],
new_qos_policy_id)
if not floatingip_obj.db_obj.qos_policy_binding:
floatingip_obj.db_obj.qos_policy_binding = qos_policy_binding
self._create_fip_qos_db(
context, floatingip_obj['id'], new_qos_policy_id)

View File

@ -89,7 +89,7 @@ class Agent(base.NeutronDbObject):
@classmethod
def get_l3_agent_with_min_routers(cls, context, agent_ids):
"""Return l3 agent with the least number of routers."""
with context.session.begin(subtransactions=True):
with cls.db_context_reader(context):
query = context.session.query(
agent_model.Agent,
func.count(
@ -105,7 +105,7 @@ class Agent(base.NeutronDbObject):
@classmethod
def get_l3_agents_ordered_by_num_routers(cls, context, agent_ids):
with context.session.begin(subtransactions=True):
with cls.db_context_reader(context):
query = (context.session.query(agent_model.Agent, func.count(
rb_model.RouterL3AgentBinding.router_id)
.label('count')).

View File

@ -81,7 +81,7 @@ class Pager(object):
self.page_reverse = page_reverse
self.marker = marker
def to_kwargs(self, context, model):
def to_kwargs(self, context, obj_cls):
res = {
attr: getattr(self, attr)
for attr in ('sorts', 'limit', 'page_reverse')
@ -89,7 +89,7 @@ class Pager(object):
}
if self.marker and self.limit:
res['marker_obj'] = obj_db_api.get_object(
context, model, id=self.marker)
obj_cls, context, id=self.marker)
return res
def __str__(self):
@ -310,16 +310,16 @@ def _detach_db_obj(func):
@functools.wraps(func)
def decorator(self, *args, **kwargs):
synthetic_changed = bool(self._get_changed_synthetic_fields())
res = func(self, *args, **kwargs)
# some relationship based fields may be changed since we
# captured the model, let's refresh it for the latest database
# state
if synthetic_changed:
# TODO(ihrachys) consider refreshing just changed attributes
self.obj_context.session.refresh(self.db_obj)
# detach the model so that consequent fetches don't reuse it
self.obj_context.session.expunge(self.db_obj)
return res
with self.db_context_writer(self.obj_context):
res = func(self, *args, **kwargs)
# some relationship based fields may be changed since we captured
# the model, let's refresh it for the latest database state
if synthetic_changed:
# TODO(ihrachys) consider refreshing just changed attributes
self.obj_context.session.refresh(self.db_obj)
# detach the model so that consequent fetches don't reuse it
self.obj_context.session.expunge(self.db_obj)
return res
return decorator
@ -390,6 +390,12 @@ class NeutronDbObject(NeutronObject):
# should be overridden for all persistent objects
db_model = None
# should be overridden for all rbac aware objects
rbac_db_cls = None
# whether to use new engine facade for the object
new_facade = False
primary_keys = ['id']
# 'unique_keys' is a list of unique keys that can be used with get_object
@ -512,6 +518,20 @@ class NeutronDbObject(NeutronObject):
if is_attr_nullable:
self[attrname] = None
@classmethod
def db_context_writer(cls, context):
"""Return read-write session activation decorator."""
if cls.new_facade:
return db_api.context_manager.writer.using(context)
return db_api.autonested_transaction(context.session)
@classmethod
def db_context_reader(cls, context):
"""Return read-only session activation decorator."""
if cls.new_facade:
return db_api.context_manager.reader.using(context)
return db_api.autonested_transaction(context.session)
@classmethod
def get_object(cls, context, **kwargs):
"""
@ -529,11 +549,9 @@ class NeutronDbObject(NeutronObject):
raise o_exc.NeutronPrimaryKeyMissing(object_class=cls,
missing_keys=missing_keys)
with context.session.begin(subtransactions=True):
with cls.db_context_reader(context):
db_obj = obj_db_api.get_object(
context, cls.db_model,
**cls.modify_fields_to_db(kwargs)
)
cls, context, **cls.modify_fields_to_db(kwargs))
if db_obj:
return cls._load_object(context, db_obj)
@ -553,11 +571,9 @@ class NeutronDbObject(NeutronObject):
"""
if validate_filters:
cls.validate_filters(**kwargs)
with context.session.begin(subtransactions=True):
with cls.db_context_reader(context):
db_objs = obj_db_api.get_objects(
context, cls.db_model, _pager=_pager,
**cls.modify_fields_to_db(kwargs)
)
cls, context, _pager=_pager, **cls.modify_fields_to_db(kwargs))
return [cls._load_object(context, db_obj) for db_obj in db_objs]
@classmethod
@ -582,9 +598,9 @@ class NeutronDbObject(NeutronObject):
return super(NeutronDbObject, cls).update_object(
context, values, validate_filters=False, **kwargs)
else:
with db_api.autonested_transaction(context.session):
with cls.db_context_writer(context):
db_obj = obj_db_api.update_object(
context, cls.db_model,
cls, context,
cls.modify_fields_to_db(values),
**cls.modify_fields_to_db(kwargs))
return cls._load_object(context, db_obj)
@ -604,15 +620,14 @@ class NeutronDbObject(NeutronObject):
if validate_filters:
cls.validate_filters(**kwargs)
# if we have standard attributes, we will need to fetch records to
# update revision numbers
if cls.has_standard_attributes():
return super(NeutronDbObject, cls).update_objects(
context, values, validate_filters=False, **kwargs)
with db_api.autonested_transaction(context.session):
with cls.db_context_writer(context):
# if we have standard attributes, we will need to fetch records to
# update revision numbers
if cls.has_standard_attributes():
return super(NeutronDbObject, cls).update_objects(
context, values, validate_filters=False, **kwargs)
return obj_db_api.update_objects(
context, cls.db_model,
cls, context,
cls.modify_fields_to_db(values),
**cls.modify_fields_to_db(kwargs))
@ -629,9 +644,9 @@ class NeutronDbObject(NeutronObject):
"""
if validate_filters:
cls.validate_filters(**kwargs)
with context.session.begin(subtransactions=True):
with cls.db_context_writer(context):
return obj_db_api.delete_objects(
context, cls.db_model, **cls.modify_fields_to_db(kwargs))
cls, context, **cls.modify_fields_to_db(kwargs))
@classmethod
def is_accessible(cls, context, db_obj):
@ -748,11 +763,10 @@ class NeutronDbObject(NeutronObject):
def create(self):
fields = self._get_changed_persistent_fields()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
try:
db_obj = obj_db_api.create_object(
self.obj_context, self.db_model,
self.modify_fields_to_db(fields))
self, self.obj_context, self.modify_fields_to_db(fields))
except obj_exc.DBDuplicateEntry as db_exc:
raise o_exc.NeutronDbObjectDuplicateEntry(
object_class=self.__class__, db_exception=db_exc)
@ -786,16 +800,16 @@ class NeutronDbObject(NeutronObject):
updates = self._get_changed_persistent_fields()
updates = self._validate_changed_fields(updates)
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
db_obj = obj_db_api.update_object(
self.obj_context, self.db_model,
self, self.obj_context,
self.modify_fields_to_db(updates),
**self.modify_fields_to_db(
self._get_composite_keys()))
self.from_db_object(db_obj)
def delete(self):
obj_db_api.delete_object(self.obj_context, self.db_model,
obj_db_api.delete_object(self, self.obj_context,
**self.modify_fields_to_db(
self._get_composite_keys()))
self._captured_db_model = None
@ -814,7 +828,7 @@ class NeutronDbObject(NeutronObject):
if validate_filters:
cls.validate_filters(**kwargs)
return obj_db_api.count(
context, cls.db_model, **cls.modify_fields_to_db(kwargs)
cls, context, **cls.modify_fields_to_db(kwargs)
)
@classmethod
@ -832,5 +846,5 @@ class NeutronDbObject(NeutronObject):
cls.validate_filters(**kwargs)
# Succeed if at least a single object matches; no need to fetch more
return bool(obj_db_api.get_object(
context, cls.db_model, **cls.modify_fields_to_db(kwargs))
cls, context, **cls.modify_fields_to_db(kwargs))
)

View File

@ -21,19 +21,20 @@ from neutron.objects import utils as obj_utils
# Common database operation implementations
def _get_filter_query(context, model, **kwargs):
with context.session.begin(subtransactions=True):
def _get_filter_query(obj_cls, context, **kwargs):
with obj_cls.db_context_reader(context):
filters = _kwargs_to_filters(**kwargs)
query = model_query.get_collection_query(context, model, filters)
query = model_query.get_collection_query(
context, obj_cls.db_model, filters)
return query
def get_object(context, model, **kwargs):
return _get_filter_query(context, model, **kwargs).first()
def get_object(obj_cls, context, **kwargs):
return _get_filter_query(obj_cls, context, **kwargs).first()
def count(context, model, **kwargs):
return _get_filter_query(context, model, **kwargs).count()
def count(obj_cls, context, **kwargs):
return _get_filter_query(obj_cls, context, **kwargs).count()
def _kwargs_to_filters(**kwargs):
@ -42,77 +43,80 @@ def _kwargs_to_filters(**kwargs):
for k, v in kwargs.items()}
def get_objects(context, model, _pager=None, **kwargs):
with context.session.begin(subtransactions=True):
def get_objects(obj_cls, context, _pager=None, **kwargs):
with obj_cls.db_context_reader(context):
filters = _kwargs_to_filters(**kwargs)
return model_query.get_collection(
context, model,
context, obj_cls.db_model,
dict_func=None, # return all the data
filters=filters,
**(_pager.to_kwargs(context, model) if _pager else {}))
**(_pager.to_kwargs(context, obj_cls) if _pager else {}))
def create_object(context, model, values, populate_id=True):
with context.session.begin(subtransactions=True):
if populate_id and 'id' not in values and hasattr(model, 'id'):
def create_object(obj_cls, context, values, populate_id=True):
with obj_cls.db_context_writer(context):
if (populate_id and
'id' not in values and
hasattr(obj_cls.db_model, 'id')):
values['id'] = uuidutils.generate_uuid()
db_obj = model(**values)
db_obj = obj_cls.db_model(**values)
context.session.add(db_obj)
return db_obj
def _safe_get_object(context, model, **kwargs):
db_obj = get_object(context, model, **kwargs)
def _safe_get_object(obj_cls, context, **kwargs):
db_obj = get_object(obj_cls, context, **kwargs)
if db_obj is None:
key = ", ".join(['%s=%s' % (key, value) for (key, value)
in kwargs.items()])
raise n_exc.ObjectNotFound(id="%s(%s)" % (model.__name__, key))
raise n_exc.ObjectNotFound(
id="%s(%s)" % (obj_cls.db_model.__name__, key))
return db_obj
def update_object(context, model, values, **kwargs):
with context.session.begin(subtransactions=True):
db_obj = _safe_get_object(context, model, **kwargs)
def update_object(obj_cls, context, values, **kwargs):
with obj_cls.db_context_writer(context):
db_obj = _safe_get_object(obj_cls, context, **kwargs)
db_obj.update(values)
db_obj.save(session=context.session)
return db_obj
def delete_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
db_obj = _safe_get_object(context, model, **kwargs)
def delete_object(obj_cls, context, **kwargs):
with obj_cls.db_context_writer(context):
db_obj = _safe_get_object(obj_cls, context, **kwargs)
context.session.delete(db_obj)
def update_objects(context, model, values, **kwargs):
def update_objects(obj_cls, context, values, **kwargs):
'''Update matching objects, if any. Return number of updated objects.
This function does not raise exceptions if nothing matches.
:param model: SQL model
:param obj_cls: Object class
:param values: values to update in matching objects
:param kwargs: multiple filters defined by key=value pairs
:return: Number of entries updated
'''
with context.session.begin(subtransactions=True):
with obj_cls.db_context_writer(context):
if not values:
return count(context, model, **kwargs)
q = _get_filter_query(context, model, **kwargs)
return count(obj_cls, context, **kwargs)
q = _get_filter_query(obj_cls, context, **kwargs)
return q.update(values, synchronize_session=False)
def delete_objects(context, model, **kwargs):
def delete_objects(obj_cls, context, **kwargs):
'''Delete matching objects, if any. Return number of deleted objects.
This function does not raise exceptions if nothing matches.
:param model: SQL model
:param obj_cls: Object class
:param kwargs: multiple filters defined by key=value pairs
:return: Number of entries deleted
'''
with context.session.begin(subtransactions=True):
db_objs = get_objects(context, model, **kwargs)
with obj_cls.db_context_writer(context):
db_objs = get_objects(obj_cls, context, **kwargs)
for db_obj in db_objs:
context.session.delete(db_obj)
return len(db_objs)

View File

@ -16,7 +16,6 @@ from neutron_lib.api.definitions import availability_zone as az_def
from neutron_lib.api.validators import availability_zone as az_validator
from oslo_versionedobjects import fields as obj_fields
from neutron.db import api as db_api
from neutron.db.models import dns as dns_models
from neutron.db.models import external_net as ext_net_model
from neutron.db.models import segment as segment_model
@ -32,6 +31,20 @@ from neutron.objects.qos import binding
from neutron.objects import rbac_db
@base.NeutronObjectRegistry.register
class NetworkRBAC(base.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
db_model = rbac_db_models.NetworkRBAC
fields = {
'object_id': obj_fields.StringField(),
'target_tenant': obj_fields.StringField(),
'action': obj_fields.StringField(),
}
@base.NeutronObjectRegistry.register
class NetworkDhcpAgentBinding(base.NeutronDbObject):
# Version 1.0: Initial version
@ -86,7 +99,7 @@ class NetworkSegment(base.NeutronDbObject):
def create(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
hosts = self.hosts
if hosts is None:
hosts = []
@ -96,7 +109,7 @@ class NetworkSegment(base.NeutronDbObject):
def update(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(NetworkSegment, self).update()
if 'hosts' in fields:
self._attach_hosts(fields['hosts'])
@ -176,7 +189,7 @@ class Network(rbac_db.NeutronRbacObject):
# Version 1.0: Initial version
VERSION = '1.0'
rbac_db_model = rbac_db_models.NetworkRBAC
rbac_db_cls = NetworkRBAC
db_model = models_v2.Network
fields = {
@ -223,7 +236,7 @@ class Network(rbac_db.NeutronRbacObject):
def create(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
dns_domain = self.dns_domain
qos_policy_id = self.qos_policy_id
super(Network, self).create()
@ -234,7 +247,7 @@ class Network(rbac_db.NeutronRbacObject):
def update(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(Network, self).update()
if 'dns_domain' in fields:
self._set_dns_domain(fields['dns_domain'])

View File

@ -18,7 +18,6 @@ from oslo_utils import versionutils
from oslo_versionedobjects import fields as obj_fields
from neutron.common import utils
from neutron.db import api as db_api
from neutron.db.models import dns as dns_models
from neutron.db.models import l3
from neutron.db.models import securitygroup as sg_models
@ -234,6 +233,22 @@ class PortDNS(base.NeutronDbObject):
primitive.pop('dns_domain', None)
@base.NeutronObjectRegistry.register
class SecurityGroupPortBinding(base.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
db_model = sg_models.SecurityGroupPortBinding
fields = {
'port_id': common_types.UUIDField(),
'security_group_id': common_types.UUIDField(),
}
primary_keys = ['port_id', 'security_group_id']
@base.NeutronObjectRegistry.register
class Port(base.NeutronDbObject):
# Version 1.0: Initial version
@ -318,7 +333,7 @@ class Port(base.NeutronDbObject):
def create(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
sg_ids = self.security_group_ids
if sg_ids is None:
sg_ids = set()
@ -331,7 +346,7 @@ class Port(base.NeutronDbObject):
def update(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(Port, self).update()
if 'security_group_ids' in fields:
self._attach_security_groups(fields['security_group_ids'])
@ -353,9 +368,7 @@ class Port(base.NeutronDbObject):
# TODO(ihrachys): consider introducing an (internal) object for the
# binding to decouple database operations a bit more
obj_db_api.delete_objects(
self.obj_context, sg_models.SecurityGroupPortBinding,
port_id=self.id,
)
SecurityGroupPortBinding, self.obj_context, port_id=self.id)
if sg_ids:
for sg_id in sg_ids:
self._attach_security_group(sg_id)
@ -364,7 +377,7 @@ class Port(base.NeutronDbObject):
def _attach_security_group(self, sg_id):
obj_db_api.create_object(
self.obj_context, sg_models.SecurityGroupPortBinding,
SecurityGroupPortBinding, self.obj_context,
{'port_id': self.id, 'security_group_id': sg_id}
)

View File

@ -22,11 +22,10 @@ from oslo_versionedobjects import exception
from oslo_versionedobjects import fields as obj_fields
from neutron.common import exceptions
from neutron.db import api as db_api
from neutron.db.models import l3
from neutron.db import models_v2
from neutron.db.qos import models as qos_db_model
from neutron.db.rbac_db_models import QosPolicyRBAC
from neutron.db import rbac_db_models
from neutron.objects import base as base_db
from neutron.objects import common_types
from neutron.objects.db import api as obj_db_api
@ -35,6 +34,20 @@ from neutron.objects.qos import rule as rule_obj_impl
from neutron.objects import rbac_db
@base_db.NeutronObjectRegistry.register
class QosPolicyRBAC(base_db.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
db_model = rbac_db_models.QosPolicyRBAC
fields = {
'object_id': obj_fields.StringField(),
'target_tenant': obj_fields.StringField(),
'action': obj_fields.StringField(),
}
@base_db.NeutronObjectRegistry.register
class QosPolicy(rbac_db.NeutronRbacObject):
# Version 1.0: Initial version
@ -48,13 +61,9 @@ class QosPolicy(rbac_db.NeutronRbacObject):
VERSION = '1.7'
# required by RbacNeutronMetaclass
rbac_db_model = QosPolicyRBAC
rbac_db_cls = QosPolicyRBAC
db_model = qos_db_model.QosPolicy
port_binding_model = qos_db_model.QosPortPolicyBinding
network_binding_model = qos_db_model.QosNetworkPolicyBinding
fip_binding_model = qos_db_model.QosFIPPolicyBinding
fields = {
'id': common_types.UUIDField(),
'project_id': obj_fields.StringField(),
@ -82,7 +91,7 @@ class QosPolicy(rbac_db.NeutronRbacObject):
return super(QosPolicy, self).obj_load_attr(attrname)
def _reload_rules(self):
rules = rule_obj_impl.get_rules(self.obj_context, self.id)
rules = rule_obj_impl.get_rules(self, self.obj_context, self.id)
setattr(self, 'rules', rules)
self.obj_reset_changes(['rules'])
@ -121,7 +130,7 @@ class QosPolicy(rbac_db.NeutronRbacObject):
# We want to get the policy regardless of its tenant id. We'll make
# sure the tenant has permission to access the policy later on.
admin_context = context.elevated()
with db_api.autonested_transaction(admin_context.session):
with cls.db_context_reader(admin_context):
policy_obj = super(QosPolicy, cls).get_object(admin_context,
**kwargs)
if (not policy_obj or
@ -138,7 +147,7 @@ class QosPolicy(rbac_db.NeutronRbacObject):
# We want to get the policy regardless of its tenant id. We'll make
# sure the tenant has permission to access the policy later on.
admin_context = context.elevated()
with db_api.autonested_transaction(admin_context.session):
with cls.db_context_reader(admin_context):
objs = super(QosPolicy, cls).get_objects(admin_context, _pager,
validate_filters,
**kwargs)
@ -152,37 +161,38 @@ class QosPolicy(rbac_db.NeutronRbacObject):
return result
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
with db_api.autonested_transaction(context.session):
binding_db_obj = obj_db_api.get_object(context, model, **kwargs)
def _get_object_policy(cls, context, binding_cls, **kwargs):
with cls.db_context_reader(context):
binding_db_obj = obj_db_api.get_object(binding_cls, context,
**kwargs)
if binding_db_obj:
return cls.get_object(context, id=binding_db_obj['policy_id'])
@classmethod
def get_network_policy(cls, context, network_id):
return cls._get_object_policy(context, cls.network_binding_model,
return cls._get_object_policy(context, binding.QosPolicyNetworkBinding,
network_id=network_id)
@classmethod
def get_port_policy(cls, context, port_id):
return cls._get_object_policy(context, cls.port_binding_model,
return cls._get_object_policy(context, binding.QosPolicyPortBinding,
port_id=port_id)
@classmethod
def get_fip_policy(cls, context, fip_id):
return cls._get_object_policy(context, cls.fip_binding_model,
fip_id=fip_id)
return cls._get_object_policy(
context, binding.QosPolicyFloatingIPBinding, fip_id=fip_id)
# TODO(QoS): Consider extending base to trigger registered methods for us
def create(self):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(QosPolicy, self).create()
if self.is_default:
self.set_default()
self.obj_load_attr('rules')
def update(self):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
if 'is_default' in self.obj_what_changed():
if self.is_default:
self.set_default()
@ -191,7 +201,7 @@ class QosPolicy(rbac_db.NeutronRbacObject):
super(QosPolicy, self).update()
def delete(self):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
for object_type, obj_class in self.binding_models.items():
pager = base_db.Pager(limit=1)
binding_obj = obj_class.get_objects(self.obj_context,
@ -322,7 +332,7 @@ class QosPolicy(rbac_db.NeutronRbacObject):
fip = l3.FloatingIP
qosfip = qos_db_model.QosFIPPolicyBinding
bound_tenants = []
with db_api.autonested_transaction(context.session):
with cls.db_context_reader(context):
bound_tenants.extend(cls._get_bound_tenant_ids(
context.session, qosnet, net, qosnet.network_id, policy_id))
bound_tenants.extend(

View File

@ -24,7 +24,6 @@ from oslo_versionedobjects import exception
from oslo_versionedobjects import fields as obj_fields
import six
from neutron.db import api as db_api
from neutron.db.qos import models as qos_db_model
from neutron.objects import base
from neutron.objects import common_types
@ -32,9 +31,9 @@ from neutron.objects import common_types
DSCP_MARK = 'dscp_mark'
def get_rules(context, qos_policy_id):
def get_rules(obj_cls, context, qos_policy_id):
all_rules = []
with db_api.autonested_transaction(context.session):
with obj_cls.db_context_reader(context):
for rule_type in qos_consts.VALID_RULE_TYPES:
rule_cls_name = 'Qos%sRule' % helpers.camelize(rule_type)
rule_cls = getattr(sys.modules[__name__], rule_cls_name)

View File

@ -16,7 +16,6 @@ from oslo_versionedobjects import fields as obj_fields
import sqlalchemy as sa
from sqlalchemy import sql
from neutron.db import api as db_api
from neutron.db.quota import models
from neutron.objects import base
from neutron.objects import common_types
@ -60,7 +59,7 @@ class Reservation(base.NeutronDbObject):
def create(self):
deltas = self.resource_deltas
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(Reservation, self).create()
if deltas:
for delta in deltas:

View File

@ -25,7 +25,6 @@ from sqlalchemy import and_
from neutron._i18n import _
from neutron.common import exceptions as n_exc
from neutron.db import _utils as db_utils
from neutron.db import api as db_api
from neutron.db import rbac_db_mixin
from neutron.db import rbac_db_models as models
from neutron.extensions import rbac as ext_rbac
@ -37,7 +36,7 @@ from neutron.objects.db import api as obj_db_api
class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
base.NeutronDbObject):
rbac_db_model = None
rbac_db_cls = None
@classmethod
@abc.abstractmethod
@ -65,9 +64,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
return False
@staticmethod
def get_shared_with_tenant(context, rbac_db_model, obj_id, tenant_id):
def get_shared_with_tenant(context, rbac_db_cls, obj_id, tenant_id):
# NOTE(korzen) This method enables to query within already started
# session
rbac_db_model = rbac_db_cls.db_model
return (db_utils.model_query(context, rbac_db_model).filter(
and_(rbac_db_model.object_id == obj_id,
rbac_db_model.action == models.ACCESS_SHARED,
@ -77,9 +77,8 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
@classmethod
def is_shared_with_tenant(cls, context, obj_id, tenant_id):
ctx = context.elevated()
rbac_db_model = cls.rbac_db_model
with ctx.session.begin(subtransactions=True):
return cls.get_shared_with_tenant(ctx, rbac_db_model,
with cls.db_context_reader(ctx):
return cls.get_shared_with_tenant(ctx, cls.rbac_db_cls,
obj_id, tenant_id)
@classmethod
@ -91,23 +90,24 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
@classmethod
def _get_db_obj_rbac_entries(cls, context, rbac_obj_id, rbac_action):
rbac_db_model = cls.rbac_db_model
rbac_db_model = cls.rbac_db_cls.db_model
return db_utils.model_query(context, rbac_db_model).filter(
and_(rbac_db_model.object_id == rbac_obj_id,
rbac_db_model.action == rbac_action))
@classmethod
def _get_tenants_with_shared_access_to_db_obj(cls, context, obj_id):
rbac_db_model = cls.rbac_db_cls.db_model
return set(itertools.chain.from_iterable(context.session.query(
cls.rbac_db_model.target_tenant).filter(
and_(cls.rbac_db_model.object_id == obj_id,
cls.rbac_db_model.action == models.ACCESS_SHARED,
cls.rbac_db_model.target_tenant != '*'))))
rbac_db_model.target_tenant).filter(
and_(rbac_db_model.object_id == obj_id,
rbac_db_model.action == models.ACCESS_SHARED,
rbac_db_model.target_tenant != '*'))))
@classmethod
def _validate_rbac_policy_delete(cls, context, obj_id, target_tenant):
ctx_admin = context.elevated()
rb_model = cls.rbac_db_model
rb_model = cls.rbac_db_cls.db_model
bound_tenant_ids = cls.get_bound_tenant_ids(ctx_admin, obj_id)
db_obj_sharing_entries = cls._get_db_obj_rbac_entries(
ctx_admin, obj_id, models.ACCESS_SHARED)
@ -146,7 +146,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
return
target_tenant = policy['target_tenant']
db_obj = obj_db_api.get_object(
context.elevated(), cls.db_model, id=policy['object_id'])
cls, context.elevated(), id=policy['object_id'])
if db_obj.tenant_id == target_tenant:
return
cls._validate_rbac_policy_delete(context=context,
@ -181,10 +181,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
# NeutronDbPluginV2.validate_network_rbac_policy_change(), those pieces
# should be synced and contain the same bugs, until Network RBAC logic
# (hopefully) melded with this one.
if object_type != cls.rbac_db_model.object_type:
if object_type != cls.rbac_db_cls.db_model.object_type:
return
db_obj = obj_db_api.get_object(
context.elevated(), cls.db_model, id=policy['object_id'])
cls, context.elevated(), id=policy['object_id'])
if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE):
if (not context.is_admin and
db_obj['tenant_id'] != context.tenant_id):
@ -198,7 +198,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
object_type, policy, **kwargs)
def attach_rbac(self, obj_id, tenant_id, target_tenant='*'):
obj_type = self.rbac_db_model.object_type
obj_type = self.rbac_db_cls.db_model.object_type
rbac_policy = {'rbac_policy': {'object_id': obj_id,
'target_tenant': target_tenant,
'tenant_id': tenant_id,
@ -208,7 +208,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
def update_shared(self, is_shared_new, obj_id):
admin_context = self.obj_context.elevated()
shared_prev = obj_db_api.get_object(admin_context, self.rbac_db_model,
shared_prev = obj_db_api.get_object(self.rbac_db_cls, admin_context,
object_id=obj_id,
target_tenant='*',
action=models.ACCESS_SHARED)
@ -233,7 +233,7 @@ def _update_post(self, obj_changes):
def _update_hook(self, update_orig):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
# NOTE(slaweq): copy of object changes is required to pass it later to
# _update_post method because update() will reset all those changes
obj_changes = self.obj_get_changes()
@ -247,7 +247,7 @@ def _create_post(self):
def _create_hook(self, orig_create):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
orig_create(self)
_create_post(self)
@ -305,8 +305,8 @@ class RbacNeutronMetaclass(type):
def validate_existing_attrs(cls_name, dct):
if 'shared' not in dct['fields']:
raise KeyError(_('No shared key in %s fields') % cls_name)
if 'rbac_db_model' not in dct:
raise AttributeError(_('rbac_db_model not found in %s') % cls_name)
if 'rbac_db_cls' not in dct:
raise AttributeError(_('rbac_db_cls not found in %s') % cls_name)
@staticmethod
def get_replaced_method(orig_method, new_method):

View File

@ -13,7 +13,6 @@
from oslo_versionedobjects import fields as obj_fields
from neutron.common import utils
from neutron.db import api as db_api
from neutron.db.models import securitygroup as sg_models
from neutron.objects import base
from neutron.objects import common_types
@ -47,7 +46,7 @@ class SecurityGroup(base.NeutronDbObject):
def create(self):
# save is_default before super() resets it to False
is_default = self.is_default
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(SecurityGroup, self).create()
if is_default:
default_group = DefaultSecurityGroup(

View File

@ -0,0 +1,33 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_versionedobjects import fields as obj_fields
from neutron.db import standard_attr
from neutron.objects import base
from neutron.objects.extensions import standardattributes as stdattr_obj
# TODO(ihrachys): add unit tests for the object
@base.NeutronObjectRegistry.register
class StandardAttribute(base.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
db_model = standard_attr.StandardAttribute
fields = {
'id': obj_fields.IntegerField(),
'resource_type': obj_fields.StringField(),
}
fields.update(stdattr_obj.STANDARD_ATTRIBUTES)

View File

@ -17,9 +17,9 @@ from oslo_versionedobjects import fields as obj_fields
from neutron.common import utils
from neutron.db.models import subnet_service_type
from neutron.db import models_v2
from neutron.db import rbac_db_models
from neutron.objects import base
from neutron.objects import common_types
from neutron.objects import network
from neutron.objects import rbac_db
@ -228,7 +228,7 @@ class Subnet(base.NeutronDbObject):
# create), it should be rare case to load 'shared' by that method
shared = (rbac_db.RbacNeutronDbObjectMixin.
get_shared_with_tenant(self.obj_context.elevated(),
rbac_db_models.NetworkRBAC,
network.NetworkRBAC,
self.network_id,
self.project_id))
setattr(self, 'shared', shared)

View File

@ -16,7 +16,6 @@
import netaddr
from oslo_versionedobjects import fields as obj_fields
from neutron.db import api as db_api
from neutron.db import models_v2 as models
from neutron.objects import base
from neutron.objects import common_types
@ -70,7 +69,7 @@ class SubnetPool(base.NeutronDbObject):
# TODO(ihrachys): Consider extending base to trigger registered methods
def create(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
prefixes = self.prefixes
super(SubnetPool, self).create()
if 'prefixes' in fields:
@ -79,7 +78,7 @@ class SubnetPool(base.NeutronDbObject):
# TODO(ihrachys): Consider extending base to trigger registered methods
def update(self):
fields = self.obj_get_changes()
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
super(SubnetPool, self).update()
if 'prefixes' in fields:
self._attach_prefixes(fields['prefixes'])

View File

@ -19,7 +19,6 @@ from oslo_db import exception as o_db_exc
from oslo_utils import versionutils
from oslo_versionedobjects import fields as obj_fields
from neutron.db import api as db_api
from neutron.objects import base
from neutron.objects import common_types
from neutron.services.trunk import exceptions as t_exc
@ -52,7 +51,7 @@ class SubPort(base.NeutronDbObject):
return _dict
def create(self):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
try:
super(SubPort, self).create()
except o_db_exc.DBReferenceError as ex:
@ -104,7 +103,7 @@ class Trunk(base.NeutronDbObject):
synthetic_fields = ['sub_ports']
def create(self):
with db_api.autonested_transaction(self.obj_context.session):
with self.db_context_writer(self.obj_context):
sub_ports = []
if self.obj_attr_is_set('sub_ports'):
sub_ports = self.sub_ports

View File

@ -18,6 +18,7 @@ from oslo_config import cfg
from oslo_utils import uuidutils
from neutron.common import exceptions as n_exception
from neutron.conf.db import extraroute_db
from neutron.db import l3_fip_qos
from neutron.extensions import l3
from neutron.extensions import qos_fip
@ -211,9 +212,12 @@ class FloatingIPQoSDBIntTestCase(test_l3.L3BaseForIntTests,
plugin = ('neutron.tests.unit.extensions.test_qos_fip.'
'TestFloatingIPQoSIntPlugin')
service_plugins = {'qos': 'neutron.services.qos.qos_plugin.QoSPlugin'}
extraroute_db.register_db_extraroute_opts()
# for these tests we need to enable overlapping ips
cfg.CONF.set_default('allow_overlapping_ips', True)
cfg.CONF.set_default('max_routes', 3)
ext_mgr = FloatingIPQoSTestExtensionManager()
super(test_l3.L3BaseForIntTests, self).setUp(
plugin=plugin,
@ -236,9 +240,11 @@ class FloatingIPQoSDBSepTestCase(test_l3.L3BaseForSepTests,
service_plugins = {'l3_plugin_name': l3_plugin,
'qos': 'neutron.services.qos.qos_plugin.QoSPlugin'}
extraroute_db.register_db_extraroute_opts()
# for these tests we need to enable overlapping ips
cfg.CONF.set_default('allow_overlapping_ips', True)
cfg.CONF.set_default('max_routes', 3)
ext_mgr = FloatingIPQoSTestExtensionManager()
super(test_l3.L3BaseForSepTests, self).setUp(
plugin=plugin,

View File

@ -17,9 +17,9 @@ from neutron_lib import context
from neutron_lib import exceptions as n_exc
from neutron.db import _model_query as model_query
from neutron.db import models_v2
from neutron.objects import base
from neutron.objects.db import api
from neutron.objects import network
from neutron.objects import utils as obj_utils
from neutron.tests import base as test_base
from neutron.tests.unit import testlib_api
@ -28,6 +28,15 @@ from neutron.tests.unit import testlib_api
PLUGIN_NAME = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2'
class FakeModel(object):
def __init__(self, *args, **kwargs):
pass
class FakeObj(base.NeutronDbObject):
db_model = FakeModel
class GetObjectsTestCase(test_base.BaseTestCase):
def setUp(self):
@ -38,7 +47,6 @@ class GetObjectsTestCase(test_base.BaseTestCase):
def test_get_objects_pass_marker_obj_when_limit_and_marker_passed(self):
ctxt = context.get_admin_context()
model = mock.sentinel.model
marker = mock.sentinel.marker
limit = mock.sentinel.limit
pager = base.Pager(marker=marker, limit=limit)
@ -46,10 +54,10 @@ class GetObjectsTestCase(test_base.BaseTestCase):
with mock.patch.object(
model_query, 'get_collection') as get_collection:
with mock.patch.object(api, 'get_object') as get_object:
api.get_objects(ctxt, model, _pager=pager)
get_object.assert_called_with(ctxt, model, id=marker)
api.get_objects(FakeObj, ctxt, _pager=pager)
get_object.assert_called_with(FakeObj, ctxt, id=marker)
get_collection.assert_called_with(
ctxt, model, dict_func=None,
ctxt, FakeObj.db_model, dict_func=None,
filters={},
limit=limit,
marker_obj=get_object.return_value)
@ -58,15 +66,15 @@ class GetObjectsTestCase(test_base.BaseTestCase):
class CreateObjectTestCase(test_base.BaseTestCase):
def test_populate_id(self, populate_id=True):
ctxt = context.get_admin_context()
model_cls = mock.Mock()
values = {'x': 1, 'y': 2, 'z': 3}
with mock.patch.object(ctxt.__class__, 'session'):
api.create_object(ctxt, model_cls, values,
populate_id=populate_id)
with mock.patch.object(FakeObj, 'db_model') as db_model_mock:
with mock.patch.object(ctxt.__class__, 'session'):
api.create_object(FakeObj, ctxt, values,
populate_id=populate_id)
expected = copy.copy(values)
if populate_id:
expected['id'] = mock.ANY
model_cls.assert_called_with(**expected)
db_model_mock.assert_called_with(**expected)
def test_populate_id_False(self):
self.test_populate_id(populate_id=False)
@ -82,90 +90,93 @@ class CRUDScenarioTestCase(testlib_api.SqlTestCase):
# neutron.objects.db.api from core plugin instance
self.setup_coreplugin(self.CORE_PLUGIN)
# NOTE(ihrachys): nothing specific to networks in this test case, but
# we needed to pick some real model, so we picked the network. Any
# other model would work as well for our needs here.
self.model = models_v2.Network
# we needed to pick some real object, so we picked the network. Any
# other object would work as well for our needs here.
self.obj_cls = network.Network
self.ctxt = context.get_admin_context()
def test_get_object_with_None_value_in_filters(self):
obj = api.create_object(self.ctxt, self.model, {'name': 'foo'})
obj = api.create_object(self.obj_cls, self.ctxt, {'name': 'foo'})
new_obj = api.get_object(
self.ctxt, self.model, name='foo', status=None)
self.obj_cls, self.ctxt, name='foo', status=None)
self.assertEqual(obj, new_obj)
def test_get_objects_with_None_value_in_filters(self):
obj = api.create_object(self.ctxt, self.model, {'name': 'foo'})
obj = api.create_object(self.obj_cls, self.ctxt, {'name': 'foo'})
new_objs = api.get_objects(
self.ctxt, self.model, name='foo', status=None)
self.obj_cls, self.ctxt, name='foo', status=None)
self.assertEqual(obj, new_objs[0])
def test_get_objects_with_string_matching_filters_contains(self):
obj1 = api.create_object(self.ctxt, self.model, {'name': 'obj_con_1'})
obj2 = api.create_object(self.ctxt, self.model, {'name': 'obj_con_2'})
obj3 = api.create_object(self.ctxt, self.model, {'name': 'obj_3'})
obj1 = api.create_object(
self.obj_cls, self.ctxt, {'name': 'obj_con_1'})
obj2 = api.create_object(
self.obj_cls, self.ctxt, {'name': 'obj_con_2'})
obj3 = api.create_object(
self.obj_cls, self.ctxt, {'name': 'obj_3'})
objs = api.get_objects(
self.ctxt, self.model, name=obj_utils.StringContains('con'))
self.obj_cls, self.ctxt, name=obj_utils.StringContains('con'))
self.assertEqual(2, len(objs))
self.assertIn(obj1, objs)
self.assertIn(obj2, objs)
self.assertNotIn(obj3, objs)
def test_get_objects_with_string_matching_filters_starts(self):
obj1 = api.create_object(self.ctxt, self.model, {'name': 'pre_obj1'})
obj2 = api.create_object(self.ctxt, self.model, {'name': 'pre_obj2'})
obj3 = api.create_object(self.ctxt, self.model, {'name': 'obj_3'})
obj1 = api.create_object(self.obj_cls, self.ctxt, {'name': 'pre_obj1'})
obj2 = api.create_object(self.obj_cls, self.ctxt, {'name': 'pre_obj2'})
obj3 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj_3'})
objs = api.get_objects(
self.ctxt, self.model, name=obj_utils.StringStarts('pre'))
self.obj_cls, self.ctxt, name=obj_utils.StringStarts('pre'))
self.assertEqual(2, len(objs))
self.assertIn(obj1, objs)
self.assertIn(obj2, objs)
self.assertNotIn(obj3, objs)
def test_get_objects_with_string_matching_filters_ends(self):
obj1 = api.create_object(self.ctxt, self.model, {'name': 'obj1_end'})
obj2 = api.create_object(self.ctxt, self.model, {'name': 'obj2_end'})
obj3 = api.create_object(self.ctxt, self.model, {'name': 'obj_3'})
obj1 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj1_end'})
obj2 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj2_end'})
obj3 = api.create_object(self.obj_cls, self.ctxt, {'name': 'obj_3'})
objs = api.get_objects(
self.ctxt, self.model, name=obj_utils.StringEnds('end'))
self.obj_cls, self.ctxt, name=obj_utils.StringEnds('end'))
self.assertEqual(2, len(objs))
self.assertIn(obj1, objs)
self.assertIn(obj2, objs)
self.assertNotIn(obj3, objs)
def test_get_object_create_update_delete(self):
obj = api.create_object(self.ctxt, self.model, {'name': 'foo'})
obj = api.create_object(self.obj_cls, self.ctxt, {'name': 'foo'})
new_obj = api.get_object(self.ctxt, self.model, id=obj.id)
new_obj = api.get_object(self.obj_cls, self.ctxt, id=obj.id)
self.assertEqual(obj, new_obj)
obj = new_obj
api.update_object(self.ctxt, self.model, {'name': 'bar'}, id=obj.id)
api.update_object(self.obj_cls, self.ctxt, {'name': 'bar'}, id=obj.id)
new_obj = api.get_object(self.ctxt, self.model, id=obj.id)
new_obj = api.get_object(self.obj_cls, self.ctxt, id=obj.id)
self.assertEqual(obj, new_obj)
obj = new_obj
api.delete_object(self.ctxt, self.model, id=obj.id)
api.delete_object(self.obj_cls, self.ctxt, id=obj.id)
new_obj = api.get_object(self.ctxt, self.model, id=obj.id)
new_obj = api.get_object(self.obj_cls, self.ctxt, id=obj.id)
self.assertIsNone(new_obj)
# delete_object raises an exception on missing object
self.assertRaises(
n_exc.ObjectNotFound,
api.delete_object, self.ctxt, self.model, id=obj.id)
api.delete_object, self.obj_cls, self.ctxt, id=obj.id)
# but delete_objects does not not
api.delete_objects(self.ctxt, self.model, id=obj.id)
api.delete_objects(self.obj_cls, self.ctxt, id=obj.id)
def test_delete_objects_removes_all_matching_objects(self):
# create some objects with identical description
for i in range(10):
api.create_object(
self.ctxt, self.model,
self.obj_cls, self.ctxt,
{'name': 'foo%d' % i, 'description': 'bar'})
# create some more objects with a different description
descriptions = set()
@ -173,16 +184,16 @@ class CRUDScenarioTestCase(testlib_api.SqlTestCase):
desc = 'bar%d' % i
descriptions.add(desc)
api.create_object(
self.ctxt, self.model,
self.obj_cls, self.ctxt,
{'name': 'foo%d' % i, 'description': desc})
# make sure that all objects are in the database
self.assertEqual(20, api.count(self.ctxt, self.model))
self.assertEqual(20, api.count(self.obj_cls, self.ctxt))
# now delete just those with the 'bar' description
api.delete_objects(self.ctxt, self.model, description='bar')
api.delete_objects(self.obj_cls, self.ctxt, description='bar')
# check that half of objects are gone, and remaining have expected
# descriptions
objs = api.get_objects(self.ctxt, self.model)
objs = api.get_objects(self.obj_cls, self.ctxt)
self.assertEqual(10, len(objs))
self.assertEqual(
descriptions,

View File

@ -18,9 +18,10 @@ from oslo_versionedobjects import exception
import testtools
from neutron.common import exceptions as n_exc
from neutron.db import models_v2
from neutron.objects.db import api as db_api
from neutron.objects import network as net_obj
from neutron.objects import ports as port_obj
from neutron.objects.qos import binding
from neutron.objects.qos import policy
from neutron.objects.qos import rule
from neutron.tests.unit.objects import test_base
@ -34,6 +35,9 @@ RULE_OBJ_CLS = {
}
# TODO(ihrachys): add tests for QosPolicyRBAC
class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
_test_class = policy.QosPolicy
@ -57,8 +61,8 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
self.model_map.update({
self._test_class.db_model: self.db_objs,
self._test_class.port_binding_model: [],
self._test_class.network_binding_model: [],
binding.QosPolicyPortBinding.db_model: [],
binding.QosPolicyNetworkBinding.db_model: [],
rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules,
rule.QosDscpMarkingRule.db_model: self.db_qos_dscp_rules,
rule.QosMinimumBandwidthRule.db_model:
@ -73,7 +77,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
objs = self._test_class.get_objects(self.context)
context_mock.assert_called_once_with()
self.get_objects_mock.assert_any_call(
admin_context, self._test_class.db_model, _pager=None)
self._test_class, admin_context, _pager=None)
self.assertItemsEqual(
[test_base.get_obj_persistent_fields(obj) for obj in self.objs],
[test_base.get_obj_persistent_fields(obj) for obj in objs])
@ -95,7 +99,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
**self.valid_field_filter)
context_mock.assert_called_once_with()
get_objects_mock.assert_any_call(
admin_context, self._test_class.db_model, _pager=None,
self._test_class, admin_context, _pager=None,
**self.valid_field_filter)
self._check_equal(self.objs[0], objs[0])
@ -110,7 +114,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
self._check_equal(self.objs[0], obj)
context_mock.assert_called_once_with()
get_object_mock.assert_called_once_with(
admin_context, self._test_class.db_model, id='fake_id')
self._test_class, admin_context, id='fake_id')
def test_to_dict_makes_primitive_field_value(self):
# is_shared_with_tenant requires DB
@ -237,7 +241,7 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
def test_attach_and_get_multiple_policy_ports(self):
port1_id = self._port['id']
port2 = db_api.create_object(self.context, models_v2.Port,
port2 = db_api.create_object(port_obj.Port, self.context,
{'tenant_id': 'fake_tenant_id',
'name': 'test-port2',
'network_id': self._network_id,

View File

@ -32,7 +32,6 @@ from oslo_versionedobjects import fields as obj_fields
import testtools
from neutron.db import _model_query as model_query
from neutron.db import standard_attr
from neutron import objects
from neutron.objects import agent
from neutron.objects import base
@ -46,6 +45,7 @@ from neutron.objects.qos import policy as qos_policy
from neutron.objects import rbac_db
from neutron.objects import router
from neutron.objects import securitygroup
from neutron.objects import stdattrs
from neutron.objects import subnet
from neutron.objects import utils as obj_utils
from neutron.tests import base as test_base
@ -54,6 +54,7 @@ from neutron.tests.unit.db import test_db_base_plugin_v2
SQLALCHEMY_COMMIT = 'sqlalchemy.engine.Connection._commit_impl'
SQLALCHEMY_CLOSE = 'sqlalchemy.engine.Connection.close'
OBJECTS_BASE_OBJ_FROM_PRIMITIVE = ('oslo_versionedobjects.base.'
'VersionedObject.obj_from_primitive')
TIMESTAMP_FIELDS = ['created_at', 'updated_at', 'revision_number']
@ -663,8 +664,8 @@ class _BaseObjectTestCase(object):
def _is_test_class(cls, obj):
return isinstance(obj, cls._test_class)
def fake_get_objects(self, context, model, **kwargs):
return self.model_map[model]
def fake_get_objects(self, obj_cls, context, **kwargs):
return self.model_map[obj_cls.db_model]
def _get_object_synthetic_fields(self, objclass):
return [field for field in objclass.synthetic_fields
@ -705,13 +706,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
# NOTE(ihrachys): for matters of basic object behaviour validation,
# mock out rbac code accessing database. There are separate tests that
# cover RBAC, per object type.
if getattr(self._test_class, 'rbac_db_model', None):
mock.patch.object(
rbac_db.RbacNeutronDbObjectMixin,
'is_shared_with_tenant', return_value=False).start()
mock.patch.object(
rbac_db.RbacNeutronDbObjectMixin,
'get_shared_with_tenant').start()
if self._test_class.rbac_db_cls is not None:
if getattr(self._test_class.rbac_db_cls, 'db_model', None):
mock.patch.object(
rbac_db.RbacNeutronDbObjectMixin,
'is_shared_with_tenant', return_value=False).start()
mock.patch.object(
rbac_db.RbacNeutronDbObjectMixin,
'get_shared_with_tenant').start()
def fake_get_object(self, context, model, **kwargs):
objs = self.model_map[model]
@ -719,8 +721,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
return None
return [obj for obj in objs if obj['id'] == kwargs['id']][0]
def fake_get_objects(self, context, model, **kwargs):
return self.model_map[model]
def fake_get_objects(self, obj_cls, context, **kwargs):
return self.model_map[obj_cls.db_model]
# TODO(ihrachys) document the intent of all common test cases in docstrings
def test_get_object(self):
@ -734,7 +736,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self.assertTrue(self._is_test_class(obj))
self._check_equal(self.objs[0], obj)
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model,
self._test_class, self.context,
**self._test_class.modify_fields_to_db(obj_keys))
def test_get_object_missing_object(self):
@ -773,7 +775,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self.assertTrue(self._is_test_class(obj))
self._check_equal(self.objs[0], obj)
get_object_mock.assert_called_once_with(
mock.ANY, self._test_class.db_model,
self._test_class, mock.ANY,
**self._test_class.modify_fields_to_db(obj_keys))
def _get_synthetic_fields_get_objects_calls(self, db_objs):
@ -790,7 +792,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
}
mock_calls.append(
mock.call(
self.context, obj_class.db_model,
obj_class, self.context,
_pager=self.pager_map[obj_class.obj_name()],
**filter_kwargs))
return mock_calls
@ -805,7 +807,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
[get_obj_persistent_fields(obj) for obj in self.objs],
[get_obj_persistent_fields(obj) for obj in objs])
get_objects_mock.assert_any_call(
self.context, self._test_class.db_model,
self._test_class, self.context,
_pager=self.pager_map[self._test_class.obj_name()]
)
@ -952,7 +954,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
) as delete_objects_mock:
self.assertEqual(0, self._test_class.delete_objects(self.context))
delete_objects_mock.assert_any_call(
self.context, self._test_class.db_model)
self._test_class, self.context)
def test_delete_objects_valid_fields(self):
'''Test that a valid filter does not raise an error.'''
@ -1012,7 +1014,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj.create()
self._check_equal(self.objs[0], obj)
create_mock.assert_called_once_with(
self.context, self._test_class.db_model,
obj, self.context,
self._test_class.modify_fields_to_db(
get_obj_persistent_fields(self.objs[0])))
@ -1126,7 +1128,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
update_mock.return_value[key] = value
obj.update()
update_mock.assert_called_once_with(
self.context, self._test_class.db_model,
obj, self.context,
self._test_class.modify_fields_to_db(fields_to_update),
**fixed_keys)
@ -1172,7 +1174,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj.delete()
self._check_equal(self.objs[0], obj)
delete_mock.assert_called_once_with(
self.context, self._test_class.db_model,
obj, self.context,
**self._test_class.modify_fields_to_db(obj._get_composite_keys()))
@mock.patch(OBJECTS_BASE_OBJ_FROM_PRIMITIVE)
@ -1229,7 +1231,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
pager = base.Pager()
self._test_class.get_objects(self.context, _pager=pager)
get_objects.assert_called_once_with(
mock.ANY, self._test_class.db_model, _pager=pager)
self._test_class, mock.ANY, _pager=pager)
class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase):
@ -1545,9 +1547,8 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
'revision_number': tools.get_random_integer()
}
return obj_db_api.create_object(
self.context,
standard_attr.StandardAttribute, attrs,
populate_id=False)['id']
stdattrs.StandardAttribute,
self.context, attrs, populate_id=False)['id']
def _create_test_flavor_id(self):
attrs = self.get_random_object_fields(obj_cls=flavor.Flavor)
@ -1673,19 +1674,27 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
obj.delete()
self.assertEqual(1, mock_commit.call_count)
@mock.patch(SQLALCHEMY_COMMIT)
def test_get_objects_single_transaction(self, mock_commit):
self._test_class.get_objects(self.context)
self.assertEqual(1, mock_commit.call_count)
def _get_ro_txn_exit_func_name(self):
# for old engine facade, we didn't have distinction between r/o and r/w
# transactions and so we always call commit even for getters when the
# old facade is used
return (
SQLALCHEMY_CLOSE
if self._test_class.new_facade else SQLALCHEMY_COMMIT)
@mock.patch(SQLALCHEMY_COMMIT)
def test_get_object_single_transaction(self, mock_commit):
def test_get_objects_single_transaction(self):
with mock.patch(self._get_ro_txn_exit_func_name()) as mock_exit:
self._test_class.get_objects(self.context)
self.assertEqual(1, mock_exit.call_count)
def test_get_object_single_transaction(self):
obj = self._make_object(self.obj_fields[0])
obj.create()
obj = self._test_class.get_object(self.context,
**obj._get_composite_keys())
self.assertEqual(2, mock_commit.call_count)
with mock.patch(self._get_ro_txn_exit_func_name()) as mock_exit:
obj = self._test_class.get_object(self.context,
**obj._get_composite_keys())
self.assertEqual(1, mock_exit.call_count)
def test_get_objects_supports_extra_filtername(self):
self.filtered_args = None

View File

@ -20,6 +20,8 @@ from neutron.tests.unit.objects import test_base as obj_test_base
from neutron.tests.unit import testlib_api
# TODO(ihrachys): add tests for NetworkRBAC
class NetworkDhcpAgentBindingObjectIfaceTestCase(
obj_test_base.BaseObjectIfaceTestCase):

View File

@ -60,6 +60,7 @@ object_data = {
'NetworkDhcpAgentBinding': '1.0-6eeceb5fb4335cd65a305016deb41c68',
'NetworkDNSDomain': '1.0-420db7910294608534c1e2e30d6d8319',
'NetworkPortSecurity': '1.0-b30802391a87945ee9c07582b4ff95e3',
'NetworkRBAC': '1.0-c8a67f39809c5a3c8c7f26f2f2c620b2',
'NetworkSegment': '1.0-57b7f2960971e3b95ded20cbc59244a8',
'Port': '1.1-5bf48d12a7bf7f5b7a319e8003b437a5',
'PortBinding': '1.0-3306deeaa6deb01e33af06777d48d578',
@ -72,6 +73,7 @@ object_data = {
'QosBandwidthLimitRule': '1.3-51b662b12a8d1dfa89288d826c6d26d3',
'QosDscpMarkingRule': '1.3-0313c6554b34fd10c753cb63d638256c',
'QosMinimumBandwidthRule': '1.3-314c3419f4799067cc31cc319080adff',
'QosPolicyRBAC': '1.0-c8a67f39809c5a3c8c7f26f2f2c620b2',
'QosRuleType': '1.3-7286188edeb3a0386f9cf7979b9700fc',
'QosRuleTypeDriver': '1.0-7d8cb9f0ef661ac03700eae97118e3db',
'QosPolicy': '1.7-4adb0cde3102c10d8970ec9487fd7fe7',
@ -90,9 +92,11 @@ object_data = {
'RouterPort': '1.0-c8c8f499bcdd59186fcd83f323106908',
'RouterRoute': '1.0-07fc5337c801fb8c6ccfbcc5afb45907',
'SecurityGroup': '1.0-e26b90c409b31fd2e3c6fcec402ac0b9',
'SecurityGroupPortBinding': '1.0-6879d5c0af80396ef5a72934b6a6ef20',
'SecurityGroupRule': '1.0-e9b8dace9d48b936c62ad40fe1f339d5',
'SegmentHostMapping': '1.0-521597cf82ead26217c3bd10738f00f0',
'ServiceProfile': '1.0-9beafc9e7d081b8258f3c5cb66ac5eed',
'StandardAttribute': '1.0-617d4f46524c4ce734a6fc1cc0ac6a0b',
'Subnet': '1.0-927155c1fdd5a615cbcb981dda97bce4',
'SubnetPool': '1.0-a0e03895d1a6e7b9d4ab7b0ca13c3867',
'SubnetPoolPrefix': '1.0-13c15144135eb869faa4a76dc3ee3b6c',

View File

@ -24,6 +24,16 @@ from neutron.tests.unit.objects import test_base as obj_test_base
from neutron.tests.unit import testlib_api
class SecurityGroupPortBindingIfaceObjTestCase(
obj_test_base.BaseObjectIfaceTestCase):
_test_class = ports.SecurityGroupPortBinding
class SecurityGroupPortBindingDbObjectTestCase(
obj_test_base.BaseDbObjectTestCase):
_test_class = ports.SecurityGroupPortBinding
class BasePortBindingDbObjectTestCase(obj_test_base._BaseObjectTestCase,
testlib_api.SqlTestCase):
def setUp(self):

View File

@ -41,12 +41,25 @@ class FakeRbacModel(rbac_db_models.RBACColumns, model_base.BASEV2):
return (rbac_db_models.ACCESS_SHARED,)
@base.NeutronObjectRegistry.register_if(False)
class FakeNeutronRbacObject(base.NeutronDbObject):
VERSION = '1.0'
db_model = FakeRbacModel
fields = {
'object_id': obj_fields.StringField(),
'target_tenant': obj_fields.StringField(),
'action': obj_fields.StringField(),
}
@base.NeutronObjectRegistry.register_if(False)
class FakeNeutronDbObject(rbac_db.NeutronRbacObject):
# Version 1.0: Initial version
VERSION = '1.0'
rbac_db_model = FakeRbacModel
rbac_db_cls = FakeNeutronRbacObject
db_model = FakeDbModel
fields = {
@ -72,7 +85,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
super(RbacNeutronDbObjectTestCase, self).setUp()
FakeNeutronDbObject.update_post = mock.Mock()
@mock.patch.object(_test_class, 'rbac_db_model')
@mock.patch.object(_test_class.rbac_db_cls, 'db_model')
def test_get_tenants_with_shared_access_to_db_obj_return_tenant_ids(
self, *mocks):
ctx = mock.Mock()
@ -138,7 +151,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
context = mock.Mock(is_admin=True, tenant_id='db_obj_owner_id')
self._rbac_policy_generate_change_events(
resource=None, trigger='dummy_trigger', context=context,
object_type=self._test_class.rbac_db_model.object_type,
object_type=self._test_class.rbac_db_cls.db_model.object_type,
policy={'object_id': 'fake_object_id'},
event_list=(events.BEFORE_CREATE, events.BEFORE_UPDATE))
@ -154,7 +167,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
n_exc.InvalidInput,
self._rbac_policy_generate_change_events,
resource=mock.Mock(), trigger='dummy_trigger', context=context,
object_type=self._test_class.rbac_db_model.object_type,
object_type=self._test_class.rbac_db_cls.db_model.object_type,
policy={'object_id': 'fake_object_id'},
event_list=(events.BEFORE_CREATE, events.BEFORE_UPDATE))
self.assertFalse(mock_validate_update.called)
@ -165,7 +178,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
self._test_class.validate_rbac_policy_delete(
resource=mock.Mock(), event=events.BEFORE_DELETE,
trigger='dummy_trigger', context=n_context.get_admin_context(),
object_type=self._test_class.rbac_db_model.object_type,
object_type=self._test_class.rbac_db_cls.db_model.object_type,
policy=policy)
mock_validate_delete.assert_not_called()
@ -205,7 +218,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
event=events.BEFORE_DELETE,
trigger='dummy_trigger',
context=context,
object_type=self._test_class.rbac_db_model.object_type,
object_type=self._test_class.rbac_db_cls.db_model.object_type,
policy=policy)
def test_validate_rbac_policy_delete_not_bound_tenant_success(self):
@ -247,7 +260,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
event=events.BEFORE_DELETE,
trigger='dummy_trigger',
context=context,
object_type=self._test_class.rbac_db_model.object_type,
object_type=self._test_class.rbac_db_cls.db_model.object_type,
policy=policy)
@mock.patch.object(_test_class, 'attach_rbac')
@ -257,10 +270,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
def test_update_shared_avoid_duplicate_update(
self, mock_validate_delete, get_object_mock, attach_rbac_mock):
obj_id = 'fake_obj_id'
self._test_class(mock.Mock()).update_shared(is_shared_new=True,
obj_id=obj_id)
obj = self._test_class(mock.Mock())
obj.update_shared(is_shared_new=True, obj_id=obj_id)
get_object_mock.assert_called_with(
mock.ANY, self._test_class.rbac_db_model, object_id=obj_id,
obj.rbac_db_cls, mock.ANY, object_id=obj_id,
target_tenant='*', action=rbac_db_models.ACCESS_SHARED)
self.assertFalse(mock_validate_delete.called)
self.assertFalse(attach_rbac_mock.called)
@ -275,7 +288,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
test_neutron_obj = self._test_class(mock.Mock())
test_neutron_obj.update_shared(is_shared_new=True, obj_id=obj_id)
get_object_mock.assert_called_with(
mock.ANY, self._test_class.rbac_db_model, object_id=obj_id,
test_neutron_obj.rbac_db_cls, mock.ANY, object_id=obj_id,
target_tenant='*', action=rbac_db_models.ACCESS_SHARED)
attach_rbac_mock.assert_called_with(
@ -292,10 +305,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
def test_update_shared_remove_wildcard_sharing(
self, mock_validate_delete, get_object_mock, attach_rbac_mock):
obj_id = 'fake_obj_id'
self._test_class(mock.Mock()).update_shared(is_shared_new=False,
obj_id=obj_id)
obj = self._test_class(mock.Mock())
obj.update_shared(is_shared_new=False, obj_id=obj_id)
get_object_mock.assert_called_with(
mock.ANY, self._test_class.rbac_db_model, object_id=obj_id,
obj.rbac_db_cls, mock.ANY, object_id=obj_id,
target_tenant='*', action=rbac_db_models.ACCESS_SHARED)
self.assertFalse(attach_rbac_mock.attach_rbac.called)
@ -313,4 +326,4 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
self.assertEqual(rbac_pol['target_tenant'], target_tenant)
self.assertEqual(rbac_pol['action'], rbac_db_models.ACCESS_SHARED)
self.assertEqual(rbac_pol['object_type'],
self._test_class.rbac_db_model.object_type)
self._test_class.rbac_db_cls.db_model.object_type)

View File

@ -17,6 +17,7 @@ from oslo_utils import uuidutils
from neutron.db import rbac_db_models
from neutron.objects import base as obj_base
from neutron.objects.db import api as obj_db_api
from neutron.objects import network as net_obj
from neutron.objects import rbac_db
from neutron.objects import subnet
from neutron.tests.unit.objects import test_base as obj_test_base
@ -175,8 +176,7 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
'target_tenant': '*',
'action': rbac_db_models.ACCESS_SHARED
}
obj_db_api.create_object(self.context, rbac_db_models.NetworkRBAC,
attrs)
obj_db_api.create_object(net_obj.NetworkRBAC, self.context, attrs)
def test_get_subnet_shared_true(self):
network = self._create_test_network()