diff --git a/heat/db/api.py b/heat/db/api.py index 3c89b0f260..4e2fe20e65 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -120,9 +120,9 @@ def stack_get_all_by_owner_id(context, owner_id): def stack_get_all_by_tenant(context, limit=None, sort_keys=None, - marker=None, sort_dir=None): + marker=None, sort_dir=None, filters=None): return IMPL.stack_get_all_by_tenant(context, limit, sort_keys, - marker, sort_dir) + marker, sort_dir, filters) def stack_count_all_by_tenant(context): diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 16a303eb4b..d6c4bed3e6 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -28,6 +28,7 @@ from heat.openstack.common.gettextutils import _ from heat.common import crypt from heat.common import exception +from heat.db.sqlalchemy import filters as db_filters from heat.db.sqlalchemy import migration from heat.db.sqlalchemy import models from heat.openstack.common.db.sqlalchemy import session as db_session @@ -284,8 +285,12 @@ def _query_stack_get_all_by_tenant(context): def stack_get_all_by_tenant(context, limit=None, sort_keys=None, marker=None, - sort_dir=None): + sort_dir=None, filters=None): + if filters is None: + filters = {} + query = _query_stack_get_all_by_tenant(context) + query = db_filters.exact_filter(query, models.Stack, filters) return _paginate_query(context, query, models.Stack, limit, sort_keys, marker, sort_dir).all() diff --git a/heat/db/sqlalchemy/filters.py b/heat/db/sqlalchemy/filters.py new file mode 100644 index 0000000000..d8483ced63 --- /dev/null +++ b/heat/db/sqlalchemy/filters.py @@ -0,0 +1,44 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +def exact_filter(query, model, filters): + """Applies exact match filtering to a query. + + Returns the updated query. Modifies filters argument to remove + filters consumed. + + :param query: query to apply filters to + :param model: model object the query applies to, for IN-style + filtering + :param filters: dictionary of filters; values that are lists, + tuples, sets, or frozensets cause an 'IN' test to + be performed, while exact matching ('==' operator) + is used for other values + """ + + filter_dict = {} + + for key, value in filters.iteritems(): + if isinstance(value, (list, tuple, set, frozenset)): + column_attr = getattr(model, key) + query = query.filter(column_attr.in_(value)) + else: + filter_dict[key] = value + + if filter_dict: + query = query.filter_by(**filter_dict) + + return query diff --git a/heat/tests/test_sqlalchemy_api.py b/heat/tests/test_sqlalchemy_api.py index 5f68c088e4..a20fcd83eb 100644 --- a/heat/tests/test_sqlalchemy_api.py +++ b/heat/tests/test_sqlalchemy_api.py @@ -217,6 +217,41 @@ class SqlAlchemyTest(HeatTestCase): st_db = db_api.stack_get_all_by_tenant(self.ctx) self.assertEqual(1, len(st_db)) + def test_stack_get_all_by_tenant_and_filters(self): + stack1 = self._setup_test_stack('foo', UUIDs[0]) + stack2 = self._setup_test_stack('bar', UUIDs[1]) + stacks = [stack1, stack2] + + filters = {'name': 'foo'} + results = db_api.stack_get_all_by_tenant(self.ctx, + filters=filters) + + self.assertEqual(1, len(results)) + self.assertEqual('foo', results[0]['name']) + + def test_stack_get_all_by_tenant_filter_matches_in_list(self): + stack1 = self._setup_test_stack('foo', UUIDs[0]) + stack2 = self._setup_test_stack('bar', UUIDs[1]) + stacks = [stack1, stack2] + + filters = {'name': ['bar', 'quux']} + results = db_api.stack_get_all_by_tenant(self.ctx, + filters=filters) + + self.assertEqual(1, len(results)) + self.assertEqual('bar', results[0]['name']) + + def test_stack_get_all_by_tenant_returns_all_if_no_filters(self): + stack1 = self._setup_test_stack('foo', UUIDs[0]) + stack2 = self._setup_test_stack('bar', UUIDs[1]) + stacks = [stack1, stack2] + + filters = None + results = db_api.stack_get_all_by_tenant(self.ctx, + filters=filters) + + self.assertEqual(2, len(results)) + def test_stack_get_all_by_tenant_default_sort_keys_and_dir(self): stacks = [self._setup_test_stack('stack', x)[1] for x in UUIDs] diff --git a/heat/tests/test_sqlalchemy_filters.py b/heat/tests/test_sqlalchemy_filters.py new file mode 100644 index 0000000000..b93e1b4cb1 --- /dev/null +++ b/heat/tests/test_sqlalchemy_filters.py @@ -0,0 +1,42 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from heat.db.sqlalchemy import filters as db_filters +from heat.tests.common import HeatTestCase + + +class ExactFilterTest(HeatTestCase): + def setUp(self): + super(ExactFilterTest, self).setUp() + self.query = mock.Mock() + self.model = mock.Mock() + + def test_returns_same_query_for_empty_filters(self): + filters = {} + db_filters.exact_filter(self.query, self.model, filters) + self.assertEqual(0, self.query.call_count) + + def test_add_exact_match_clause_for_single_values(self): + filters = {'cat': 'foo'} + db_filters.exact_filter(self.query, self.model, filters) + + self.query.filter_by.assert_called_once_with(cat='foo') + + def test_adds_an_in_clause_for_multiple_values(self): + self.model.cat.in_.return_value = 'fake in clause' + filters = {'cat': ['foo', 'quux']} + db_filters.exact_filter(self.query, self.model, filters) + + self.query.filter.assert_called_once_with('fake in clause') + self.model.cat.in_.assert_called_once_with(['foo', 'quux'])