Create a hook in base object to modify the fields before DB operations

Added modify_fields_to_db(), modify_fields_from_db() methods and
fields_need_translation dict to define the map what name in object
implementation should be changed to database naming.
It will prepare dicts before writing to DB and also modify the DB output
to match the OVO layout.

Partial-Bug: #1541928
Change-Id: I923b58870584c5e5756b307760f9d502c53c18b1
This commit is contained in:
Artur Korzeniewski
2016-03-10 14:05:12 +01:00
committed by Ihar Hrachyshka
parent 53c03f5ed3
commit 412012de59
2 changed files with 127 additions and 40 deletions

View File

@@ -11,6 +11,7 @@
# under the License.
import abc
import copy
from neutron_lib import exceptions
from oslo_db import exception as obj_exc
@@ -115,14 +116,58 @@ class NeutronDbObject(NeutronObject):
fields_no_update = []
# dict with name mapping: {'field_name_in_object': 'field_name_in_db'}
fields_need_translation = {}
def from_db_object(self, *objs):
db_objs = [self.modify_fields_from_db(db_obj) for db_obj in objs]
for field in self.fields:
for db_obj in objs:
for db_obj in db_objs:
if field in db_obj:
setattr(self, field, db_obj[field])
break
self.obj_reset_changes()
@classmethod
def modify_fields_to_db(cls, fields):
"""
This method enables to modify the fields and its
content before data is inserted into DB.
It uses the fields_need_translation dict with structure:
{
'field_name_in_object': 'field_name_in_db'
}
:param fields: dict of fields from NeutronDbObject
:return: modified dict of fields
"""
result = copy.deepcopy(dict(fields))
for field, field_db in cls.fields_need_translation.items():
if field in result:
result[field_db] = result.pop(field)
return result
@classmethod
def modify_fields_from_db(cls, db_obj):
"""
This method enables to modify the fields and its
content after data was fetched from DB.
It uses the fields_need_translation dict with structure:
{
'field_name_in_object': 'field_name_in_db'
}
:param db_obj: dict of object fetched from database
:return: modified dict of DB values
"""
result = dict(db_obj)
for field, field_db in cls.fields_need_translation.items():
if field_db in result:
result[field] = result.pop(field_db)
return result
@classmethod
def get_object(cls, context, **kwargs):
"""
@@ -140,7 +185,7 @@ class NeutronDbObject(NeutronObject):
db_obj = obj_db_api.get_object(context, cls.db_model, **kwargs)
if db_obj:
obj = cls(context, **db_obj)
obj = cls(context, **cls.modify_fields_from_db(db_obj))
obj.obj_reset_changes()
return obj
@@ -148,10 +193,12 @@ class NeutronDbObject(NeutronObject):
def get_objects(cls, context, **kwargs):
cls.validate_filters(**kwargs)
db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs)
objs = [cls(context, **db_obj) for db_obj in db_objs]
for obj in objs:
result = []
for db_obj in db_objs:
obj = cls(context, **cls.modify_fields_from_db(db_obj))
obj.obj_reset_changes()
return objs
result.append(obj)
return result
@classmethod
def is_accessible(cls, context, db_obj):
@@ -181,18 +228,17 @@ class NeutronDbObject(NeutronObject):
fields = self._get_changed_persistent_fields()
try:
db_obj = obj_db_api.create_object(self._context, self.db_model,
fields)
self.modify_fields_to_db(fields))
except obj_exc.DBDuplicateEntry as db_exc:
raise NeutronDbObjectDuplicateEntry(object_class=self.__class__,
db_exception=db_exc)
self.from_db_object(db_obj)
def _get_composite_keys(self):
keys = {}
for key in self.primary_keys:
keys[key] = getattr(self, key)
return keys
return self.modify_fields_to_db(keys)
def update(self):
updates = self._get_changed_persistent_fields()
@@ -200,8 +246,8 @@ class NeutronDbObject(NeutronObject):
if updates:
db_obj = obj_db_api.update_object(self._context, self.db_model,
updates,
**self._get_composite_keys())
self.modify_fields_to_db(updates),
**self._get_composite_keys())
self.from_db_object(self, db_obj)
def delete(self):

View File

@@ -92,6 +92,34 @@ class FakeNeutronObjectCompositePrimaryKey(base.NeutronDbObject):
synthetic_fields = ['field2']
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObjectRenamedField(base.NeutronDbObject):
"""
Testing renaming the parameter from DB to NeutronDbObject
For tests:
- db fields: id, field_db, field2
- object: id, field_ovo, field2
"""
# Version 1.0: Initial version
VERSION = '1.0'
db_model = FakeModel
primary_keys = ['id']
fields = {
'id': obj_fields.UUIDField(),
'field_ovo': obj_fields.StringField(),
'field2': obj_fields.StringField()
}
synthetic_fields = ['field2']
fields_no_update = ['id']
fields_need_translation = {'field_ovo': 'field_db'}
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObjectCompositePrimaryKeyWithId(base.NeutronDbObject):
# Version 1.0: Initial version
@@ -136,9 +164,15 @@ class _BaseObjectTestCase(object):
self.db_objs = list(self.get_random_fields() for _ in range(3))
self.db_obj = self.db_objs[0]
self.obj_fields = []
for db_obj in self.db_objs:
self.obj_fields.append(
self._test_class.modify_fields_from_db(db_obj))
valid_field = [f for f in self._test_class.fields
if f not in self._test_class.synthetic_fields][0]
self.valid_field_filter = {valid_field: self.db_obj[valid_field]}
self.valid_field_filter = {valid_field:
self.obj_fields[0][valid_field]}
@classmethod
def get_random_fields(cls, obj_cls=None):
@@ -148,7 +182,7 @@ class _BaseObjectTestCase(object):
if field not in obj_cls.synthetic_fields:
generator = FIELD_TYPE_VALUE_GENERATOR_MAP[type(field_obj)]
fields[field] = generator()
return fields
return obj_cls.modify_fields_to_db(fields)
@classmethod
def generate_object_keys(cls, obj_cls):
@@ -175,7 +209,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj_keys = self.generate_object_keys(self._test_class)
obj = self._test_class.get_object(self.context, **obj_keys)
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, get_obj_db_fields(obj))
self.assertEqual(self.obj_fields[0],
get_obj_db_fields(obj))
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model, **obj_keys)
@@ -250,7 +285,8 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def _validate_objects(self, expected, observed):
self.assertTrue(all(self._is_test_class(obj) for obj in observed))
self.assertEqual(
sorted(expected,
sorted([self._test_class.modify_fields_from_db(db_obj)
for db_obj in expected],
key=common_utils.safe_sort_key),
sorted([get_obj_db_fields(obj) for obj in observed],
key=common_utils.safe_sort_key))
@@ -263,25 +299,25 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def test_create(self):
with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj) as create_mock:
obj = self._test_class(self.context, **self.db_obj)
self._check_equal(obj, self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
self._check_equal(obj, self.obj_fields[0])
obj.create()
self._check_equal(obj, self.db_obj)
self._check_equal(obj, self.obj_fields[0])
create_mock.assert_called_once_with(
self.context, self._test_class.db_model, self.db_obj)
def test_create_updates_from_db_object(self):
with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj):
obj = self._test_class(self.context, **self.db_objs[1])
self._check_equal(obj, self.db_objs[1])
obj = self._test_class(self.context, **self.obj_fields[1])
self._check_equal(obj, self.obj_fields[1])
obj.create()
self._check_equal(obj, self.db_obj)
self._check_equal(obj, self.obj_fields[0])
def test_create_duplicates(self):
with mock.patch.object(obj_db_api, 'create_object',
side_effect=obj_exc.DBDuplicateEntry):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
self.assertRaises(base.NeutronDbObjectDuplicateEntry, obj.create)
@mock.patch.object(obj_db_api, 'update_object')
@@ -300,7 +336,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
with mock.patch.object(base.NeutronDbObject,
'_get_changed_persistent_fields',
return_value=fields_to_update):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.update()
update_mock.assert_called_once_with(
self.context, self._test_class.db_model,
@@ -322,20 +358,20 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
def test_update_updates_from_db_object(self):
with mock.patch.object(obj_db_api, 'update_object',
return_value=self.db_obj):
obj = self._test_class(self.context, **self.db_objs[1])
fields_to_update = self.get_updatable_fields(self.db_objs[1])
obj = self._test_class(self.context, **self.obj_fields[1])
fields_to_update = self.get_updatable_fields(self.obj_fields[1])
with mock.patch.object(base.NeutronDbObject,
'_get_changed_persistent_fields',
return_value=fields_to_update):
obj.update()
self._check_equal(obj, self.db_obj)
self._check_equal(obj, self.obj_fields[0])
@mock.patch.object(obj_db_api, 'delete_object')
def test_delete(self, delete_mock):
obj = self._test_class(self.context, **self.db_obj)
self._check_equal(obj, self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
self._check_equal(obj, self.obj_fields[0])
obj.delete()
self._check_equal(obj, self.db_obj)
self._check_equal(obj, self.obj_fields[0])
delete_mock.assert_called_once_with(
self.context, self._test_class.db_model,
**obj._get_composite_keys())
@@ -363,6 +399,11 @@ class BaseDbObjectCompositePrimaryKeyWithIdTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectCompositePrimaryKeyWithId
class BaseDbObjectRenamedFieldTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectRenamedField
class BaseDbObjectTestCase(_BaseObjectTestCase):
def _create_test_network(self):
@@ -386,7 +427,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
'device_owner': 'fake_owner'})
def test_get_object_create_update_delete(self):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.create()
new = self._test_class.get_object(self.context,
@@ -395,7 +436,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
obj = new
for key, val in self.get_updatable_fields(self.db_objs[1]).items():
for key, val in self.get_updatable_fields(self.obj_fields[1]).items():
setattr(obj, key, val)
obj.update()
@@ -407,33 +448,33 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
new.delete()
new = self._test_class.get_object(self.context,
**obj._get_composite_keys())
**obj._get_composite_keys())
self.assertIsNone(new)
def test_update_non_existent_object_raises_not_found(self):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.obj_reset_changes()
for key, val in self.get_updatable_fields(self.db_obj).items():
for key, val in self.get_updatable_fields(self.obj_fields[0]).items():
setattr(obj, key, val)
self.assertRaises(n_exc.ObjectNotFound, obj.update)
def test_delete_non_existent_object_raises_not_found(self):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
self.assertRaises(n_exc.ObjectNotFound, obj.delete)
@mock.patch(SQLALCHEMY_COMMIT)
def test_create_single_transaction(self, mock_commit):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.create()
self.assertEqual(1, mock_commit.call_count)
def test_update_single_transaction(self):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.create()
for key, val in self.get_updatable_fields(self.db_obj).items():
for key, val in self.get_updatable_fields(self.obj_fields[1]).items():
setattr(obj, key, val)
with mock.patch(SQLALCHEMY_COMMIT) as mock_commit:
@@ -441,7 +482,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
self.assertEqual(1, mock_commit.call_count)
def test_delete_single_transaction(self):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.create()
with mock.patch(SQLALCHEMY_COMMIT) as mock_commit:
@@ -455,9 +496,9 @@ class BaseDbObjectTestCase(_BaseObjectTestCase):
@mock.patch(SQLALCHEMY_COMMIT)
def test_get_object_single_transaction(self, mock_commit):
obj = self._test_class(self.context, **self.db_obj)
obj = self._test_class(self.context, **self.obj_fields[0])
obj.create()
obj = self._test_class.get_object(self.context,
**obj._get_composite_keys())
**obj._get_composite_keys())
self.assertEqual(2, mock_commit.call_count)