objects: support advanced criteria for get_objects

Those are needed to accommodate to API request needs without handling
sorting or pagination in Python.

Instead of adding four new arguments to get_objects interface, they are
consolidated using a single Pager object that is passed through the
_pager argument. The name uses underscore to avoid breaking objects that
want to pass 'pager' filter into SQLAlchemy. Hopefully, no objects will
ever have '_pager' attribute in their models.

Related-Bug: #1541928
Change-Id: I7dafc4dbd80f0ac35dbc2c2f30e56e441f5b1fc0
This commit is contained in:
Ihar Hrachyshka 2016-05-18 19:19:48 +02:00
parent 067a5c2a47
commit 5decc850f1
7 changed files with 132 additions and 16 deletions

View File

@ -69,6 +69,29 @@ def get_updatable_fields(cls, fields):
return fields
class Pager(object):
'''
This class represents a pager object. It is consumed by get_objects to
specify sorting and pagination criteria.
'''
def __init__(self, sorts=None, limit=None, page_reverse=None, marker=None):
self.sorts = sorts
self.limit = limit
self.page_reverse = page_reverse
self.marker = marker
def to_kwargs(self, context, model):
res = {
attr: getattr(self, attr)
for attr in ('sorts', 'limit', 'page_reverse')
if getattr(self, attr) is not None
}
if self.marker and self.limit:
res['marker_obj'] = obj_db_api.get_object(
context, model, id=self.marker)
return res
@six.add_metaclass(abc.ABCMeta)
class NeutronObject(obj_base.VersionedObject,
obj_base.VersionedObjectDictCompat,
@ -126,7 +149,7 @@ class NeutronObject(obj_base.VersionedObject,
@classmethod
@abc.abstractmethod
def get_objects(cls, context, **kwargs):
def get_objects(cls, context, _pager=None, **kwargs):
raise NotImplementedError()
def create(self):
@ -241,11 +264,10 @@ class NeutronDbObject(NeutronObject):
@classmethod
def get_object(cls, context, **kwargs):
"""
This method fetches object from DB and convert it to versioned
object.
Fetch object from DB and convert it to a versioned object.
:param context:
:param kwargs: multiple primary keys defined key=value pairs
:param kwargs: multiple keys defined by key=value pairs
:return: single object of NeutronDbObject class
"""
missing_keys = set(cls.primary_keys).difference(kwargs.keys())
@ -259,10 +281,20 @@ class NeutronDbObject(NeutronObject):
return cls._load_object(context, db_obj)
@classmethod
def get_objects(cls, context, **kwargs):
def get_objects(cls, context, _pager=None, **kwargs):
"""
Fetch objects from DB and convert them to versioned objects.
:param context:
:param _pager: a Pager object representing advanced sorting/pagination
criteria
:param kwargs: multiple keys defined by key=value pairs
:return: list of objects of NeutronDbObject class
"""
cls.validate_filters(**kwargs)
with db_api.autonested_transaction(context.session):
db_objs = obj_db_api.get_objects(context, cls.db_model, **kwargs)
db_objs = obj_db_api.get_objects(
context, cls.db_model, _pager=_pager, **kwargs)
return [cls._load_object(context, db_obj) for db_obj in db_objs]
@classmethod

View File

@ -14,6 +14,7 @@ from neutron_lib import exceptions as n_exc
from oslo_utils import uuidutils
from neutron.db import common_db_mixin
from neutron import manager
# Common database operation implementations
@ -24,11 +25,22 @@ def get_object(context, model, **kwargs):
.first())
def get_objects(context, model, **kwargs):
def _kwargs_to_filters(**kwargs):
return {k: [v]
for k, v in kwargs.items()}
def get_objects(context, model, _pager=None, **kwargs):
with context.session.begin(subtransactions=True):
return (common_db_mixin.model_query(context, model)
.filter_by(**kwargs)
.all())
filters = _kwargs_to_filters(**kwargs)
# TODO(ihrachys): decompose _get_collection from plugin instance
plugin = manager.NeutronManager.get_plugin()
return plugin._get_collection(
context, model,
# TODO(ihrachys): avoid this no-op call per model found
lambda obj, fields: obj,
filters=filters,
**(_pager.to_kwargs(context, model) if _pager else {}))
def create_object(context, model, values):

View File

@ -0,0 +1,49 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from neutron import context
from neutron import manager
from neutron.objects import base
from neutron.objects.db import api
from neutron.tests import base as test_base
PLUGIN_NAME = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2'
class GetObjectsTestCase(test_base.BaseTestCase):
def setUp(self):
super(GetObjectsTestCase, self).setUp()
# TODO(ihrachys): revisit plugin setup once we decouple
# objects.db.objects.api from core plugin instance
self.setup_coreplugin(PLUGIN_NAME)
def test_get_objects_pass_marker_obj_when_limit_and_marker_passed(self):
ctxt = context.get_admin_context()
model = mock.sentinel.model
marker = mock.sentinel.marker
limit = mock.sentinel.limit
pager = base.Pager(marker=marker, limit=limit)
plugin = manager.NeutronManager.get_plugin()
with mock.patch.object(plugin, '_get_collection') as get_collection:
with mock.patch.object(api, 'get_object') as get_object:
api.get_objects(ctxt, model, _pager=pager)
get_object.assert_called_with(ctxt, model, id=marker)
get_collection.assert_called_with(
ctxt, model, mock.ANY,
filters={},
limit=limit,
marker_obj=get_object.return_value)

View File

@ -65,7 +65,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
objs = self._test_class.get_objects(self.context)
context_mock.assert_called_once_with()
self.get_objects.assert_any_call(
admin_context, self._test_class.db_model)
admin_context, self._test_class.db_model, _pager=None)
self._validate_objects(self.db_objs, objs)
def test_get_objects_valid_fields(self):
@ -85,7 +85,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
**self.valid_field_filter)
context_mock.assert_called_once_with()
get_objects_mock.assert_any_call(
admin_context, self._test_class.db_model,
admin_context, self._test_class.db_model, _pager=None,
**self.valid_field_filter)
self._validate_objects([self.db_obj], objs)

View File

@ -364,7 +364,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
field].objname)[0]
mock_calls.append(
mock.call(
self.context, obj_class.db_model,
self.context, obj_class.db_model, _pager=None,
**{k: db_obj[v]
for k, v in obj_class.foreign_keys.items()}))
return mock_calls
@ -375,7 +375,9 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
side_effect=self.fake_get_objects) as get_objects_mock:
objs = self._test_class.get_objects(self.context)
self._validate_objects(self.db_objs, objs)
mock_calls = [mock.call(self.context, self._test_class.db_model)]
mock_calls = [
mock.call(self.context, self._test_class.db_model, _pager=None)
]
mock_calls.extend(self._get_synthetic_fields_get_objects_calls(
self.db_objs))
get_objects_mock.assert_has_calls(mock_calls)
@ -389,8 +391,10 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
**self.valid_field_filter)
self._validate_objects(self.db_objs, objs)
mock_calls = [mock.call(self.context, self._test_class.db_model,
**self.valid_field_filter)]
mock_calls = [
mock.call(self.context, self._test_class.db_model, _pager=None,
**self.valid_field_filter)
]
mock_calls.extend(self._get_synthetic_fields_get_objects_calls(
[self.db_obj]))
get_objects_mock.assert_has_calls(mock_calls)
@ -601,6 +605,13 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
dict_ = obj.to_dict()
self.assertEqual(child_dict, dict_[field])
def test_get_objects_pager_is_passed_through(self):
with mock.patch.object(obj_db_api, 'get_objects') as get_objects:
pager = base.Pager()
self._test_class.get_objects(self.context, _pager=pager)
get_objects.assert_called_once_with(
mock.ANY, self._test_class.db_model, _pager=pager)
class BaseDbObjectNonStandardPrimaryKeyTestCase(BaseObjectIfaceTestCase):
@ -635,6 +646,14 @@ class BaseDbObjectMultipleForeignKeysTestCase(_BaseObjectTestCase,
class BaseDbObjectTestCase(_BaseObjectTestCase):
CORE_PLUGIN = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2'
def setUp(self):
super(BaseDbObjectTestCase, self).setUp()
# TODO(ihrachys): revisit plugin setup once we decouple
# neutron.objects.db.api from core plugin instance
self.setup_coreplugin(self.CORE_PLUGIN)
def _create_test_network(self):
# TODO(ihrachys): replace with network.create() once we get an object
# implementation for networks

View File

@ -2031,8 +2031,12 @@ class TestML2PluggableIPAM(test_ipam.UseIpamMixin, TestMl2SubnetsV2):
class TestMl2PluginCreateUpdateDeletePort(base.BaseTestCase):
def setUp(self):
super(TestMl2PluginCreateUpdateDeletePort, self).setUp()
# TODO(ihrachys): revisit plugin setup once we decouple
# neutron.objects.db.api from core plugin instance
self.setup_coreplugin(PLUGIN_NAME)
self.context = mock.MagicMock()
self.notify_p = mock.patch('neutron.callbacks.registry.notify')
self.notify = self.notify_p.start()