diff --git a/heat/db/api.py b/heat/db/api.py index 98c53c646..05b9dba10 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -129,6 +129,10 @@ def resource_get_all_active_by_stack(context, stack_id): return IMPL.resource_get_all_active_by_stack(context, stack_id) +def resource_get_all_by_root_stack(context, stack_id, filters=None): + return IMPL.resource_get_all_by_root_stack(context, stack_id, filters) + + def resource_get_by_name_and_stack(context, resource_name, stack_id): return IMPL.resource_get_by_name_and_stack(context, resource_name, stack_id) diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 99944a565..a612cf66b 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -373,6 +373,19 @@ def resource_get_all_active_by_stack(context, stack_id): return dict((res.id, res) for res in results) +def resource_get_all_by_root_stack(context, stack_id, filters=None): + query = model_query( + context, models.Resource + ).filter_by( + root_stack_id=stack_id + ).options(orm.joinedload("data")) + + query = db_filters.exact_filter(query, models.Resource, filters) + results = query.all() + + return dict((res.id, res) for res in results) + + def stack_get_by_name_and_owner_id(context, stack_name, owner_id): query = soft_delete_aware_query( context, models.Stack diff --git a/heat/objects/resource.py b/heat/objects/resource.py index 6f900c3f9..8ccb01d9b 100644 --- a/heat/objects/resource.py +++ b/heat/objects/resource.py @@ -138,6 +138,10 @@ class Resource( def get_all_by_stack(cls, context, stack_id, filters=None): resources_db = db_api.resource_get_all_by_stack(context, stack_id, filters) + return cls._resources_to_dict(context, resources_db) + + @classmethod + def _resources_to_dict(cls, context, resources_db): resources = [ ( resource_name, @@ -160,6 +164,14 @@ class Resource( ] return dict(resources) + @classmethod + def get_all_by_root_stack(cls, context, stack_id, filters): + resources_db = db_api.resource_get_all_by_root_stack( + context, + stack_id, + filters) + return cls._resources_to_dict(context, resources_db) + @classmethod def get_by_name_and_stack(cls, context, resource_name, stack_id): resource_db = db_api.resource_get_by_name_and_stack( diff --git a/heat/tests/db/test_sqlalchemy_api.py b/heat/tests/db/test_sqlalchemy_api.py index ea54dbf55..9be3a557c 100644 --- a/heat/tests/db/test_sqlalchemy_api.py +++ b/heat/tests/db/test_sqlalchemy_api.py @@ -2278,6 +2278,50 @@ class DBAPIResourceTest(common.HeatTestCase): for rsrc_id, res in resources.items(): self.assertIn(res.name, ['res2', 'res3', 'res4', 'res5', 'res6']) + def test_resource_get_all_by_root_stack(self): + self.stack1 = create_stack(self.ctx, self.template, self.user_creds) + self.stack2 = create_stack(self.ctx, self.template, self.user_creds) + + create_resource(self.ctx, self.stack, name='res1', + root_stack_id=self.stack.id) + create_resource(self.ctx, self.stack, name='res2', + root_stack_id=self.stack.id) + create_resource(self.ctx, self.stack, name='res3', + root_stack_id=self.stack.id) + create_resource(self.ctx, self.stack1, name='res4', + root_stack_id=self.stack.id) + + # Test for all resources in a stack + resources = db_api.resource_get_all_by_root_stack( + self.ctx, self.stack.id) + self.assertEqual(4, len(resources)) + resource_names = [r.name for r in resources.values()] + self.assertEqual(['res1', 'res2', 'res3', 'res4'], + sorted(resource_names)) + + # Test for resources matching single entry + resources = db_api.resource_get_all_by_root_stack( + self.ctx, self.stack.id, filters=dict(name='res1')) + self.assertEqual(1, len(resources)) + resource_names = [r.name for r in resources.values()] + self.assertEqual(['res1'], resource_names) + self.assertEqual(1, len(resources)) + + # Test for resources matching multi entry + resources = db_api.resource_get_all_by_root_stack( + self.ctx, self.stack.id, filters=dict(name=[ + 'res1', + 'res2' + ]) + ) + self.assertEqual(2, len(resources)) + resource_names = [r.name for r in resources.values()] + self.assertEqual(['res1', 'res2'], + sorted(resource_names)) + + self.assertEqual({}, db_api.resource_get_all_by_root_stack( + self.ctx, self.stack2.id)) + class DBAPIStackLockTest(common.HeatTestCase): def setUp(self):