Create a hook in base object to modify the fields before DB operations

Added modify_fields_to_db(), modify_fields_from_db() methods and
fields_need_translation dict to define the map what name in object
implementation should be changed to database naming.
It will prepare dicts before writing to DB and also modify the DB output
to match the OVO layout.

Partial-Bug: #1541928
Change-Id: I923b58870584c5e5756b307760f9d502c53c18b1
This commit is contained in:
Artur Korzeniewski
2016-03-10 14:05:12 +01:00
committed by Ihar Hrachyshka
parent 53c03f5ed3
commit 412012de59
2 changed files with 127 additions and 40 deletions

View File

@@ -11,6 +11,7 @@
# under the License. # under the License.
import abc import abc
import copy
from neutron_lib import exceptions from neutron_lib import exceptions
from oslo_db import exception as obj_exc from oslo_db import exception as obj_exc
@@ -115,14 +116,58 @@ class NeutronDbObject(NeutronObject):
fields_no_update = [] fields_no_update = []
# dict with name mapping: {'field_name_in_object': 'field_name_in_db'}
fields_need_translation = {}
def from_db_object(self, *objs): def from_db_object(self, *objs):
db_objs = [self.modify_fields_from_db(db_obj) for db_obj in objs]
for field in self.fields: for field in self.fields:
for db_obj in objs: for db_obj in db_objs:
if field in db_obj: if field in db_obj:
setattr(self, field, db_obj[field]) setattr(self, field, db_obj[field])
break break
self.obj_reset_changes() self.obj_reset_changes()
@classmethod
def modify_fields_to_db(cls, fields):
"""
This method enables to modify the fields and its
content before data is inserted into DB.
It uses the fields_need_translation dict with structure:
{
'field_name_in_object': 'field_name_in_db'
}
:param fields: dict of fields from NeutronDbObject
:return: modified dict of fields
"""
result = copy.deepcopy(dict(fields))
for field, field_db in cls.fields_need_translation.items():
if field in result:
result[field_db] = result.pop(field)
return result
@classmethod
def modify_fields_from_db(cls, db_obj):
"""
This method enables to modify the fields and its
content after data was fetched from DB.
It uses the fields_need_translation dict with structure:
{
'field_name_in_object': 'field_name_in_db'
}
:param db_obj: dict of object fetched from database
:return: modified dict of DB values
"""
result = dict(db_obj)
for field, field_db in cls.fields_need_translation.items():
if field_db in result:
result[field] = result.pop(field_db)
return result
@classmethod @classmethod
def get_object(cls, context, **kwargs): def get_object(cls, context, **kwargs):
""" """
@@ -140,7 +185,7 @@ class NeutronDbObject(NeutronObject):
db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs) db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs)
if db_obj: if db_obj:
obj = cls(context, **db_obj) obj = cls(context, **cls.modify_fields_from_db(db_obj))
obj.obj_reset_changes() obj.obj_reset_changes()
return obj return obj
@@ -148,10 +193,12 @@ class NeutronDbObject(NeutronObject):
def get_objects(cls, context, **kwargs): def get_objects(cls, context, **kwargs):
cls.validate_filters(**kwargs) cls.validate_filters(**kwargs)
db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs) db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs)
objs = [cls(context, **db_obj) for db_obj in db_objs] result = []
for obj in objs: for db_obj in db_objs:
obj = cls(context, **cls.modify_fields_from_db(db_obj))
obj.obj_reset_changes() obj.obj_reset_changes()
return objs result.append(obj)
return result
@classmethod @classmethod
def is_accessible(cls, context, db_obj): def is_accessible(cls, context, db_obj):
@@ -181,18 +228,17 @@ class NeutronDbObject(NeutronObject):
fields = self._get_changed_persistent_fields() fields = self._get_changed_persistent_fields()
try: try:
db_obj = obj_db_api.create_object(self._context, self.db_model, db_obj = obj_db_api.create_object(self._context, self.db_model,
fields) self.modify_fields_to_db(fields))
except obj_exc.DBDuplicateEntry as db_exc: except obj_exc.DBDuplicateEntry as db_exc:
raise NeutronDbObjectDuplicateEntry(object_class=self.__class__, raise NeutronDbObjectDuplicateEntry(object_class=self.__class__,
db_exception=db_exc) db_exception=db_exc)
self.from_db_object(db_obj) self.from_db_object(db_obj)
def _get_composite_keys(self): def _get_composite_keys(self):
keys = {} keys = {}
for key in self.primary_keys: for key in self.primary_keys:
keys[key] = getattr(self, key) keys[key] = getattr(self, key)
return keys return self.modify_fields_to_db(keys)
def update(self): def update(self):
updates = self._get_changed_persistent_fields() updates = self._get_changed_persistent_fields()
@@ -200,8 +246,8 @@ class NeutronDbObject(NeutronObject):
if updates: if updates:
db_obj = obj_db_api.update_object(self._context, self.db_model, db_obj = obj_db_api.update_object(self._context, self.db_model,
updates, self.modify_fields_to_db(updates),
**self._get_composite_keys()) **self._get_composite_keys())
self.from_db_object(self, db_obj) self.from_db_object(self, db_obj)
def delete(self): def delete(self):

View File

@@ -92,6 +92,34 @@ class FakeNeutronObjectCompositePrimaryKey(base.NeutronDbObject):
synthetic_fields = ['field2'] synthetic_fields = ['field2']
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObjectRenamedField(base.NeutronDbObject):
"""
Testing renaming the parameter from DB to NeutronDbObject
For tests:
- db fields: id, field_db, field2
- object: id, field_ovo, field2
"""
# Version 1.0: Initial version
VERSION = '1.0'
db_model = FakeModel
primary_keys = ['id']
fields = {
'id': obj_fields.UUIDField(),
'field_ovo': obj_fields.StringField(),
'field2': obj_fields.StringField()
}
synthetic_fields = ['field2']
fields_no_update = ['id']
fields_need_translation = {'field_ovo': 'field_db'}
@obj_base.VersionedObjectRegistry.register_if(False) @obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject): class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject):
# Version 1.0: Initial version # Version 1.0: Initial version
@@ -136,9 +164,15 @@ class _BaseObjectTestCase(object):
self.db_objs = list(self.get_random_fields() for _ in range(3)) self.db_objs = list(self.get_random_fields() for _ in range(3))
self.db_obj = self.db_objs[0] self.db_obj = self.db_objs[0]
self.obj_fields = []
for db_obj in self.db_objs:
self.obj_fields.append(
self._test_class.modify_fields_from_db(db_obj))
valid_field = [f for f in self._test_class.fields valid_field = [f for f in self._test_class.fields
if f not in self._test_class.synthetic_fields][0] if f not in self._test_class.synthetic_fields][0]
self.valid_field_filter = {valid_field: self.db_obj[valid_field]} self.valid_field_filter = {valid_field:
self.obj_fields[0][valid_field]}
@classmethod @classmethod
def get_random_fields(cls, obj_cls=None): def get_random_fields(cls, obj_cls=None):
@@ -148,7 +182,7 @@ class _BaseObjectTestCase(object):
if field not in obj_cls.synthetic_fields: if field not in obj_cls.synthetic_fields:
generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)] generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]
fields[field] = generator() fields[field] = generator()
return fields return obj_cls.modify_fields_to_db(fields)
@classmethod @classmethod
def generate_object_keys(cls, obj_cls): def generate_object_keys(cls, obj_cls):
@@ -175,7 +209,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj_keys = self.generate_object_keys(self._test_class) obj_keys = self.generate_object_keys(self._test_class)
obj = self._test_class.get_object(self.context, **obj_keys) obj = self._test_class.get_object(self.context, **obj_keys)
self.assertTrue(self._is_test_class(obj)) self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, get_obj_db_fields(obj)) self.assertEqual(self.obj_fields[0],
get_obj_db_fields(obj))
get_object_mock.assert_called_once_with( get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model, **obj_keys) self.context, self._test_class.db_model, **obj_keys)
@@ -250,7 +285,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def _validate_objects(self, expected, observed): def _validate_objects(self, expected, observed):
self.assertTrue(all(self._is_test_class(obj) for obj in observed)) self.assertTrue(all(self._is_test_class(obj) for obj in observed))
self.assertEqual( self.assertEqual(
sorted(expected, sorted([self._test_class.modify_fields_from_db(db_obj)
for db_obj in expected],
key=common_utils.safe_sort_key), key=common_utils.safe_sort_key),
sorted([get_obj_db_fields(obj) for obj in observed], sorted([get_obj_db_fields(obj) for obj in observed],
key=common_utils.safe_sort_key)) key=common_utils.safe_sort_key))
@@ -263,25 +299,25 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def test_create(self): def test_create(self):
with mock.patch.object(obj_db_api, 'create_object', with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj) as create_mock: return_value=self.db_obj) as create_mock:
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
self._check_equal(obj, self.db_obj) self._check_equal(obj, self.obj_fields[0])
obj.create() obj.create()
self._check_equal(obj, self.db_obj) self._check_equal(obj, self.obj_fields[0])
create_mock.assert_called_once_with( create_mock.assert_called_once_with(
self.context, self._test_class.db_model, self.db_obj) self.context, self._test_class.db_model, self.db_obj)
def test_create_updates_from_db_object(self): def test_create_updates_from_db_object(self):
with mock.patch.object(obj_db_api, 'create_object', with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj): return_value=self.db_obj):
obj = self._test_class(self.context, **self.db_objs[1]) obj = self._test_class(self.context, **self.obj_fields[1])
self._check_equal(obj, self.db_objs[1]) self._check_equal(obj, self.obj_fields[1])
obj.create() obj.create()
self._check_equal(obj, self.db_obj) self._check_equal(obj, self.obj_fields[0])
def test_create_duplicates(self): def test_create_duplicates(self):
with mock.patch.object(obj_db_api, 'create_object', with mock.patch.object(obj_db_api, 'create_object',
side_effect=obj_exc.DBDuplicateEntry): side_effect=obj_exc.DBDuplicateEntry):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
self.assertRaises(base.NeutronDbObjectDuplicateEntry, obj.create) self.assertRaises(base.NeutronDbObjectDuplicateEntry, obj.create)
@mock.patch.object(obj_db_api, 'update_object') @mock.patch.object(obj_db_api, 'update_object')
@@ -300,7 +336,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
with mock.patch.object(base.NeutronDbObject, with mock.patch.object(base.NeutronDbObject,
'_get_changed_persistent_fields', '_get_changed_persistent_fields',
return_value=fields_to_update): return_value=fields_to_update):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.update() obj.update()
update_mock.assert_called_once_with( update_mock.assert_called_once_with(
self.context, self._test_class.db_model, self.context, self._test_class.db_model,
@@ -322,20 +358,20 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def test_update_updates_from_db_object(self): def test_update_updates_from_db_object(self):
with mock.patch.object(obj_db_api, 'update_object', with mock.patch.object(obj_db_api, 'update_object',
return_value=self.db_obj): return_value=self.db_obj):
obj = self._test_class(self.context, **self.db_objs[1]) obj = self._test_class(self.context, **self.obj_fields[1])
fields_to_update = self.get_updatable_fields(self.db_objs[1]) fields_to_update = self.get_updatable_fields(self.obj_fields[1])
with mock.patch.object(base.NeutronDbObject, with mock.patch.object(base.NeutronDbObject,
'_get_changed_persistent_fields', '_get_changed_persistent_fields',
return_value=fields_to_update): return_value=fields_to_update):
obj.update() obj.update()
self._check_equal(obj, self.db_obj) self._check_equal(obj, self.obj_fields[0])
@mock.patch.object(obj_db_api, 'delete_object') @mock.patch.object(obj_db_api, 'delete_object')
def test_delete(self, delete_mock): def test_delete(self, delete_mock):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
self._check_equal(obj, self.db_obj) self._check_equal(obj, self.obj_fields[0])
obj.delete() obj.delete()
self._check_equal(obj, self.db_obj) self._check_equal(obj, self.obj_fields[0])
delete_mock.assert_called_once_with( delete_mock.assert_called_once_with(
self.context, self._test_class.db_model, self.context, self._test_class.db_model,
**obj._get_composite_keys()) **obj._get_composite_keys())
@@ -363,6 +399,11 @@ class BaseDbObjectCompositePrimaryKeyWithIdTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectCompositePrimaryKeyWithId _test_class = FakeNeutronObjectCompositePrimaryKeyWithId
class BaseDbObjectRenamedFieldTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectRenamedField
class BaseDbObjectTestCase(_BaseObjectTestCase): class BaseDbObjectTestCase(_BaseObjectTestCase):
def _create_test_network(self): def _create_test_network(self):
@@ -386,7 +427,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
'device_owner': 'fake_owner'}) 'device_owner': 'fake_owner'})
def test_get_object_create_update_delete(self): def test_get_object_create_update_delete(self):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.create() obj.create()
new = self._test_class.get_object(self.context, new = self._test_class.get_object(self.context,
@@ -395,7 +436,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
obj = new obj = new
for key, val in self.get_updatable_fields(self.db_objs[1]).items(): for key, val in self.get_updatable_fields(self.obj_fields[1]).items():
setattr(obj, key, val) setattr(obj, key, val)
obj.update() obj.update()
@@ -407,33 +448,33 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
new.delete() new.delete()
new = self._test_class.get_object(self.context, new = self._test_class.get_object(self.context,
**obj._get_composite_keys()) **obj._get_composite_keys())
self.assertIsNone(new) self.assertIsNone(new)
def test_update_non_existent_object_raises_not_found(self): def test_update_non_existent_object_raises_not_found(self):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.obj_reset_changes() obj.obj_reset_changes()
for key, val in self.get_updatable_fields(self.db_obj).items(): for key, val in self.get_updatable_fields(self.obj_fields[0]).items():
setattr(obj, key, val) setattr(obj, key, val)
self.assertRaises(n_exc.ObjectNotFound, obj.update) self.assertRaises(n_exc.ObjectNotFound, obj.update)
def test_delete_non_existent_object_raises_not_found(self): def test_delete_non_existent_object_raises_not_found(self):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
self.assertRaises(n_exc.ObjectNotFound, obj.delete) self.assertRaises(n_exc.ObjectNotFound, obj.delete)
@mock.patch(SQLALCHEMY_COMMIT) @mock.patch(SQLALCHEMY_COMMIT)
def test_create_single_transaction(self, mock_commit): def test_create_single_transaction(self, mock_commit):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.create() obj.create()
self.assertEqual(1, mock_commit.call_count) self.assertEqual(1, mock_commit.call_count)
def test_update_single_transaction(self): def test_update_single_transaction(self):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.create() obj.create()
for key, val in self.get_updatable_fields(self.db_obj).items(): for key, val in self.get_updatable_fields(self.obj_fields[1]).items():
setattr(obj, key, val) setattr(obj, key, val)
with mock.patch(SQLALCHEMY_COMMIT) as mock_commit: with mock.patch(SQLALCHEMY_COMMIT) as mock_commit:
@@ -441,7 +482,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
self.assertEqual(1, mock_commit.call_count) self.assertEqual(1, mock_commit.call_count)
def test_delete_single_transaction(self): def test_delete_single_transaction(self):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.create() obj.create()
with mock.patch(SQLALCHEMY_COMMIT) as mock_commit: with mock.patch(SQLALCHEMY_COMMIT) as mock_commit:
@@ -455,9 +496,9 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
@mock.patch(SQLALCHEMY_COMMIT) @mock.patch(SQLALCHEMY_COMMIT)
def test_get_object_single_transaction(self, mock_commit): def test_get_object_single_transaction(self, mock_commit):
obj = self._test_class(self.context, **self.db_obj) obj = self._test_class(self.context, **self.obj_fields[0])
obj.create() obj.create()
obj = self._test_class.get_object(self.context, obj = self._test_class.get_object(self.context,
**obj._get_composite_keys()) **obj._get_composite_keys())
self.assertEqual(2, mock_commit.call_count) self.assertEqual(2, mock_commit.call_count)