diff --git a/neutron/db/api.py b/neutron/db/api.py index fdaaca1f431..7f65e6bd1ca 100644 --- a/neutron/db/api.py +++ b/neutron/db/api.py @@ -114,22 +114,30 @@ def create_object(context, model, values): return db_obj.__dict__ -def _safe_get_object(context, model, id): - db_obj = get_object(context, model, id=id) +def _safe_get_object(context, model, id, key='id'): + db_obj = get_object(context, model, **{key: id}) if db_obj is None: raise n_exc.ObjectNotFound(id=id) return db_obj -def update_object(context, model, id, values): +def update_object(context, model, id, values, key=None): with context.session.begin(subtransactions=True): - db_obj = _safe_get_object(context, model, id) + kwargs = {} + if key: + kwargs['key'] = key + db_obj = _safe_get_object(context, model, id, + **kwargs) db_obj.update(values) db_obj.save(session=context.session) return db_obj.__dict__ -def delete_object(context, model, id): +def delete_object(context, model, id, key=None): with context.session.begin(subtransactions=True): - db_obj = _safe_get_object(context, model, id) + kwargs = {} + if key: + kwargs['key'] = key + db_obj = _safe_get_object(context, model, id, + **kwargs) context.session.delete(db_obj) diff --git a/neutron/objects/base.py b/neutron/objects/base.py index ac16d4152e4..1d587499e28 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -99,6 +99,8 @@ class NeutronDbObject(NeutronObject): # should be overridden for all persistent objects db_model = None + primary_key = 'id' + fields_no_update = [] def from_db_object(self, *objs): @@ -111,7 +113,8 @@ class NeutronDbObject(NeutronObject): @classmethod def get_by_id(cls, context, id): - db_obj = db_api.get_object(context, cls.db_model, id=id) + db_obj = db_api.get_object(context, cls.db_model, + **{cls.primary_key: id}) if db_obj: obj = cls(context, **db_obj) obj.obj_reset_changes() @@ -161,8 +164,11 @@ class NeutronDbObject(NeutronObject): if updates: db_obj = db_api.update_object(self._context, self.db_model, - self.id, updates) + getattr(self, self.primary_key), + updates, key=self.primary_key) self.from_db_object(self, db_obj) def delete(self): - db_api.delete_object(self._context, self.db_model, self.id) + db_api.delete_object(self._context, self.db_model, + getattr(self, self.primary_key), + key=self.primary_key) diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 34743ac5f94..af6e279ff84 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -54,6 +54,24 @@ class FakeNeutronObject(base.NeutronDbObject): synthetic_fields = ['field2'] +@obj_base.VersionedObjectRegistry.register_if(False) +class FakeNeutronObjectNonStandardPrimaryKey(base.NeutronDbObject): + # Version 1.0: Initial version + VERSION = '1.0' + + db_model = FakeModel + + primary_key = 'weird_key' + + fields = { + 'weird_key': obj_fields.UUIDField(), + 'field1': obj_fields.StringField(), + 'field2': obj_fields.StringField() + } + + synthetic_fields = ['field2'] + + FIELD_TYPE_VALUE_GENERATOR_MAP = { obj_fields.BooleanField: tools.get_random_boolean, obj_fields.IntegerField: tools.get_random_integer, @@ -109,7 +127,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.assertTrue(self._is_test_class(obj)) self.assertEqual(self.db_obj, get_obj_db_fields(obj)) get_object_mock.assert_called_once_with( - self.context, self._test_class.db_model, id='fake_id') + self.context, self._test_class.db_model, + **{self._test_class.primary_key: 'fake_id'}) def test_get_by_id_missing_object(self): with mock.patch.object(db_api, 'get_object', return_value=None): @@ -227,7 +246,9 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj.update() update_mock.assert_called_once_with( self.context, self._test_class.db_model, - self.db_obj['id'], fields_to_update) + self.db_obj[self._test_class.primary_key], + fields_to_update, + key=self._test_class.primary_key) @mock.patch.object(base.NeutronDbObject, '_get_changed_persistent_fields', @@ -259,7 +280,9 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj.delete() self._check_equal(obj, self.db_obj) delete_mock.assert_called_once_with( - self.context, self._test_class.db_model, self.db_obj['id']) + self.context, self._test_class.db_model, + self.db_obj[self._test_class.primary_key], + key=self._test_class.primary_key) @mock.patch(OBJECTS_BASE_OBJ_FROM_PRIMITIVE) def test_clean_obj_from_primitive(self, get_prim_m): @@ -269,13 +292,19 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.assertTrue(observed_obj.obj_reset_changes.called) +class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase): + + _test_class = FakeNeutronObjectNonStandardPrimaryKey + + class BaseDbObjectTestCase(_BaseObjectTestCase): def test_get_by_id_create_update_delete(self): obj = self._test_class(self.context, **self.db_obj) obj.create() - new = self._test_class.get_by_id(self.context, id=obj.id) + new = self._test_class.get_by_id(self.context, + id=getattr(obj, obj.primary_key)) self.assertEqual(obj, new) obj = new @@ -284,13 +313,15 @@ class BaseDbObjectTestCase(_BaseObjectTestCase): setattr(obj, key, val) obj.update() - new = self._test_class.get_by_id(self.context, id=obj.id) + new = self._test_class.get_by_id(self.context, + getattr(obj, obj.primary_key)) self.assertEqual(obj, new) obj = new new.delete() - new = self._test_class.get_by_id(self.context, id=obj.id) + new = self._test_class.get_by_id(self.context, + getattr(obj, obj.primary_key)) self.assertIsNone(new) def test_update_non_existent_object_raises_not_found(self): @@ -341,5 +372,6 @@ class BaseDbObjectTestCase(_BaseObjectTestCase): obj = self._test_class(self.context, **self.db_obj) obj.create() - obj = self._test_class.get_by_id(self.context, obj.id) + obj = self._test_class.get_by_id(self.context, + getattr(obj, obj.primary_key)) self.assertEqual(2, mock_commit.call_count)