diff --git a/heat/db/api.py b/heat/db/api.py index 90537bb2c..98c53c646 100644 --- a/heat/db/api.py +++ b/heat/db/api.py @@ -121,8 +121,12 @@ def resource_exchange_stacks(context, resource_id1, resource_id2): return IMPL.resource_exchange_stacks(context, resource_id1, resource_id2) -def resource_get_all_by_stack(context, stack_id, key_id=False, filters=None): - return IMPL.resource_get_all_by_stack(context, stack_id, key_id, filters) +def resource_get_all_by_stack(context, stack_id, filters=None): + return IMPL.resource_get_all_by_stack(context, stack_id, filters) + + +def resource_get_all_active_by_stack(context, stack_id): + return IMPL.resource_get_all_active_by_stack(context, stack_id) def resource_get_by_name_and_stack(context, resource_name, stack_id): diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index 1a5d1e3ee..99944a565 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -341,7 +341,7 @@ def resource_create(context, values): return resource_ref -def resource_get_all_by_stack(context, stack_id, key_id=False, filters=None): +def resource_get_all_by_stack(context, stack_id, filters=None): query = model_query( context, models.Resource ).filter_by( @@ -354,10 +354,23 @@ def resource_get_all_by_stack(context, stack_id, key_id=False, filters=None): if not results: raise exception.NotFound(_("no resources for stack_id %s were found") % stack_id) - if key_id: - return dict((res.id, res) for res in results) - else: - return dict((res.name, res) for res in results) + + return dict((res.name, res) for res in results) + + +def resource_get_all_active_by_stack(context, stack_id): + filters = {'stack_id': stack_id, 'action': 'DELETE', 'status': 'COMPLETE'} + subquery = model_query(context, models.Resource.id).filter_by(**filters) + + results = model_query(context, models.Resource).filter_by( + stack_id=stack_id).filter( + models.Resource.id.notin_(subquery.as_scalar()) + ).options(orm.joinedload("data")).all() + + if not results: + raise exception.NotFound(_("no active resources for stack_id %s were" + " found") % stack_id) + return dict((res.id, res) for res in results) def stack_get_by_name_and_owner_id(context, stack_name, owner_id): diff --git a/heat/engine/stack.py b/heat/engine/stack.py index 7e0902dff..956bf2979 100644 --- a/heat/engine/stack.py +++ b/heat/engine/stack.py @@ -314,7 +314,7 @@ class Stack(collections.Mapping): def _find_filtered_resources(self, filters): for rsc in six.itervalues( resource_objects.Resource.get_all_by_stack( - self.context, self.id, True, filters)): + self.context, self.id, filters)): yield self.resources[rsc.name] def iter_resources(self, nested_depth=0, filters=None): @@ -338,10 +338,10 @@ class Stack(collections.Mapping): for nested_res in nested_stack.iter_resources(nested_depth - 1): yield nested_res - def _db_resources_get(self, key_id=False): + def db_active_resources_get(self): try: - return resource_objects.Resource.get_all_by_stack( - self.context, self.id, key_id) + return resource_objects.Resource.get_all_active_by_stack( + self.context, self.id) except exception.NotFound: return None @@ -349,9 +349,13 @@ class Stack(collections.Mapping): if not self.id: return None if self._db_resources is None: - self._db_resources = self._db_resources_get() - - return self._db_resources.get(name) if self._db_resources else None + try: + _db_resources = resource_objects.Resource.get_all_by_stack( + self.context, self.id) + self._db_resources = _db_resources + except exception.NotFound: + return None + return self._db_resources.get(name) @property def dependencies(self): @@ -1276,7 +1280,7 @@ class Stack(collections.Mapping): return candidate def _update_or_store_resources(self): - self.ext_rsrcs_db = self._db_resources_get(key_id=True) + self.ext_rsrcs_db = self.db_active_resources_get() curr_name_translated_dep = self.dependencies.translate(lambda res: res.name) diff --git a/heat/objects/resource.py b/heat/objects/resource.py index 526450e3e..6f900c3f9 100644 --- a/heat/objects/resource.py +++ b/heat/objects/resource.py @@ -135,16 +135,28 @@ class Resource( resource_id2) @classmethod - def get_all_by_stack(cls, context, stack_id, key_id=False, filters=None): - resources_db = db_api.resource_get_all_by_stack(context, - stack_id, key_id, + def get_all_by_stack(cls, context, stack_id, filters=None): + resources_db = db_api.resource_get_all_by_stack(context, stack_id, filters) resources = [ ( - resource_key, + resource_name, cls._from_db_object(cls(context), context, resource_db) ) - for resource_key, resource_db in six.iteritems(resources_db) + for resource_name, resource_db in six.iteritems(resources_db) + ] + return dict(resources) + + @classmethod + def get_all_active_by_stack(cls, context, stack_id): + resources_db = db_api.resource_get_all_active_by_stack(context, + stack_id) + resources = [ + ( + resource_id, + cls._from_db_object(cls(context), context, resource_db) + ) + for resource_id, resource_db in six.iteritems(resources_db) ] return dict(resources) diff --git a/heat/tests/db/test_sqlalchemy_api.py b/heat/tests/db/test_sqlalchemy_api.py index 9e4f78d4a..ea54dbf55 100644 --- a/heat/tests/db/test_sqlalchemy_api.py +++ b/heat/tests/db/test_sqlalchemy_api.py @@ -2256,6 +2256,28 @@ class DBAPIResourceTest(common.HeatTestCase): self.assertRaises(exception.NotFound, db_api.resource_get_all_by_stack, self.ctx, self.stack2.id) + def test_resource_get_all_active_by_stack(self): + values = [ + {'name': 'res1', 'action': rsrc.Resource.DELETE, + 'status': rsrc.Resource.COMPLETE}, + {'name': 'res2', 'action': rsrc.Resource.DELETE, + 'status': rsrc.Resource.IN_PROGRESS}, + {'name': 'res3', 'action': rsrc.Resource.UPDATE, + 'status': rsrc.Resource.IN_PROGRESS}, + {'name': 'res4', 'action': rsrc.Resource.UPDATE, + 'status': rsrc.Resource.COMPLETE}, + {'name': 'res5', 'action': rsrc.Resource.INIT, + 'status': rsrc.Resource.COMPLETE}, + {'name': 'res6'}, + ] + [create_resource(self.ctx, self.stack, **val) for val in values] + + resources = db_api.resource_get_all_active_by_stack(self.ctx, + self.stack.id) + self.assertEqual(5, len(resources)) + for rsrc_id, res in resources.items(): + self.assertIn(res.name, ['res2', 'res3', 'res4', 'res5', 'res6']) + class DBAPIStackLockTest(common.HeatTestCase): def setUp(self): diff --git a/heat/tests/test_convg_stack.py b/heat/tests/test_convg_stack.py index 4bb4c1b4a..06ed17938 100644 --- a/heat/tests/test_convg_stack.py +++ b/heat/tests/test_convg_stack.py @@ -131,7 +131,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): is_update, None)) self.assertEqual(expected_calls, mock_cr.mock_calls) - def _mock_convg_db_update_requires(self, key_id=False): + def _mock_convg_db_update_requires(self): """Updates requires column of resources. Required for testing the generation of convergence dependency graph @@ -144,8 +144,8 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): rsrc_id, is_update)) requires[rsrc_id] = list({id for id, is_update in reqs}) - rsrcs_db = resource_objects.Resource.get_all_by_stack( - self.stack.context, self.stack.id, key_id=key_id) + rsrcs_db = resource_objects.Resource.get_all_active_by_stack( + self.stack.context, self.stack.id) for rsrc_id, rsrc in rsrcs_db.items(): if rsrc.id in requires: @@ -172,7 +172,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): # rsrc.requires. Mock the same behavior here. self.stack = stack with mock.patch.object( - parser.Stack, '_db_resources_get', + parser.Stack, 'db_active_resources_get', side_effect=self._mock_convg_db_update_requires): curr_stack.converge_stack(template=template2, action=stack.UPDATE) @@ -297,7 +297,7 @@ class StackConvergenceCreateUpdateDeleteTest(common.HeatTestCase): # rsrc.requires. Mock the same behavior here. self.stack = stack with mock.patch.object( - parser.Stack, '_db_resources_get', + parser.Stack, 'db_active_resources_get', side_effect=self._mock_convg_db_update_requires): curr_stack.converge_stack(template=template2, action=stack.DELETE) diff --git a/heat/tests/test_stack.py b/heat/tests/test_stack.py index 9abcc3d6a..5fee0d57a 100644 --- a/heat/tests/test_stack.py +++ b/heat/tests/test_stack.py @@ -272,7 +272,6 @@ class StackTest(common.HeatTestCase): # Verify, the db query is called with expected filter mock_db_call.assert_called_once_with(self.ctx, self.stack.id, - True, dict(name=['A'])) # Make sure it returns only one resource. self.assertEqual(1, len(all_resources))