diff --git a/mistral/engine/task_handler.py b/mistral/engine/task_handler.py index 0c1e7bfb..5d2f4768 100644 --- a/mistral/engine/task_handler.py +++ b/mistral/engine/task_handler.py @@ -240,16 +240,21 @@ def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None, @profiler.trace('task-handler-refresh-task-state') def _refresh_task_state(task_ex_id): with db_api.transaction(): - task_ex = db_api.get_task_execution(task_ex_id) + task_ex = db_api.load_task_execution(task_ex_id) + + if not task_ex: + return + + wf_ex = task_ex.workflow_execution + + if states.is_completed(wf_ex.state): + return wf_spec = spec_parser.get_workflow_spec_by_execution_id( task_ex.workflow_execution_id ) - wf_ctrl = wf_base.get_controller( - task_ex.workflow_execution, - wf_spec - ) + wf_ctrl = wf_base.get_controller(wf_ex, wf_spec) state, state_info, cardinality = wf_ctrl.get_logical_task_state( task_ex diff --git a/mistral/engine/workflow_handler.py b/mistral/engine/workflow_handler.py index 6d49c98e..02ed46ca 100644 --- a/mistral/engine/workflow_handler.py +++ b/mistral/engine/workflow_handler.py @@ -81,7 +81,10 @@ def cancel_workflow(wf_ex, msg=None): def _check_and_complete(wf_ex_id): # Note: This method can only be called via scheduler. with db_api.transaction(): - wf_ex = db_api.get_workflow_execution(wf_ex_id) + wf_ex = db_api.load_workflow_execution(wf_ex_id) + + if not wf_ex or states.is_completed(wf_ex.state): + return wf = workflows.Workflow( db_api.get_workflow_definition(wf_ex.workflow_id), @@ -167,6 +170,10 @@ def set_workflow_state(wf_ex, state, msg=None): ) +def _get_completion_check_key(wf_ex): + return 'wfh_on_c_a_c-%s' % wf_ex.id + + @profiler.trace('workflow-handler-schedule-check-and-complete') def _schedule_check_and_complete(wf_ex, delay=0): """Schedules workflow completion check. @@ -183,8 +190,7 @@ def _schedule_check_and_complete(wf_ex, delay=0): :param delay: Minimum amount of time before task completion check should be made. """ - # TODO(rakhmerov): update docstring - key = 'wfh_on_c_a_c-%s' % wf_ex.id + key = _get_completion_check_key(wf_ex) scheduler.schedule_call( None, diff --git a/mistral/tests/unit/engine/test_default_engine.py b/mistral/tests/unit/engine/test_default_engine.py index c0b57e1a..ab7af07c 100644 --- a/mistral/tests/unit/engine/test_default_engine.py +++ b/mistral/tests/unit/engine/test_default_engine.py @@ -398,7 +398,14 @@ class DefaultEngineTest(base.DbTestCase): def test_stop_workflow_fail(self): # Start workflow. wf_ex = self.engine.start_workflow( - 'wb.wf', {'param1': 'Hey', 'param2': 'Hi'}, task_name="task2") + 'wb.wf', + { + 'param1': 'Hey', + 'param2': 'Hi' + }, + task_name="task2" + ) + # Re-read execution to access related tasks. wf_ex = db_api.get_workflow_execution(wf_ex.id) @@ -413,7 +420,14 @@ class DefaultEngineTest(base.DbTestCase): def test_stop_workflow_succeed(self): # Start workflow. wf_ex = self.engine.start_workflow( - 'wb.wf', {'param1': 'Hey', 'param2': 'Hi'}, task_name="task2") + 'wb.wf', + { + 'param1': 'Hey', + 'param2': 'Hi' + }, + task_name="task2" + ) + # Re-read execution to access related tasks. wf_ex = db_api.get_workflow_execution(wf_ex.id) @@ -427,7 +441,14 @@ class DefaultEngineTest(base.DbTestCase): def test_stop_workflow_bad_status(self): wf_ex = self.engine.start_workflow( - 'wb.wf', {'param1': 'Hey', 'param2': 'Hi'}, task_name="task2") + 'wb.wf', + { + 'param1': 'Hey', + 'param2': 'Hi' + }, + task_name="task2" + ) + # Re-read execution to access related tasks. wf_ex = db_api.get_workflow_execution(wf_ex.id) diff --git a/mistral/tests/unit/engine/test_direct_workflow.py b/mistral/tests/unit/engine/test_direct_workflow.py index 85e737f8..78f2daa3 100644 --- a/mistral/tests/unit/engine/test_direct_workflow.py +++ b/mistral/tests/unit/engine/test_direct_workflow.py @@ -582,3 +582,30 @@ class DirectWorkflowEngineTest(base.EngineTestCase): ) self.assertIn("Task 'task3' not found", exception.message) + + def test_delete_workflow_completion_check_on_stop(self): + wf_text = """--- + version: '2.0' + + wf: + tasks: + async_task: + action: std.async_noop + """ + + wf_service.create_workflows(wf_text) + + wf_ex = self.engine.start_workflow('wf', {}) + + calls = db_api.get_delayed_calls() + + mtd_name = 'mistral.engine.workflow_handler._check_and_complete' + + self._assert_single_item(calls, target_method_name=mtd_name) + + self.engine.stop_workflow(wf_ex.id, state=states.CANCELLED) + + self._await( + lambda: + len(db_api.get_delayed_calls(target_method_name=mtd_name)) == 0 + ) diff --git a/mistral/tests/unit/engine/test_join.py b/mistral/tests/unit/engine/test_join.py index a665a8b6..02127ef3 100644 --- a/mistral/tests/unit/engine/test_join.py +++ b/mistral/tests/unit/engine/test_join.py @@ -811,3 +811,58 @@ class JoinEngineTest(base.EngineTestCase): self.assertEqual(4, len(task_execs)) self._assert_multiple_items(task_execs, 4, state=states.SUCCESS) + + def delete_join_completion_check_on_stop(self): + wf_text = """--- + version: '2.0' + + wf: + tasks: + task1: + action: std.noop + on-success: join_task + + task2: + description: Never ends + action: std.async_noop + on-success: join_task + + join_task: + join: all + """ + + wf_service.create_workflows(wf_text) + + wf_ex = self.engine.start_workflow('wf', {}) + + tasks = db_api.get_task_executions(workflow_execution_id=wf_ex.id) + + self.assertTrue(len(tasks) >= 2) + + task1 = self._assert_single_item(tasks, name='task1') + + self.await_task_success(task1.id) + + # Once task1 is finished we know that join_task must be created. + + tasks = db_api.get_task_executions(workflow_execution_id=wf_ex.id) + + self._assert_single_item( + tasks, + name='join_task', + state=states.WAITING + ) + + calls = db_api.get_delayed_calls() + + mtd_name = 'mistral.engine.task_handler._refresh_task_state' + + self._assert_single_item(calls, target_method_name=mtd_name) + + # Stop the workflow. + self.engine.stop_workflow(wf_ex.id, state=states.CANCELLED) + + self._await( + lambda: + len(db_api.get_delayed_calls(target_method_name=mtd_name)) == 0 + )