diff --git a/neutron/objects/base.py b/neutron/objects/base.py index 0c56d0ff817..dafbfdfaf94 100644 --- a/neutron/objects/base.py +++ b/neutron/objects/base.py @@ -69,6 +69,29 @@ def get_updatable_fields(cls, fields): return fields +class Pager(object): + ''' + This class represents a pager object. It is consumed by get_objects to + specify sorting and pagination criteria. + ''' + def __init__(self, sorts=None, limit=None, page_reverse=None, marker=None): + self.sorts = sorts + self.limit = limit + self.page_reverse = page_reverse + self.marker = marker + + def to_kwargs(self, context, model): + res = { + attr: getattr(self, attr) + for attr in ('sorts', 'limit', 'page_reverse') + if getattr(self, attr) is not None + } + if self.marker and self.limit: + res['marker_obj'] = obj_db_api.get_object( + context, model, id=self.marker) + return res + + @six.add_metaclass(abc.ABCMeta) class NeutronObject(obj_base.VersionedObject, obj_base.VersionedObjectDictCompat, @@ -126,7 +149,7 @@ class NeutronObject(obj_base.VersionedObject, @classmethod @abc.abstractmethod - def get_objects(cls, context, **kwargs): + def get_objects(cls, context, _pager=None, **kwargs): raise NotImplementedError() def create(self): @@ -241,11 +264,10 @@ class NeutronDbObject(NeutronObject): @classmethod def get_object(cls, context, **kwargs): """ - This method fetches object from DB and convert it to versioned - object. + Fetch object from DB and convert it to a versioned object. :param context: - :param kwargs: multiple primary keys defined key=value pairs + :param kwargs: multiple keys defined by key=value pairs :return: single object of NeutronDbObject class """ missing_keys = set(cls.primary_keys).difference(kwargs.keys()) @@ -259,10 +281,20 @@ class NeutronDbObject(NeutronObject): return cls._load_object(context, db_obj) @classmethod - def get_objects(cls, context, **kwargs): + def get_objects(cls, context, _pager=None, **kwargs): + """ + Fetch objects from DB and convert them to versioned objects. + + :param context: + :param _pager: a Pager object representing advanced sorting/pagination + criteria + :param kwargs: multiple keys defined by key=value pairs + :return: list of objects of NeutronDbObject class + """ cls.validate_filters(**kwargs) with db_api.autonested_transaction(context.session): - db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs) + db_objs = obj_db_api.get_objects( + context, cls.db_model, _pager=_pager, **kwargs) return [cls._load_object(context, db_obj) for db_obj in db_objs] @classmethod diff --git a/neutron/objects/db/api.py b/neutron/objects/db/api.py index 670b285c083..8c37dec515b 100644 --- a/neutron/objects/db/api.py +++ b/neutron/objects/db/api.py @@ -14,6 +14,7 @@ from neutron_lib import exceptions as n_exc from oslo_utils import uuidutils from neutron.db import common_db_mixin +from neutron import manager # Common database operation implementations @@ -24,11 +25,22 @@ def get_object(context, model, **kwargs): .first()) -def get_objects(context, model, **kwargs): +def _kwargs_to_filters(**kwargs): + return {k: [v] + for k, v in kwargs.items()} + + +def get_objects(context, model, _pager=None, **kwargs): with context.session.begin(subtransactions=True): - return (common_db_mixin.model_query(context, model) - .filter_by(**kwargs) - .all()) + filters = _kwargs_to_filters(**kwargs) + # TODO(ihrachys): decompose _get_collection from plugin instance + plugin = manager.NeutronManager.get_plugin() + return plugin._get_collection( + context, model, + # TODO(ihrachys): avoid this no-op call per model found + lambda obj, fields: obj, + filters=filters, + **(_pager.to_kwargs(context, model) if _pager else {})) def create_object(context, model, values): diff --git a/neutron/tests/unit/objects/db/__init__.py b/neutron/tests/unit/objects/db/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/neutron/tests/unit/objects/db/test_api.py b/neutron/tests/unit/objects/db/test_api.py new file mode 100644 index 00000000000..2f967d62b59 --- /dev/null +++ b/neutron/tests/unit/objects/db/test_api.py @@ -0,0 +1,49 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from neutron import context +from neutron import manager +from neutron.objects import base +from neutron.objects.db import api +from neutron.tests import base as test_base + + +PLUGIN_NAME = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2' + + +class GetObjectsTestCase(test_base.BaseTestCase): + + def setUp(self): + super(GetObjectsTestCase, self).setUp() + # TODO(ihrachys): revisit plugin setup once we decouple + # objects.db.objects.api from core plugin instance + self.setup_coreplugin(PLUGIN_NAME) + + def test_get_objects_pass_marker_obj_when_limit_and_marker_passed(self): + ctxt = context.get_admin_context() + model = mock.sentinel.model + marker = mock.sentinel.marker + limit = mock.sentinel.limit + pager = base.Pager(marker=marker, limit=limit) + + plugin = manager.NeutronManager.get_plugin() + with mock.patch.object(plugin, '_get_collection') as get_collection: + with mock.patch.object(api, 'get_object') as get_object: + api.get_objects(ctxt, model, _pager=pager) + get_object.assert_called_with(ctxt, model, id=marker) + get_collection.assert_called_with( + ctxt, model, mock.ANY, + filters={}, + limit=limit, + marker_obj=get_object.return_value) diff --git a/neutron/tests/unit/objects/qos/test_policy.py b/neutron/tests/unit/objects/qos/test_policy.py index cb60901b8c5..290e9c9e0cc 100644 --- a/neutron/tests/unit/objects/qos/test_policy.py +++ b/neutron/tests/unit/objects/qos/test_policy.py @@ -65,7 +65,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): objs = self._test_class.get_objects(self.context) context_mock.assert_called_once_with() self.get_objects.assert_any_call( - admin_context, self._test_class.db_model) + admin_context, self._test_class.db_model, _pager=None) self._validate_objects(self.db_objs, objs) def test_get_objects_valid_fields(self): @@ -85,7 +85,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase): **self.valid_field_filter) context_mock.assert_called_once_with() get_objects_mock.assert_any_call( - admin_context, self._test_class.db_model, + admin_context, self._test_class.db_model, _pager=None, **self.valid_field_filter) self._validate_objects([self.db_obj], objs) diff --git a/neutron/tests/unit/objects/test_base.py b/neutron/tests/unit/objects/test_base.py index 9ed3a2727e4..5ffdaa448a9 100644 --- a/neutron/tests/unit/objects/test_base.py +++ b/neutron/tests/unit/objects/test_base.py @@ -364,7 +364,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): field].objname)[0] mock_calls.append( mock.call( - self.context, obj_class.db_model, + self.context, obj_class.db_model, _pager=None, **{k: db_obj[v] for k, v in obj_class.foreign_keys.items()})) return mock_calls @@ -375,7 +375,9 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): side_effect=self.fake_get_objects) as get_objects_mock: objs = self._test_class.get_objects(self.context) self._validate_objects(self.db_objs, objs) - mock_calls = [mock.call(self.context, self._test_class.db_model)] + mock_calls = [ + mock.call(self.context, self._test_class.db_model, _pager=None) + ] mock_calls.extend(self._get_synthetic_fields_get_objects_calls( self.db_objs)) get_objects_mock.assert_has_calls(mock_calls) @@ -389,8 +391,10 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): **self.valid_field_filter) self._validate_objects(self.db_objs, objs) - mock_calls = [mock.call(self.context, self._test_class.db_model, - **self.valid_field_filter)] + mock_calls = [ + mock.call(self.context, self._test_class.db_model, _pager=None, + **self.valid_field_filter) + ] mock_calls.extend(self._get_synthetic_fields_get_objects_calls( [self.db_obj])) get_objects_mock.assert_has_calls(mock_calls) @@ -601,6 +605,13 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase): dict_ = obj.to_dict() self.assertEqual(child_dict, dict_[field]) + def test_get_objects_pager_is_passed_through(self): + with mock.patch.object(obj_db_api, 'get_objects') as get_objects: + pager = base.Pager() + self._test_class.get_objects(self.context, _pager=pager) + get_objects.assert_called_once_with( + mock.ANY, self._test_class.db_model, _pager=pager) + class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase): @@ -635,6 +646,14 @@ class BaseDbObjectMultipleForeignKeysTestCase(_BaseObjectTestCase, class BaseDbObjectTestCase(_BaseObjectTestCase): + CORE_PLUGIN = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2' + + def setUp(self): + super(BaseDbObjectTestCase, self).setUp() + # TODO(ihrachys): revisit plugin setup once we decouple + # neutron.objects.db.api from core plugin instance + self.setup_coreplugin(self.CORE_PLUGIN) + def _create_test_network(self): # TODO(ihrachys): replace with network.create() once we get an object # implementation for networks diff --git a/neutron/tests/unit/plugins/ml2/test_plugin.py b/neutron/tests/unit/plugins/ml2/test_plugin.py index 64241ae44f4..60f9e3ea0b3 100644 --- a/neutron/tests/unit/plugins/ml2/test_plugin.py +++ b/neutron/tests/unit/plugins/ml2/test_plugin.py @@ -2031,8 +2031,12 @@ class TestML2PluggableIPAM(test_ipam.UseIpamMixin, TestMl2SubnetsV2): class TestMl2PluginCreateUpdateDeletePort(base.BaseTestCase): + def setUp(self): super(TestMl2PluginCreateUpdateDeletePort, self).setUp() + # TODO(ihrachys): revisit plugin setup once we decouple + # neutron.objects.db.api from core plugin instance + self.setup_coreplugin(PLUGIN_NAME) self.context = mock.MagicMock() self.notify_p = mock.patch('neutron.callbacks.registry.notify') self.notify = self.notify_p.start()