Fix invalid workflow completion in case of "join"
When we use 'join' clause in workflow execution, the execution should be in running status, until all possible upstream tasks are completed and corresponding conditions have triggered. Do not remove the tasks with unsatisfied join when continuing execution, or the workflow will succeed erroneously. So, a new attribute 'wait_flag' is added to RunTask class, to indicate if the task should run immediately. The difference between 'waiting' and 'delayed' status is, 'waiting' indicates a task needs to wait for its conditions to be satisfied, 'delayed' means the task has to wait because of policy(wait-before, wait-after or retry). Change-Id: Idcd162b9c000b85acf2adabc158f848d0824ac57 Closes-Bug: #1455038
This commit is contained in:
		@@ -354,7 +354,9 @@ class DefaultEngine(base.Engine):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        for cmd in wf_cmds:
 | 
			
		||||
            if isinstance(cmd, commands.RunTask):
 | 
			
		||||
            if isinstance(cmd, commands.RunTask) and cmd.is_waiting():
 | 
			
		||||
                task_handler.defer_task(cmd)
 | 
			
		||||
            elif isinstance(cmd, commands.RunTask):
 | 
			
		||||
                task_handler.run_new_task(cmd)
 | 
			
		||||
            elif isinstance(cmd, commands.RunExistingTask):
 | 
			
		||||
                task_handler.run_existing_task(cmd.task_ex.id)
 | 
			
		||||
 
 | 
			
		||||
@@ -29,6 +29,7 @@ from mistral.utils import wf_trace
 | 
			
		||||
from mistral.workbook import parser as spec_parser
 | 
			
		||||
from mistral.workflow import data_flow
 | 
			
		||||
from mistral.workflow import states
 | 
			
		||||
from mistral.workflow import utils as wf_utils
 | 
			
		||||
from mistral.workflow import with_items
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -73,6 +74,16 @@ def _run_existing_task(task_ex, task_spec, wf_spec):
 | 
			
		||||
        _schedule_noop_action(task_ex, task_spec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def defer_task(wf_cmd):
 | 
			
		||||
    """Defers a task"""
 | 
			
		||||
    ctx = wf_cmd.ctx
 | 
			
		||||
    wf_ex = wf_cmd.wf_ex
 | 
			
		||||
    task_spec = wf_cmd.task_spec
 | 
			
		||||
 | 
			
		||||
    if not wf_utils.find_task_execution(wf_ex, task_spec):
 | 
			
		||||
        _create_task_execution(wf_ex, task_spec, ctx, state=states.WAITING)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_new_task(wf_cmd):
 | 
			
		||||
    """Runs a task."""
 | 
			
		||||
    ctx = wf_cmd.ctx
 | 
			
		||||
@@ -80,12 +91,24 @@ def run_new_task(wf_cmd):
 | 
			
		||||
    wf_spec = spec_parser.get_workflow_spec(wf_ex.spec)
 | 
			
		||||
    task_spec = wf_cmd.task_spec
 | 
			
		||||
 | 
			
		||||
    LOG.debug(
 | 
			
		||||
        'Starting workflow task [workflow=%s, task_spec=%s]' %
 | 
			
		||||
        (wf_ex.name, task_spec)
 | 
			
		||||
    # NOTE(xylan): Need to think how to get rid of this weird judgment to keep
 | 
			
		||||
    # it more consistent with the function name.
 | 
			
		||||
    task_ex = wf_utils.find_task_execution_with_state(
 | 
			
		||||
        wf_ex,
 | 
			
		||||
        task_spec,
 | 
			
		||||
        states.WAITING
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    task_ex = _create_task_execution(wf_ex, task_spec, ctx)
 | 
			
		||||
    if task_ex:
 | 
			
		||||
        _set_task_state(task_ex, states.RUNNING)
 | 
			
		||||
        task_ex.in_context = ctx
 | 
			
		||||
    else:
 | 
			
		||||
        task_ex = _create_task_execution(wf_ex, task_spec, ctx)
 | 
			
		||||
 | 
			
		||||
    LOG.debug(
 | 
			
		||||
        'Starting workflow task [workflow=%s, task_spec=%s, init_state=%s]' %
 | 
			
		||||
        (wf_ex.name, task_spec, task_ex.state)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # TODO(rakhmerov): 'concurrency' policy should keep a number of running
 | 
			
		||||
    # actions/workflows under control so it can't be implemented if it runs
 | 
			
		||||
@@ -144,12 +167,12 @@ def on_action_complete(action_ex, result):
 | 
			
		||||
    return task_ex
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_task_execution(wf_ex, task_spec, ctx):
 | 
			
		||||
def _create_task_execution(wf_ex, task_spec, ctx, state=states.RUNNING):
 | 
			
		||||
    task_ex = db_api.create_task_execution({
 | 
			
		||||
        'name': task_spec.get_name(),
 | 
			
		||||
        'workflow_execution_id': wf_ex.id,
 | 
			
		||||
        'workflow_name': wf_ex.workflow_name,
 | 
			
		||||
        'state': states.RUNNING,
 | 
			
		||||
        'state': state,
 | 
			
		||||
        'spec': task_spec.to_dict(),
 | 
			
		||||
        'in_context': ctx,
 | 
			
		||||
        'published': {},
 | 
			
		||||
 
 | 
			
		||||
@@ -178,24 +178,28 @@ class JoinEngineTest(base.EngineTestCase):
 | 
			
		||||
        # Start workflow.
 | 
			
		||||
        wf_ex = self.engine.start_workflow('wf', {})
 | 
			
		||||
 | 
			
		||||
        self._await(lambda: self.is_execution_success(wf_ex.id))
 | 
			
		||||
        self._await(
 | 
			
		||||
            lambda:
 | 
			
		||||
            len(db_api.get_workflow_execution(wf_ex.id).task_executions) == 4
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Note: We need to reread execution to access related tasks.
 | 
			
		||||
        wf_ex = db_api.get_workflow_execution(wf_ex.id)
 | 
			
		||||
 | 
			
		||||
        tasks = wf_ex.task_executions
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(3, len(tasks))
 | 
			
		||||
 | 
			
		||||
        task1 = self._assert_single_item(tasks, name='task1')
 | 
			
		||||
        task2 = self._assert_single_item(tasks, name='task2')
 | 
			
		||||
        task3 = self._assert_single_item(tasks, name='task3')
 | 
			
		||||
        task4 = self._assert_single_item(tasks, name='task4')
 | 
			
		||||
 | 
			
		||||
        # NOTE(xylan): We ensure task4 is successful here because of the
 | 
			
		||||
        # uncertainty of its running parallelly with task3.
 | 
			
		||||
        self._await(lambda: self.is_task_success(task4.id))
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(states.RUNNING, wf_ex.state)
 | 
			
		||||
        self.assertEqual(states.SUCCESS, task1.state)
 | 
			
		||||
        self.assertEqual(states.SUCCESS, task2.state)
 | 
			
		||||
        self.assertEqual(states.SUCCESS, task4.state)
 | 
			
		||||
 | 
			
		||||
        self.assertDictEqual({'result': 4}, wf_ex.output)
 | 
			
		||||
        self.assertEqual(states.WAITING, task3.state)
 | 
			
		||||
 | 
			
		||||
    def test_partial_join(self):
 | 
			
		||||
        wf_partial_join = """---
 | 
			
		||||
 
 | 
			
		||||
@@ -71,3 +71,11 @@ class StatesModuleTest(base.BaseTest):
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.ERROR, s.DELAYED))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.ERROR, s.SUCCESS))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.ERROR, s.IDLE))
 | 
			
		||||
 | 
			
		||||
        # From WAITING
 | 
			
		||||
        self.assertTrue(s.is_valid_transition(s.WAITING, s.RUNNING))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.WAITING, s.SUCCESS))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.WAITING, s.PAUSED))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.WAITING, s.DELAYED))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.WAITING, s.IDLE))
 | 
			
		||||
        self.assertFalse(s.is_valid_transition(s.WAITING, s.ERROR))
 | 
			
		||||
 
 | 
			
		||||
@@ -13,6 +13,7 @@
 | 
			
		||||
#    limitations under the License.
 | 
			
		||||
 | 
			
		||||
from mistral.workbook import parser as spec_parser
 | 
			
		||||
from mistral.workbook.v2 import tasks
 | 
			
		||||
from mistral.workflow import states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -41,10 +42,19 @@ class Noop(WorkflowCommand):
 | 
			
		||||
class RunTask(WorkflowCommand):
 | 
			
		||||
    """Instruction to run a workflow task."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, wf_ex, task_spec, ctx):
 | 
			
		||||
        super(RunTask, self).__init__(wf_ex, task_spec, ctx)
 | 
			
		||||
        self.wait_flag = False
 | 
			
		||||
 | 
			
		||||
    def is_waiting(self):
 | 
			
		||||
        return (self.wait_flag and
 | 
			
		||||
                isinstance(self.task_spec, tasks.DirectWorkflowTaskSpec) and
 | 
			
		||||
                self.task_spec.get_join())
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return (
 | 
			
		||||
            "Run task [workflow=%s, task=%s]"
 | 
			
		||||
            % (self.wf_ex.name, self.task_spec.get_name())
 | 
			
		||||
            "Run task [workflow=%s, task=%s, waif_flag=%s]"
 | 
			
		||||
            % (self.wf_ex.name, self.task_spec.get_name(), self.wait_flag)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -116,23 +116,28 @@ class DirectWorkflowController(base.WorkflowController):
 | 
			
		||||
                self.wf_spec.get_tasks()[task_ex.name]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            cmds.append(
 | 
			
		||||
                commands.create_command(
 | 
			
		||||
                    t_n,
 | 
			
		||||
                    self.wf_ex,
 | 
			
		||||
                    t_s,
 | 
			
		||||
                    self._get_task_inbound_context(t_s)
 | 
			
		||||
                )
 | 
			
		||||
            cmd = commands.create_command(
 | 
			
		||||
                t_n,
 | 
			
		||||
                self.wf_ex,
 | 
			
		||||
                t_s,
 | 
			
		||||
                self._get_task_inbound_context(t_s)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        LOG.debug("Found commands: %s" % cmds)
 | 
			
		||||
            # NOTE(xylan): Decide whether or not a join task should run
 | 
			
		||||
            # immediately
 | 
			
		||||
            if self._is_unsatisfied_join(cmd):
 | 
			
		||||
                cmd.wait_flag = True
 | 
			
		||||
 | 
			
		||||
            cmds.append(cmd)
 | 
			
		||||
 | 
			
		||||
        # We need to remove all "join" tasks that have already started
 | 
			
		||||
        # (or even completed) to prevent running "join" tasks more than
 | 
			
		||||
        # once.
 | 
			
		||||
        cmds = self._remove_started_joins(cmds)
 | 
			
		||||
 | 
			
		||||
        return self._remove_unsatisfied_joins(cmds)
 | 
			
		||||
        LOG.debug("Found commands: %s" % cmds)
 | 
			
		||||
 | 
			
		||||
        return cmds
 | 
			
		||||
 | 
			
		||||
    def _has_inbound_transitions(self, task_spec):
 | 
			
		||||
        return len(self._find_inbound_task_specs(task_spec)) > 0
 | 
			
		||||
@@ -294,14 +299,15 @@ class DirectWorkflowController(base.WorkflowController):
 | 
			
		||||
        return filter(lambda cmd: not self._is_started_join(cmd), cmds)
 | 
			
		||||
 | 
			
		||||
    def _is_started_join(self, cmd):
 | 
			
		||||
        if not isinstance(cmd, commands.RunTask):
 | 
			
		||||
        if not (isinstance(cmd, commands.RunTask) and
 | 
			
		||||
                cmd.task_spec.get_join()):
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        return (cmd.task_spec.get_join()
 | 
			
		||||
                and wf_utils.find_task_execution(self.wf_ex, cmd.task_spec))
 | 
			
		||||
 | 
			
		||||
    def _remove_unsatisfied_joins(self, cmds):
 | 
			
		||||
        return filter(lambda cmd: not self._is_unsatisfied_join(cmd), cmds)
 | 
			
		||||
        return wf_utils.find_task_execution_not_state(
 | 
			
		||||
            self.wf_ex,
 | 
			
		||||
            cmd.task_spec,
 | 
			
		||||
            states.WAITING
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _is_unsatisfied_join(self, cmd):
 | 
			
		||||
        if not isinstance(cmd, commands.RunTask):
 | 
			
		||||
 
 | 
			
		||||
@@ -17,16 +17,18 @@
 | 
			
		||||
"""Valid task and workflow states."""
 | 
			
		||||
 | 
			
		||||
IDLE = 'IDLE'
 | 
			
		||||
WAITING = 'WAITING'
 | 
			
		||||
RUNNING = 'RUNNING'
 | 
			
		||||
PAUSED = 'PAUSED'
 | 
			
		||||
DELAYED = 'DELAYED'
 | 
			
		||||
SUCCESS = 'SUCCESS'
 | 
			
		||||
ERROR = 'ERROR'
 | 
			
		||||
 | 
			
		||||
_ALL = [IDLE, RUNNING, SUCCESS, ERROR, PAUSED, DELAYED]
 | 
			
		||||
_ALL = [IDLE, WAITING, RUNNING, SUCCESS, ERROR, PAUSED, DELAYED]
 | 
			
		||||
 | 
			
		||||
_VALID_TRANSITIONS = {
 | 
			
		||||
    IDLE: [RUNNING, ERROR],
 | 
			
		||||
    WAITING: [RUNNING],
 | 
			
		||||
    RUNNING: [PAUSED, DELAYED, SUCCESS, ERROR],
 | 
			
		||||
    PAUSED: [RUNNING, ERROR],
 | 
			
		||||
    DELAYED: [RUNNING, ERROR],
 | 
			
		||||
@@ -47,6 +49,10 @@ def is_completed(state):
 | 
			
		||||
    return state in [SUCCESS, ERROR]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_waiting(state):
 | 
			
		||||
    return state == WAITING
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_idle(state):
 | 
			
		||||
    return state == IDLE
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -54,6 +54,24 @@ def find_task_execution(wf_ex, task_spec):
 | 
			
		||||
    return task_execs[0] if len(task_execs) > 0 else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_task_execution_not_state(wf_ex, task_spec, state):
 | 
			
		||||
    task_execs = [
 | 
			
		||||
        t for t in wf_ex.task_executions
 | 
			
		||||
        if t.name == task_spec.get_name() and t.state != state
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return task_execs[0] if len(task_execs) > 0 else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_task_execution_with_state(wf_ex, task_spec, state):
 | 
			
		||||
    task_execs = [
 | 
			
		||||
        t for t in wf_ex.task_executions
 | 
			
		||||
        if t.name == task_spec.get_name() and t.state == state
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return task_execs[0] if len(task_execs) > 0 else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_task_executions(wf_ex, task_specs):
 | 
			
		||||
    return filter(
 | 
			
		||||
        None,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user