diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 27790dbb0..4b3dab183 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -585,25 +585,10 @@ def stack_get_root_id(context, stack_id): def stack_count_total_resources(context, stack_id): - - # start with a stack_get to confirm the context can access the stack - if stack_id is None or stack_get(context, stack_id) is None: - return 0 - - def nested_stack_ids(sid): - yield sid - for child in stack_get_all_by_owner_id(context, sid): - for stack in nested_stack_ids(child.id): - yield stack - - stack_ids = list(nested_stack_ids(stack_id)) - - # count all resources which belong to the stacks + # count all resources which belong to the root stack results = model_query( context, models.Resource - ).filter( - models.Resource.stack_id.in_(stack_ids) - ).count() + ).filter(models.Resource.root_stack_id == stack_id).count() return results diff --git a/heat/tests/db/test_sqlalchemy_api.py b/heat/tests/db/test_sqlalchemy_api.py index b8abd0e50..fd51b829f 100644 --- a/heat/tests/db/test_sqlalchemy_api.py +++ b/heat/tests/db/test_sqlalchemy_api.py @@ -1871,10 +1871,14 @@ class DBAPIStackTest(common.HeatTestCase): def test_stack_count_total_resources(self): - def add_resources(stack, count): + def add_resources(stack, count, root_stack_id): for i in range(count): create_resource( - self.ctx, stack, name='%s-%s' % (stack.name, i)) + self.ctx, + stack, + name='%s-%s' % (stack.name, i), + root_stack_id=root_stack_id + ) root = create_stack(self.ctx, self.template, self.user_creds, name='root stack') @@ -1904,38 +1908,19 @@ class DBAPIStackTest(common.HeatTestCase): s_4 = create_stack(self.ctx, self.template, self.user_creds, name='s_4', owner_id=root.id) - add_resources(root, 3) - add_resources(s_1, 2) - add_resources(s_1_1, 4) - add_resources(s_1_2, 5) - add_resources(s_1_3, 6) + add_resources(root, 3, root.id) + add_resources(s_1, 2, root.id) + add_resources(s_1_1, 4, root.id) + add_resources(s_1_2, 5, root.id) + add_resources(s_1_3, 6, root.id) - add_resources(s_2, 1) - add_resources(s_2_1_1_1, 1) - add_resources(s_3, 4) + add_resources(s_2, 1, root.id) + add_resources(s_2_1_1_1, 1, root.id) + add_resources(s_3, 4, root.id) self.assertEqual(26, db_api.stack_count_total_resources( self.ctx, root.id)) - self.assertEqual(17, db_api.stack_count_total_resources( - self.ctx, s_1.id)) - self.assertEqual(4, db_api.stack_count_total_resources( - self.ctx, s_1_1.id)) - self.assertEqual(5, db_api.stack_count_total_resources( - self.ctx, s_1_2.id)) - self.assertEqual(6, db_api.stack_count_total_resources( - self.ctx, s_1_3.id)) - - self.assertEqual(2, db_api.stack_count_total_resources( - self.ctx, s_2.id)) - self.assertEqual(1, db_api.stack_count_total_resources( - self.ctx, s_2_1.id)) - self.assertEqual(1, db_api.stack_count_total_resources( - self.ctx, s_2_1_1.id)) - self.assertEqual(1, db_api.stack_count_total_resources( - self.ctx, s_2_1_1_1.id)) - self.assertEqual(4, db_api.stack_count_total_resources( - self.ctx, s_3.id)) self.assertEqual(0, db_api.stack_count_total_resources( self.ctx, s_4.id)) self.assertEqual(0, db_api.stack_count_total_resources(