diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 39782766a..4ddd47682 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -54,7 +54,6 @@ class ActionEngine(base.EngineBase): def __init__(self, flow, flow_detail, backend, conf): super(ActionEngine, self).__init__(flow, flow_detail, backend, conf) - self._failures = {} # task uuid => failure self._root = None self._lock = threading.RLock() self._state_lock = threading.RLock() @@ -72,7 +71,8 @@ class ActionEngine(base.EngineBase): self._change_state(state) if state == states.SUSPENDED: return - misc.Failure.reraise_if_any(self._failures.values()) + failures = self.storage.get_failures() + misc.Failure.reraise_if_any(failures.values()) if current_failure: current_failure.reraise() @@ -104,7 +104,7 @@ class ActionEngine(base.EngineBase): missing = self._flow.requires - external_provides if missing: raise exc.MissingDependencies(self._flow, sorted(missing)) - if self._failures: + if self.storage.has_failures(): self._revert() else: self._run() @@ -145,8 +145,6 @@ class ActionEngine(base.EngineBase): a given state with a given result. This is a *internal* to the action engine and its associated action classes, not for use externally. """ - if isinstance(result, misc.Failure): - self._failures[task_action.uuid] = result details = dict(engine=self, task_name=task_action.name, task_uuid=task_action.uuid, @@ -160,7 +158,6 @@ class ActionEngine(base.EngineBase): task_uuid=uuid, result=None) self.task_notifier.notify(states.PENDING, details) - self._failures = {} self._change_state(states.PENDING) @lock_utils.locked @@ -177,8 +174,6 @@ class ActionEngine(base.EngineBase): self._change_state(states.RESUMING) # does nothing in PENDING state task_graph = flow_utils.flatten(self._flow) self._root = self._graph_action(task_graph) - loaded_failures = {} - for task in task_graph.nodes_iter(): try: task_id = self.storage.get_uuid_by_name(task.name) @@ -187,24 +182,9 @@ class ActionEngine(base.EngineBase): task_version = misc.get_version_string(task) self.storage.add_task(task_name=task.name, uuid=task_id, task_version=task_version) - try: - result = self.storage.get(task_id) - except exc.NotFound: - result = None - - if isinstance(result, misc.Failure): - # NOTE(imelnikov): old failure may have exc_info which - # might get lost during serialization, so we preserve - # old failure object if possible. - old_failure = self._failures.get(task_id, None) - if result.matches(old_failure): - loaded_failures[task_id] = old_failure - else: - loaded_failures[task_id] = result self.storage.set_result_mapping(task_id, task.save_as) self._root.add(task, task_action.TaskAction(task, task_id)) - self._failures = loaded_failures self._change_state(states.SUSPENDED) # does nothing in PENDING state @property diff --git a/taskflow/storage.py b/taskflow/storage.py index d72d661ee..47aa4ad06 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -65,6 +65,11 @@ class Storage(object): self._backend = backend self._flowdetail = flow_detail + # NOTE(imelnikov): failure serialization looses information, + # so we cache failures here, in task name -> misc.Failure mapping + self._failures = {} + self._reload_failures() + injector_td = self._flowdetail.find_by_name(self.injector_name) if injector_td is not None and injector_td.results is not None: names = six.iterkeys(injector_td.results) @@ -214,21 +219,51 @@ class Storage(object): if state == states.FAILURE and isinstance(data, misc.Failure): td.results = None td.failure = data + self._failures[td.name] = data else: td.results = data td.failure = None self._check_all_results_provided(uuid, td.name, data) self._with_connection(self._save_task_detail, task_detail=td) + def _cache_failure(self, name, fail): + """Ensure that cache has matching failure for task with this name. + + We leave cached version if it matches as it may contain more + information. Returns cached failure. + """ + cached = self._failures.get(name) + if fail.matches(cached): + return cached + self._failures[name] = fail + return fail + + def _reload_failures(self): + """Refresh failures cache""" + for td in self._flowdetail: + if td.failure is not None: + self._cache_failure(td.name, td.failure) + def get(self, uuid): """Get result for task with id 'uuid' to storage""" td = self._taskdetail_by_uuid(uuid) - if td.failure: - return td.failure + if td.failure is not None: + return self._cache_failure(td.name, td.failure) if td.state not in STATES_WITH_RESULTS: raise exceptions.NotFound("Result for task %r is not known" % uuid) return td.results + def get_failures(self): + """Get list of failures that happened with this flow. + + No order guaranteed. + """ + return self._failures.copy() + + def has_failures(self): + """Returns True if there are failed tasks in the storage""" + return bool(self._failures) + def _reset_task(self, td, state): if td.name == self.injector_name: return False @@ -237,6 +272,7 @@ class Storage(object): td.results = None td.failure = None td.state = state + self._failures.pop(td.name, None) return True def reset(self, uuid, state=states.PENDING): diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index 31d1b1848..4cd33b208 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -94,6 +94,8 @@ class StorageTest(test.TestCase): s.save('42', fail, states.FAILURE) self.assertEqual(s.get('42'), fail) self.assertEqual(s.get_task_state('42'), states.FAILURE) + self.assertIs(s.has_failures(), True) + self.assertEqual(s.get_failures(), {'my task': fail}) def test_get_failure_from_reverted_task(self): fail = misc.Failure(exc_info=(RuntimeError, RuntimeError(), None)) @@ -107,6 +109,18 @@ class StorageTest(test.TestCase): s.set_task_state('42', states.REVERTED) self.assertEqual(s.get('42'), fail) + def test_get_failure_after_reload(self): + fail = misc.Failure(exc_info=(RuntimeError, RuntimeError(), None)) + s = self._get_storage() + s.add_task('42', 'my task') + s.save('42', fail, states.FAILURE) + + s2 = storage.Storage(backend=self.backend, flow_detail=s._flowdetail) + self.assertIs(s2.has_failures(), True) + self.assertEqual(s2.get_failures(), {'my task': fail}) + self.assertEqual(s2.get('42'), fail) + self.assertEqual(s2.get_task_state('42'), states.FAILURE) + def test_get_non_existing_var(self): s = self._get_storage() s.add_task('42', 'my task') diff --git a/taskflow/tests/unit/test_suspend_flow.py b/taskflow/tests/unit/test_suspend_flow.py index 8ba299e64..b49c38b35 100644 --- a/taskflow/tests/unit/test_suspend_flow.py +++ b/taskflow/tests/unit/test_suspend_flow.py @@ -152,12 +152,6 @@ class SuspendFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) engine.storage.inject({'engine': engine}) engine.run() - self.assertEqual(engine.storage.get_flow_state(), states.SUSPENDED) - self.assertEqual( - self.values, - ['a', 'b', - 'c reverted(Failure: RuntimeError: Woot!)', - 'b reverted(5)']) # pretend we are resuming engine2 = self._make_engine(flow, engine.storage._flowdetail) @@ -172,6 +166,33 @@ class SuspendFlowTest(utils.EngineTestBase): 'b reverted(5)', 'a reverted(5)']) + def test_suspend_and_revert_even_if_task_is_gone(self): + flow = lf.Flow('linear').add( + TestTask(self.values, 'a'), + AutoSuspendingTaskOnRevert(self.values, 'b'), + FailingTask(self.values, 'c') + ) + engine = self._make_engine(flow) + engine.storage.inject({'engine': engine}) + engine.run() + + # pretend we are resuming, but task 'c' gone when flow got updated + flow2 = lf.Flow('linear').add( + TestTask(self.values, 'a'), + AutoSuspendingTaskOnRevert(self.values, 'b') + ) + engine2 = self._make_engine(flow2, engine.storage._flowdetail) + with self.assertRaisesRegexp(RuntimeError, '^Woot'): + engine2.run() + self.assertEqual(engine2.storage.get_flow_state(), states.REVERTED) + self.assertEqual( + self.values, + ['a', + 'b', + 'c reverted(Failure: RuntimeError: Woot!)', + 'b reverted(5)', + 'a reverted(5)']) + def test_storage_is_rechecked(self): flow = lf.Flow('linear').add( AutoSuspendingTask(self.values, 'b'),