Fix workflow and join completion logic
* Workflow completion logic was previously taking tasks in WAITING state into account which doesn't make sense anymore with the new non-locking approach. We assume that any task is expected to get into a terminal state, we should normally have no hanging tasks. * Join completion logic was too primitive before because it didn't take into account inbound tasks in ERROR state. Just using method _is_unsatisfied_join() was not enough because a decision on what join state should be is not binary, it may be "ok to run", "still blocked" or "it will never run and hence should fail". Change-Id: Ibc2a7c678eb62ab3309a2279d722f76a3607aa69
This commit is contained in:
parent
d8bb98cc42
commit
8bdef0f84e
@ -35,8 +35,8 @@ from mistral.workflow import states
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_CHECK_TASK_START_ALLOWED_PATH = (
|
||||
'mistral.engine.task_handler._check_task_start_allowed'
|
||||
_REFRESH_TASK_STATE_PATH = (
|
||||
'mistral.engine.task_handler._refresh_task_state'
|
||||
)
|
||||
|
||||
_SCHEDULED_ON_ACTION_COMPLETE_PATH = (
|
||||
@ -73,7 +73,7 @@ def run_task(wf_cmd):
|
||||
return
|
||||
|
||||
if task.is_waiting():
|
||||
_schedule_check_task_start_allowed(task.task_ex)
|
||||
_schedule_refresh_task_state(task.task_ex)
|
||||
|
||||
if task.is_completed():
|
||||
wf_handler.schedule_on_task_complete(task.task_ex)
|
||||
@ -259,7 +259,7 @@ def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None,
|
||||
)
|
||||
|
||||
|
||||
def _check_task_start_allowed(task_ex_id):
|
||||
def _refresh_task_state(task_ex_id):
|
||||
with db_api.transaction():
|
||||
task_ex = db_api.get_task_execution(task_ex_id)
|
||||
|
||||
@ -268,16 +268,24 @@ def _check_task_start_allowed(task_ex_id):
|
||||
spec_parser.get_workflow_spec_by_id(task_ex.workflow_id)
|
||||
)
|
||||
|
||||
if wf_ctrl.is_task_start_allowed(task_ex):
|
||||
state, state_info = wf_ctrl.get_logical_task_state(task_ex)
|
||||
|
||||
if state == states.RUNNING:
|
||||
continue_task(task_ex)
|
||||
|
||||
return
|
||||
|
||||
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
|
||||
_schedule_check_task_start_allowed(task_ex, 1)
|
||||
elif state == states.ERROR:
|
||||
fail_task(task_ex, state_info)
|
||||
elif state == states.WAITING:
|
||||
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
|
||||
_schedule_refresh_task_state(task_ex, 1)
|
||||
else:
|
||||
# Must never get here.
|
||||
raise RuntimeError(
|
||||
'Unexpected logical task state [task_ex=%s, state=%s]' %
|
||||
(task_ex, state)
|
||||
)
|
||||
|
||||
|
||||
def _schedule_check_task_start_allowed(task_ex, delay=0):
|
||||
def _schedule_refresh_task_state(task_ex, delay=0):
|
||||
"""Schedules task preconditions check.
|
||||
|
||||
This method provides transactional decoupling of task preconditions
|
||||
@ -298,7 +306,7 @@ def _schedule_check_task_start_allowed(task_ex, delay=0):
|
||||
|
||||
scheduler.schedule_call(
|
||||
None,
|
||||
_CHECK_TASK_START_ALLOWED_PATH,
|
||||
_REFRESH_TASK_STATE_PATH,
|
||||
delay,
|
||||
unique_key=key,
|
||||
task_ex_id=task_ex.id
|
||||
|
@ -278,19 +278,10 @@ class Workflow(object):
|
||||
return
|
||||
|
||||
# Workflow is not completed if there are any incomplete task
|
||||
# executions that are not in WAITING state. If all incomplete
|
||||
# tasks are waiting and there are unhandled errors, then these
|
||||
# tasks will not reach completion. In this case, mark the
|
||||
# workflow complete.
|
||||
# TODO(rakhmerov): In the new non-locking transactional model
|
||||
# tasks in WAITING state should not be taken into account.
|
||||
# I'm actually surprised why we did this so far. The point is
|
||||
# that every task should complete sooner or later and get into
|
||||
# one of the terminal state. We should not be leaving any tasks
|
||||
# in a non-terminal state if we know exactly it's real state.
|
||||
# executions.
|
||||
incomplete_tasks = wf_utils.find_incomplete_task_executions(self.wf_ex)
|
||||
|
||||
if any(not states.is_waiting(t.state) for t in incomplete_tasks):
|
||||
if incomplete_tasks:
|
||||
return
|
||||
|
||||
wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec)
|
||||
|
@ -339,7 +339,7 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
|
||||
self.assertEqual(states.SUCCESS, task_10_ex.state)
|
||||
self.assertEqual(states.ERROR, task_21_ex.state)
|
||||
self.assertIsNotNone(task_21_ex.state_info)
|
||||
self.assertEqual(states.WAITING, task_30_ex.state)
|
||||
self.assertEqual(states.ERROR, task_30_ex.state)
|
||||
|
||||
# Update env in workflow execution with the following.
|
||||
updated_env = {
|
||||
@ -933,7 +933,7 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
|
||||
|
||||
self.assertEqual(states.SUCCESS, task_1_ex.state)
|
||||
self.assertEqual(states.ERROR, task_2_ex.state)
|
||||
self.assertEqual(states.WAITING, task_3_ex.state)
|
||||
self.assertEqual(states.ERROR, task_3_ex.state)
|
||||
|
||||
# Resume workflow and re-run failed task.
|
||||
wf_ex = self.engine.rerun_workflow(task_2_ex.id)
|
||||
|
@ -455,7 +455,8 @@ class JoinEngineTest(base.EngineTestCase):
|
||||
result4 = task4.published['result4']
|
||||
|
||||
self.assertIsNotNone(result4)
|
||||
self.assertEqual(2, result4.count('False'))
|
||||
self.assertTrue(result4.count('False') < 3)
|
||||
self.assertTrue(result4.count('True') >= 1)
|
||||
|
||||
def test_full_join_parallel_published_vars(self):
|
||||
wfs_tasks_join_complex = """---
|
||||
@ -645,6 +646,8 @@ class JoinEngineTest(base.EngineTestCase):
|
||||
|
||||
tasks = wf_ex.task_executions
|
||||
|
||||
self.assertIsNotNone(wf_ex.state_info)
|
||||
|
||||
task10 = self._assert_single_item(tasks, name='task10')
|
||||
task21 = self._assert_single_item(tasks, name='task21')
|
||||
task22 = self._assert_single_item(tasks, name='task22')
|
||||
@ -656,7 +659,7 @@ class JoinEngineTest(base.EngineTestCase):
|
||||
self.assertEqual(states.SUCCESS, task22.state)
|
||||
self.assertEqual(states.ERROR, task31.state)
|
||||
self.assertNotIn('task32', [task.name for task in tasks])
|
||||
self.assertEqual(states.WAITING, task40.state)
|
||||
self.assertEqual(states.ERROR, task40.state)
|
||||
|
||||
def test_diamond_join_all(self):
|
||||
wf_text = """---
|
||||
|
@ -126,12 +126,12 @@ class WorkflowController(object):
|
||||
return cmds
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_task_start_allowed(self, task_ex):
|
||||
"""Determines if the given task is allowed to start.
|
||||
def get_logical_task_state(self, task_ex):
|
||||
"""Determines a logical state of the given task.
|
||||
|
||||
:param task_ex: Task execution.
|
||||
:return: True if all preconditions are met and the given task
|
||||
is allowed to start.
|
||||
:return: Tuple (state, state_info) which the given task should have
|
||||
according to workflow rules and current states of other tasks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -60,11 +60,13 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
if not t_spec.get_join():
|
||||
return not t_ex_candidate.processed
|
||||
|
||||
return self._triggers_join(
|
||||
t_spec,
|
||||
self.wf_spec.get_tasks()[t_ex_candidate.name]
|
||||
induced_state = self._get_induced_join_state(
|
||||
self.wf_spec.get_tasks()[t_ex_candidate.name],
|
||||
t_spec
|
||||
)
|
||||
|
||||
return induced_state == states.RUNNING
|
||||
|
||||
def _find_next_commands(self, task_ex=None):
|
||||
cmds = super(DirectWorkflowController, self)._find_next_commands(
|
||||
task_ex
|
||||
@ -140,9 +142,7 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
return
|
||||
|
||||
cmd.unique_key = self._get_join_unique_key(cmd)
|
||||
|
||||
if self._is_unsatisfied_join(cmd.task_spec):
|
||||
cmd.wait = True
|
||||
cmd.wait = True
|
||||
|
||||
def _get_join_unique_key(self, cmd):
|
||||
return 'join-task-%s-%s' % (self.wf_ex.id, cmd.task_spec.get_name())
|
||||
@ -160,13 +160,16 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
|
||||
return ctx
|
||||
|
||||
def is_task_start_allowed(self, task_ex):
|
||||
def get_logical_task_state(self, task_ex):
|
||||
task_spec = self.wf_spec.get_tasks()[task_ex.name]
|
||||
|
||||
return (
|
||||
not task_spec.get_join() or
|
||||
not self._is_unsatisfied_join(task_spec)
|
||||
)
|
||||
if not task_spec.get_join():
|
||||
# A simple 'non-join' task does not have any preconditions
|
||||
# based on state of other tasks so its logical state always
|
||||
# equals to its real state.
|
||||
return task_ex.state, task_ex.state_info
|
||||
|
||||
return self._get_join_logical_state(task_spec)
|
||||
|
||||
def is_error_handled_for(self, task_ex):
|
||||
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
|
||||
@ -261,54 +264,115 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
if not condition or expr.evaluate(condition, ctx)
|
||||
]
|
||||
|
||||
def _is_unsatisfied_join(self, task_spec):
|
||||
def _get_join_logical_state(self, task_spec):
|
||||
# TODO(rakhmerov): We need to use task_ex instead of task_spec
|
||||
# in order to cover a use case when there's more than one instance
|
||||
# of the same 'join' task in a workflow.
|
||||
join_expr = task_spec.get_join()
|
||||
|
||||
if not join_expr:
|
||||
return False
|
||||
|
||||
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
|
||||
|
||||
if not in_task_specs:
|
||||
return False
|
||||
return states.RUNNING
|
||||
|
||||
# We need to count a number of triggering inbound transitions.
|
||||
num = len([1 for in_t_s in in_task_specs
|
||||
if self._triggers_join(task_spec, in_t_s)])
|
||||
# List of tuples (task_name, state).
|
||||
induced_states = [
|
||||
(t_s.get_name(), self._get_induced_join_state(t_s, task_spec))
|
||||
for t_s in in_task_specs
|
||||
]
|
||||
|
||||
# If "join" is configured as a number.
|
||||
if isinstance(join_expr, int) and num < join_expr:
|
||||
return True
|
||||
def count(state):
|
||||
return len(list(filter(lambda s: s[1] == state, induced_states)))
|
||||
|
||||
if join_expr == 'all' and len(in_task_specs) > num:
|
||||
return True
|
||||
error_count = count(states.ERROR)
|
||||
running_count = count(states.RUNNING)
|
||||
total_count = len(induced_states)
|
||||
|
||||
if join_expr == 'one' and num == 0:
|
||||
return True
|
||||
def _blocked_message():
|
||||
return (
|
||||
'Blocked by tasks: %s' %
|
||||
[s[0] for s in induced_states if s[1] == states.WAITING]
|
||||
)
|
||||
|
||||
return False
|
||||
def _failed_message():
|
||||
return (
|
||||
'Failed by tasks: %s' %
|
||||
[s[0] for s in induced_states if s[1] == states.ERROR]
|
||||
)
|
||||
|
||||
# If "join" is configured as a number or 'one'.
|
||||
if isinstance(join_expr, int) or join_expr == 'one':
|
||||
cardinality = 1 if join_expr == 'one' else join_expr
|
||||
|
||||
if running_count >= cardinality:
|
||||
return states.RUNNING, None
|
||||
|
||||
# E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING]
|
||||
# No chance to get 3 RUNNING states.
|
||||
if error_count > (total_count - cardinality):
|
||||
return states.ERROR, _failed_message()
|
||||
|
||||
return states.WAITING, _blocked_message()
|
||||
|
||||
if join_expr == 'all':
|
||||
if total_count == running_count:
|
||||
return states.RUNNING, None
|
||||
|
||||
if error_count > 0:
|
||||
return states.ERROR, _failed_message()
|
||||
|
||||
return states.WAITING, _blocked_message()
|
||||
|
||||
raise RuntimeError('Unexpected join expression: %s' % join_expr)
|
||||
|
||||
# TODO(rakhmerov): Method signature is incorrect given that
|
||||
# we may have multiple task executions for a task. It should
|
||||
# accept inbound task execution rather than a spec.
|
||||
def _triggers_join(self, join_task_spec, inbound_task_spec):
|
||||
def _get_induced_join_state(self, inbound_task_spec, join_task_spec):
|
||||
join_task_name = join_task_spec.get_name()
|
||||
|
||||
in_task_ex = self._find_task_execution_by_spec(inbound_task_spec)
|
||||
|
||||
if not in_task_ex:
|
||||
if self._possible_route(inbound_task_spec):
|
||||
return states.WAITING
|
||||
else:
|
||||
return states.ERROR
|
||||
|
||||
if not states.is_completed(in_task_ex.state):
|
||||
return states.WAITING
|
||||
|
||||
if join_task_name not in self._find_next_task_names(in_task_ex):
|
||||
return states.ERROR
|
||||
|
||||
return states.RUNNING
|
||||
|
||||
def _find_task_execution_by_spec(self, task_spec):
|
||||
in_t_execs = wf_utils.find_task_executions_by_spec(
|
||||
self.wf_ex,
|
||||
inbound_task_spec
|
||||
task_spec
|
||||
)
|
||||
|
||||
# TODO(rakhmerov): Temporary hack. See the previous comment.
|
||||
in_t_ex = in_t_execs[-1] if in_t_execs else None
|
||||
return in_t_execs[-1] if in_t_execs else None
|
||||
|
||||
if not in_t_ex or not states.is_completed(in_t_ex.state):
|
||||
return False
|
||||
def _possible_route(self, task_spec):
|
||||
# TODO(rakhmerov): In some cases this method will be expensive because
|
||||
# it uses a multistep recursive search with DB queries.
|
||||
# It will be optimized with Workflow Execution Graph moving forward.
|
||||
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
|
||||
|
||||
return list(
|
||||
filter(
|
||||
lambda t_name: join_task_spec.get_name() == t_name,
|
||||
self._find_next_task_names(in_t_ex)
|
||||
)
|
||||
)
|
||||
if not in_task_specs:
|
||||
return True
|
||||
|
||||
for t_s in in_task_specs:
|
||||
t_ex = self._find_task_execution_by_spec(t_s)
|
||||
|
||||
if not t_ex:
|
||||
if self._possible_route(t_s):
|
||||
return True
|
||||
else:
|
||||
if task_spec.get_name() in self._find_next_task_names(t_ex):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -108,9 +108,9 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
|
||||
return data_flow.evaluate_task_outbound_context(task_execs[0])
|
||||
|
||||
def is_task_start_allowed(self, task_ex):
|
||||
def get_logical_task_state(self, task_ex):
|
||||
# TODO(rakhmerov): Implement.
|
||||
return True
|
||||
return task_ex.state, task_ex.state_info
|
||||
|
||||
def is_error_handled_for(self, task_ex):
|
||||
return task_ex.state != states.ERROR
|
||||
|
Loading…
Reference in New Issue
Block a user