From 9cd230397c8ccec42f981bf5aa96379b07b0a95a Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Mon, 29 Aug 2016 14:37:40 +0000 Subject: [PATCH] tests: refactor objects test cases to use db models instead of dicts This should reflect the code under test better, and is needed for one of patches in the review queue (I130609194f15b89df89e5606fb8193849edd14d8) to pass some of those test cases. Partially-Implements: blueprint adopt-oslo-versioned-objects-for-db Change-Id: Id1ca4ce7b134d9729e68661cedb2f5556e58d6ff --- .../network/extensions/test_port_security.py | 4 +- .../objects/network/test_network_segment.py | 2 +- .../extensions/test_allowedaddresspairs.py | 2 +- .../port/extensions/test_extra_dhcp_opt.py | 6 +- neutron/tests/unit/objects/qos/test_policy.py | 13 +- neutron/tests/unit/objects/test_base.py | 122 ++++++++---------- neutron/tests/unit/objects/test_rbac_db.py | 5 +- .../tests/unit/objects/test_securitygroup.py | 4 +- neutron/tests/unit/objects/test_subnet.py | 8 +- neutron/tests/unit/objects/test_subnetpool.py | 2 +- 10 files changed, 81 insertions(+), 87 deletions(-) diff --git a/neutron/tests/unit/objects/network/extensions/test_port_security.py b/neutron/tests/unit/objects/network/extensions/test_port_security.py index cee5d5ce673..5c24e80d6c7 100644 --- a/neutron/tests/unit/objects/network/extensions/test_port_security.py +++ b/neutron/tests/unit/objects/network/extensions/test_port_security.py @@ -30,7 +30,9 @@ class NetworkPortSecurityDbObjTestCase(obj_test_base.BaseDbObjectTestCase, def setUp(self): super(NetworkPortSecurityDbObjTestCase, self).setUp() - for db_obj, obj_field in zip(self.db_objs, self.obj_fields): + for db_obj, obj_field, obj in zip( + self.db_objs, self.obj_fields, self.objs): network = self._create_network() db_obj['network_id'] = network['id'] obj_field['id'] = network['id'] + obj['id'] = network['id'] diff --git a/neutron/tests/unit/objects/network/test_network_segment.py b/neutron/tests/unit/objects/network/test_network_segment.py index 1008af9572f..67772dde356 100644 --- a/neutron/tests/unit/objects/network/test_network_segment.py +++ b/neutron/tests/unit/objects/network/test_network_segment.py @@ -32,5 +32,5 @@ class NetworkSegmentDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, def setUp(self): super(NetworkSegmentDbObjectTestCase, self).setUp() self._create_test_network() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['network_id'] = self._network['id'] diff --git a/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py b/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py index fbd23408acb..50b62127ba2 100644 --- a/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py +++ b/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py @@ -34,5 +34,5 @@ class AllowedAddrPairsDbObjTestCase(obj_test_base.BaseDbObjectTestCase, self.context = context.get_admin_context() self._create_test_network() self._create_test_port(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['port_id'] = self._port['id'] diff --git a/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py b/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py index adb9f737121..d89152251f9 100644 --- a/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py +++ b/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py @@ -10,6 +10,8 @@ # License for the specific language governing permissions and limitations # under the License. +import itertools + from neutron.objects.port.extensions import extra_dhcp_opt from neutron.tests.unit.objects import test_base as obj_test_base from neutron.tests.unit import testlib_api @@ -29,7 +31,5 @@ class ExtraDhcpOptDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(ExtraDhcpOptDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_port(self._network) - for obj in self.db_objs: - obj['port_id'] = self._port['id'] - for obj in self.obj_fields: + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['port_id'] = self._port['id'] diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index 94ee7f40eef..8fc91f76f9e 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -82,7 +82,9 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): context_mock.assert_called_once_with() self.get_objects.assert_any_call( admin_context, self._test_class.db_model, _pager=None) - self._validate_objects(self.db_objs, objs) + self.assertItemsEqual( + [test_base.get_obj_db_fields(obj) for obj in self.objs], + [test_base.get_obj_db_fields(obj) for obj in objs]) def test_get_objects_valid_fields(self): admin_context = self.context.elevated() @@ -103,7 +105,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): get_objects_mock.assert_any_call( admin_context, self._test_class.db_model, _pager=None, **self.valid_field_filter) - self._validate_objects([self.db_obj], objs) + self._check_equal(objs[0], self.objs[0]) def test_get_object(self): admin_context = self.context.elevated() @@ -114,7 +116,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): return_value=admin_context) as context_mock: obj = self._test_class.get_object(self.context, id='fake_id') self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj)) + self._check_equal(obj, self.objs[0]) context_mock.assert_called_once_with() get_object_mock.assert_called_once_with( admin_context, self._test_class.db_model, id='fake_id') @@ -139,9 +141,8 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, self._create_test_port(self._network) def _create_test_policy(self): - policy_obj = policy.QosPolicy(self.context, **self.db_obj) - policy_obj.create() - return policy_obj + self.objs[0].create() + return self.objs[0] def _create_test_policy_with_rules(self, rule_type, reload_rules=False): policy_obj = self._create_test_policy() diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 4a6d5113234..aff0203db70 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -27,7 +27,6 @@ from oslo_versionedobjects import fixture import testtools from neutron.common import constants -from neutron.common import utils as common_utils from neutron import context from neutron.db import db_base_plugin_v2 from neutron.db import model_base @@ -48,14 +47,12 @@ OBJECTS_BASE_OBJ_FROM_PRIMITIVE = ('oslo_versionedobjects.base.' TIMESTAMP_FIELDS = ['created_at', 'updated_at', 'revision_number'] -class FakeModel(object): - def __init__(self, *args, **kwargs): - pass +class FakeModel(dict): + pass -class ObjectFieldsModel(object): - def __init__(self, *args, **kwargs): - pass +class ObjectFieldsModel(dict): + pass @obj_base.VersionedObjectRegistry.register_if(False) @@ -396,9 +393,11 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = { } +# TODO(ihrachys) consider renaming into e.g. get_obj_persistent_fields def get_obj_db_fields(obj): return {field: getattr(obj, field) for field in obj.fields - if field not in obj.synthetic_fields} + if field not in obj.synthetic_fields + if field in obj} def get_value(generator, version): @@ -429,11 +428,19 @@ class _BaseObjectTestCase(object): # neutron.objects.db.api from core plugin instance self.setup_coreplugin(self.CORE_PLUGIN) self.context = context.get_admin_context() - self.db_objs = list(self.get_random_fields() for _ in range(3)) + self.db_objs = [ + self._test_class.db_model(**self.get_random_fields()) + for _ in range(3) + ] self.db_obj = self.db_objs[0] + # TODO(ihrachys) remove obj_fields since they duplicate self.objs self.obj_fields = [self._test_class.modify_fields_from_db(db_obj) for db_obj in self.db_objs] + self.objs = [ + self._test_class(self.context, **fields) + for fields in self.obj_fields + ] valid_field = [f for f in self._test_class.fields if f not in self._test_class.synthetic_fields][0] @@ -447,8 +454,10 @@ class _BaseObjectTestCase(object): synthetic_obj_fields = self.get_random_fields(FakeSmallNeutronObject) self.model_map = { self._test_class.db_model: self.db_objs, - ObjectFieldsModel: [synthetic_obj_fields]} + ObjectFieldsModel: [ObjectFieldsModel(**synthetic_obj_fields)]} + # TODO(ihrachys): rename the method to explicitly reflect it returns db + # attributes not object fields @classmethod def get_random_fields(cls, obj_cls=None): obj_cls = obj_cls or cls._test_class @@ -504,6 +513,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.model_map[self._test_class.db_model] = self.db_objs self.pager_map = collections.defaultdict(lambda: None) + # TODO(ihrachys) document the intent of all common test cases in docstrings def test_get_object(self): with mock.patch.object(obj_db_api, 'get_object', return_value=self.db_obj) as get_object_mock: @@ -512,7 +522,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj_keys = self.generate_object_keys(self._test_class) obj = self._test_class.get_object(self.context, **obj_keys) self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.obj_fields[0], get_obj_db_fields(obj)) + self._check_equal(obj, self.objs[0]) get_object_mock.assert_called_once_with( self.context, self._test_class.db_model, **self._test_class.modify_fields_to_db(obj_keys)) @@ -550,8 +560,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj = self._test_class.get_object(self.context, **obj_keys) self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.obj_fields[0], - get_obj_db_fields(obj)) + self._check_equal(obj, self.objs[0]) get_object_mock.assert_called_once_with( self.context, self._test_class.db_model, **self._test_class.modify_fields_to_db(obj_keys)) @@ -574,37 +583,25 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): return mock_calls def test_get_objects(self): + '''Test that get_objects fetches data from database.''' with mock.patch.object( obj_db_api, 'get_objects', side_effect=self.fake_get_objects) as get_objects_mock: objs = self._test_class.get_objects(self.context) - self._validate_objects(self.db_objs, objs) - mock_calls = [ - mock.call(self.context, self._test_class.db_model, - _pager=self.pager_map[self._test_class.obj_name()]) - ] - mock_calls.extend(self._get_synthetic_fields_get_objects_calls( - self.db_objs)) - get_objects_mock.assert_has_calls(mock_calls) + self.assertItemsEqual( + [get_obj_db_fields(obj) for obj in self.objs], + [get_obj_db_fields(obj) for obj in objs]) + get_objects_mock.assert_any_call( + self.context, self._test_class.db_model, + _pager=self.pager_map[self._test_class.obj_name()] + ) def test_get_objects_valid_fields(self): + '''Test that a valid filter does not raise an error.''' with mock.patch.object( - obj_db_api, 'get_objects', - side_effect=self.fake_get_objects) as get_objects_mock: - - objs = self._test_class.get_objects(self.context, - **self.valid_field_filter) - self._validate_objects(self.db_objs, objs) - mock_calls = [ - mock.call( - self.context, self._test_class.db_model, - _pager=self.pager_map[self._test_class.obj_name()], - **self._test_class.modify_fields_to_db(self.valid_field_filter) - ) - ] - mock_calls.extend(self._get_synthetic_fields_get_objects_calls( - [self.db_obj])) - get_objects_mock.assert_has_calls(mock_calls) + obj_db_api, 'get_objects', side_effect=self.fake_get_objects): + self._test_class.get_objects(self.context, + **self.valid_field_filter) def test_get_objects_mixed_fields(self): synthetic_fields = ( @@ -661,19 +658,11 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self._test_class.count, self.context, fake_field='xxx') - def _validate_objects(self, expected, observed): - self.assertTrue(all(self._is_test_class(obj) for obj in observed)) - self.assertEqual( - sorted([self._test_class.modify_fields_from_db(db_obj) - for db_obj in expected], - key=common_utils.safe_sort_key), - sorted([get_obj_db_fields(obj) for obj in observed], - key=common_utils.safe_sort_key)) - - def _check_equal(self, obj, db_obj): - self.assertEqual( - sorted(db_obj), - sorted(get_obj_db_fields(obj))) + # TODO(ihrachys) swap the order of arguments to reflect the order of + # self.assert* methods + def _check_equal(self, observed, expected): + self.assertItemsEqual(get_obj_db_fields(expected), + get_obj_db_fields(observed)) def test_create(self): with mock.patch.object(obj_db_api, 'create_object', @@ -681,21 +670,21 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): with mock.patch.object(obj_db_api, 'get_objects', side_effect=self.fake_get_objects): obj = self._test_class(self.context, **self.obj_fields[0]) - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) obj.create() - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) create_mock.assert_called_once_with( - self.context, self._test_class.db_model, self.db_obj) + self.context, self._test_class.db_model, + self._test_class.modify_fields_to_db( + get_obj_db_fields(self.objs[0]))) def test_create_updates_from_db_object(self): with mock.patch.object(obj_db_api, 'create_object', return_value=self.db_obj): with mock.patch.object(obj_db_api, 'get_objects', side_effect=self.fake_get_objects): - obj = self._test_class(self.context, **self.obj_fields[1]) - self._check_equal(obj, self.obj_fields[1]) - obj.create() - self._check_equal(obj, self.obj_fields[0]) + self.objs[1].create() + self._check_equal(self.objs[1], self.objs[0]) def test_create_duplicates(self): with mock.patch.object(obj_db_api, 'create_object', @@ -772,7 +761,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): side_effect=self.fake_get_objects): obj = self._test_class(self.context, **self.obj_fields[0]) # get new values and fix keys - update_mock.return_value = self.db_objs[1].copy() + update_mock.return_value = self.db_objs[1] fixed_keys = self._test_class.modify_fields_to_db( obj._get_composite_keys()) for key, value in fixed_keys.items(): @@ -813,14 +802,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj_db_api, 'get_objects', side_effect=self.fake_get_objects): obj.update() - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) @mock.patch.object(obj_db_api, 'delete_object') def test_delete(self, delete_mock): obj = self._test_class(self.context, **self.obj_fields[0]) - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) obj.delete() - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) delete_mock.assert_called_once_with( self.context, self._test_class.db_model, **self._test_class.modify_fields_to_db(obj._get_composite_keys())) @@ -1024,7 +1013,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, continue for db_obj in self.db_objs: objclass_fields = self.get_random_fields(objclass) - db_obj[synth_field] = [objclass_fields] + db_obj[synth_field] = [objclass.db_model(**objclass_fields)] def _create_test_network(self): # TODO(ihrachys): replace with network.create() once we get an object @@ -1208,10 +1197,13 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, obj = self._make_object(self.obj_fields[0]) obj.create() - for field in remove_timestamps_from_fields(self.obj_fields[0]): - filters = {field: [self.obj_fields[0][field]]} + for field in remove_timestamps_from_fields(get_obj_db_fields(obj)): + filters = {field: [self.objs[0][field]]} new = self._test_class.get_objects(self.context, **filters) - self.assertEqual([obj], new, 'Filtering by %s failed.' % field) + self.assertItemsEqual( + [obj._get_composite_keys()], + [obj_._get_composite_keys() for obj_ in new], + 'Filtering by %s failed.' % field) def _get_non_synth_fields(self, objclass, db_attrs): fields = objclass.modify_fields_from_db(db_attrs) diff --git a/neutron/tests/unit/objects/test_rbac_db.py b/neutron/tests/unit/objects/test_rbac_db.py index 2ee70e2696b..facc02d2f8f 100644 --- a/neutron/tests/unit/objects/test_rbac_db.py +++ b/neutron/tests/unit/objects/test_rbac_db.py @@ -29,9 +29,8 @@ from neutron.tests.unit.objects import test_base from neutron.tests.unit import testlib_api -class FakeDbModel(object): - def __init__(self, *args, **kwargs): - pass +class FakeDbModel(dict): + pass class FakeRbacModel(rbac_db_models.RBACColumns, model_base.BASEV2): diff --git a/neutron/tests/unit/objects/test_securitygroup.py b/neutron/tests/unit/objects/test_securitygroup.py index be44e834344..b78bdf978b2 100644 --- a/neutron/tests/unit/objects/test_securitygroup.py +++ b/neutron/tests/unit/objects/test_securitygroup.py @@ -118,7 +118,7 @@ class DefaultSecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase, self.sg_obj = securitygroup.SecurityGroup( self.context, **test_base.remove_timestamps_from_fields(sg_fields)) self.sg_obj.create() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['security_group_id'] = self.sg_obj['id'] @@ -140,6 +140,6 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase, self.sg_obj = securitygroup.SecurityGroup( self.context, **test_base.remove_timestamps_from_fields(sg_fields)) self.sg_obj.create() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['security_group_id'] = self.sg_obj['id'] obj['remote_group_id'] = self.sg_obj['id'] diff --git a/neutron/tests/unit/objects/test_subnet.py b/neutron/tests/unit/objects/test_subnet.py index 76f8caf8906..0706c4512c0 100644 --- a/neutron/tests/unit/objects/test_subnet.py +++ b/neutron/tests/unit/objects/test_subnet.py @@ -39,7 +39,7 @@ class IPAllocationPoolDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(IPAllocationPoolDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_subnet(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnet_id'] = self._subnet['id'] @@ -69,7 +69,7 @@ class DNSNameServerDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, for db_obj in self.db_objs] self._create_test_network() self._create_test_subnet(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnet_id'] = self._subnet['id'] def _is_objects_unique(self): @@ -128,7 +128,7 @@ class RouteDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(RouteDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_subnet(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnet_id'] = self._subnet['id'] @@ -151,7 +151,7 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(SubnetDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_segment(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['network_id'] = self._network['id'] obj['segment_id'] = self._segment['id'] diff --git a/neutron/tests/unit/objects/test_subnetpool.py b/neutron/tests/unit/objects/test_subnetpool.py index 63d00866066..3bf8036e4d3 100644 --- a/neutron/tests/unit/objects/test_subnetpool.py +++ b/neutron/tests/unit/objects/test_subnetpool.py @@ -87,5 +87,5 @@ class SubnetPoolPrefixDbObjectTestCase( def setUp(self): super(SubnetPoolPrefixDbObjectTestCase, self).setUp() self._create_test_subnetpool() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnetpool_id'] = self._pool.id