From 4b227c3771eba1cbaa27c6c33829108981cd9b69 Mon Sep 17 00:00:00 2001 From: Artur Korzeniewski Date: Tue, 1 Mar 2016 12:07:15 +0100 Subject: [PATCH] Objects DB api: added composite key to handle multiple primary key Moving CRUD DB operation for objects from db/api.py to objects/db/api.py Renaming object get_by_id(id) to get_object(**kwargs) Many models in Neutron DB have complex primary keys, concatenated from a few properties. This patch adds ability to define multiple primary keys in NeutronDbObject, which are automatically evaluated into DB query when performing operations. Partial-Bug: #1541928 Change-Id: I0f63a62418db76415ddd40c30c778ff7541b93dc --- doc/source/devref/quality_of_service.rst | 2 +- neutron/core_extensions/qos.py | 2 +- neutron/db/api.py | 7 + neutron/objects/base.py | 60 +++++-- neutron/objects/db/__init__.py | 0 neutron/objects/db/api.py | 64 +++++++ neutron/objects/qos/policy.py | 14 +- neutron/objects/rbac_db.py | 7 +- .../qos/notification_drivers/message_queue.py | 2 +- neutron/services/qos/qos_plugin.py | 6 +- .../tests/unit/core_extensions/test_qos.py | 12 +- neutron/tests/unit/objects/qos/test_policy.py | 17 +- neutron/tests/unit/objects/test_base.py | 161 ++++++++++++------ neutron/tests/unit/objects/test_rbac_db.py | 30 ++-- .../unit/services/qos/test_qos_plugin.py | 36 ++-- 15 files changed, 297 insertions(+), 123 deletions(-) create mode 100644 neutron/objects/db/__init__.py create mode 100644 neutron/objects/db/api.py diff --git a/doc/source/devref/quality_of_service.rst b/doc/source/devref/quality_of_service.rst index 8e3e6d81d25..62548360042 100644 --- a/doc/source/devref/quality_of_service.rst +++ b/doc/source/devref/quality_of_service.rst @@ -144,7 +144,7 @@ effort. Every NeutronObject supports the following operations: -* get_by_id: returns specific object that is represented by the id passed as an +* get_object: returns specific object that is represented by the id passed as an argument. * get_objects: returns all objects of the type, potentially with a filter applied. diff --git a/neutron/core_extensions/qos.py b/neutron/core_extensions/qos.py index 72fb898836c..0ee323ff703 100644 --- a/neutron/core_extensions/qos.py +++ b/neutron/core_extensions/qos.py @@ -32,7 +32,7 @@ class QosCoreResourceExtension(base.CoreResourceExtension): return self._plugin_loaded def _get_policy_obj(self, context, policy_id): - obj = policy_object.QosPolicy.get_by_id(context, policy_id) + obj = policy_object.QosPolicy.get_object(context, id=policy_id) if obj is None: raise n_exc.QosPolicyNotFound(policy_id=policy_id) return obj diff --git a/neutron/db/api.py b/neutron/db/api.py index 01a9f8c7f6a..adb5e28bdcc 100644 --- a/neutron/db/api.py +++ b/neutron/db/api.py @@ -15,6 +15,7 @@ import contextlib +import debtcollector from oslo_config import cfg from oslo_db import api as oslo_db_api from oslo_db import exception as db_exc @@ -89,6 +90,7 @@ def autonested_transaction(sess): # Common database operation implementations +@debtcollector.removals.remove(message="This will be removed in the N cycle.") def get_object(context, model, **kwargs): with context.session.begin(subtransactions=True): return (common_db_mixin.model_query(context, model) @@ -96,6 +98,7 @@ def get_object(context, model, **kwargs): .first()) +@debtcollector.removals.remove(message="This will be removed in the N cycle.") def get_objects(context, model, **kwargs): with context.session.begin(subtransactions=True): return (common_db_mixin.model_query(context, model) @@ -103,6 +106,7 @@ def get_objects(context, model, **kwargs): .all()) +@debtcollector.removals.remove(message="This will be removed in the N cycle.") def create_object(context, model, values): with context.session.begin(subtransactions=True): if 'id' not in values and hasattr(model, 'id'): @@ -112,6 +116,7 @@ def create_object(context, model, values): return db_obj.__dict__ +@debtcollector.removals.remove(message="This will be removed in the N cycle.") def _safe_get_object(context, model, id, key='id'): db_obj = get_object(context, model, **{key: id}) if db_obj is None: @@ -119,6 +124,7 @@ def _safe_get_object(context, model, id, key='id'): return db_obj +@debtcollector.removals.remove(message="This will be removed in the N cycle.") def update_object(context, model, id, values, key=None): with context.session.begin(subtransactions=True): kwargs = {} @@ -131,6 +137,7 @@ def update_object(context, model, id, values, key=None): return db_obj.__dict__ +@debtcollector.removals.remove(message="This will be removed in the N cycle.") def delete_object(context, model, id, key=None): with context.session.begin(subtransactions=True): kwargs = {} diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 6a5dfda82fa..b06605d212b 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -12,14 +12,14 @@ import abc +from neutron_lib import exceptions from oslo_db import exception as obj_exc from oslo_utils import reflection from oslo_versionedobjects import base as obj_base import six from neutron._i18n import _ -from neutron.common import exceptions -from neutron.db import api as db_api +from neutron.objects.db import api as obj_db_api class NeutronObjectUpdateForbidden(exceptions.NeutronException): @@ -38,6 +38,18 @@ class NeutronDbObjectDuplicateEntry(exceptions.Conflict): values=db_exception.value) +class NeutronPrimaryKeyMissing(exceptions.BadRequest): + message = _("For class %(object_type)s missing primary keys: " + "%(missing_keys)s") + + def __init__(self, object_class, missing_keys): + super(NeutronPrimaryKeyMissing, self).__init__( + object_type=reflection.get_class_name(object_class, + fully_qualified=False), + missing_keys=missing_keys + ) + + def get_updatable_fields(cls, fields): fields = fields.copy() for field in cls.fields_no_update: @@ -67,7 +79,7 @@ class NeutronObject(obj_base.VersionedObject, return obj @classmethod - def get_by_id(cls, context, id): + def get_object(cls, context, **kwargs): raise NotImplementedError() @classmethod @@ -99,7 +111,7 @@ class NeutronDbObject(NeutronObject): # should be overridden for all persistent objects db_model = None - primary_key = 'id' + primary_keys = ['id'] fields_no_update = [] @@ -112,9 +124,21 @@ class NeutronDbObject(NeutronObject): self.obj_reset_changes() @classmethod - def get_by_id(cls, context, id): - db_obj = db_api.get_object(context, cls.db_model, - **{cls.primary_key: id}) + def get_object(cls, context, **kwargs): + """ + This method fetches object from DB and convert it to versioned + object. + + :param context: + :param kwargs: multiple primary keys defined key=value pairs + :return: single object of NeutronDbObject class + """ + missing_keys = set(cls.primary_keys).difference(kwargs.keys()) + if missing_keys: + raise NeutronPrimaryKeyMissing(object_class=cls.__class__, + missing_keys=missing_keys) + + db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs) if db_obj: obj = cls(context, **db_obj) obj.obj_reset_changes() @@ -123,7 +147,7 @@ class NeutronDbObject(NeutronObject): @classmethod def get_objects(cls, context, **kwargs): cls.validate_filters(**kwargs) - db_objs = db_api.get_objects(context, cls.db_model, **kwargs) + db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs) objs = [cls(context, **db_obj) for db_obj in db_objs] for obj in objs: obj.obj_reset_changes() @@ -156,24 +180,30 @@ class NeutronDbObject(NeutronObject): def create(self): fields = self._get_changed_persistent_fields() try: - db_obj = db_api.create_object(self._context, self.db_model, fields) + db_obj = obj_db_api.create_object(self._context, self.db_model, + fields) except obj_exc.DBDuplicateEntry as db_exc: raise NeutronDbObjectDuplicateEntry(object_class=self.__class__, db_exception=db_exc) self.from_db_object(db_obj) + def _get_composite_keys(self): + keys = {} + for key in self.primary_keys: + keys[key] = getattr(self, key) + return keys + def update(self): updates = self._get_changed_persistent_fields() updates = self._validate_changed_fields(updates) if updates: - db_obj = db_api.update_object(self._context, self.db_model, - getattr(self, self.primary_key), - updates, key=self.primary_key) + db_obj = obj_db_api.update_object(self._context, self.db_model, + updates, + **self._get_composite_keys()) self.from_db_object(self, db_obj) def delete(self): - db_api.delete_object(self._context, self.db_model, - getattr(self, self.primary_key), - key=self.primary_key) + obj_db_api.delete_object(self._context, self.db_model, + **self._get_composite_keys()) diff --git a/neutron/objects/db/__init__.py b/neutron/objects/db/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neutron/objects/db/api.py b/neutron/objects/db/api.py new file mode 100644 index 00000000000..a8f57fc7888 --- /dev/null +++ b/neutron/objects/db/api.py @@ -0,0 +1,64 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from neutron_lib import exceptions as n_exc +from oslo_utils import uuidutils + +from neutron.db import common_db_mixin + + +# Common database operation implementations +def get_object(context, model, **kwargs): + with context.session.begin(subtransactions=True): + return (common_db_mixin.model_query(context, model) + .filter_by(**kwargs) + .first()) + + +def get_objects(context, model, **kwargs): + with context.session.begin(subtransactions=True): + return (common_db_mixin.model_query(context, model) + .filter_by(**kwargs) + .all()) + + +def create_object(context, model, values): + with context.session.begin(subtransactions=True): + if 'id' not in values and hasattr(model, 'id'): + values['id'] = uuidutils.generate_uuid() + db_obj = model(**values) + context.session.add(db_obj) + return db_obj.__dict__ + + +def _safe_get_object(context, model, **kwargs): + db_obj = get_object(context, model, **kwargs) + + if db_obj is None: + key = "".join(['%s:: %s ' % (key, value) for (key, value) + in kwargs.items()]) + raise n_exc.ObjectNotFound(id=key) + return db_obj + + +def update_object(context, model, values, **kwargs): + with context.session.begin(subtransactions=True): + db_obj = _safe_get_object(context, model, **kwargs) + db_obj.update(values) + db_obj.save(session=context.session) + return db_obj.__dict__ + + +def delete_object(context, model, **kwargs): + with context.session.begin(subtransactions=True): + db_obj = _safe_get_object(context, model, **kwargs) + context.session.delete(db_obj) diff --git a/neutron/objects/qos/policy.py b/neutron/objects/qos/policy.py index 985ea45e7f5..5aa9269d4c8 100644 --- a/neutron/objects/qos/policy.py +++ b/neutron/objects/qos/policy.py @@ -27,6 +27,7 @@ from neutron.db.qos import api as qos_db_api from neutron.db.qos import models as qos_db_model from neutron.db.rbac_db_models import QosPolicyRBAC from neutron.objects import base +from neutron.objects.db import api as obj_db_api from neutron.objects.qos import rule as rule_obj_impl from neutron.objects import rbac_db @@ -93,12 +94,13 @@ class QosPolicy(base.NeutronDbObject): rule_id=rule_id) @classmethod - def get_by_id(cls, context, id): + def get_object(cls, context, **kwargs): # We want to get the policy regardless of its tenant id. We'll make # sure the tenant has permission to access the policy later on. admin_context = context.elevated() with db_api.autonested_transaction(admin_context.session): - policy_obj = super(QosPolicy, cls).get_by_id(admin_context, id) + policy_obj = super(QosPolicy, cls).get_object(admin_context, + **kwargs) if (not policy_obj or not cls.is_accessible(context, policy_obj)): return @@ -125,9 +127,9 @@ class QosPolicy(base.NeutronDbObject): @classmethod def _get_object_policy(cls, context, model, **kwargs): with db_api.autonested_transaction(context.session): - binding_db_obj = db_api.get_object(context, model, **kwargs) + binding_db_obj = obj_db_api.get_object(context, model, **kwargs) if binding_db_obj: - return cls.get_by_id(context, binding_db_obj['policy_id']) + return cls.get_object(context, id=binding_db_obj['policy_id']) @classmethod def get_network_policy(cls, context, network_id): @@ -148,8 +150,8 @@ class QosPolicy(base.NeutronDbObject): def delete(self): with db_api.autonested_transaction(self._context.session): for object_type, model in self.binding_models.items(): - binding_db_obj = db_api.get_object(self._context, model, - policy_id=self.id) + binding_db_obj = obj_db_api.get_object(self._context, model, + policy_id=self.id) if binding_db_obj: raise exceptions.QosPolicyInUse( policy_id=self.id, diff --git a/neutron/objects/rbac_db.py b/neutron/objects/rbac_db.py index 11aac789482..c020def1fb5 100644 --- a/neutron/objects/rbac_db.py +++ b/neutron/objects/rbac_db.py @@ -28,6 +28,7 @@ from neutron.db import rbac_db_mixin from neutron.db import rbac_db_models as models from neutron.extensions import rbac as ext_rbac from neutron.objects import base +from neutron.objects.db import api as obj_db_api @add_metaclass(abc.ABCMeta) @@ -123,7 +124,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, if policy['action'] != models.ACCESS_SHARED: return target_tenant = policy['target_tenant'] - db_obj = cls.get_by_id(context, policy['object_id']) + db_obj = cls.get_object(context, id=policy['object_id']) if db_obj.tenant_id == target_tenant: return cls._validate_rbac_policy_delete(context=context, @@ -160,7 +161,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, # (hopefully) melded with this one. if object_type != cls.rbac_db_model.object_type: return - db_obj = cls.get_by_id(context.elevated(), policy['object_id']) + db_obj = cls.get_object(context.elevated(), id=policy['object_id']) if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE): if (not context.is_admin and db_obj['tenant_id'] != context.tenant_id): @@ -184,7 +185,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin, def update_shared(self, is_shared_new, obj_id): admin_context = self._context.elevated() - shared_prev = db_api.get_object(admin_context, self.rbac_db_model, + shared_prev = obj_db_api.get_object(admin_context, self.rbac_db_model, object_id=obj_id, target_tenant='*', action=models.ACCESS_SHARED) is_shared_prev = bool(shared_prev) diff --git a/neutron/services/qos/notification_drivers/message_queue.py b/neutron/services/qos/notification_drivers/message_queue.py index 1ffeda0ec41..7e7f1e533eb 100644 --- a/neutron/services/qos/notification_drivers/message_queue.py +++ b/neutron/services/qos/notification_drivers/message_queue.py @@ -33,7 +33,7 @@ def _get_qos_policy_cb(resource, policy_id, **kwargs): ) return - policy = policy_object.QosPolicy.get_by_id(context, policy_id) + policy = policy_object.QosPolicy.get_object(context, id=policy_id) return policy diff --git a/neutron/services/qos/qos_plugin.py b/neutron/services/qos/qos_plugin.py index 1327a748099..624a330c57b 100644 --- a/neutron/services/qos/qos_plugin.py +++ b/neutron/services/qos/qos_plugin.py @@ -61,7 +61,7 @@ class QoSPlugin(qos.QoSPluginBase): policy.delete() def _get_policy_obj(self, context, policy_id): - obj = policy_object.QosPolicy.get_by_id(context, policy_id) + obj = policy_object.QosPolicy.get_object(context, id=policy_id) if obj is None: raise n_exc.QosPolicyNotFound(policy_id=policy_id) return obj @@ -132,8 +132,8 @@ class QoSPlugin(qos.QoSPluginBase): with db_api.autonested_transaction(context.session): # first, validate that we have access to the policy self._get_policy_obj(context, policy_id) - rule = rule_object.QosBandwidthLimitRule.get_by_id( - context, rule_id) + rule = rule_object.QosBandwidthLimitRule.get_object( + context, id=rule_id) if not rule: raise n_exc.QosRuleNotFound(policy_id=policy_id, rule_id=rule_id) return rule diff --git a/neutron/tests/unit/core_extensions/test_qos.py b/neutron/tests/unit/core_extensions/test_qos.py index 07ba6398cca..88aa6b86663 100644 --- a/neutron/tests/unit/core_extensions/test_qos.py +++ b/neutron/tests/unit/core_extensions/test_qos.py @@ -62,7 +62,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase): actual_port = {'id': mock.Mock(), qos_consts.QOS_POLICY_ID: qos_policy_id} qos_policy = mock.MagicMock() - self.policy_m.get_by_id = mock.Mock(return_value=qos_policy) + self.policy_m.get_object = mock.Mock(return_value=qos_policy) self.core_extension.process_fields( self.context, base_core.PORT, {qos_consts.QOS_POLICY_ID: qos_policy_id}, @@ -81,7 +81,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase): self.policy_m.get_port_policy = mock.Mock( return_value=old_qos_policy) new_qos_policy = mock.MagicMock() - self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy) + self.policy_m.get_object = mock.Mock(return_value=new_qos_policy) self.core_extension.process_fields( self.context, base_core.PORT, {qos_consts.QOS_POLICY_ID: qos_policy2_id}, @@ -101,7 +101,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase): self.policy_m.get_port_policy = mock.Mock( return_value=old_qos_policy) new_qos_policy = mock.MagicMock() - self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy) + self.policy_m.get_object = mock.Mock(return_value=new_qos_policy) self.core_extension.process_fields( self.context, base_core.PORT, {qos_consts.QOS_POLICY_ID: None}, @@ -120,7 +120,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase): self.policy_m.get_network_policy = mock.Mock( return_value=old_qos_policy) new_qos_policy = mock.MagicMock() - self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy) + self.policy_m.get_object = mock.Mock(return_value=new_qos_policy) self.core_extension.process_fields( self.context, base_core.NETWORK, {qos_consts.QOS_POLICY_ID: None}, @@ -135,7 +135,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase): actual_network = {'id': mock.Mock(), qos_consts.QOS_POLICY_ID: qos_policy_id} qos_policy = mock.MagicMock() - self.policy_m.get_by_id = mock.Mock(return_value=qos_policy) + self.policy_m.get_object = mock.Mock(return_value=qos_policy) self.core_extension.process_fields( self.context, base_core.NETWORK, {qos_consts.QOS_POLICY_ID: qos_policy_id}, actual_network) @@ -153,7 +153,7 @@ class QosCoreResourceExtensionTestCase(base.BaseTestCase): self.policy_m.get_network_policy = mock.Mock( return_value=old_qos_policy) new_qos_policy = mock.MagicMock() - self.policy_m.get_by_id = mock.Mock(return_value=new_qos_policy) + self.policy_m.get_object = mock.Mock(return_value=new_qos_policy) self.core_extension.process_fields( self.context, base_core.NETWORK, {qos_consts.QOS_POLICY_ID: qos_policy_id}, actual_network) diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index 17e9aa93d7d..ce7e6d86782 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -13,7 +13,7 @@ import mock from neutron.common import exceptions as n_exc -from neutron.db import api as db_api +from neutron.objects.db import api as db_api from neutron.objects.qos import policy from neutron.objects.qos import rule from neutron.tests.unit.objects import test_base @@ -84,14 +84,14 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): **self.valid_field_filter) self._validate_objects([self.db_obj], objs) - def test_get_by_id(self): + def test_get_object(self): admin_context = self.context.elevated() with mock.patch.object(db_api, 'get_object', return_value=self.db_obj) as get_object_mock: with mock.patch.object(self.context, 'elevated', return_value=admin_context) as context_mock: - obj = self._test_class.get_by_id(self.context, id='fake_id') + obj = self._test_class.get_object(self.context, id='fake_id') self.assertTrue(self._is_test_class(obj)) self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj)) context_mock.assert_called_once_with() @@ -221,19 +221,22 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, def test_synthetic_rule_fields(self): policy_obj, rule_obj = self._create_test_policy_with_rule() - policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id) + policy_obj = policy.QosPolicy.get_object(self.context, + id=policy_obj.id) self.assertEqual([rule_obj], policy_obj.rules) - def test_get_by_id_fetches_rules_non_lazily(self): + def test_get_object_fetches_rules_non_lazily(self): policy_obj, rule_obj = self._create_test_policy_with_rule() - policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id) + policy_obj = policy.QosPolicy.get_object(self.context, + id=policy_obj.id) primitive = policy_obj.obj_to_primitive() self.assertNotEqual([], (primitive['versioned_object.data']['rules'])) def test_to_dict_returns_rules_as_dicts(self): policy_obj, rule_obj = self._create_test_policy_with_rule() - policy_obj = policy.QosPolicy.get_by_id(self.context, policy_obj.id) + policy_obj = policy.QosPolicy.get_object(self.context, + id=policy_obj.id) obj_dict = policy_obj.to_dict() rule_dict = rule_obj.to_dict() diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 58ccc92e26b..cd257b0b50c 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -21,9 +21,9 @@ from oslo_versionedobjects import fields as obj_fields from neutron.common import exceptions as n_exc from neutron.common import utils as common_utils from neutron import context -from neutron.db import api as db_api from neutron.db import models_v2 from neutron.objects import base +from neutron.objects.db import api as obj_db_api from neutron.tests import base as test_base from neutron.tests import tools @@ -63,7 +63,7 @@ class FakeNeutronObjectNonStandardPrimaryKey(base.NeutronDbObject): db_model = FakeModel - primary_key = 'weird_key' + primary_keys = ['weird_key'] fields = { 'weird_key': obj_fields.UUIDField(), @@ -74,6 +74,44 @@ class FakeNeutronObjectNonStandardPrimaryKey(base.NeutronDbObject): synthetic_fields = ['field2'] +@obj_base.VersionedObjectRegistry.register_if(False) +class FakeNeutronObjectCompositePrimaryKey(base.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = FakeModel + + primary_keys = ['weird_key', 'field1'] + + fields = { + 'weird_key': obj_fields.UUIDField(), + 'field1': obj_fields.StringField(), + 'field2': obj_fields.StringField() + } + + synthetic_fields = ['field2'] + + +@obj_base.VersionedObjectRegistry.register_if(False) +class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = FakeModel + + primary_keys = ['id', 'field1'] + + fields = { + 'id': obj_fields.UUIDField(), + 'field1': obj_fields.StringField(), + 'field2': obj_fields.StringField() + } + + synthetic_fields = ['field2'] + + fields_no_update = ['id'] + + FIELD_TYPE_VALUE_GENERATOR_MAP = { obj_fields.BooleanField: tools.get_random_boolean, obj_fields.IntegerField: tools.get_random_integer, @@ -112,6 +150,15 @@ class _BaseObjectTestCase(object): fields[field] = generator() return fields + @classmethod + def generate_object_keys(cls, obj_cls): + keys = {} + for field, field_obj in obj_cls.fields.items(): + if field in obj_cls.primary_keys: + generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)] + keys[field] = generator() + return keys + def get_updatable_fields(self, fields): return base.get_updatable_fields(self._test_class, fields) @@ -122,23 +169,31 @@ class _BaseObjectTestCase(object): class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): - def test_get_by_id(self): - with mock.patch.object(db_api, 'get_object', + def test_get_object(self): + with mock.patch.object(obj_db_api, 'get_object', return_value=self.db_obj) as get_object_mock: - obj = self._test_class.get_by_id(self.context, id='fake_id') + obj_keys = self.generate_object_keys(self._test_class) + obj = self._test_class.get_object(self.context, **obj_keys) self.assertTrue(self._is_test_class(obj)) self.assertEqual(self.db_obj, get_obj_db_fields(obj)) get_object_mock.assert_called_once_with( - self.context, self._test_class.db_model, - **{self._test_class.primary_key: 'fake_id'}) + self.context, self._test_class.db_model, **obj_keys) - def test_get_by_id_missing_object(self): - with mock.patch.object(db_api, 'get_object', return_value=None): - obj = self._test_class.get_by_id(self.context, id='fake_id') + def test_get_object_missing_object(self): + with mock.patch.object(obj_db_api, 'get_object', return_value=None): + obj_keys = self.generate_object_keys(self._test_class) + obj = self._test_class.get_object(self.context, **obj_keys) self.assertIsNone(obj) + def test_get_object_missing_primary_key(self): + obj_keys = self.generate_object_keys(self._test_class) + obj_keys.popitem() + self.assertRaises(base.NeutronPrimaryKeyMissing, + self._test_class.get_object, + self.context, **obj_keys) + def test_get_objects(self): - with mock.patch.object(db_api, 'get_objects', + with mock.patch.object(obj_db_api, 'get_objects', return_value=self.db_objs) as get_objects_mock: objs = self._test_class.get_objects(self.context) self._validate_objects(self.db_objs, objs) @@ -147,7 +202,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): def test_get_objects_valid_fields(self): with mock.patch.object( - db_api, 'get_objects', + obj_db_api, 'get_objects', return_value=[self.db_obj]) as get_objects_mock: objs = self._test_class.get_objects(self.context, @@ -167,7 +222,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): filters = copy.copy(self.valid_field_filter) filters[synthetic_fields[0]] = 'xxx' - with mock.patch.object(db_api, 'get_objects', + with mock.patch.object(obj_db_api, 'get_objects', return_value=self.db_objs): self.assertRaises(base.exceptions.InvalidInput, self._test_class.get_objects, self.context, @@ -179,14 +234,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.skipTest('No synthetic fields found in test class %r' % self._test_class) - with mock.patch.object(db_api, 'get_objects', + with mock.patch.object(obj_db_api, 'get_objects', return_value=self.db_objs): self.assertRaises(base.exceptions.InvalidInput, self._test_class.get_objects, self.context, **{synthetic_fields[0]: 'xxx'}) def test_get_objects_invalid_fields(self): - with mock.patch.object(db_api, 'get_objects', + with mock.patch.object(obj_db_api, 'get_objects', return_value=self.db_objs): self.assertRaises(base.exceptions.InvalidInput, self._test_class.get_objects, self.context, @@ -206,7 +261,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): sorted(get_obj_db_fields(obj))) def test_create(self): - with mock.patch.object(db_api, 'create_object', + with mock.patch.object(obj_db_api, 'create_object', return_value=self.db_obj) as create_mock: obj = self._test_class(self.context, **self.db_obj) self._check_equal(obj, self.db_obj) @@ -216,7 +271,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.context, self._test_class.db_model, self.db_obj) def test_create_updates_from_db_object(self): - with mock.patch.object(db_api, 'create_object', + with mock.patch.object(obj_db_api, 'create_object', return_value=self.db_obj): obj = self._test_class(self.context, **self.db_objs[1]) self._check_equal(obj, self.db_objs[1]) @@ -224,21 +279,22 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self._check_equal(obj, self.db_obj) def test_create_duplicates(self): - with mock.patch.object(db_api, 'create_object', + with mock.patch.object(obj_db_api, 'create_object', side_effect=obj_exc.DBDuplicateEntry): obj = self._test_class(self.context, **self.db_obj) self.assertRaises(base.NeutronDbObjectDuplicateEntry, obj.create) - @mock.patch.object(db_api, 'update_object') + @mock.patch.object(obj_db_api, 'update_object') def test_update_no_changes(self, update_mock): with mock.patch.object(base.NeutronDbObject, '_get_changed_persistent_fields', return_value={}): - obj = self._test_class(self.context, id=7777) + obj_keys = self.generate_object_keys(self._test_class) + obj = self._test_class(self.context, **obj_keys) obj.update() self.assertFalse(update_mock.called) - @mock.patch.object(db_api, 'update_object') + @mock.patch.object(obj_db_api, 'update_object') def test_update_changes(self, update_mock): fields_to_update = self.get_updatable_fields(self.db_obj) with mock.patch.object(base.NeutronDbObject, @@ -248,9 +304,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj.update() update_mock.assert_called_once_with( self.context, self._test_class.db_model, - self.db_obj[self._test_class.primary_key], fields_to_update, - key=self._test_class.primary_key) + **obj._get_composite_keys()) @mock.patch.object(base.NeutronDbObject, '_get_changed_persistent_fields', @@ -265,7 +320,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.assertRaises(base.NeutronObjectUpdateForbidden, obj.update) def test_update_updates_from_db_object(self): - with mock.patch.object(db_api, 'update_object', + with mock.patch.object(obj_db_api, 'update_object', return_value=self.db_obj): obj = self._test_class(self.context, **self.db_objs[1]) fields_to_update = self.get_updatable_fields(self.db_objs[1]) @@ -275,7 +330,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj.update() self._check_equal(obj, self.db_obj) - @mock.patch.object(db_api, 'delete_object') + @mock.patch.object(obj_db_api, 'delete_object') def test_delete(self, delete_mock): obj = self._test_class(self.context, **self.db_obj) self._check_equal(obj, self.db_obj) @@ -283,8 +338,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self._check_equal(obj, self.db_obj) delete_mock.assert_called_once_with( self.context, self._test_class.db_model, - self.db_obj[self._test_class.primary_key], - key=self._test_class.primary_key) + **obj._get_composite_keys()) @mock.patch(OBJECTS_BASE_OBJ_FROM_PRIMITIVE) def test_clean_obj_from_primitive(self, get_prim_m): @@ -299,33 +353,44 @@ class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase): _test_class = FakeNeutronObjectNonStandardPrimaryKey +class BaseDbObjectCompositePrimaryKeyTestCase(BaseObjectIfaceTestCase): + + _test_class = FakeNeutronObjectCompositePrimaryKey + + +class BaseDbObjectCompositePrimaryKeyWithIdTestCase(BaseObjectIfaceTestCase): + + _test_class = FakeNeutronObjectCompositePrimaryKeyWithId + + class BaseDbObjectTestCase(_BaseObjectTestCase): def _create_test_network(self): # TODO(ihrachys): replace with network.create() once we get an object # implementation for networks - self._network = db_api.create_object(self.context, models_v2.Network, - {'name': 'test-network1'}) + self._network = obj_db_api.create_object(self.context, + models_v2.Network, + {'name': 'test-network1'}) def _create_test_port(self, network): # TODO(ihrachys): replace with port.create() once we get an object # implementation for ports - self._port = db_api.create_object(self.context, models_v2.Port, - {'tenant_id': 'fake_tenant_id', - 'name': 'test-port1', - 'network_id': network['id'], - 'mac_address': 'fake_mac', - 'admin_state_up': True, - 'status': 'ACTIVE', - 'device_id': 'fake_device', - 'device_owner': 'fake_owner'}) + self._port = obj_db_api.create_object(self.context, models_v2.Port, + {'tenant_id': 'fake_tenant_id', + 'name': 'test-port1', + 'network_id': network['id'], + 'mac_address': 'fake_mac', + 'admin_state_up': True, + 'status': 'ACTIVE', + 'device_id': 'fake_device', + 'device_owner': 'fake_owner'}) - def test_get_by_id_create_update_delete(self): + def test_get_object_create_update_delete(self): obj = self._test_class(self.context, **self.db_obj) obj.create() - new = self._test_class.get_by_id(self.context, - id=getattr(obj, obj.primary_key)) + new = self._test_class.get_object(self.context, + **obj._get_composite_keys()) self.assertEqual(obj, new) obj = new @@ -334,15 +399,15 @@ class BaseDbObjectTestCase(_BaseObjectTestCase): setattr(obj, key, val) obj.update() - new = self._test_class.get_by_id(self.context, - getattr(obj, obj.primary_key)) + new = self._test_class.get_object(self.context, + **obj._get_composite_keys()) self.assertEqual(obj, new) obj = new new.delete() - new = self._test_class.get_by_id(self.context, - getattr(obj, obj.primary_key)) + new = self._test_class.get_object(self.context, + **obj._get_composite_keys()) self.assertIsNone(new) def test_update_non_existent_object_raises_not_found(self): @@ -389,10 +454,10 @@ class BaseDbObjectTestCase(_BaseObjectTestCase): self.assertEqual(1, mock_commit.call_count) @mock.patch(SQLALCHEMY_COMMIT) - def test_get_by_id_single_transaction(self, mock_commit): + def test_get_object_single_transaction(self, mock_commit): obj = self._test_class(self.context, **self.db_obj) obj.create() - obj = self._test_class.get_by_id(self.context, - getattr(obj, obj.primary_key)) + obj = self._test_class.get_object(self.context, + **obj._get_composite_keys()) self.assertEqual(2, mock_commit.call_count) diff --git a/neutron/tests/unit/objects/test_rbac_db.py b/neutron/tests/unit/objects/test_rbac_db.py index 8e60dcfa876..17b502a3f24 100644 --- a/neutron/tests/unit/objects/test_rbac_db.py +++ b/neutron/tests/unit/objects/test_rbac_db.py @@ -19,11 +19,11 @@ import sqlalchemy as sa from neutron.callbacks import events from neutron.common import exceptions as n_exc -from neutron.db import api as db_api from neutron.db import model_base from neutron.db import rbac_db_models from neutron.extensions import rbac as ext_rbac from neutron.objects import base +from neutron.objects.db import api as obj_db_api from neutron.objects import rbac_db from neutron.tests.unit.objects import test_base from neutron.tests.unit import testlib_api @@ -133,10 +133,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, mock_validate_rbac_update.assert_not_called() @mock.patch.object(_test_class, 'validate_rbac_policy_update') - @mock.patch.object(_test_class, 'get_by_id', + @mock.patch.object(_test_class, 'get_object', return_value={'tenant_id': 'tyrion_lannister'}) def test_validate_rbac_policy_change_allowed_for_admin_or_owner( - self, mock_get_by_id, mock_validate_update): + self, mock_get_object, mock_validate_update): context = mock.Mock(is_admin=True, tenant_id='db_obj_owner_id') self._rbac_policy_generate_change_events( resource=None, trigger='dummy_trigger', context=context, @@ -147,10 +147,10 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, self.assertTrue(self._test_class.validate_rbac_policy_update.called) @mock.patch.object(_test_class, 'validate_rbac_policy_update') - @mock.patch.object(_test_class, 'get_by_id', + @mock.patch.object(_test_class, 'get_object', return_value={'tenant_id': 'king_beyond_the_wall'}) def test_validate_rbac_policy_change_forbidden_for_outsiders( - self, mock_get_by_id, mock_validate_update): + self, mock_get_object, mock_validate_update): context = mock.Mock(is_admin=False, tenant_id='db_obj_owner_id') self.assertRaises( n_exc.InvalidInput, @@ -175,21 +175,21 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, self._test_validate_rbac_policy_delete_handles_policy( {'action': 'unknown_action'}) - @mock.patch.object(_test_class, 'get_by_id') + @mock.patch.object(_test_class, 'get_object') def test_validate_rbac_policy_delete_skips_db_object_owner(self, - mock_get_by_id): + mock_get_object): policy = {'action': rbac_db_models.ACCESS_SHARED, 'target_tenant': 'fake_tenant_id', 'object_id': 'fake_obj_id', 'tenant_id': 'fake_tenant_id'} - mock_get_by_id.return_value.tenant_id = policy['target_tenant'] + mock_get_object.return_value.tenant_id = policy['target_tenant'] self._test_validate_rbac_policy_delete_handles_policy(policy) - @mock.patch.object(_test_class, 'get_by_id') + @mock.patch.object(_test_class, 'get_object') @mock.patch.object(_test_class, 'get_bound_tenant_ids', return_value='tenant_id_shared_with') def test_validate_rbac_policy_delete_fails_single_tenant_and_in_use( - self, get_bound_tenant_ids_mock, mock_get_by_id): + self, get_bound_tenant_ids_mock, mock_get_object): policy = {'action': rbac_db_models.ACCESS_SHARED, 'target_tenant': 'tenant_id_shared_with', 'tenant_id': 'object_owner_tenant_id', @@ -241,7 +241,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, 'tenant_id': 'object_owner_tenant_id', 'object_id': 'fake_obj_id'} context = mock.Mock() - with mock.patch.object(self._test_class, 'get_by_id'): + with mock.patch.object(self._test_class, 'get_object'): self.assertRaises( ext_rbac.RbacPolicyInUse, self._test_class.validate_rbac_policy_delete, @@ -253,7 +253,8 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, policy=policy) @mock.patch.object(_test_class, 'attach_rbac') - @mock.patch.object(db_api, 'get_object', return_value=['fake_rbac_policy']) + @mock.patch.object(obj_db_api, 'get_object', + return_value=['fake_rbac_policy']) @mock.patch.object(_test_class, '_validate_rbac_policy_delete') def test_update_shared_avoid_duplicate_update( self, mock_validate_delete, get_object_mock, attach_rbac_mock): @@ -267,7 +268,7 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, self.assertFalse(attach_rbac_mock.called) @mock.patch.object(_test_class, 'attach_rbac') - @mock.patch.object(db_api, 'get_object', return_value=[]) + @mock.patch.object(obj_db_api, 'get_object', return_value=[]) @mock.patch.object(_test_class, '_validate_rbac_policy_delete') def test_update_shared_wildcard( self, mock_validate_delete, get_object_mock, attach_rbac_mock): @@ -283,7 +284,8 @@ class RbacNeutronDbObjectTestCase(test_base.BaseObjectIfaceTestCase, obj_id, test_neutron_obj._context.tenant_id) @mock.patch.object(_test_class, 'attach_rbac') - @mock.patch.object(db_api, 'get_object', return_value=['fake_rbac_policy']) + @mock.patch.object(obj_db_api, 'get_object', + return_value=['fake_rbac_policy']) @mock.patch.object(_test_class, '_validate_rbac_policy_delete') def test_update_shared_remove_wildcard_sharing( self, mock_validate_delete, get_object_mock, attach_rbac_mock): diff --git a/neutron/tests/unit/services/qos/test_qos_plugin.py b/neutron/tests/unit/services/qos/test_qos_plugin.py index 2bdb9ec8832..40c70c4cbb9 100644 --- a/neutron/tests/unit/services/qos/test_qos_plugin.py +++ b/neutron/tests/unit/services/qos/test_qos_plugin.py @@ -33,10 +33,10 @@ class TestQosPlugin(base.BaseQosTestCase): super(TestQosPlugin, self).setUp() self.setup_coreplugin() - mock.patch('neutron.db.api.create_object').start() - mock.patch('neutron.db.api.update_object').start() - mock.patch('neutron.db.api.delete_object').start() - mock.patch('neutron.db.api.get_object').start() + mock.patch('neutron.objects.db.api.create_object').start() + mock.patch('neutron.objects.db.api.update_object').start() + mock.patch('neutron.objects.db.api.delete_object').start() + mock.patch('neutron.objects.db.api.get_object').start() mock.patch( 'neutron.objects.qos.policy.QosPolicy.obj_load_attr').start() @@ -93,13 +93,13 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.policy.id, {'policy': fields}) self._validate_notif_driver_params('update_policy') - @mock.patch('neutron.db.api.get_object', return_value=None) + @mock.patch('neutron.objects.db.api.get_object', return_value=None) def test_delete_policy(self, *mocks): self.qos_plugin.delete_policy(self.ctxt, self.policy.id) self._validate_notif_driver_params('delete_policy') def test_create_policy_rule(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=self.policy): self.qos_plugin.create_policy_bandwidth_limit_rule( self.ctxt, self.policy.id, self.rule_data) @@ -108,7 +108,7 @@ class TestQosPlugin(base.BaseQosTestCase): def test_update_policy_rule(self): _policy = policy_object.QosPolicy( self.ctxt, **self.policy_data['policy']) - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=_policy): setattr(_policy, "rules", [self.rule]) self.qos_plugin.update_policy_bandwidth_limit_rule( @@ -118,7 +118,7 @@ class TestQosPlugin(base.BaseQosTestCase): def test_update_policy_rule_bad_policy(self): _policy = policy_object.QosPolicy( self.ctxt, **self.policy_data['policy']) - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=_policy): setattr(_policy, "rules", []) self.assertRaises( @@ -130,7 +130,7 @@ class TestQosPlugin(base.BaseQosTestCase): def test_delete_policy_rule(self): _policy = policy_object.QosPolicy( self.ctxt, **self.policy_data['policy']) - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=_policy): setattr(_policy, "rules", [self.rule]) self.qos_plugin.delete_policy_bandwidth_limit_rule( @@ -140,7 +140,7 @@ class TestQosPlugin(base.BaseQosTestCase): def test_delete_policy_rule_bad_policy(self): _policy = policy_object.QosPolicy( self.ctxt, **self.policy_data['policy']) - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=_policy): setattr(_policy, "rules", []) self.assertRaises( @@ -149,7 +149,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.rule.id, _policy.id) def test_get_policy_bandwidth_limit_rules_for_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=self.policy): with mock.patch('neutron.objects.qos.rule.' 'QosBandwidthLimitRule.' @@ -160,7 +160,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, qos_policy_id=self.policy.id) def test_get_policy_bandwidth_limit_rules_for_policy_with_filters(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=self.policy): with mock.patch('neutron.objects.qos.rule.' 'QosBandwidthLimitRule.' @@ -174,7 +174,7 @@ class TestQosPlugin(base.BaseQosTestCase): filter='filter_id') def test_get_policy_for_nonexistent_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=None): self.assertRaises( n_exc.QosPolicyNotFound, @@ -182,7 +182,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.policy.id) def test_get_policy_bandwidth_limit_rule_for_nonexistent_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=None): self.assertRaises( n_exc.QosPolicyNotFound, @@ -190,7 +190,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.rule.id, self.policy.id) def test_get_policy_bandwidth_limit_rules_for_nonexistent_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=None): self.assertRaises( n_exc.QosPolicyNotFound, @@ -198,7 +198,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.policy.id) def test_create_policy_rule_for_nonexistent_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=None): self.assertRaises( n_exc.QosPolicyNotFound, @@ -206,7 +206,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.policy.id, self.rule_data) def test_update_policy_rule_for_nonexistent_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=None): self.assertRaises( n_exc.QosPolicyNotFound, @@ -214,7 +214,7 @@ class TestQosPlugin(base.BaseQosTestCase): self.ctxt, self.rule.id, self.policy.id, self.rule_data) def test_delete_policy_rule_for_nonexistent_policy(self): - with mock.patch('neutron.objects.qos.policy.QosPolicy.get_by_id', + with mock.patch('neutron.objects.qos.policy.QosPolicy.get_object', return_value=None): self.assertRaises( n_exc.QosPolicyNotFound,