diff --git a/mistral/tests/unit/engine/test_direct_workflow.py b/mistral/tests/unit/engine/test_direct_workflow.py index c4571109..6d4508a0 100644 --- a/mistral/tests/unit/engine/test_direct_workflow.py +++ b/mistral/tests/unit/engine/test_direct_workflow.py @@ -93,6 +93,45 @@ class DirectWorkflowEngineTest(base.EngineTestCase): self.assertTrue(wf_ex.state, states.ERROR) + def test_direct_workflow_condition_transition_not_triggering(self): + wf_text = """--- + version: '2.0' + + wf: + input: + - var: null + + tasks: + task1: + action: std.fail + on-success: + - task2 + on-error: + - task3: <% $.var != null %> + + task2: + action: std.noop + + task3: + action: std.noop + """ + + wf_service.create_workflows(wf_text) + wf_ex = self.engine.start_workflow('wf', {}) + + self._await(lambda: self.is_execution_error(wf_ex.id)) + + wf_ex = db_api.get_workflow_execution(wf_ex.id) + tasks = wf_ex.task_executions + + task1 = self._assert_single_item(tasks, name='task1') + + self.assertEqual(1, len(tasks)) + + self._await(lambda: self.is_task_error(task1.id)) + + self.assertTrue(wf_ex.state, states.ERROR) + def test_direct_workflow_change_state_after_success(self): wf_text = """ version: '2.0' diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 95427394..d0a29466 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -150,7 +150,13 @@ class DirectWorkflowController(base.WorkflowController): def all_errors_handled(self): for t_ex in wf_utils.find_error_task_executions(self.wf_ex): - if not self.wf_spec.get_on_error_clause(t_ex.name): + + tasks_on_error = self._find_next_task_names_for_clause( + self.wf_spec.get_on_error_clause(t_ex.name), + data_flow.evaluate_task_outbound_context(t_ex) + ) + + if not tasks_on_error: return False return True