From 8bccf9c2a5821be606cded7294efa00078bd3e4b Mon Sep 17 00:00:00 2001 From: rossella Date: Mon, 22 Feb 2016 13:07:04 +0100 Subject: [PATCH] Handle synthetic fields in NeutronDbObject Many objects in neutron have synthetic fields. Synthetic fields are filled using the data from fields of a different table. This patch makes it possible to handle synthetic fields directly in NeutronDbObject so that objects can share this implementation. Change-Id: Ia4695b1b10c0370c77b66f31588a56de332f462e Partial-bug: #1541928 --- neutron/objects/base.py | 128 ++++++++++--- neutron/tests/unit/objects/qos/test_policy.py | 8 + neutron/tests/unit/objects/test_base.py | 179 +++++++++++++----- 3 files changed, 242 insertions(+), 73 deletions(-) diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 3abb6840e6a..8dfc8e8cec5 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -18,9 +18,11 @@ from neutron_lib import exceptions from oslo_db import exception as obj_exc from oslo_utils import reflection from oslo_versionedobjects import base as obj_base +from oslo_versionedobjects import fields as obj_fields import six from neutron._i18n import _ +from neutron.db import api as db_api from neutron.objects.db import api as obj_db_api @@ -52,6 +54,11 @@ class NeutronPrimaryKeyMissing(exceptions.BadRequest): ) +class NeutronSyntheticFieldMultipleForeignKeys(exceptions.NeutronException): + message = _("Synthetic fields %(fields)s shouldn't have more than one " + "foreign key") + + def get_updatable_fields(cls, fields): fields = fields.copy() for field in cls.fields_no_update: @@ -72,7 +79,24 @@ class NeutronObject(obj_base.VersionedObject, self.obj_set_defaults() def to_dict(self): - return dict(self.items()) + dict_ = dict(self.items()) + for field in self.synthetic_fields: + if field in dict_: + if isinstance(dict_[field], obj_fields.ListOfObjectsField): + dict_[field] = [obj.to_dict() for obj in dict_[field]] + elif isinstance(dict_[field], obj_fields.ObjectField): + dict_[field] = ( + dict_[field].to_dict() if dict_[field] else None) + return dict_ + + @classmethod + def is_synthetic(cls, field): + return field in cls.synthetic_fields + + @classmethod + def is_object_field(cls, field): + return (isinstance(cls.fields[field], obj_fields.ListOfObjectsField) or + isinstance(cls.fields[field], obj_fields.ObjectField)) @classmethod def clean_obj_from_primitive(cls, primitive, context=None): @@ -87,7 +111,7 @@ class NeutronObject(obj_base.VersionedObject, @classmethod def validate_filters(cls, **kwargs): bad_filters = [key for key in kwargs - if key not in cls.fields or key in cls.synthetic_fields] + if key not in cls.fields or cls.is_synthetic(key)] if bad_filters: bad_filters = ', '.join(bad_filters) msg = _("'%s' is not supported for filtering") % bad_filters @@ -127,6 +151,14 @@ class NeutronDbObject(NeutronObject): primary_keys = ['id'] + # this is a dict to store the association between the foreign key and the + # corresponding key in the main table, e.g. port extension have 'port_id' + # as foreign key, that is associated with the key 'id' of the table Port, + # so foreign_keys = {'port_id': 'id'}. The assumption is the association is + # the same for all object fields. E.g. all the port extension will use + # 'port_id' as key. + foreign_keys = {} + fields_no_update = [] # dict with name mapping: {'field_name_in_object': 'field_name_in_db'} @@ -136,9 +168,10 @@ class NeutronDbObject(NeutronObject): db_objs = [self.modify_fields_from_db(db_obj) for db_obj in objs] for field in self.fields: for db_obj in db_objs: - if field in db_obj: + if field in db_obj and not self.is_synthetic(field): setattr(self, field, db_obj[field]) break + self.load_synthetic_db_fields() self.obj_reset_changes() @classmethod @@ -181,6 +214,12 @@ class NeutronDbObject(NeutronObject): result[field] = result.pop(field_db) return result + @classmethod + def _load_object(cls, context, db_obj): + obj = cls(context) + obj.from_db_object(db_obj) + return obj + @classmethod def get_object(cls, context, **kwargs): """ @@ -196,22 +235,17 @@ class NeutronDbObject(NeutronObject): raise NeutronPrimaryKeyMissing(object_class=cls.__class__, missing_keys=missing_keys) - db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs) - if db_obj: - obj = cls(context, **cls.modify_fields_from_db(db_obj)) - obj.obj_reset_changes() - return obj + with db_api.autonested_transaction(context.session): + db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs) + if db_obj: + return cls._load_object(context, db_obj) @classmethod def get_objects(cls, context, **kwargs): cls.validate_filters(**kwargs) - db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs) - result = [] - for db_obj in db_objs: - obj = cls(context, **cls.modify_fields_from_db(db_obj)) - obj.obj_reset_changes() - result.append(obj) - return result + with db_api.autonested_transaction(context.session): + db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs) + return [cls._load_object(context, db_obj) for db_obj in db_objs] @classmethod def is_accessible(cls, context, db_obj): @@ -233,15 +267,55 @@ class NeutronDbObject(NeutronObject): return fields + def load_synthetic_db_fields(self): + """ + This method loads the synthetic fields that are stored in a different + table from the main object + + This method doesn't take care of loading synthetic fields that aren't + stored in the DB, e.g. 'shared' in rbac policy + """ + + # TODO(rossella_s) Find a way to handle ObjectFields with + # subclasses=True + for field in self.synthetic_fields: + try: + objclasses = obj_base.VersionedObjectRegistry.obj_classes( + ).get(self.fields[field].objname) + except AttributeError: + # NOTE(rossella_s) this is probably because this field is not + # an ObjectField + continue + if not objclasses: + # NOTE(rossella_s) some synthetic fields are not handled by + # this method, for example the ones that have subclasses, see + # QosRule + continue + objclass = objclasses[0] + if len(objclass.foreign_keys.keys()) > 1: + raise NeutronSyntheticFieldMultipleForeignKeys(field=field) + objs = objclass.get_objects( + self._context, **{ + k: getattr( + self, v) for k, v in objclass.foreign_keys.items()}) + if isinstance(self.fields[field], obj_fields.ObjectField): + setattr(self, field, objs[0] if objs else None) + else: + setattr(self, field, objs) + self.obj_reset_changes([field]) + def create(self): fields = self._get_changed_persistent_fields() - try: - db_obj = obj_db_api.create_object(self._context, self.db_model, - self.modify_fields_to_db(fields)) - except obj_exc.DBDuplicateEntry as db_exc: - raise NeutronDbObjectDuplicateEntry(object_class=self.__class__, - db_exception=db_exc) - self.from_db_object(db_obj) + with db_api.autonested_transaction(self._context.session): + try: + db_obj = obj_db_api.create_object( + self._context, self.db_model, + self.modify_fields_to_db(fields)) + except obj_exc.DBDuplicateEntry as db_exc: + raise NeutronDbObjectDuplicateEntry( + object_class=self.__class__, db_exception=db_exc) + + self.from_db_object(db_obj) def _get_composite_keys(self): keys = {} @@ -254,10 +328,12 @@ class NeutronDbObject(NeutronObject): updates = self._validate_changed_fields(updates) if updates: - db_obj = obj_db_api.update_object(self._context, self.db_model, - self.modify_fields_to_db(updates), - **self._get_composite_keys()) - self.from_db_object(self, db_obj) + with db_api.autonested_transaction(self._context.session): + db_obj = obj_db_api.update_object( + self._context, self.db_model, + self.modify_fields_to_db(updates), + **self._get_composite_keys()) + self.from_db_object(self, db_obj) def delete(self): obj_db_api.delete_object(self._context, self.db_model, diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index b6c687d0350..80a8b25ee28 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -105,6 +105,14 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, def setUp(self): super(QosPolicyDbObjectTestCase, self).setUp() + self.db_qos_bandwidth_rules = [ + self.get_random_fields(rule.QosBandwidthLimitRule) + for _ in range(3)] + + self.model_map.update({ + rule.QosBandwidthLimitRule.db_model: self.db_qos_bandwidth_rules + }) + self._create_test_network() self._create_test_port(self._network) diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 7c6c9756ce7..d2a8cf6222e 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -17,6 +17,7 @@ from oslo_db import exception as obj_exc from oslo_utils import uuidutils from oslo_versionedobjects import base as obj_base from oslo_versionedobjects import fields as obj_fields +from oslo_versionedobjects import fixture from neutron.common import exceptions as n_exc from neutron.common import utils as common_utils @@ -38,6 +39,45 @@ class FakeModel(object): pass +class ObjectFieldsModel(object): + def __init__(self, *args, **kwargs): + pass + + +@obj_base.VersionedObjectRegistry.register_if(False) +class FakeSmallNeutronObject(base.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = ObjectFieldsModel + + primary_keys = ['field1'] + + foreign_keys = {'field1': 'id'} + + fields = { + 'field1': obj_fields.UUIDField(), + 'field2': obj_fields.StringField(), + } + + +@obj_base.VersionedObjectRegistry.register_if(False) +class FakeWeirdKeySmallNeutronObject(base.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = ObjectFieldsModel + + primary_keys = ['field1'] + + foreign_keys = {'field1': 'weird_key'} + + fields = { + 'field1': obj_fields.UUIDField(), + 'field2': obj_fields.StringField(), + } + + @obj_base.VersionedObjectRegistry.register_if(False) class FakeNeutronObject(base.NeutronDbObject): # Version 1.0: Initial version @@ -48,14 +88,15 @@ class FakeNeutronObject(base.NeutronDbObject): fields = { 'id': obj_fields.UUIDField(), 'field1': obj_fields.StringField(), - 'field2': obj_fields.StringField() + 'obj_field': obj_fields.ObjectField('FakeSmallNeutronObject', + nullable=True) } primary_keys = ['id'] fields_no_update = ['field1'] - synthetic_fields = ['field2'] + synthetic_fields = ['obj_field'] @obj_base.VersionedObjectRegistry.register_if(False) @@ -70,10 +111,12 @@ class FakeNeutronObjectNonStandardPrimaryKey(base.NeutronDbObject): fields = { 'weird_key': obj_fields.UUIDField(), 'field1': obj_fields.StringField(), + 'obj_field': obj_fields.ListOfObjectsField( + 'FakeWeirdKeySmallNeutronObject'), 'field2': obj_fields.StringField() } - synthetic_fields = ['field2'] + synthetic_fields = ['obj_field', 'field2'] @obj_base.VersionedObjectRegistry.register_if(False) @@ -88,10 +131,11 @@ class FakeNeutronObjectCompositePrimaryKey(base.NeutronDbObject): fields = { 'weird_key': obj_fields.UUIDField(), 'field1': obj_fields.StringField(), - 'field2': obj_fields.StringField() + 'obj_field': obj_fields.ListOfObjectsField( + 'FakeWeirdKeySmallNeutronObject') } - synthetic_fields = ['field2'] + synthetic_fields = ['obj_field'] @obj_base.VersionedObjectRegistry.register_if(False) @@ -132,10 +176,10 @@ class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject): fields = { 'id': obj_fields.UUIDField(), 'field1': obj_fields.StringField(), - 'field2': obj_fields.StringField() + 'obj_field': obj_fields.ListOfObjectsField('FakeSmallNeutronObject') } - synthetic_fields = ['field2'] + synthetic_fields = ['obj_field'] FIELD_TYPE_VALUE_GENERATOR_MAP = { @@ -143,6 +187,7 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = { obj_fields.IntegerField: tools.get_random_integer, obj_fields.StringField: tools.get_random_string, obj_fields.UUIDField: uuidutils.generate_uuid, + obj_fields.ObjectField: lambda: None, obj_fields.ListOfObjectsField: lambda: [] } @@ -171,6 +216,14 @@ class _BaseObjectTestCase(object): if f not in self._test_class.synthetic_fields][0] self.valid_field_filter = {valid_field: self.obj_fields[0][valid_field]} + self.obj_registry = self.useFixture( + fixture.VersionedObjectRegistryFixture()) + self.obj_registry.register(FakeSmallNeutronObject) + self.obj_registry.register(FakeWeirdKeySmallNeutronObject) + synthetic_obj_fields = self.get_random_fields(FakeSmallNeutronObject) + self.model_map = { + self._test_class.db_model: self.db_objs, + ObjectFieldsModel: [synthetic_obj_fields]} @classmethod def get_random_fields(cls, obj_cls=None): @@ -198,19 +251,23 @@ class _BaseObjectTestCase(object): def _is_test_class(cls, obj): return isinstance(obj, cls._test_class) + def fake_get_objects(self, context, model, **kwargs): + return self.model_map[model] + class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): def test_get_object(self): with mock.patch.object(obj_db_api, 'get_object', return_value=self.db_obj) as get_object_mock: - obj_keys = self.generate_object_keys(self._test_class) - obj = self._test_class.get_object(self.context, **obj_keys) - self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.obj_fields[0], - get_obj_db_fields(obj)) - get_object_mock.assert_called_once_with( - self.context, self._test_class.db_model, **obj_keys) + with mock.patch.object(obj_db_api, 'get_objects', + side_effect=self.fake_get_objects): + obj_keys = self.generate_object_keys(self._test_class) + obj = self._test_class.get_object(self.context, **obj_keys) + self.assertTrue(self._is_test_class(obj)) + self.assertEqual(self.obj_fields[0], get_obj_db_fields(obj)) + get_object_mock.assert_called_once_with( + self.context, self._test_class.db_model, **obj_keys) def test_get_object_missing_object(self): with mock.patch.object(obj_db_api, 'get_object', return_value=None): @@ -225,13 +282,27 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self._test_class.get_object, self.context, **obj_keys) + def _get_synthetic_fields_get_objects_calls(self, db_objs): + mock_calls = [] + for db_obj in db_objs: + for field in self._test_class.synthetic_fields: + if self._test_class.is_object_field(field): + mock_calls.append( + mock.call( + self.context, FakeSmallNeutronObject.db_model, + **{FakeSmallNeutronObject.primary_keys[0]: db_obj[ + self._test_class.primary_keys[0]]})) + return mock_calls + def test_get_objects(self): with mock.patch.object(obj_db_api, 'get_objects', return_value=self.db_objs) as get_objects_mock: objs = self._test_class.get_objects(self.context) self._validate_objects(self.db_objs, objs) - get_objects_mock.assert_called_once_with( - self.context, self._test_class.db_model) + mock_calls = [mock.call(self.context, self._test_class.db_model)] + mock_calls.extend(self._get_synthetic_fields_get_objects_calls( + self.db_objs)) + get_objects_mock.assert_has_calls(mock_calls) def test_get_objects_valid_fields(self): with mock.patch.object( @@ -242,9 +313,11 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): **self.valid_field_filter) self._validate_objects([self.db_obj], objs) - get_objects_mock.assert_called_with( - self.context, self._test_class.db_model, - **self.valid_field_filter) + mock_calls = [mock.call(self.context, self._test_class.db_model, + **self.valid_field_filter)] + mock_calls.extend(self._get_synthetic_fields_get_objects_calls( + [self.db_obj])) + get_objects_mock.assert_has_calls(mock_calls) def test_get_objects_mixed_fields(self): synthetic_fields = self._test_class.synthetic_fields @@ -268,14 +341,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self._test_class) with mock.patch.object(obj_db_api, 'get_objects', - return_value=self.db_objs): + side_effect=self.fake_get_objects): self.assertRaises(base.exceptions.InvalidInput, self._test_class.get_objects, self.context, **{synthetic_fields[0]: 'xxx'}) def test_get_objects_invalid_fields(self): with mock.patch.object(obj_db_api, 'get_objects', - return_value=self.db_objs): + side_effect=self.fake_get_objects): self.assertRaises(base.exceptions.InvalidInput, self._test_class.get_objects, self.context, fake_field='xxx') @@ -297,20 +370,24 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): def test_create(self): with mock.patch.object(obj_db_api, 'create_object', return_value=self.db_obj) as create_mock: - obj = self._test_class(self.context, **self.obj_fields[0]) - self._check_equal(obj, self.obj_fields[0]) - obj.create() - self._check_equal(obj, self.obj_fields[0]) - create_mock.assert_called_once_with( - self.context, self._test_class.db_model, self.db_obj) + with mock.patch.object(obj_db_api, 'get_objects', + side_effect=self.fake_get_objects): + obj = self._test_class(self.context, **self.obj_fields[0]) + self._check_equal(obj, self.obj_fields[0]) + obj.create() + self._check_equal(obj, self.obj_fields[0]) + create_mock.assert_called_once_with( + self.context, self._test_class.db_model, self.db_obj) def test_create_updates_from_db_object(self): with mock.patch.object(obj_db_api, 'create_object', return_value=self.db_obj): - obj = self._test_class(self.context, **self.obj_fields[1]) - self._check_equal(obj, self.obj_fields[1]) - obj.create() - self._check_equal(obj, self.obj_fields[0]) + with mock.patch.object(obj_db_api, 'get_objects', + side_effect=self.fake_get_objects): + obj = self._test_class(self.context, **self.obj_fields[1]) + self._check_equal(obj, self.obj_fields[1]) + obj.create() + self._check_equal(obj, self.obj_fields[0]) def test_create_duplicates(self): with mock.patch.object(obj_db_api, 'create_object', @@ -338,12 +415,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): with mock.patch.object(base.NeutronDbObject, '_get_changed_persistent_fields', return_value=fields_to_update): - obj = self._test_class(self.context, **self.obj_fields[0]) - obj.update() - update_mock.assert_called_once_with( - self.context, self._test_class.db_model, - fields_to_update, - **obj._get_composite_keys()) + with mock.patch.object(obj_db_api, 'get_objects', + side_effect=self.fake_get_objects): + obj = self._test_class(self.context, **self.db_obj) + obj.update() + update_mock.assert_called_once_with( + self.context, self._test_class.db_model, + fields_to_update, + **obj._get_composite_keys()) @mock.patch.object(base.NeutronDbObject, '_get_changed_persistent_fields', @@ -360,16 +439,22 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): def test_update_updates_from_db_object(self): with mock.patch.object(obj_db_api, 'update_object', return_value=self.db_obj): - obj = self._test_class(self.context, **self.obj_fields[1]) - fields_to_update = self.get_updatable_fields(self.obj_fields[1]) - if not fields_to_update: - self.skipTest('No updatable fields found in test class %r' % - self._test_class) - with mock.patch.object(base.NeutronDbObject, - '_get_changed_persistent_fields', - return_value=fields_to_update): - obj.update() - self._check_equal(obj, self.obj_fields[0]) + with mock.patch.object(obj_db_api, 'get_objects', + side_effect=self.fake_get_objects): + obj = self._test_class(self.context, **self.obj_fields[1]) + fields_to_update = self.get_updatable_fields( + self.obj_fields[1]) + if not fields_to_update: + self.skipTest('No updatable fields found in test ' + 'class %r' % self._test_class) + with mock.patch.object(base.NeutronDbObject, + '_get_changed_persistent_fields', + return_value=fields_to_update): + with mock.patch.object( + obj_db_api, 'get_objects', + side_effect=self.fake_get_objects): + obj.update() + self._check_equal(obj, self.obj_fields[0]) @mock.patch.object(obj_db_api, 'delete_object') def test_delete(self, delete_mock):