diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index e8c97045c..75496ca02 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -112,6 +112,7 @@ class TaskAction(base.Action): engine=engine): kwargs = engine.storage.fetch_mapped_args(self._task.rebind) kwargs['result'] = engine.storage.get(self._id) + kwargs['flow_failures'] = engine.storage.get_failures() try: self._task.revert(**kwargs) except Exception: diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index e5114ee19..f803785d6 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -30,8 +30,10 @@ import taskflow.engines from taskflow.engines.action_engine import engine as eng from taskflow.persistence import logbook from taskflow import states +from taskflow import task from taskflow import test from taskflow.tests import utils +from taskflow.utils import misc from taskflow.utils import persistence_utils as p_utils @@ -186,6 +188,26 @@ class EngineLinearFlowTest(utils.EngineTestBase): 'fail reverted(Failure: RuntimeError: Woot!)', 'task2 reverted(5)', 'task1 reverted(5)']) + def test_flow_failures_are_passed_to_revert(self): + class CheckingTask(task.Task): + def execute(m_self): + return 'RESULT' + + def revert(m_self, result, flow_failures): + self.assertEqual(result, 'RESULT') + self.assertEqual(flow_failures.keys(), ['fail1']) + fail = flow_failures['fail1'] + self.assertIsInstance(fail, misc.Failure) + self.assertEqual(str(fail), 'Failure: RuntimeError: Woot!') + + flow = lf.Flow('test').add( + CheckingTask(), + utils.FailingTask(self.values, 'fail1') + ) + engine = self._make_engine(flow) + with self.assertRaisesRegexp(RuntimeError, '^Woot'): + engine.run() + class EngineParallelFlowTest(utils.EngineTestBase): diff --git a/taskflow/tests/unit/test_suspend_flow.py b/taskflow/tests/unit/test_suspend_flow.py index b49c38b35..1a53ba873 100644 --- a/taskflow/tests/unit/test_suspend_flow.py +++ b/taskflow/tests/unit/test_suspend_flow.py @@ -75,7 +75,7 @@ class AutoSuspendingTask(TestTask): engine.suspend() return result - def revert(self, engine, result): + def revert(self, engine, result, flow_failures): super(AutoSuspendingTask, self).revert(**{'result': result}) @@ -84,7 +84,7 @@ class AutoSuspendingTaskOnRevert(TestTask): def execute(self, engine): return super(AutoSuspendingTaskOnRevert, self).execute() - def revert(self, engine, result): + def revert(self, engine, result, flow_failures): super(AutoSuspendingTaskOnRevert, self).revert(**{'result': result}) engine.suspend()