objects: introduce count() API to count matching objects

This is used in several places in current database code. The default
implementation relies on counting get_objects results, but for
NeutronDbObject based objects, we rely on .count API from SQLAlchemy.

Change-Id: I43fa371588ee5f0e5fd75cabc20f8f6fffed61bf
Partially-Implements: blueprint adopt-oslo-versioned-objects-for-db
This commit is contained in:
Ihar Hrachyshka 2016-08-03 09:22:05 +02:00
parent 17d85e4748
commit 0fce7de103
3 changed files with 86 additions and 5 deletions

View File

@ -202,6 +202,11 @@ class NeutronObject(obj_base.VersionedObject,
def delete(self):
raise NotImplementedError()
@classmethod
def count(cls, context, **kwargs):
'''Count the number of objects matching filtering criteria.'''
return len(cls.get_objects(context, **kwargs))
class DeclarativeObject(abc.ABCMeta):
@ -537,3 +542,17 @@ class NeutronDbObject(NeutronObject):
obj_db_api.delete_object(self.obj_context, self.db_model,
**self.modify_fields_to_db(
self._get_composite_keys()))
@classmethod
def count(cls, context, **kwargs):
"""
Count the number of objects matching filtering criteria.
:param context:
:param kwargs: multiple keys defined by key=value pairs
:return: number of matching objects
"""
cls.validate_filters(**kwargs)
return obj_db_api.count(
context, cls.db_model, **cls.modify_fields_to_db(kwargs)
)

View File

@ -17,13 +17,21 @@ from neutron import manager
# Common database operation implementations
def get_object(context, model, **kwargs):
def _get_filter_query(context, model, **kwargs):
# TODO(jlibosva): decompose _get_collection_query from plugin instance
plugin = manager.NeutronManager.get_plugin()
with context.session.begin(subtransactions=True):
filters = _kwargs_to_filters(**kwargs)
query = plugin._get_collection_query(context, model, filters)
return query.first()
return query
def get_object(context, model, **kwargs):
return _get_filter_query(context, model, **kwargs).first()
def count(context, model, **kwargs):
return _get_filter_query(context, model, **kwargs).count()
def _kwargs_to_filters(**kwargs):

View File

@ -93,7 +93,7 @@ class FakeWeirdKeySmallNeutronObject(base.NeutronDbObject):
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObject(base.NeutronDbObject):
class FakeNeutronDbObject(base.NeutronDbObject):
# Version 1.0: Initial version
VERSION = '1.0'
@ -263,6 +263,33 @@ class FakeNeutronObjectWithProjectId(base.NeutronDbObject):
}
@obj_base.VersionedObjectRegistry.register_if(False)
class FakeNeutronObject(base.NeutronObject):
# Version 1.0: Initial version
VERSION = '1.0'
fields = {
'id': obj_fields.UUIDField(),
'project_id': obj_fields.StringField(),
'field2': obj_fields.UUIDField(),
}
@classmethod
def get_object(cls, context, **kwargs):
if not hasattr(cls, '_obj'):
cls._obj = FakeNeutronObject(id=uuidutils.generate_uuid(),
project_id='fake-id',
field2=uuidutils.generate_uuid())
return cls._obj
@classmethod
def get_objects(cls, context, _pager=None, count=1, **kwargs):
return [
cls.get_object(context, **kwargs)
for i in range(count)
]
def get_random_dscp_mark():
return random.choice(constants.VALID_DSCP_MARKS)
@ -319,7 +346,7 @@ def get_non_synthetic_fields(objclass, obj_fields):
class _BaseObjectTestCase(object):
_test_class = FakeNeutronObject
_test_class = FakeNeutronDbObject
CORE_PLUGIN = 'neutron.db.db_base_plugin_v2.NeutronDbPluginV2'
@ -538,6 +565,19 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self._test_class.get_objects, self.context,
fake_field='xxx')
def test_count(self):
if not isinstance(self._test_class, base.NeutronDbObject):
self.skipTest('Class %s does not inherit from NeutronDbObject' %
self._test_class)
expected = 10
with mock.patch.object(obj_db_api, 'count', return_value=expected):
self.assertEqual(expected, self._test_class.count(self.context))
def test_count_invalid_fields(self):
self.assertRaises(base.exceptions.InvalidInput,
self._test_class.count, self.context,
fake_field='xxx')
def _validate_objects(self, expected, observed):
self.assertTrue(all(self._is_test_class(obj) for obj in observed))
self.assertEqual(
@ -805,6 +845,14 @@ class UniqueKeysTestCase(test_base.BaseTestCase):
self.assertEqual(expected, observed)
class NeutronObjectCountTestCase(test_base.BaseTestCase):
def test_count(self):
expected = 10
self.assertEqual(
expected, FakeNeutronObject.count(None, count=expected))
class BaseDbObjectCompositePrimaryKeyWithIdTestCase(BaseObjectIfaceTestCase):
_test_class = FakeNeutronObjectCompositePrimaryKeyWithId
@ -1115,6 +1163,12 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
self._assert_object_list_queries_constant(_create, self._test_class)
def test_count(self):
for fields in self.obj_fields:
self._make_object(fields).create()
self.assertEqual(
len(self.obj_fields), self._test_class.count(self.context))
class UniqueObjectBase(test_base.BaseTestCase):
def setUp(self):
@ -1151,7 +1205,7 @@ class RegisterFilterHookOnModelTestCase(UniqueObjectBase):
self.assertNotIn(
filter_name, self.registered_object.extra_filter_names)
base.register_filter_hook_on_model(
FakeNeutronObject.db_model, filter_name)
FakeNeutronDbObject.db_model, filter_name)
self.assertIn(filter_name, self.registered_object.extra_filter_names)