diff --git a/neutron/tests/unit/objects/network/extensions/test_port_security.py b/neutron/tests/unit/objects/network/extensions/test_port_security.py index cee5d5ce673..5c24e80d6c7 100644 --- a/neutron/tests/unit/objects/network/extensions/test_port_security.py +++ b/neutron/tests/unit/objects/network/extensions/test_port_security.py @@ -30,7 +30,9 @@ class NetworkPortSecurityDbObjTestCase(obj_test_base.BaseDbObjectTestCase, def setUp(self): super(NetworkPortSecurityDbObjTestCase, self).setUp() - for db_obj, obj_field in zip(self.db_objs, self.obj_fields): + for db_obj, obj_field, obj in zip( + self.db_objs, self.obj_fields, self.objs): network = self._create_network() db_obj['network_id'] = network['id'] obj_field['id'] = network['id'] + obj['id'] = network['id'] diff --git a/neutron/tests/unit/objects/network/test_network_segment.py b/neutron/tests/unit/objects/network/test_network_segment.py index 1008af9572f..67772dde356 100644 --- a/neutron/tests/unit/objects/network/test_network_segment.py +++ b/neutron/tests/unit/objects/network/test_network_segment.py @@ -32,5 +32,5 @@ class NetworkSegmentDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, def setUp(self): super(NetworkSegmentDbObjectTestCase, self).setUp() self._create_test_network() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['network_id'] = self._network['id'] diff --git a/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py b/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py index fbd23408acb..50b62127ba2 100644 --- a/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py +++ b/neutron/tests/unit/objects/port/extensions/test_allowedaddresspairs.py @@ -34,5 +34,5 @@ class AllowedAddrPairsDbObjTestCase(obj_test_base.BaseDbObjectTestCase, self.context = context.get_admin_context() self._create_test_network() self._create_test_port(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['port_id'] = self._port['id'] diff --git a/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py b/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py index adb9f737121..d89152251f9 100644 --- a/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py +++ b/neutron/tests/unit/objects/port/extensions/test_extra_dhcp_opt.py @@ -10,6 +10,8 @@ # License for the specific language governing permissions and limitations # under the License. +import itertools + from neutron.objects.port.extensions import extra_dhcp_opt from neutron.tests.unit.objects import test_base as obj_test_base from neutron.tests.unit import testlib_api @@ -29,7 +31,5 @@ class ExtraDhcpOptDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(ExtraDhcpOptDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_port(self._network) - for obj in self.db_objs: - obj['port_id'] = self._port['id'] - for obj in self.obj_fields: + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['port_id'] = self._port['id'] diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index 4721229f033..1e9eae8b6a9 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -82,7 +82,9 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): context_mock.assert_called_once_with() self.get_objects.assert_any_call( admin_context, self._test_class.db_model, _pager=None) - self._validate_objects(self.db_objs, objs) + self.assertItemsEqual( + [test_base.get_obj_db_fields(obj) for obj in self.objs], + [test_base.get_obj_db_fields(obj) for obj in objs]) def test_get_objects_valid_fields(self): admin_context = self.context.elevated() @@ -103,7 +105,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): get_objects_mock.assert_any_call( admin_context, self._test_class.db_model, _pager=None, **self.valid_field_filter) - self._validate_objects([self.db_obj], objs) + self._check_equal(objs[0], self.objs[0]) def test_get_object(self): admin_context = self.context.elevated() @@ -114,7 +116,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): return_value=admin_context) as context_mock: obj = self._test_class.get_object(self.context, id='fake_id') self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj)) + self._check_equal(obj, self.objs[0]) context_mock.assert_called_once_with() get_object_mock.assert_called_once_with( admin_context, self._test_class.db_model, id='fake_id') @@ -146,9 +148,8 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase, self._create_test_port(self._network) def _create_test_policy(self): - policy_obj = policy.QosPolicy(self.context, **self.db_obj) - policy_obj.create() - return policy_obj + self.objs[0].create() + return self.objs[0] def _create_test_policy_with_rules(self, rule_type, reload_rules=False): policy_obj = self._create_test_policy() diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 973157066df..306cfea76f7 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -27,7 +27,6 @@ from oslo_versionedobjects import fixture import testtools from neutron.common import constants -from neutron.common import utils as common_utils from neutron import context from neutron.db import db_base_plugin_v2 from neutron.db import model_base @@ -48,14 +47,12 @@ OBJECTS_BASE_OBJ_FROM_PRIMITIVE = ('oslo_versionedobjects.base.' TIMESTAMP_FIELDS = ['created_at', 'updated_at', 'revision_number'] -class FakeModel(object): - def __init__(self, *args, **kwargs): - pass +class FakeModel(dict): + pass -class ObjectFieldsModel(object): - def __init__(self, *args, **kwargs): - pass +class ObjectFieldsModel(dict): + pass @obj_base.VersionedObjectRegistry.register_if(False) @@ -396,9 +393,11 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = { } +# TODO(ihrachys) consider renaming into e.g. get_obj_persistent_fields def get_obj_db_fields(obj): return {field: getattr(obj, field) for field in obj.fields - if field not in obj.synthetic_fields} + if field not in obj.synthetic_fields + if field in obj} def get_value(generator, version): @@ -429,11 +428,19 @@ class _BaseObjectTestCase(object): # neutron.objects.db.api from core plugin instance self.setup_coreplugin(self.CORE_PLUGIN) self.context = context.get_admin_context() - self.db_objs = list(self.get_random_fields() for _ in range(3)) + self.db_objs = [ + self._test_class.db_model(**self.get_random_fields()) + for _ in range(3) + ] self.db_obj = self.db_objs[0] + # TODO(ihrachys) remove obj_fields since they duplicate self.objs self.obj_fields = [self._test_class.modify_fields_from_db(db_obj) for db_obj in self.db_objs] + self.objs = [ + self._test_class(self.context, **fields) + for fields in self.obj_fields + ] valid_field = [f for f in self._test_class.fields if f not in self._test_class.synthetic_fields][0] @@ -447,8 +454,10 @@ class _BaseObjectTestCase(object): synthetic_obj_fields = self.get_random_fields(FakeSmallNeutronObject) self.model_map = { self._test_class.db_model: self.db_objs, - ObjectFieldsModel: [synthetic_obj_fields]} + ObjectFieldsModel: [ObjectFieldsModel(**synthetic_obj_fields)]} + # TODO(ihrachys): rename the method to explicitly reflect it returns db + # attributes not object fields @classmethod def get_random_fields(cls, obj_cls=None): obj_cls = obj_cls or cls._test_class @@ -504,6 +513,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self.model_map[self._test_class.db_model] = self.db_objs self.pager_map = collections.defaultdict(lambda: None) + # TODO(ihrachys) document the intent of all common test cases in docstrings def test_get_object(self): with mock.patch.object(obj_db_api, 'get_object', return_value=self.db_obj) as get_object_mock: @@ -512,7 +522,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj_keys = self.generate_object_keys(self._test_class) obj = self._test_class.get_object(self.context, **obj_keys) self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.obj_fields[0], get_obj_db_fields(obj)) + self._check_equal(obj, self.objs[0]) get_object_mock.assert_called_once_with( self.context, self._test_class.db_model, **self._test_class.modify_fields_to_db(obj_keys)) @@ -550,8 +560,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj = self._test_class.get_object(self.context, **obj_keys) self.assertTrue(self._is_test_class(obj)) - self.assertEqual(self.obj_fields[0], - get_obj_db_fields(obj)) + self._check_equal(obj, self.objs[0]) get_object_mock.assert_called_once_with( self.context, self._test_class.db_model, **self._test_class.modify_fields_to_db(obj_keys)) @@ -574,37 +583,25 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): return mock_calls def test_get_objects(self): + '''Test that get_objects fetches data from database.''' with mock.patch.object( obj_db_api, 'get_objects', side_effect=self.fake_get_objects) as get_objects_mock: objs = self._test_class.get_objects(self.context) - self._validate_objects(self.db_objs, objs) - mock_calls = [ - mock.call(self.context, self._test_class.db_model, - _pager=self.pager_map[self._test_class.obj_name()]) - ] - mock_calls.extend(self._get_synthetic_fields_get_objects_calls( - self.db_objs)) - get_objects_mock.assert_has_calls(mock_calls) + self.assertItemsEqual( + [get_obj_db_fields(obj) for obj in self.objs], + [get_obj_db_fields(obj) for obj in objs]) + get_objects_mock.assert_any_call( + self.context, self._test_class.db_model, + _pager=self.pager_map[self._test_class.obj_name()] + ) def test_get_objects_valid_fields(self): + '''Test that a valid filter does not raise an error.''' with mock.patch.object( - obj_db_api, 'get_objects', - side_effect=self.fake_get_objects) as get_objects_mock: - - objs = self._test_class.get_objects(self.context, - **self.valid_field_filter) - self._validate_objects(self.db_objs, objs) - mock_calls = [ - mock.call( - self.context, self._test_class.db_model, - _pager=self.pager_map[self._test_class.obj_name()], - **self._test_class.modify_fields_to_db(self.valid_field_filter) - ) - ] - mock_calls.extend(self._get_synthetic_fields_get_objects_calls( - [self.db_obj])) - get_objects_mock.assert_has_calls(mock_calls) + obj_db_api, 'get_objects', side_effect=self.fake_get_objects): + self._test_class.get_objects(self.context, + **self.valid_field_filter) def test_get_objects_mixed_fields(self): synthetic_fields = ( @@ -661,19 +658,11 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): self._test_class.count, self.context, fake_field='xxx') - def _validate_objects(self, expected, observed): - self.assertTrue(all(self._is_test_class(obj) for obj in observed)) - self.assertEqual( - sorted([self._test_class.modify_fields_from_db(db_obj) - for db_obj in expected], - key=common_utils.safe_sort_key), - sorted([get_obj_db_fields(obj) for obj in observed], - key=common_utils.safe_sort_key)) - - def _check_equal(self, obj, db_obj): - self.assertEqual( - sorted(db_obj), - sorted(get_obj_db_fields(obj))) + # TODO(ihrachys) swap the order of arguments to reflect the order of + # self.assert* methods + def _check_equal(self, observed, expected): + self.assertItemsEqual(get_obj_db_fields(expected), + get_obj_db_fields(observed)) def test_create(self): with mock.patch.object(obj_db_api, 'create_object', @@ -681,21 +670,21 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): with mock.patch.object(obj_db_api, 'get_objects', side_effect=self.fake_get_objects): obj = self._test_class(self.context, **self.obj_fields[0]) - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) obj.create() - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) create_mock.assert_called_once_with( - self.context, self._test_class.db_model, self.db_obj) + self.context, self._test_class.db_model, + self._test_class.modify_fields_to_db( + get_obj_db_fields(self.objs[0]))) def test_create_updates_from_db_object(self): with mock.patch.object(obj_db_api, 'create_object', return_value=self.db_obj): with mock.patch.object(obj_db_api, 'get_objects', side_effect=self.fake_get_objects): - obj = self._test_class(self.context, **self.obj_fields[1]) - self._check_equal(obj, self.obj_fields[1]) - obj.create() - self._check_equal(obj, self.obj_fields[0]) + self.objs[1].create() + self._check_equal(self.objs[1], self.objs[0]) def test_create_duplicates(self): with mock.patch.object(obj_db_api, 'create_object', @@ -781,7 +770,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): side_effect=self.fake_get_objects): obj = self._test_class(self.context, **self.obj_fields[0]) # get new values and fix keys - update_mock.return_value = self.db_objs[1].copy() + update_mock.return_value = self.db_objs[1] fixed_keys = self._test_class.modify_fields_to_db( obj._get_composite_keys()) for key, value in fixed_keys.items(): @@ -822,14 +811,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): obj_db_api, 'get_objects', side_effect=self.fake_get_objects): obj.update() - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) @mock.patch.object(obj_db_api, 'delete_object') def test_delete(self, delete_mock): obj = self._test_class(self.context, **self.obj_fields[0]) - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) obj.delete() - self._check_equal(obj, self.obj_fields[0]) + self._check_equal(obj, self.objs[0]) delete_mock.assert_called_once_with( self.context, self._test_class.db_model, **self._test_class.modify_fields_to_db(obj._get_composite_keys())) @@ -1033,7 +1022,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, continue for db_obj in self.db_objs: objclass_fields = self.get_random_fields(objclass) - db_obj[synth_field] = [objclass_fields] + db_obj[synth_field] = [objclass.db_model(**objclass_fields)] def _create_test_network(self): # TODO(ihrachys): replace with network.create() once we get an object @@ -1217,10 +1206,13 @@ class BaseDbObjectTestCase(_BaseObjectTestCase, obj = self._make_object(self.obj_fields[0]) obj.create() - for field in remove_timestamps_from_fields(self.obj_fields[0]): - filters = {field: [self.obj_fields[0][field]]} + for field in remove_timestamps_from_fields(get_obj_db_fields(obj)): + filters = {field: [self.objs[0][field]]} new = self._test_class.get_objects(self.context, **filters) - self.assertEqual([obj], new, 'Filtering by %s failed.' % field) + self.assertItemsEqual( + [obj._get_composite_keys()], + [obj_._get_composite_keys() for obj_ in new], + 'Filtering by %s failed.' % field) def _get_non_synth_fields(self, objclass, db_attrs): fields = objclass.modify_fields_from_db(db_attrs) diff --git a/neutron/tests/unit/objects/test_rbac_db.py b/neutron/tests/unit/objects/test_rbac_db.py index 2ee70e2696b..facc02d2f8f 100644 --- a/neutron/tests/unit/objects/test_rbac_db.py +++ b/neutron/tests/unit/objects/test_rbac_db.py @@ -29,9 +29,8 @@ from neutron.tests.unit.objects import test_base from neutron.tests.unit import testlib_api -class FakeDbModel(object): - def __init__(self, *args, **kwargs): - pass +class FakeDbModel(dict): + pass class FakeRbacModel(rbac_db_models.RBACColumns, model_base.BASEV2): diff --git a/neutron/tests/unit/objects/test_securitygroup.py b/neutron/tests/unit/objects/test_securitygroup.py index be44e834344..b78bdf978b2 100644 --- a/neutron/tests/unit/objects/test_securitygroup.py +++ b/neutron/tests/unit/objects/test_securitygroup.py @@ -118,7 +118,7 @@ class DefaultSecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase, self.sg_obj = securitygroup.SecurityGroup( self.context, **test_base.remove_timestamps_from_fields(sg_fields)) self.sg_obj.create() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['security_group_id'] = self.sg_obj['id'] @@ -140,6 +140,6 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase, self.sg_obj = securitygroup.SecurityGroup( self.context, **test_base.remove_timestamps_from_fields(sg_fields)) self.sg_obj.create() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['security_group_id'] = self.sg_obj['id'] obj['remote_group_id'] = self.sg_obj['id'] diff --git a/neutron/tests/unit/objects/test_subnet.py b/neutron/tests/unit/objects/test_subnet.py index 76f8caf8906..0706c4512c0 100644 --- a/neutron/tests/unit/objects/test_subnet.py +++ b/neutron/tests/unit/objects/test_subnet.py @@ -39,7 +39,7 @@ class IPAllocationPoolDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(IPAllocationPoolDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_subnet(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnet_id'] = self._subnet['id'] @@ -69,7 +69,7 @@ class DNSNameServerDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, for db_obj in self.db_objs] self._create_test_network() self._create_test_subnet(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnet_id'] = self._subnet['id'] def _is_objects_unique(self): @@ -128,7 +128,7 @@ class RouteDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(RouteDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_subnet(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnet_id'] = self._subnet['id'] @@ -151,7 +151,7 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase, super(SubnetDbObjectTestCase, self).setUp() self._create_test_network() self._create_test_segment(self._network) - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['network_id'] = self._network['id'] obj['segment_id'] = self._segment['id'] diff --git a/neutron/tests/unit/objects/test_subnetpool.py b/neutron/tests/unit/objects/test_subnetpool.py index 63d00866066..3bf8036e4d3 100644 --- a/neutron/tests/unit/objects/test_subnetpool.py +++ b/neutron/tests/unit/objects/test_subnetpool.py @@ -87,5 +87,5 @@ class SubnetPoolPrefixDbObjectTestCase( def setUp(self): super(SubnetPoolPrefixDbObjectTestCase, self).setUp() self._create_test_subnetpool() - for obj in itertools.chain(self.db_objs, self.obj_fields): + for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs): obj['subnetpool_id'] = self._pool.id