From 2c223fd04f9dbbe179fefb51a5767985c9e84ad1 Mon Sep 17 00:00:00 2001 From: Richard Lee Date: Mon, 3 Feb 2014 14:18:10 -0500 Subject: [PATCH] Replace stack_get_all_by_tenant with stack_get_all Preparing the database for the unscoped list stacks, this removes stack_get_all_by_tenant in favor of a single function that is tenant aware. This way, you can call stack_get_all(tenant_safe=False) and get all stacks, regardless of the tenants. tenant_safe defaults to True. Implements: blueprint management-api (partial) Change-Id: Ifdcfc44a3483a089ae803cfd044151262957674a --- heat/db/api.py | 10 +--- heat/db/sqlalchemy/api.py | 25 +++------ heat/engine/service.py | 6 +-- heat/tests/test_engine_service.py | 19 ++++--- heat/tests/test_sqlalchemy_api.py | 87 +++++++++++++++---------------- 5 files changed, 64 insertions(+), 83 deletions(-) diff --git a/heat/db/api.py b/heat/db/api.py index 097a8dafcc..5bb36f254c 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -121,21 +121,15 @@ def stack_get_by_name(context, stack_name): def stack_get_all(context, limit=None, sort_keys=None, marker=None, - sort_dir=None, filters=None): + sort_dir=None, filters=None, tenant_safe=True): return IMPL.stack_get_all(context, limit, sort_keys, - marker, sort_dir, filters) + marker, sort_dir, filters, tenant_safe) def stack_get_all_by_owner_id(context, owner_id): return IMPL.stack_get_all_by_owner_id(context, owner_id) -def stack_get_all_by_tenant(context, limit=None, sort_keys=None, - marker=None, sort_dir=None, filters=None): - return IMPL.stack_get_all_by_tenant(context, limit, sort_keys, - marker, sort_dir, filters) - - def stack_count_all_by_tenant(context, filters=None): return IMPL.stack_count_all_by_tenant(context, filters=filters) diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 60660745ce..5a1a14c4ae 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -321,30 +321,19 @@ def _paginate_query(context, query, model, limit=None, sort_keys=None, return query -def _query_stack_get_all_by_tenant(context): - query = _query_stack_get_all(context).\ - filter_by(tenant=context.tenant_id) - - return query - - -def _query_stack_get_all(context): +def _query_stack_get_all(context, tenant_safe=True): query = soft_delete_aware_query(context, models.Stack).\ filter_by(owner_id=None) + if tenant_safe: + query = query.filter_by(tenant=context.tenant_id) + return query def stack_get_all(context, limit=None, sort_keys=None, marker=None, - sort_dir=None, filters=None): - query = _query_stack_get_all(context) - return _filter_and_page_query(context, query, limit, sort_keys, - marker, sort_dir, filters).all() - - -def stack_get_all_by_tenant(context, limit=None, sort_keys=None, marker=None, - sort_dir=None, filters=None): - query = _query_stack_get_all_by_tenant(context) + sort_dir=None, filters=None, tenant_safe=True): + query = _query_stack_get_all(context, tenant_safe) return _filter_and_page_query(context, query, limit, sort_keys, marker, sort_dir, filters).all() @@ -366,7 +355,7 @@ def _filter_and_page_query(context, query, limit=None, sort_keys=None, def stack_count_all_by_tenant(context, filters=None): - query = _query_stack_get_all_by_tenant(context) + query = _query_stack_get_all(context) query = db_filters.exact_filter(query, models.Stack, filters) return query.count() diff --git a/heat/engine/service.py b/heat/engine/service.py index aa23ce819a..980ee23f25 100644 --- a/heat/engine/service.py +++ b/heat/engine/service.py @@ -290,7 +290,7 @@ class EngineService(service.Service): if stack_identity is not None: stacks = [self._get_stack(cnxt, stack_identity, show_deleted=True)] else: - stacks = db_api.stack_get_all_by_tenant(cnxt) or [] + stacks = db_api.stack_get_all(cnxt) or [] def format_stack_detail(s): stack = parser.Stack.load(cnxt, stack=s) @@ -330,8 +330,8 @@ class EngineService(service.Service): else: yield api.format_stack(stack) - stacks = db_api.stack_get_all_by_tenant(cnxt, limit, sort_keys, marker, - sort_dir, filters) or [] + stacks = db_api.stack_get_all(cnxt, limit, sort_keys, marker, + sort_dir, filters) or [] return list(format_stack_details(stacks)) @request_context diff --git a/heat/tests/test_engine_service.py b/heat/tests/test_engine_service.py index d52b0192a2..2d879933d9 100644 --- a/heat/tests/test_engine_service.py +++ b/heat/tests/test_engine_service.py @@ -1625,18 +1625,17 @@ class StackServiceTest(HeatTestCase): self.m.VerifyAll() - @mock.patch.object(db_api, 'stack_get_all_by_tenant') - def test_stack_list_passes_filtering_info(self, mock_stack_get_all_by_t): - + @mock.patch.object(db_api, 'stack_get_all') + def test_stack_list_passes_filtering_info(self, mock_stack_get_all): filters = {'foo': 'bar'} self.eng.list_stacks(self.ctx, filters=filters) - mock_stack_get_all_by_t.assert_called_once_with(mock.ANY, - mock.ANY, - mock.ANY, - mock.ANY, - mock.ANY, - filters - ) + mock_stack_get_all.assert_called_once_with(mock.ANY, + mock.ANY, + mock.ANY, + mock.ANY, + mock.ANY, + filters + ) @stack_context('service_abandon_stack') def test_abandon_stack(self): diff --git a/heat/tests/test_sqlalchemy_api.py b/heat/tests/test_sqlalchemy_api.py index 5d3f84c761..64a17c51fb 100644 --- a/heat/tests/test_sqlalchemy_api.py +++ b/heat/tests/test_sqlalchemy_api.py @@ -365,111 +365,96 @@ class SqlAlchemyTest(HeatTestCase): st_db = db_api.stack_get_all(self.ctx) self.assertEqual(1, len(st_db)) - def test_stack_get_all_by_tenant(self): - stacks = [self._setup_test_stack('stack', x)[1] for x in UUIDs] - - st_db = db_api.stack_get_all_by_tenant(self.ctx) - self.assertEqual(3, len(st_db)) - - stacks[0].delete() - st_db = db_api.stack_get_all_by_tenant(self.ctx) - self.assertEqual(2, len(st_db)) - - stacks[1].delete() - st_db = db_api.stack_get_all_by_tenant(self.ctx) - self.assertEqual(1, len(st_db)) - - def test_stack_get_all_by_tenant_and_filters(self): + def test_stack_get_all_with_filters(self): self._setup_test_stack('foo', UUID1) self._setup_test_stack('bar', UUID2) filters = {'name': 'foo'} - results = db_api.stack_get_all_by_tenant(self.ctx, - filters=filters) + results = db_api.stack_get_all(self.ctx, + filters=filters) self.assertEqual(1, len(results)) self.assertEqual('foo', results[0]['name']) - def test_stack_get_all_by_tenant_filter_matches_in_list(self): + def test_stack_get_all_filter_matches_in_list(self): self._setup_test_stack('foo', UUID1) self._setup_test_stack('bar', UUID2) filters = {'name': ['bar', 'quux']} - results = db_api.stack_get_all_by_tenant(self.ctx, - filters=filters) + results = db_api.stack_get_all(self.ctx, + filters=filters) self.assertEqual(1, len(results)) self.assertEqual('bar', results[0]['name']) - def test_stack_get_all_by_tenant_returns_all_if_no_filters(self): + def test_stack_get_all_returns_all_if_no_filters(self): self._setup_test_stack('foo', UUID1) self._setup_test_stack('bar', UUID2) filters = None - results = db_api.stack_get_all_by_tenant(self.ctx, - filters=filters) + results = db_api.stack_get_all(self.ctx, + filters=filters) self.assertEqual(2, len(results)) - def test_stack_get_all_by_tenant_default_sort_keys_and_dir(self): + def test_stack_get_all_default_sort_keys_and_dir(self): stacks = [self._setup_test_stack('stack', x)[1] for x in UUIDs] - st_db = db_api.stack_get_all_by_tenant(self.ctx) + st_db = db_api.stack_get_all(self.ctx) self.assertEqual(3, len(st_db)) self.assertEqual(stacks[2].id, st_db[0].id) self.assertEqual(stacks[1].id, st_db[1].id) self.assertEqual(stacks[0].id, st_db[2].id) - def test_stack_get_all_by_tenant_default_sort_dir(self): + def test_stack_get_all_default_sort_dir(self): stacks = [self._setup_test_stack('stack', x)[1] for x in UUIDs] - st_db = db_api.stack_get_all_by_tenant(self.ctx, sort_dir='asc') + st_db = db_api.stack_get_all(self.ctx, sort_dir='asc') self.assertEqual(3, len(st_db)) self.assertEqual(stacks[0].id, st_db[0].id) self.assertEqual(stacks[1].id, st_db[1].id) self.assertEqual(stacks[2].id, st_db[2].id) - def test_stack_get_all_by_tenant_str_sort_keys(self): + def test_stack_get_all_str_sort_keys(self): stacks = [self._setup_test_stack('stack', x)[1] for x in UUIDs] - st_db = db_api.stack_get_all_by_tenant(self.ctx, - sort_keys='created_at') + st_db = db_api.stack_get_all(self.ctx, + sort_keys='created_at') self.assertEqual(3, len(st_db)) self.assertEqual(stacks[0].id, st_db[0].id) self.assertEqual(stacks[1].id, st_db[1].id) self.assertEqual(stacks[2].id, st_db[2].id) @mock.patch.object(db_api.utils, 'paginate_query') - def test_stack_get_all_by_tenant_filters_sort_keys(self, mock_paginate): + def test_stack_get_all_filters_sort_keys(self, mock_paginate): sort_keys = ['name', 'status', 'created_at', 'updated_at', 'username'] - db_api.stack_get_all_by_tenant(self.ctx, - sort_keys=sort_keys) + db_api.stack_get_all(self.ctx, sort_keys=sort_keys) - args, _ = mock_paginate.call_args + args = mock_paginate.call_args[0] used_sort_keys = set(args[3]) expected_keys = set(['name', 'status', 'created_at', 'updated_at', 'id']) self.assertEqual(expected_keys, used_sort_keys) - def test_stack_get_all_by_tenant_marker(self): + def test_stack_get_all_marker(self): stacks = [self._setup_test_stack('stack', x)[1] for x in UUIDs] - st_db = db_api.stack_get_all_by_tenant(self.ctx, marker=stacks[1].id) + st_db = db_api.stack_get_all(self.ctx, marker=stacks[1].id) self.assertEqual(1, len(st_db)) self.assertEqual(stacks[0].id, st_db[0].id) - def test_stack_get_all_by_tenant_non_existing_marker(self): + def test_stack_get_all_non_existing_marker(self): [self._setup_test_stack('stack', x)[1] for x in UUIDs] uuid = 'this stack doesnt exist' - st_db = db_api.stack_get_all_by_tenant(self.ctx, marker=uuid) + st_db = db_api.stack_get_all(self.ctx, marker=uuid) self.assertEqual(3, len(st_db)) - def test_stack_get_all_by_tenant_doesnt_mutate_sort_keys(self): + def test_stack_get_all_doesnt_mutate_sort_keys(self): [self._setup_test_stack('stack', x)[1] for x in UUIDs] sort_keys = ['id'] - db_api.stack_get_all_by_tenant(self.ctx, sort_keys=sort_keys) + db_api.stack_get_all(self.ctx, sort_keys=sort_keys) self.assertEqual(['id'], sort_keys) def test_stack_count_all_by_tenant(self): @@ -1045,7 +1030,7 @@ class DBAPIStackTest(HeatTestCase): parent_stack2.id) self.assertEqual(2, len(stack2_children)) - def test_stack_get_all_by_tenant(self): + def test_stack_get_all_with_regular_tenant(self): values = [ {'tenant': UUID1}, {'tenant': UUID1}, @@ -1057,15 +1042,29 @@ class DBAPIStackTest(HeatTestCase): **val) for val in values] self.ctx.tenant_id = UUID1 - stacks = db_api.stack_get_all_by_tenant(self.ctx) + stacks = db_api.stack_get_all(self.ctx) self.assertEqual(2, len(stacks)) self.ctx.tenant_id = UUID2 - stacks = db_api.stack_get_all_by_tenant(self.ctx) + stacks = db_api.stack_get_all(self.ctx) self.assertEqual(3, len(stacks)) self.ctx.tenant_id = UUID3 - self.assertEqual([], db_api.stack_get_all_by_tenant(self.ctx)) + self.assertEqual([], db_api.stack_get_all(self.ctx)) + + def test_stack_get_all_with_tenant_safe_false(self): + values = [ + {'tenant': UUID1}, + {'tenant': UUID1}, + {'tenant': UUID2}, + {'tenant': UUID2}, + {'tenant': UUID2}, + ] + [create_stack(self.ctx, self.template, self.user_creds, + **val) for val in values] + + stacks = db_api.stack_get_all(self.ctx, tenant_safe=False) + self.assertEqual(5, len(stacks)) def test_stack_count_all_by_tenant(self): values = [