Fix workflow and join completion logic

* Workflow completion logic was previously taking tasks in
  WAITING state into account which doesn't make sense anymore
  with the new non-locking approach. We assume that any task
  is expected to get into a terminal state, we should normally
  have no hanging tasks.
* Join completion logic was too primitive before because it
  didn't take into account inbound tasks in ERROR state.
  Just using method _is_unsatisfied_join() was not enough
  because a decision on what join state should be is not
  binary, it may be "ok to run", "still blocked" or
  "it will never run and hence should fail".

Change-Id: Ibc2a7c678eb62ab3309a2279d722f76a3607aa69
This commit is contained in:
Renat Akhmerov 2016-08-10 01:44:15 +07:00
parent d8bb98cc42
commit 8bdef0f84e
7 changed files with 137 additions and 71 deletions

View File

@ -35,8 +35,8 @@ from mistral.workflow import states
LOG = logging.getLogger(__name__)
_CHECK_TASK_START_ALLOWED_PATH = (
'mistral.engine.task_handler._check_task_start_allowed'
_REFRESH_TASK_STATE_PATH = (
'mistral.engine.task_handler._refresh_task_state'
)
_SCHEDULED_ON_ACTION_COMPLETE_PATH = (
@ -73,7 +73,7 @@ def run_task(wf_cmd):
return
if task.is_waiting():
_schedule_check_task_start_allowed(task.task_ex)
_schedule_refresh_task_state(task.task_ex)
if task.is_completed():
wf_handler.schedule_on_task_complete(task.task_ex)
@ -259,7 +259,7 @@ def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None,
)
def _check_task_start_allowed(task_ex_id):
def _refresh_task_state(task_ex_id):
with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex_id)
@ -268,16 +268,24 @@ def _check_task_start_allowed(task_ex_id):
spec_parser.get_workflow_spec_by_id(task_ex.workflow_id)
)
if wf_ctrl.is_task_start_allowed(task_ex):
state, state_info = wf_ctrl.get_logical_task_state(task_ex)
if state == states.RUNNING:
continue_task(task_ex)
return
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
_schedule_check_task_start_allowed(task_ex, 1)
elif state == states.ERROR:
fail_task(task_ex, state_info)
elif state == states.WAITING:
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
_schedule_refresh_task_state(task_ex, 1)
else:
# Must never get here.
raise RuntimeError(
'Unexpected logical task state [task_ex=%s, state=%s]' %
(task_ex, state)
)
def _schedule_check_task_start_allowed(task_ex, delay=0):
def _schedule_refresh_task_state(task_ex, delay=0):
"""Schedules task preconditions check.
This method provides transactional decoupling of task preconditions
@ -298,7 +306,7 @@ def _schedule_check_task_start_allowed(task_ex, delay=0):
scheduler.schedule_call(
None,
_CHECK_TASK_START_ALLOWED_PATH,
_REFRESH_TASK_STATE_PATH,
delay,
unique_key=key,
task_ex_id=task_ex.id

View File

@ -278,19 +278,10 @@ class Workflow(object):
return
# Workflow is not completed if there are any incomplete task
# executions that are not in WAITING state. If all incomplete
# tasks are waiting and there are unhandled errors, then these
# tasks will not reach completion. In this case, mark the
# workflow complete.
# TODO(rakhmerov): In the new non-locking transactional model
# tasks in WAITING state should not be taken into account.
# I'm actually surprised why we did this so far. The point is
# that every task should complete sooner or later and get into
# one of the terminal state. We should not be leaving any tasks
# in a non-terminal state if we know exactly it's real state.
# executions.
incomplete_tasks = wf_utils.find_incomplete_task_executions(self.wf_ex)
if any(not states.is_waiting(t.state) for t in incomplete_tasks):
if incomplete_tasks:
return
wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec)

View File

@ -339,7 +339,7 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
self.assertEqual(states.SUCCESS, task_10_ex.state)
self.assertEqual(states.ERROR, task_21_ex.state)
self.assertIsNotNone(task_21_ex.state_info)
self.assertEqual(states.WAITING, task_30_ex.state)
self.assertEqual(states.ERROR, task_30_ex.state)
# Update env in workflow execution with the following.
updated_env = {
@ -933,7 +933,7 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
self.assertEqual(states.SUCCESS, task_1_ex.state)
self.assertEqual(states.ERROR, task_2_ex.state)
self.assertEqual(states.WAITING, task_3_ex.state)
self.assertEqual(states.ERROR, task_3_ex.state)
# Resume workflow and re-run failed task.
wf_ex = self.engine.rerun_workflow(task_2_ex.id)

View File

@ -455,7 +455,8 @@ class JoinEngineTest(base.EngineTestCase):
result4 = task4.published['result4']
self.assertIsNotNone(result4)
self.assertEqual(2, result4.count('False'))
self.assertTrue(result4.count('False') < 3)
self.assertTrue(result4.count('True') >= 1)
def test_full_join_parallel_published_vars(self):
wfs_tasks_join_complex = """---
@ -645,6 +646,8 @@ class JoinEngineTest(base.EngineTestCase):
tasks = wf_ex.task_executions
self.assertIsNotNone(wf_ex.state_info)
task10 = self._assert_single_item(tasks, name='task10')
task21 = self._assert_single_item(tasks, name='task21')
task22 = self._assert_single_item(tasks, name='task22')
@ -656,7 +659,7 @@ class JoinEngineTest(base.EngineTestCase):
self.assertEqual(states.SUCCESS, task22.state)
self.assertEqual(states.ERROR, task31.state)
self.assertNotIn('task32', [task.name for task in tasks])
self.assertEqual(states.WAITING, task40.state)
self.assertEqual(states.ERROR, task40.state)
def test_diamond_join_all(self):
wf_text = """---

View File

@ -126,12 +126,12 @@ class WorkflowController(object):
return cmds
@abc.abstractmethod
def is_task_start_allowed(self, task_ex):
"""Determines if the given task is allowed to start.
def get_logical_task_state(self, task_ex):
"""Determines a logical state of the given task.
:param task_ex: Task execution.
:return: True if all preconditions are met and the given task
is allowed to start.
:return: Tuple (state, state_info) which the given task should have
according to workflow rules and current states of other tasks.
"""
raise NotImplementedError

View File

@ -60,11 +60,13 @@ class DirectWorkflowController(base.WorkflowController):
if not t_spec.get_join():
return not t_ex_candidate.processed
return self._triggers_join(
t_spec,
self.wf_spec.get_tasks()[t_ex_candidate.name]
induced_state = self._get_induced_join_state(
self.wf_spec.get_tasks()[t_ex_candidate.name],
t_spec
)
return induced_state == states.RUNNING
def _find_next_commands(self, task_ex=None):
cmds = super(DirectWorkflowController, self)._find_next_commands(
task_ex
@ -140,9 +142,7 @@ class DirectWorkflowController(base.WorkflowController):
return
cmd.unique_key = self._get_join_unique_key(cmd)
if self._is_unsatisfied_join(cmd.task_spec):
cmd.wait = True
cmd.wait = True
def _get_join_unique_key(self, cmd):
return 'join-task-%s-%s' % (self.wf_ex.id, cmd.task_spec.get_name())
@ -160,13 +160,16 @@ class DirectWorkflowController(base.WorkflowController):
return ctx
def is_task_start_allowed(self, task_ex):
def get_logical_task_state(self, task_ex):
task_spec = self.wf_spec.get_tasks()[task_ex.name]
return (
not task_spec.get_join() or
not self._is_unsatisfied_join(task_spec)
)
if not task_spec.get_join():
# A simple 'non-join' task does not have any preconditions
# based on state of other tasks so its logical state always
# equals to its real state.
return task_ex.state, task_ex.state_info
return self._get_join_logical_state(task_spec)
def is_error_handled_for(self, task_ex):
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
@ -261,54 +264,115 @@ class DirectWorkflowController(base.WorkflowController):
if not condition or expr.evaluate(condition, ctx)
]
def _is_unsatisfied_join(self, task_spec):
def _get_join_logical_state(self, task_spec):
# TODO(rakhmerov): We need to use task_ex instead of task_spec
# in order to cover a use case when there's more than one instance
# of the same 'join' task in a workflow.
join_expr = task_spec.get_join()
if not join_expr:
return False
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return False
return states.RUNNING
# We need to count a number of triggering inbound transitions.
num = len([1 for in_t_s in in_task_specs
if self._triggers_join(task_spec, in_t_s)])
# List of tuples (task_name, state).
induced_states = [
(t_s.get_name(), self._get_induced_join_state(t_s, task_spec))
for t_s in in_task_specs
]
# If "join" is configured as a number.
if isinstance(join_expr, int) and num < join_expr:
return True
def count(state):
return len(list(filter(lambda s: s[1] == state, induced_states)))
if join_expr == 'all' and len(in_task_specs) > num:
return True
error_count = count(states.ERROR)
running_count = count(states.RUNNING)
total_count = len(induced_states)
if join_expr == 'one' and num == 0:
return True
def _blocked_message():
return (
'Blocked by tasks: %s' %
[s[0] for s in induced_states if s[1] == states.WAITING]
)
return False
def _failed_message():
return (
'Failed by tasks: %s' %
[s[0] for s in induced_states if s[1] == states.ERROR]
)
# If "join" is configured as a number or 'one'.
if isinstance(join_expr, int) or join_expr == 'one':
cardinality = 1 if join_expr == 'one' else join_expr
if running_count >= cardinality:
return states.RUNNING, None
# E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING]
# No chance to get 3 RUNNING states.
if error_count > (total_count - cardinality):
return states.ERROR, _failed_message()
return states.WAITING, _blocked_message()
if join_expr == 'all':
if total_count == running_count:
return states.RUNNING, None
if error_count > 0:
return states.ERROR, _failed_message()
return states.WAITING, _blocked_message()
raise RuntimeError('Unexpected join expression: %s' % join_expr)
# TODO(rakhmerov): Method signature is incorrect given that
# we may have multiple task executions for a task. It should
# accept inbound task execution rather than a spec.
def _triggers_join(self, join_task_spec, inbound_task_spec):
def _get_induced_join_state(self, inbound_task_spec, join_task_spec):
join_task_name = join_task_spec.get_name()
in_task_ex = self._find_task_execution_by_spec(inbound_task_spec)
if not in_task_ex:
if self._possible_route(inbound_task_spec):
return states.WAITING
else:
return states.ERROR
if not states.is_completed(in_task_ex.state):
return states.WAITING
if join_task_name not in self._find_next_task_names(in_task_ex):
return states.ERROR
return states.RUNNING
def _find_task_execution_by_spec(self, task_spec):
in_t_execs = wf_utils.find_task_executions_by_spec(
self.wf_ex,
inbound_task_spec
task_spec
)
# TODO(rakhmerov): Temporary hack. See the previous comment.
in_t_ex = in_t_execs[-1] if in_t_execs else None
return in_t_execs[-1] if in_t_execs else None
if not in_t_ex or not states.is_completed(in_t_ex.state):
return False
def _possible_route(self, task_spec):
# TODO(rakhmerov): In some cases this method will be expensive because
# it uses a multistep recursive search with DB queries.
# It will be optimized with Workflow Execution Graph moving forward.
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
return list(
filter(
lambda t_name: join_task_spec.get_name() == t_name,
self._find_next_task_names(in_t_ex)
)
)
if not in_task_specs:
return True
for t_s in in_task_specs:
t_ex = self._find_task_execution_by_spec(t_s)
if not t_ex:
if self._possible_route(t_s):
return True
else:
if task_spec.get_name() in self._find_next_task_names(t_ex):
return True
return False

View File

@ -108,9 +108,9 @@ class ReverseWorkflowController(base.WorkflowController):
return data_flow.evaluate_task_outbound_context(task_execs[0])
def is_task_start_allowed(self, task_ex):
def get_logical_task_state(self, task_ex):
# TODO(rakhmerov): Implement.
return True
return task_ex.state, task_ex.state_info
def is_error_handled_for(self, task_ex):
return task_ex.state != states.ERROR