Objects DB api: added composite key to handle multiple primary key

Moving CRUD DB operation for objects from db/api.py to objects/db/api.py
Renaming object get_by_id(id) to get_object(**kwargs)

Many models in Neutron DB have complex primary keys, concatenated from
a few properties. This patch adds ability to define multiple primary keys
in NeutronDbObject, which are automatically evaluated into DB query
when performing operations.

Partial-Bug: #1541928
Change-Id: I0f63a62418db76415ddd40c30c778ff7541b93dc
This commit is contained in:
Artur Korzeniewski 2016-03-01 12:07:15 +01:00 committed by rossella
parent 6386a0fa12
commit 4b227c3771
15 changed files with 297 additions and 123 deletions

View File

@ -144,7 +144,7 @@ effort.
Every NeutronObject supports the following operations:
* get_by_id: returns specific object that is represented by the id passed as an
* get_object: returns specific object that is represented by the id passed as an
argument.
* get_objects: returns all objects of the type, potentially with a filter
applied.

View File

@ -32,7 +32,7 @@ class QosCoreResourceExtension(base.CoreResourceExtension):
return self._plugin_loaded
def _get_policy_obj(self, context, policy_id):
obj = policy_object.QosPolicy.get_by_id(context, policy_id)
obj = policy_object.QosPolicy.get_object(context, id=policy_id)
if obj is None:
raise n_exc.QosPolicyNotFound(policy_id=policy_id)
return obj

View File

@ -15,6 +15,7 @@
import contextlib
import debtcollector
from oslo_config import cfg
from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc
@ -89,6 +90,7 @@ def autonested_transaction(sess):
# Common database operation implementations
@debtcollector.removals.remove(message="This will be removed in the N cycle.")
def get_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
@ -96,6 +98,7 @@ def get_object(context, model, **kwargs):
.first())
@debtcollector.removals.remove(message="This will be removed in the N cycle.")
def get_objects(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
@ -103,6 +106,7 @@ def get_objects(context, model, **kwargs):
.all())
@debtcollector.removals.remove(message="This will be removed in the N cycle.")
def create_object(context, model, values):
with context.session.begin(subtransactions=True):
if 'id' not in values and hasattr(model, 'id'):
@ -112,6 +116,7 @@ def create_object(context, model, values):
return db_obj.__dict__
@debtcollector.removals.remove(message="This will be removed in the N cycle.")
def _safe_get_object(context, model, id, key='id'):
db_obj = get_object(context, model, **{key: id})
if db_obj is None:
@ -119,6 +124,7 @@ def _safe_get_object(context, model, id, key='id'):
return db_obj
@debtcollector.removals.remove(message="This will be removed in the N cycle.")
def update_object(context, model, id, values, key=None):
with context.session.begin(subtransactions=True):
kwargs = {}
@ -131,6 +137,7 @@ def update_object(context, model, id, values, key=None):
return db_obj.__dict__
@debtcollector.removals.remove(message="This will be removed in the N cycle.")
def delete_object(context, model, id, key=None):
with context.session.begin(subtransactions=True):
kwargs = {}

View File

@ -12,14 +12,14 @@
import abc
from neutron_lib import exceptions
from oslo_db import exception as obj_exc
from oslo_utils import reflection
from oslo_versionedobjects import base as obj_base
import six
from neutron._i18n import _
from neutron.common import exceptions
from neutron.db import api as db_api
from neutron.objects.db import api as obj_db_api
class NeutronObjectUpdateForbidden(exceptions.NeutronException):
@ -38,6 +38,18 @@ class NeutronDbObjectDuplicateEntry(exceptions.Conflict):
values=db_exception.value)
class NeutronPrimaryKeyMissing(exceptions.BadRequest):
message = _("For class %(object_type)s missing primary keys: "
"%(missing_keys)s")
def __init__(self, object_class, missing_keys):
super(NeutronPrimaryKeyMissing, self).__init__(
object_type=reflection.get_class_name(object_class,
fully_qualified=False),
missing_keys=missing_keys
)
def get_updatable_fields(cls, fields):
fields = fields.copy()
for field in cls.fields_no_update:
@ -67,7 +79,7 @@ class NeutronObject(obj_base.VersionedObject,
return obj
@classmethod
def get_by_id(cls, context, id):
def get_object(cls, context, **kwargs):
raise NotImplementedError()
@classmethod
@ -99,7 +111,7 @@ class NeutronDbObject(NeutronObject):
# should be overridden for all persistent objects
db_model = None
primary_key = 'id'
primary_keys = ['id']
fields_no_update = []
@ -112,9 +124,21 @@ class NeutronDbObject(NeutronObject):
self.obj_reset_changes()
@classmethod
def get_by_id(cls, context, id):
db_obj = db_api.get_object(context, cls.db_model,
**{cls.primary_key: id})
def get_object(cls, context, **kwargs):
"""
This method fetches object from DB and convert it to versioned
object.
:param context:
:param kwargs: multiple primary keys defined key=value pairs
:return: single object of NeutronDbObject class
"""
missing_keys = set(cls.primary_keys).difference(kwargs.keys())
if missing_keys:
raise NeutronPrimaryKeyMissing(object_class=cls.__class__,
missing_keys=missing_keys)
db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs)
if db_obj:
obj = cls(context, **db_obj)
obj.obj_reset_changes()
@ -123,7 +147,7 @@ class NeutronDbObject(NeutronObject):
@classmethod
def get_objects(cls, context, **kwargs):
cls.validate_filters(**kwargs)
db_objs = db_api.get_objects(context, cls.db_model, **kwargs)
db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs)
objs = [cls(context, **db_obj) for db_obj in db_objs]
for obj in objs:
obj.obj_reset_changes()
@ -156,24 +180,30 @@ class NeutronDbObject(NeutronObject):
def create(self):
fields = self._get_changed_persistent_fields()
try:
db_obj = db_api.create_object(self._context, self.db_model, fields)
db_obj = obj_db_api.create_object(self._context, self.db_model,
fields)
except obj_exc.DBDuplicateEntry as db_exc:
raise NeutronDbObjectDuplicateEntry(object_class=self.__class__,
db_exception=db_exc)
self.from_db_object(db_obj)
def _get_composite_keys(self):
keys = {}
for key in self.primary_keys:
keys[key] = getattr(self, key)
return keys
def update(self):
updates = self._get_changed_persistent_fields()
updates = self._validate_changed_fields(updates)
if updates:
db_obj = db_api.update_object(self._context, self.db_model,
getattr(self, self.primary_key),
updates, key=self.primary_key)
db_obj = obj_db_api.update_object(self._context, self.db_model,
updates,
**self._get_composite_keys())
self.from_db_object(self, db_obj)
def delete(self):
db_api.delete_object(self._context, self.db_model,
getattr(self, self.primary_key),
key=self.primary_key)
obj_db_api.delete_object(self._context, self.db_model,
**self._get_composite_keys())

View File

64
neutron/objects/db/api.py Normal file
View File

@ -0,0 +1,64 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib import exceptions as n_exc
from oslo_utils import uuidutils
from neutron.db import common_db_mixin
# Common database operation implementations
def get_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
.filter_by(**kwargs)
.first())
def get_objects(context, model, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
.filter_by(**kwargs)
.all())
def create_object(context, model, values):
with context.session.begin(subtransactions=True):
if 'id' not in values and hasattr(model, 'id'):
values['id'] = uuidutils.generate_uuid()
db_obj = model(**values)
context.session.add(db_obj)
return db_obj.__dict__
def _safe_get_object(context, model, **kwargs):
db_obj = get_object(context, model, **kwargs)
if db_obj is None:
key = "".join(['%s:: %s ' % (key, value) for (key, value)
in kwargs.items()])
raise n_exc.ObjectNotFound(id=key)
return db_obj
def update_object(context, model, values, **kwargs):
with context.session.begin(subtransactions=True):
db_obj = _safe_get_object(context, model, **kwargs)
db_obj.update(values)
db_obj.save(session=context.session)
return db_obj.__dict__
def delete_object(context, model, **kwargs):
with context.session.begin(subtransactions=True):
db_obj = _safe_get_object(context, model, **kwargs)
context.session.delete(db_obj)

View File

@ -27,6 +27,7 @@ from neutron.db.qos import api as qos_db_api
from neutron.db.qos import models as qos_db_model
from neutron.db.rbac_db_models import QosPolicyRBAC
from neutron.objects import base
from neutron.objects.db import api as obj_db_api
from neutron.objects.qos import rule as rule_obj_impl
from neutron.objects import rbac_db
@ -93,12 +94,13 @@ class QosPolicy(base.NeutronDbObject):
rule_id=rule_id)
@classmethod
def get_by_id(cls, context, id):
def get_object(cls, context, **kwargs):
# We want to get the policy regardless of its tenant id. We'll make
# sure the tenant has permission to access the policy later on.
admin_context = context.elevated()
with db_api.autonested_transaction(admin_context.session):
policy_obj = super(QosPolicy, cls).get_by_id(admin_context, id)
policy_obj = super(QosPolicy, cls).get_object(admin_context,
**kwargs)
if (not policy_obj or
not cls.is_accessible(context, policy_obj)):
return
@ -125,9 +127,9 @@ class QosPolicy(base.NeutronDbObject):
@classmethod
def _get_object_policy(cls, context, model, **kwargs):
with db_api.autonested_transaction(context.session):
binding_db_obj = db_api.get_object(context, model, **kwargs)
binding_db_obj = obj_db_api.get_object(context, model, **kwargs)
if binding_db_obj:
return cls.get_by_id(context, binding_db_obj['policy_id'])
return cls.get_object(context, id=binding_db_obj['policy_id'])
@classmethod
def get_network_policy(cls, context, network_id):
@ -148,8 +150,8 @@ class QosPolicy(base.NeutronDbObject):
def delete(self):
with db_api.autonested_transaction(self._context.session):
for object_type, model in self.binding_models.items():
binding_db_obj = db_api.get_object(self._context, model,
policy_id=self.id)
binding_db_obj = obj_db_api.get_object(self._context, model,
policy_id=self.id)
if binding_db_obj:
raise exceptions.QosPolicyInUse(
policy_id=self.id,

View File

@ -28,6 +28,7 @@ from neutron.db import rbac_db_mixin
from neutron.db import rbac_db_models as models
from neutron.extensions import rbac as ext_rbac
from neutron.objects import base
from neutron.objects.db import api as obj_db_api
@add_metaclass(abc.ABCMeta)
@ -123,7 +124,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
if policy['action'] != models.ACCESS_SHARED:
return
target_tenant = policy['target_tenant']
db_obj = cls.get_by_id(context, policy['object_id'])
db_obj = cls.get_object(context, id=policy['object_id'])
if db_obj.tenant_id == target_tenant:
return
cls._validate_rbac_policy_delete(context=context,
@ -160,7 +161,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
# (hopefully) melded with this one.
if object_type != cls.rbac_db_model.object_type:
return
db_obj = cls.get_by_id(context.elevated(), policy['object_id'])
db_obj = cls.get_object(context.elevated(), id=policy['object_id'])
if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE):
if (not context.is_admin and
db_obj['tenant_id'] != context.tenant_id):
@ -184,7 +185,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
def update_shared(self, is_shared_new, obj_id):
admin_context = self._context.elevated()
shared_prev = db_api.get_object(admin_context, self.rbac_db_model,
shared_prev = obj_db_api.get_object(admin_context, self.rbac_db_model,
object_id=obj_id, target_tenant='*',
action=models.ACCESS_SHARED)
is_shared_prev = bool(shared_prev)

View File

@ -33,7 +33,7 @@ def _get_qos_policy_cb(resource, policy_id, **kwargs):
)
return
policy = policy_object.QosPolicy.get_by_id(context, policy_id)
policy = policy_object.QosPolicy.get_object(context, id=policy_id)
return policy

View File

@ -61,7 +61,7 @@ class QoSPlugin(qos.QoSPluginBase):
policy.delete()
def _get_policy_obj(self, context, policy_id):
obj = policy_object.QosPolicy.get_by_id(context, policy_id)
obj = policy_object.QosPolicy.get_object(context, id=policy_id)
if obj is None:
raise n_exc.QosPolicyNotFound(policy_id=policy_id)
return obj
@ -132,8 +132,8 @@ class QoSPlugin(qos.QoSPluginBase):
with db_api.autonested_transaction(context.session):
# first, validate that we have access to the policy
self._get_policy_obj(context, policy_id)
rule = rule_object.QosBandwidthLimitRule.get_by_id(
context, rule_id)
rule = rule_object.QosBandwidthLimitRule.get_object(
context, id=rule_id)
if not rule:
raise n_exc.QosRuleNotFound(policy_id=policy_id, rule_id=rule_id)
return rule

View File

@ -62,7 +62,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
actual_port = {'id': mock.Mock(),
qos_consts.QOS_POLICY_ID: qos_policy_id}
qos_policy = mock.MagicMock()
self.policy_m.get_by_id = mock.Mock(return_value=qos_policy)
self.policy_m.get_object = mock.Mock(return_value=qos_policy)
self.core_extension.process_fields(
self.context, base_core.PORT,
{qos_consts.QOS_POLICY_ID: qos_policy_id},
@ -81,7 +81,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_port_policy = mock.Mock(
return_value=old_qos_policy)
new_qos_policy = mock.MagicMock()
self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy)
self.policy_m.get_object = mock.Mock(return_value=new_qos_policy)
self.core_extension.process_fields(
self.context, base_core.PORT,
{qos_consts.QOS_POLICY_ID: qos_policy2_id},
@ -101,7 +101,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_port_policy = mock.Mock(
return_value=old_qos_policy)
new_qos_policy = mock.MagicMock()
self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy)
self.policy_m.get_object = mock.Mock(return_value=new_qos_policy)
self.core_extension.process_fields(
self.context, base_core.PORT,
{qos_consts.QOS_POLICY_ID: None},
@ -120,7 +120,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_network_policy = mock.Mock(
return_value=old_qos_policy)
new_qos_policy = mock.MagicMock()
self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy)
self.policy_m.get_object = mock.Mock(return_value=new_qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK,
{qos_consts.QOS_POLICY_ID: None},
@ -135,7 +135,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
actual_network = {'id': mock.Mock(),
qos_consts.QOS_POLICY_ID: qos_policy_id}
qos_policy = mock.MagicMock()
self.policy_m.get_by_id = mock.Mock(return_value=qos_policy)
self.policy_m.get_object = mock.Mock(return_value=qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK,
{qos_consts.QOS_POLICY_ID: qos_policy_id}, actual_network)
@ -153,7 +153,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase):
self.policy_m.get_network_policy = mock.Mock(
return_value=old_qos_policy)
new_qos_policy = mock.MagicMock()
self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy)
self.policy_m.get_object = mock.Mock(return_value=new_qos_policy)
self.core_extension.process_fields(
self.context, base_core.NETWORK,
{qos_consts.QOS_POLICY_ID: qos_policy_id}, actual_network)

View File

@ -13,7 +13,7 @@
import mock
from neutron.common import exceptions as n_exc
from neutron.db import api as db_api
from neutron.objects.db import api as db_api
from neutron.objects.qos import policy
from neutron.objects.qos import rule
from neutron.tests.unit.objects import test_base
@ -84,14 +84,14 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
**self.valid_field_filter)
self._validate_objects([self.db_obj], objs)
def test_get_by_id(self):
def test_get_object(self):
admin_context = self.context.elevated()
with mock.patch.object(db_api, 'get_object',
return_value=self.db_obj) as get_object_mock:
with mock.patch.object(self.context,
'elevated',
return_value=admin_context) as context_mock:
obj = self._test_class.get_by_id(self.context, id='fake_id')
obj = self._test_class.get_object(self.context, id='fake_id')
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
context_mock.assert_called_once_with()
@ -221,19 +221,22 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
def test_synthetic_rule_fields(self):
policy_obj, rule_obj = self._create_test_policy_with_rule()
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
policy_obj = policy.QosPolicy.get_object(self.context,
id=policy_obj.id)
self.assertEqual([rule_obj], policy_obj.rules)
def test_get_by_id_fetches_rules_non_lazily(self):
def test_get_object_fetches_rules_non_lazily(self):
policy_obj, rule_obj = self._create_test_policy_with_rule()
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
policy_obj = policy.QosPolicy.get_object(self.context,
id=policy_obj.id)
primitive = policy_obj.obj_to_primitive()
self.assertNotEqual([], (primitive['versioned_object.data']['rules']))
def test_to_dict_returns_rules_as_dicts(self):
policy_obj, rule_obj = self._create_test_policy_with_rule()
policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id)
policy_obj = policy.QosPolicy.get_object(self.context,
id=policy_obj.id)
obj_dict = policy_obj.to_dict()
rule_dict = rule_obj.to_dict()

View File

@ -21,9 +21,9 @@ from oslo_versionedobjects import fields as obj_fields
from neutron.common import exceptions as n_exc
from neutron.common import utils as common_utils
from neutron import context
from neutron.db import api as db_api
from neutron.db import models_v2
from neutron.objects import base
from neutron.objects.db import api as obj_db_api
from neutron.tests import base as test_base
from neutron.tests import tools
@ -63,7 +63,7 @@ class FakeNeutronObjectNonStandardPrimaryKey(base.NeutronDbObject):
db_model = FakeModel
primary_key = 'weird_key'
primary_keys = ['weird_key']
fields = {
'weird_key': obj_fields.UUIDField(),
@ -74,6 +74,44 @@ class FakeNeutronObjectNonStandardPrimaryKey(base.NeutronDbObject):
synthetic_fields = ['field2']
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObjectCompositePrimaryKey(base.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
db_model = FakeModel
primary_keys = ['weird_key', 'field1']
fields = {
'weird_key': obj_fields.UUIDField(),
'field1': obj_fields.StringField(),
'field2': obj_fields.StringField()
}
synthetic_fields = ['field2']
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
db_model = FakeModel
primary_keys = ['id', 'field1']
fields = {
'id': obj_fields.UUIDField(),
'field1': obj_fields.StringField(),
'field2': obj_fields.StringField()
}
synthetic_fields = ['field2']
fields_no_update = ['id']
FIELD_TYPE_VALUE_GENERATOR_MAP = {
obj_fields.BooleanField: tools.get_random_boolean,
obj_fields.IntegerField: tools.get_random_integer,
@ -112,6 +150,15 @@ class _BaseObjectTestCase(object):
fields[field] = generator()
return fields
@classmethod
def generate_object_keys(cls, obj_cls):
keys = {}
for field, field_obj in obj_cls.fields.items():
if field in obj_cls.primary_keys:
generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]
keys[field] = generator()
return keys
def get_updatable_fields(self, fields):
return base.get_updatable_fields(self._test_class, fields)
@ -122,23 +169,31 @@ class _BaseObjectTestCase(object):
class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def test_get_by_id(self):
with mock.patch.object(db_api, 'get_object',
def test_get_object(self):
with mock.patch.object(obj_db_api, 'get_object',
return_value=self.db_obj) as get_object_mock:
obj = self._test_class.get_by_id(self.context, id='fake_id')
obj_keys = self.generate_object_keys(self._test_class)
obj = self._test_class.get_object(self.context, **obj_keys)
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, get_obj_db_fields(obj))
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model,
**{self._test_class.primary_key: 'fake_id'})
self.context, self._test_class.db_model, **obj_keys)
def test_get_by_id_missing_object(self):
with mock.patch.object(db_api, 'get_object', return_value=None):
obj = self._test_class.get_by_id(self.context, id='fake_id')
def test_get_object_missing_object(self):
with mock.patch.object(obj_db_api, 'get_object', return_value=None):
obj_keys = self.generate_object_keys(self._test_class)
obj = self._test_class.get_object(self.context, **obj_keys)
self.assertIsNone(obj)
def test_get_object_missing_primary_key(self):
obj_keys = self.generate_object_keys(self._test_class)
obj_keys.popitem()
self.assertRaises(base.NeutronPrimaryKeyMissing,
self._test_class.get_object,
self.context, **obj_keys)
def test_get_objects(self):
with mock.patch.object(db_api, 'get_objects',
with mock.patch.object(obj_db_api, 'get_objects',
return_value=self.db_objs) as get_objects_mock:
objs = self._test_class.get_objects(self.context)
self._validate_objects(self.db_objs, objs)
@ -147,7 +202,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def test_get_objects_valid_fields(self):
with mock.patch.object(
db_api, 'get_objects',
obj_db_api, 'get_objects',
return_value=[self.db_obj]) as get_objects_mock:
objs = self._test_class.get_objects(self.context,
@ -167,7 +222,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
filters = copy.copy(self.valid_field_filter)
filters[synthetic_fields[0]] = 'xxx'
with mock.patch.object(db_api, 'get_objects',
with mock.patch.object(obj_db_api, 'get_objects',
return_value=self.db_objs):
self.assertRaises(base.exceptions.InvalidInput,
self._test_class.get_objects, self.context,
@ -179,14 +234,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self.skipTest('No synthetic fields found in test class %r' %
self._test_class)
with mock.patch.object(db_api, 'get_objects',
with mock.patch.object(obj_db_api, 'get_objects',
return_value=self.db_objs):
self.assertRaises(base.exceptions.InvalidInput,
self._test_class.get_objects, self.context,
**{synthetic_fields[0]: 'xxx'})
def test_get_objects_invalid_fields(self):
with mock.patch.object(db_api, 'get_objects',
with mock.patch.object(obj_db_api, 'get_objects',
return_value=self.db_objs):
self.assertRaises(base.exceptions.InvalidInput,
self._test_class.get_objects, self.context,
@ -206,7 +261,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
sorted(get_obj_db_fields(obj)))
def test_create(self):
with mock.patch.object(db_api, 'create_object',
with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj) as create_mock:
obj = self._test_class(self.context, **self.db_obj)
self._check_equal(obj, self.db_obj)
@ -216,7 +271,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self.context, self._test_class.db_model, self.db_obj)
def test_create_updates_from_db_object(self):
with mock.patch.object(db_api, 'create_object',
with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj):
obj = self._test_class(self.context, **self.db_objs[1])
self._check_equal(obj, self.db_objs[1])
@ -224,21 +279,22 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self._check_equal(obj, self.db_obj)
def test_create_duplicates(self):
with mock.patch.object(db_api, 'create_object',
with mock.patch.object(obj_db_api, 'create_object',
side_effect=obj_exc.DBDuplicateEntry):
obj = self._test_class(self.context, **self.db_obj)
self.assertRaises(base.NeutronDbObjectDuplicateEntry, obj.create)
@mock.patch.object(db_api, 'update_object')
@mock.patch.object(obj_db_api, 'update_object')
def test_update_no_changes(self, update_mock):
with mock.patch.object(base.NeutronDbObject,
'_get_changed_persistent_fields',
return_value={}):
obj = self._test_class(self.context, id=7777)
obj_keys = self.generate_object_keys(self._test_class)
obj = self._test_class(self.context, **obj_keys)
obj.update()
self.assertFalse(update_mock.called)
@mock.patch.object(db_api, 'update_object')
@mock.patch.object(obj_db_api, 'update_object')
def test_update_changes(self, update_mock):
fields_to_update = self.get_updatable_fields(self.db_obj)
with mock.patch.object(base.NeutronDbObject,
@ -248,9 +304,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj.update()
update_mock.assert_called_once_with(
self.context, self._test_class.db_model,
self.db_obj[self._test_class.primary_key],
fields_to_update,
key=self._test_class.primary_key)
**obj._get_composite_keys())
@mock.patch.object(base.NeutronDbObject,
'_get_changed_persistent_fields',
@ -265,7 +320,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self.assertRaises(base.NeutronObjectUpdateForbidden, obj.update)
def test_update_updates_from_db_object(self):
with mock.patch.object(db_api, 'update_object',
with mock.patch.object(obj_db_api, 'update_object',
return_value=self.db_obj):
obj = self._test_class(self.context, **self.db_objs[1])
fields_to_update = self.get_updatable_fields(self.db_objs[1])
@ -275,7 +330,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj.update()
self._check_equal(obj, self.db_obj)
@mock.patch.object(db_api, 'delete_object')
@mock.patch.object(obj_db_api, 'delete_object')
def test_delete(self, delete_mock):
obj = self._test_class(self.context, **self.db_obj)
self._check_equal(obj, self.db_obj)
@ -283,8 +338,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self._check_equal(obj, self.db_obj)
delete_mock.assert_called_once_with(
self.context, self._test_class.db_model,
self.db_obj[self._test_class.primary_key],
key=self._test_class.primary_key)
**obj._get_composite_keys())
@mock.patch(OBJECTS_BASE_OBJ_FROM_PRIMITIVE)
def test_clean_obj_from_primitive(self, get_prim_m):
@ -299,33 +353,44 @@ class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectNonStandardPrimaryKey
class BaseDbObjectCompositePrimaryKeyTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectCompositePrimaryKey
class BaseDbObjectCompositePrimaryKeyWithIdTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectCompositePrimaryKeyWithId
class BaseDbObjectTestCase(_BaseObjectTestCase):
def _create_test_network(self):
# TODO(ihrachys): replace with network.create() once we get an object
# implementation for networks
self._network = db_api.create_object(self.context, models_v2.Network,
{'name': 'test-network1'})
self._network = obj_db_api.create_object(self.context,
models_v2.Network,
{'name': 'test-network1'})
def _create_test_port(self, network):
# TODO(ihrachys): replace with port.create() once we get an object
# implementation for ports
self._port = db_api.create_object(self.context, models_v2.Port,
{'tenant_id': 'fake_tenant_id',
'name': 'test-port1',
'network_id': network['id'],
'mac_address': 'fake_mac',
'admin_state_up': True,
'status': 'ACTIVE',
'device_id': 'fake_device',
'device_owner': 'fake_owner'})
self._port = obj_db_api.create_object(self.context, models_v2.Port,
{'tenant_id': 'fake_tenant_id',
'name': 'test-port1',
'network_id': network['id'],
'mac_address': 'fake_mac',
'admin_state_up': True,
'status': 'ACTIVE',
'device_id': 'fake_device',
'device_owner': 'fake_owner'})
def test_get_by_id_create_update_delete(self):
def test_get_object_create_update_delete(self):
obj = self._test_class(self.context, **self.db_obj)
obj.create()
new = self._test_class.get_by_id(self.context,
id=getattr(obj, obj.primary_key))
new = self._test_class.get_object(self.context,
**obj._get_composite_keys())
self.assertEqual(obj, new)
obj = new
@ -334,15 +399,15 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
setattr(obj, key, val)
obj.update()
new = self._test_class.get_by_id(self.context,
getattr(obj, obj.primary_key))
new = self._test_class.get_object(self.context,
**obj._get_composite_keys())
self.assertEqual(obj, new)
obj = new
new.delete()
new = self._test_class.get_by_id(self.context,
getattr(obj, obj.primary_key))
new = self._test_class.get_object(self.context,
**obj._get_composite_keys())
self.assertIsNone(new)
def test_update_non_existent_object_raises_not_found(self):
@ -389,10 +454,10 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
self.assertEqual(1, mock_commit.call_count)
@mock.patch(SQLALCHEMY_COMMIT)
def test_get_by_id_single_transaction(self, mock_commit):
def test_get_object_single_transaction(self, mock_commit):
obj = self._test_class(self.context, **self.db_obj)
obj.create()
obj = self._test_class.get_by_id(self.context,
getattr(obj, obj.primary_key))
obj = self._test_class.get_object(self.context,
**obj._get_composite_keys())
self.assertEqual(2, mock_commit.call_count)

View File

@ -19,11 +19,11 @@ import sqlalchemy as sa
from neutron.callbacks import events
from neutron.common import exceptions as n_exc
from neutron.db import api as db_api
from neutron.db import model_base
from neutron.db import rbac_db_models
from neutron.extensions import rbac as ext_rbac
from neutron.objects import base
from neutron.objects.db import api as obj_db_api
from neutron.objects import rbac_db
from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
@ -133,10 +133,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
mock_validate_rbac_update.assert_not_called()
@mock.patch.object(_test_class, 'validate_rbac_policy_update')
@mock.patch.object(_test_class, 'get_by_id',
@mock.patch.object(_test_class, 'get_object',
return_value={'tenant_id': 'tyrion_lannister'})
def test_validate_rbac_policy_change_allowed_for_admin_or_owner(
self, mock_get_by_id, mock_validate_update):
self, mock_get_object, mock_validate_update):
context = mock.Mock(is_admin=True, tenant_id='db_obj_owner_id')
self._rbac_policy_generate_change_events(
resource=None, trigger='dummy_trigger', context=context,
@ -147,10 +147,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
self.assertTrue(self._test_class.validate_rbac_policy_update.called)
@mock.patch.object(_test_class, 'validate_rbac_policy_update')
@mock.patch.object(_test_class, 'get_by_id',
@mock.patch.object(_test_class, 'get_object',
return_value={'tenant_id': 'king_beyond_the_wall'})
def test_validate_rbac_policy_change_forbidden_for_outsiders(
self, mock_get_by_id, mock_validate_update):
self, mock_get_object, mock_validate_update):
context = mock.Mock(is_admin=False, tenant_id='db_obj_owner_id')
self.assertRaises(
n_exc.InvalidInput,
@ -175,21 +175,21 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
self._test_validate_rbac_policy_delete_handles_policy(
{'action': 'unknown_action'})
@mock.patch.object(_test_class, 'get_by_id')
@mock.patch.object(_test_class, 'get_object')
def test_validate_rbac_policy_delete_skips_db_object_owner(self,
mock_get_by_id):
mock_get_object):
policy = {'action': rbac_db_models.ACCESS_SHARED,
'target_tenant': 'fake_tenant_id',
'object_id': 'fake_obj_id',
'tenant_id': 'fake_tenant_id'}
mock_get_by_id.return_value.tenant_id = policy['target_tenant']
mock_get_object.return_value.tenant_id = policy['target_tenant']
self._test_validate_rbac_policy_delete_handles_policy(policy)
@mock.patch.object(_test_class, 'get_by_id')
@mock.patch.object(_test_class, 'get_object')
@mock.patch.object(_test_class, 'get_bound_tenant_ids',
return_value='tenant_id_shared_with')
def test_validate_rbac_policy_delete_fails_single_tenant_and_in_use(
self, get_bound_tenant_ids_mock, mock_get_by_id):
self, get_bound_tenant_ids_mock, mock_get_object):
policy = {'action': rbac_db_models.ACCESS_SHARED,
'target_tenant': 'tenant_id_shared_with',
'tenant_id': 'object_owner_tenant_id',
@ -241,7 +241,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
'tenant_id': 'object_owner_tenant_id',
'object_id': 'fake_obj_id'}
context = mock.Mock()
with mock.patch.object(self._test_class, 'get_by_id'):
with mock.patch.object(self._test_class, 'get_object'):
self.assertRaises(
ext_rbac.RbacPolicyInUse,
self._test_class.validate_rbac_policy_delete,
@ -253,7 +253,8 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
policy=policy)
@mock.patch.object(_test_class, 'attach_rbac')
@mock.patch.object(db_api, 'get_object', return_value=['fake_rbac_policy'])
@mock.patch.object(obj_db_api, 'get_object',
return_value=['fake_rbac_policy'])
@mock.patch.object(_test_class, '_validate_rbac_policy_delete')
def test_update_shared_avoid_duplicate_update(
self, mock_validate_delete, get_object_mock, attach_rbac_mock):
@ -267,7 +268,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
self.assertFalse(attach_rbac_mock.called)
@mock.patch.object(_test_class, 'attach_rbac')
@mock.patch.object(db_api, 'get_object', return_value=[])
@mock.patch.object(obj_db_api, 'get_object', return_value=[])
@mock.patch.object(_test_class, '_validate_rbac_policy_delete')
def test_update_shared_wildcard(
self, mock_validate_delete, get_object_mock, attach_rbac_mock):
@ -283,7 +284,8 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase,
obj_id, test_neutron_obj._context.tenant_id)
@mock.patch.object(_test_class, 'attach_rbac')
@mock.patch.object(db_api, 'get_object', return_value=['fake_rbac_policy'])
@mock.patch.object(obj_db_api, 'get_object',
return_value=['fake_rbac_policy'])
@mock.patch.object(_test_class, '_validate_rbac_policy_delete')
def test_update_shared_remove_wildcard_sharing(
self, mock_validate_delete, get_object_mock, attach_rbac_mock):

View File

@ -33,10 +33,10 @@ class TestQosPlugin(base.BaseQosTestCase):
super(TestQosPlugin, self).setUp()
self.setup_coreplugin()
mock.patch('neutron.db.api.create_object').start()
mock.patch('neutron.db.api.update_object').start()
mock.patch('neutron.db.api.delete_object').start()
mock.patch('neutron.db.api.get_object').start()
mock.patch('neutron.objects.db.api.create_object').start()
mock.patch('neutron.objects.db.api.update_object').start()
mock.patch('neutron.objects.db.api.delete_object').start()
mock.patch('neutron.objects.db.api.get_object').start()
mock.patch(
'neutron.objects.qos.policy.QosPolicy.obj_load_attr').start()
@ -93,13 +93,13 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.policy.id, {'policy': fields})
self._validate_notif_driver_params('update_policy')
@mock.patch('neutron.db.api.get_object', return_value=None)
@mock.patch('neutron.objects.db.api.get_object', return_value=None)
def test_delete_policy(self, *mocks):
self.qos_plugin.delete_policy(self.ctxt, self.policy.id)
self._validate_notif_driver_params('delete_policy')
def test_create_policy_rule(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=self.policy):
self.qos_plugin.create_policy_bandwidth_limit_rule(
self.ctxt, self.policy.id, self.rule_data)
@ -108,7 +108,7 @@ class TestQosPlugin(base.BaseQosTestCase):
def test_update_policy_rule(self):
_policy = policy_object.QosPolicy(
self.ctxt, **self.policy_data['policy'])
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=_policy):
setattr(_policy, "rules", [self.rule])
self.qos_plugin.update_policy_bandwidth_limit_rule(
@ -118,7 +118,7 @@ class TestQosPlugin(base.BaseQosTestCase):
def test_update_policy_rule_bad_policy(self):
_policy = policy_object.QosPolicy(
self.ctxt, **self.policy_data['policy'])
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=_policy):
setattr(_policy, "rules", [])
self.assertRaises(
@ -130,7 +130,7 @@ class TestQosPlugin(base.BaseQosTestCase):
def test_delete_policy_rule(self):
_policy = policy_object.QosPolicy(
self.ctxt, **self.policy_data['policy'])
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=_policy):
setattr(_policy, "rules", [self.rule])
self.qos_plugin.delete_policy_bandwidth_limit_rule(
@ -140,7 +140,7 @@ class TestQosPlugin(base.BaseQosTestCase):
def test_delete_policy_rule_bad_policy(self):
_policy = policy_object.QosPolicy(
self.ctxt, **self.policy_data['policy'])
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=_policy):
setattr(_policy, "rules", [])
self.assertRaises(
@ -149,7 +149,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.rule.id, _policy.id)
def test_get_policy_bandwidth_limit_rules_for_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=self.policy):
with mock.patch('neutron.objects.qos.rule.'
'QosBandwidthLimitRule.'
@ -160,7 +160,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, qos_policy_id=self.policy.id)
def test_get_policy_bandwidth_limit_rules_for_policy_with_filters(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=self.policy):
with mock.patch('neutron.objects.qos.rule.'
'QosBandwidthLimitRule.'
@ -174,7 +174,7 @@ class TestQosPlugin(base.BaseQosTestCase):
filter='filter_id')
def test_get_policy_for_nonexistent_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=None):
self.assertRaises(
n_exc.QosPolicyNotFound,
@ -182,7 +182,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.policy.id)
def test_get_policy_bandwidth_limit_rule_for_nonexistent_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=None):
self.assertRaises(
n_exc.QosPolicyNotFound,
@ -190,7 +190,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.rule.id, self.policy.id)
def test_get_policy_bandwidth_limit_rules_for_nonexistent_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=None):
self.assertRaises(
n_exc.QosPolicyNotFound,
@ -198,7 +198,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.policy.id)
def test_create_policy_rule_for_nonexistent_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=None):
self.assertRaises(
n_exc.QosPolicyNotFound,
@ -206,7 +206,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.policy.id, self.rule_data)
def test_update_policy_rule_for_nonexistent_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=None):
self.assertRaises(
n_exc.QosPolicyNotFound,
@ -214,7 +214,7 @@ class TestQosPlugin(base.BaseQosTestCase):
self.ctxt, self.rule.id, self.policy.id, self.rule_data)
def test_delete_policy_rule_for_nonexistent_policy(self):
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id',
with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object',
return_value=None):
self.assertRaises(
n_exc.QosPolicyNotFound,