From 9f236248c1aaee238e9b476f50350b17f2dbb259 Mon Sep 17 00:00:00 2001 From: Renat Akhmerov Date: Fri, 5 Aug 2016 19:35:51 +0700 Subject: [PATCH] Towards non-locking model: adapt 'join' tasks to work w/o locks * Added a scheduled job to check if a 'join' task is allowed to start in a separate transaction to prevent phantom reads. Architecturally this is done in a generic way with a thought in mind that we also need to adapt reverse workflow to work similarly. In order to avoid getting duplicates for 'join' tasks recently added DB API methods 'insert_or_ignore' is used. TODO: * Fix ReverseWorkflowController to work in non-locking model * Fix 'with-items' so that it can work in non-locking model Partially implements: blueprint mistral-non-locking-tx-model Change-Id: Ia319965a65d7b3f09eaf28792104d7fd58e9c82e --- mistral/db/v2/api.py | 8 +- mistral/db/v2/sqlalchemy/api.py | 21 ++-- mistral/db/v2/sqlalchemy/models.py | 4 +- mistral/engine/task_handler.py | 90 ++++++++++++++++- mistral/engine/tasks.py | 65 ++++++++---- mistral/engine/workflows.py | 6 ++ .../tests/unit/db/v2/test_insert_or_ignore.py | 12 +-- .../unit/engine/test_direct_workflow_rerun.py | 98 +++++++++++-------- mistral/tests/unit/engine/test_join.py | 31 +++++- mistral/workflow/base.py | 33 ++++--- mistral/workflow/commands.py | 9 +- mistral/workflow/direct_workflow.py | 80 +++++++-------- mistral/workflow/reverse_workflow.py | 19 +++- 13 files changed, 333 insertions(+), 143 deletions(-) diff --git a/mistral/db/v2/api.py b/mistral/db/v2/api.py index e50a6ea0..37417011 100644 --- a/mistral/db/v2/api.py +++ b/mistral/db/v2/api.py @@ -292,9 +292,9 @@ def get_task_execution(id): return IMPL.get_task_execution(id) -def load_task_execution(name): +def load_task_execution(id): """Unlike get_task_execution this method is allowed to return None.""" - return IMPL.load_task_execution(name) + return IMPL.load_task_execution(id) def get_task_executions(limit=None, marker=None, sort_keys=['created_at'], @@ -312,6 +312,10 @@ def create_task_execution(values): return IMPL.create_task_execution(values) +def insert_or_ignore_task_execution(values): + return IMPL.insert_or_ignore_task_execution(values) + + def update_task_execution(id, values): return IMPL.update_task_execution(id, values) diff --git a/mistral/db/v2/sqlalchemy/api.py b/mistral/db/v2/sqlalchemy/api.py index ea812e57..0f307fc6 100644 --- a/mistral/db/v2/sqlalchemy/api.py +++ b/mistral/db/v2/sqlalchemy/api.py @@ -319,9 +319,16 @@ def insert_or_ignore(model_cls, values, session=None): replace_string=replace ) - res = session.execute(insert, model.to_dict()) - - return res.rowcount + # NOTE(rakhmerov): As it turned out the result proxy object + # returned by insert expression does not provide a valid + # count of updated rows in for all supported databases. + # For this reason we shouldn't return anything from this + # method. In order to check whether a new object was really + # inserted users should rely on different approaches. The + # simplest is just to insert an object with an explicitly + # set id and then check if object with such id exists in DB. + # Generated id must be unique to make it work. + session.execute(insert, model.to_dict()) # Workbook definitions. @@ -818,6 +825,10 @@ def create_task_execution(values, session=None): return task_ex +def insert_or_ignore_task_execution(values): + insert_or_ignore(models.TaskExecution, values.copy()) + + @b.session_aware() def update_task_execution(id, values, session=None): task_ex = get_task_execution(id) @@ -869,9 +880,7 @@ def create_delayed_call(values, session=None): def insert_or_ignore_delayed_call(values): - row_count = insert_or_ignore(models.DelayedCall, values.copy()) - - return row_count + insert_or_ignore(models.DelayedCall, values.copy()) @b.session_aware() diff --git a/mistral/db/v2/sqlalchemy/models.py b/mistral/db/v2/sqlalchemy/models.py index 5092da03..14a7a6d9 100644 --- a/mistral/db/v2/sqlalchemy/models.py +++ b/mistral/db/v2/sqlalchemy/models.py @@ -165,10 +165,12 @@ class TaskExecution(Execution): sa.Index('%s_scope' % __tablename__, 'scope'), sa.Index('%s_state' % __tablename__, 'state'), sa.Index('%s_updated_at' % __tablename__, 'updated_at'), + sa.UniqueConstraint('unique_key') ) # Main properties. action_spec = sa.Column(st.JsonLongDictType()) + unique_key = sa.Column(sa.String(80), nullable=True) # Whether the task is fully processed (publishing and calculating commands # after it). It allows to simplify workflow controller implementations @@ -323,7 +325,7 @@ class DelayedCall(mb.MistralModelBase): target_method_name = sa.Column(sa.String(80), nullable=False) method_arguments = sa.Column(st.JsonDictType()) serializers = sa.Column(st.JsonDictType()) - unique_key = sa.Column(sa.String(50), nullable=True) + unique_key = sa.Column(sa.String(80), nullable=True) auth_context = sa.Column(st.JsonDictType()) execution_time = sa.Column(sa.DateTime, nullable=False) processing = sa.Column(sa.Boolean, default=False, nullable=False) diff --git a/mistral/engine/task_handler.py b/mistral/engine/task_handler.py index 0ebc7f22..4ba61fc1 100644 --- a/mistral/engine/task_handler.py +++ b/mistral/engine/task_handler.py @@ -19,10 +19,13 @@ from oslo_log import log as logging from osprofiler import profiler import traceback as tb +from mistral.db.v2 import api as db_api from mistral.engine import tasks from mistral.engine import workflow_handler as wf_handler from mistral import exceptions as exc +from mistral.services import scheduler from mistral.workbook import parser as spec_parser +from mistral.workflow import base as wf_base from mistral.workflow import commands as wf_cmds from mistral.workflow import states @@ -31,6 +34,10 @@ from mistral.workflow import states LOG = logging.getLogger(__name__) +_CHECK_TASK_START_ALLOWED_PATH = ( + 'mistral.engine.task_handler._check_task_start_allowed' +) + @profiler.trace('task-handler-run-task') def run_task(wf_cmd): @@ -60,6 +67,9 @@ def run_task(wf_cmd): return + if task.is_waiting(): + _schedule_check_task_start_allowed(task.task_ex) + if task.is_completed(): wf_handler.schedule_on_task_complete(task.task_ex) @@ -127,6 +137,8 @@ def continue_task(task_ex): ) try: + task.set_state(states.RUNNING, None) + task.run() except exc.MistralException as e: wf_ex = task_ex.workflow_execution @@ -194,7 +206,8 @@ def _build_task_from_command(cmd): cmd.wf_spec, spec_parser.get_task_spec(cmd.task_ex.spec), cmd.ctx, - cmd.task_ex + task_ex=cmd.task_ex, + unique_key=cmd.task_ex.unique_key ) if cmd.reset: @@ -203,7 +216,13 @@ def _build_task_from_command(cmd): return task if isinstance(cmd, wf_cmds.RunTask): - task = _create_task(cmd.wf_ex, cmd.wf_spec, cmd.task_spec, cmd.ctx) + task = _create_task( + cmd.wf_ex, + cmd.wf_spec, + cmd.task_spec, + cmd.ctx, + unique_key=cmd.unique_key + ) if cmd.is_waiting(): task.defer() @@ -213,8 +232,69 @@ def _build_task_from_command(cmd): raise exc.MistralError('Unsupported workflow command: %s' % cmd) -def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None): +def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None, + unique_key=None): if task_spec.get_with_items(): - return tasks.WithItemsTask(wf_ex, wf_spec, task_spec, ctx, task_ex) + return tasks.WithItemsTask( + wf_ex, + wf_spec, + task_spec, + ctx, + task_ex, + unique_key + ) - return tasks.RegularTask(wf_ex, wf_spec, task_spec, ctx, task_ex) + return tasks.RegularTask( + wf_ex, + wf_spec, + task_spec, + ctx, + task_ex, + unique_key + ) + + +def _check_task_start_allowed(task_ex_id): + with db_api.transaction(): + task_ex = db_api.get_task_execution(task_ex_id) + + wf_ctrl = wf_base.get_controller( + task_ex.workflow_execution, + spec_parser.get_workflow_spec_by_id(task_ex.workflow_id) + ) + + if wf_ctrl.is_task_start_allowed(task_ex): + continue_task(task_ex) + + return + + # TODO(rakhmerov): Algorithm for increasing rescheduling delay. + _schedule_check_task_start_allowed(task_ex, 1) + + +def _schedule_check_task_start_allowed(task_ex, delay=0): + """Schedules task preconditions check. + + This method provides transactional decoupling of task preconditions + check from events that can potentially satisfy those preconditions. + + It's needed in non-locking model in order to avoid 'phantom read' + phenomena when reading state of multiple tasks to see if a task that + depends on them can start. Just starting a separate transaction + without using scheduler is not safe due to concurrency window that + we'll have in this case (time between transactions) whereas scheduler + is a special component that is designed to be resistant to failures. + + :param task_ex: Task execution. + :param delay: Delay. + :return: + """ + key = 'th_c_t_s_a-%s' % task_ex.id + + scheduler.schedule_call( + None, + _CHECK_TASK_START_ALLOWED_PATH, + delay, + unique_key=key, + task_ex_id=task_ex.id + ) diff --git a/mistral/engine/tasks.py b/mistral/engine/tasks.py index 65db99bf..5b442f92 100644 --- a/mistral/engine/tasks.py +++ b/mistral/engine/tasks.py @@ -47,15 +47,23 @@ class Task(object): """ @profiler.trace('task-create') - def __init__(self, wf_ex, wf_spec, task_spec, ctx, task_ex=None): + def __init__(self, wf_ex, wf_spec, task_spec, ctx, task_ex=None, + unique_key=None): self.wf_ex = wf_ex self.task_spec = task_spec self.ctx = ctx self.task_ex = task_ex self.wf_spec = wf_spec + self.unique_key = unique_key self.waiting = False self.reset_flag = False + def is_completed(self): + return self.task_ex and states.is_completed(self.task_ex.state) + + def is_waiting(self): + return self.waiting + @abc.abstractmethod def on_action_complete(self, action_ex): """Handle action completion. @@ -160,7 +168,7 @@ class Task(object): wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec) # Calculate commands to process next. - cmds = wf_ctrl.continue_workflow() + cmds = wf_ctrl.continue_workflow(self.task_ex) # Mark task as processed after all decisions have been made # upon its completion. @@ -181,23 +189,39 @@ class Task(object): p.after_task_complete(self.task_ex, self.task_spec) def _create_task_execution(self, state=states.RUNNING): - self.task_ex = db_api.create_task_execution({ + values = { + 'id': utils.generate_unicode_uuid(), 'name': self.task_spec.get_name(), 'workflow_execution_id': self.wf_ex.id, 'workflow_name': self.wf_ex.workflow_name, 'workflow_id': self.wf_ex.workflow_id, 'state': state, 'spec': self.task_spec.to_dict(), + 'unique_key': self.unique_key, 'in_context': self.ctx, 'published': {}, 'runtime_context': {}, 'project_id': self.wf_ex.project_id - }) + } + + db_api.insert_or_ignore_task_execution(values) + + # Since 'insert_or_ignore' cannot return a valid count of updated + # rows the only reliable way to check if insert operation has created + # an object is try to load this object by just generated uuid. + task_ex = db_api.load_task_execution(values['id']) + + if not task_ex: + return False + + self.task_ex = task_ex # Add to collection explicitly so that it's in a proper # state within the current session. self.wf_ex.task_executions.append(self.task_ex) + return True + def _get_action_defaults(self): action_name = self.task_spec.get_action_name() @@ -226,9 +250,6 @@ class RegularTask(Task): self.complete(state, state_info) - def is_completed(self): - return self.task_ex and states.is_completed(self.task_ex.state) - @profiler.trace('task-run') def run(self): if not self.task_ex: @@ -237,20 +258,12 @@ class RegularTask(Task): self._run_existing() def _run_new(self): - # NOTE(xylan): Need to think how to get rid of this weird judgment - # to keep it more consistent with the function name. - self.task_ex = wf_utils.find_task_execution_with_state( - self.wf_ex, - self.task_spec, - states.WAITING - ) + if not self._create_task_execution(): + # Task with the same unique key has already been created. + return - if self.task_ex: - self.set_state(states.RUNNING, None) - - self.task_ex.in_context = self.ctx - else: - self._create_task_execution() + if self.waiting: + return LOG.debug( 'Starting task [workflow=%s, task_spec=%s, init_state=%s]' % @@ -278,10 +291,18 @@ class RegularTask(Task): self.set_state(states.RUNNING, None, processed=False) + self._update_inbound_context() self._reset_actions() - self._schedule_actions() + def _update_inbound_context(self): + assert self.task_ex + + wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec) + + self.ctx = wf_ctrl.get_task_inbound_context(self.task_spec) + self.task_ex.in_context = self.ctx + def _reset_actions(self): """Resets task state. @@ -386,7 +407,9 @@ class WithItemsTask(RegularTask): if with_items.is_completed(self.task_ex): state = with_items.get_final_state(self.task_ex) + self.complete(state, state_info[state]) + return if (with_items.has_more_iterations(self.task_ex) diff --git a/mistral/engine/workflows.py b/mistral/engine/workflows.py index e3e4374b..ab344767 100644 --- a/mistral/engine/workflows.py +++ b/mistral/engine/workflows.py @@ -282,6 +282,12 @@ class Workflow(object): # tasks are waiting and there are unhandled errors, then these # tasks will not reach completion. In this case, mark the # workflow complete. + # TODO(rakhmerov): In the new non-locking transactional model + # tasks in WAITING state should not be taken into account. + # I'm actually surprised why we did this so far. The point is + # that every task should complete sooner or later and get into + # one of the terminal state. We should not be leaving any tasks + # in a non-terminal state if we know exactly it's real state. incomplete_tasks = wf_utils.find_incomplete_task_executions(self.wf_ex) if any(not states.is_waiting(t.state) for t in incomplete_tasks): diff --git a/mistral/tests/unit/db/v2/test_insert_or_ignore.py b/mistral/tests/unit/db/v2/test_insert_or_ignore.py index 15e78bbe..ce12087e 100644 --- a/mistral/tests/unit/db/v2/test_insert_or_ignore.py +++ b/mistral/tests/unit/db/v2/test_insert_or_ignore.py @@ -46,13 +46,11 @@ class InsertOrIgnoreTest(test_base.DbTestCase): ) def test_insert_or_ignore_without_conflicts(self): - row_count = db_api.insert_or_ignore( + db_api.insert_or_ignore( db_models.DelayedCall, DELAYED_CALL.copy() ) - self.assertEqual(1, row_count) - delayed_calls = db_api.get_delayed_calls() self.assertEqual(1, len(delayed_calls)) @@ -67,9 +65,7 @@ class InsertOrIgnoreTest(test_base.DbTestCase): values['unique_key'] = 'key' - row_count = db_api.insert_or_ignore(db_models.DelayedCall, values) - - self.assertEqual(1, row_count) + db_api.insert_or_ignore(db_models.DelayedCall, values) delayed_calls = db_api.get_delayed_calls() @@ -85,9 +81,7 @@ class InsertOrIgnoreTest(test_base.DbTestCase): values['unique_key'] = 'key' - row_count = db_api.insert_or_ignore(db_models.DelayedCall, values) - - self.assertEqual(0, row_count) + db_api.insert_or_ignore(db_models.DelayedCall, values) delayed_calls = db_api.get_delayed_calls() diff --git a/mistral/tests/unit/engine/test_direct_workflow_rerun.py b/mistral/tests/unit/engine/test_direct_workflow_rerun.py index 9bcbf61e..fe4e042c 100644 --- a/mistral/tests/unit/engine/test_direct_workflow_rerun.py +++ b/mistral/tests/unit/engine/test_direct_workflow_rerun.py @@ -321,18 +321,19 @@ class DirectWorkflowRerunTest(base.EngineTestCase): self.await_workflow_error(wf_ex.id) - wf_ex = db_api.get_workflow_execution(wf_ex.id) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions self.assertEqual(states.ERROR, wf_ex.state) self.assertIsNotNone(wf_ex.state_info) - self.assertEqual(3, len(wf_ex.task_executions)) + self.assertEqual(3, len(task_execs)) self.assertDictEqual(env, wf_ex.params['env']) self.assertDictEqual(env, wf_ex.context['__env']) - task_exs = wf_ex.task_executions - task_10_ex = self._assert_single_item(task_exs, name='t10') - task_21_ex = self._assert_single_item(task_exs, name='t21') - task_30_ex = self._assert_single_item(task_exs, name='t30') + task_10_ex = self._assert_single_item(task_execs, name='t10') + task_21_ex = self._assert_single_item(task_execs, name='t21') + task_30_ex = self._assert_single_item(task_execs, name='t30') self.assertEqual(states.SUCCESS, task_10_ex.state) self.assertEqual(states.ERROR, task_21_ex.state) @@ -347,29 +348,31 @@ class DirectWorkflowRerunTest(base.EngineTestCase): } # Resume workflow and re-run failed task. - self.engine.rerun_workflow(task_21_ex.id, env=updated_env) - - wf_ex = db_api.get_workflow_execution(wf_ex.id) + wf_ex = self.engine.rerun_workflow(task_21_ex.id, env=updated_env) self.assertEqual(states.RUNNING, wf_ex.state) self.assertIsNone(wf_ex.state_info) self.assertDictEqual(updated_env, wf_ex.params['env']) self.assertDictEqual(updated_env, wf_ex.context['__env']) + # Await t30 success. + self.await_task_success(task_30_ex.id) + # Wait for the workflow to succeed. self.await_workflow_success(wf_ex.id) - wf_ex = db_api.get_workflow_execution(wf_ex.id) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions self.assertEqual(states.SUCCESS, wf_ex.state) self.assertIsNone(wf_ex.state_info) - self.assertEqual(4, len(wf_ex.task_executions)) + self.assertEqual(4, len(task_execs)) - task_exs = wf_ex.task_executions - task_10_ex = self._assert_single_item(task_exs, name='t10') - task_21_ex = self._assert_single_item(task_exs, name='t21') - task_22_ex = self._assert_single_item(task_exs, name='t22') - task_30_ex = self._assert_single_item(task_exs, name='t30') + task_10_ex = self._assert_single_item(task_execs, name='t10') + task_21_ex = self._assert_single_item(task_execs, name='t21') + task_22_ex = self._assert_single_item(task_execs, name='t22') + task_30_ex = self._assert_single_item(task_execs, name='t30') # Check action executions of task 10. self.assertEqual(states.SUCCESS, task_10_ex.state) @@ -867,79 +870,92 @@ class DirectWorkflowRerunTest(base.EngineTestCase): # Run workflow and fail task. wf_ex = self.engine.start_workflow('wb1.wf1', {}) - wf_ex = db_api.get_workflow_execution(wf_ex.id) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions - task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') - task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2') + task_1_ex = self._assert_single_item(task_execs, name='t1') + task_2_ex = self._assert_single_item(task_execs, name='t2') self.await_task_error(task_1_ex.id) self.await_task_error(task_2_ex.id) self.await_workflow_error(wf_ex.id) - wf_ex = db_api.get_workflow_execution(wf_ex.id) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions self.assertEqual(states.ERROR, wf_ex.state) self.assertIsNotNone(wf_ex.state_info) - self.assertEqual(2, len(wf_ex.task_executions)) + self.assertEqual(2, len(task_execs)) - task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') - task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2') + task_1_ex = self._assert_single_item(task_execs, name='t1') self.assertEqual(states.ERROR, task_1_ex.state) self.assertIsNotNone(task_1_ex.state_info) + + task_2_ex = self._assert_single_item(task_execs, name='t2') + self.assertEqual(states.ERROR, task_2_ex.state) self.assertIsNotNone(task_2_ex.state_info) # Resume workflow and re-run failed task. - self.engine.rerun_workflow(task_1_ex.id) - - wf_ex = db_api.get_workflow_execution(wf_ex.id) + wf_ex = self.engine.rerun_workflow(task_1_ex.id) self.assertEqual(states.RUNNING, wf_ex.state) self.assertIsNone(wf_ex.state_info) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions + # Wait for the task to succeed. - task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') + task_1_ex = self._assert_single_item(task_execs, name='t1') self.await_task_success(task_1_ex.id) self.await_workflow_error(wf_ex.id) - wf_ex = db_api.get_workflow_execution(wf_ex.id) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions self.assertEqual(states.ERROR, wf_ex.state) self.assertIsNotNone(wf_ex.state_info) - self.assertEqual(3, len(wf_ex.task_executions)) + self.assertEqual(3, len(task_execs)) - task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') - task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2') - task_3_ex = self._assert_single_item(wf_ex.task_executions, name='t3') + task_1_ex = self._assert_single_item(task_execs, name='t1') + task_2_ex = self._assert_single_item(task_execs, name='t2') + task_3_ex = self._assert_single_item(task_execs, name='t3') self.assertEqual(states.SUCCESS, task_1_ex.state) self.assertEqual(states.ERROR, task_2_ex.state) self.assertEqual(states.WAITING, task_3_ex.state) # Resume workflow and re-run failed task. - self.engine.rerun_workflow(task_2_ex.id) - - wf_ex = db_api.get_workflow_execution(wf_ex.id) + wf_ex = self.engine.rerun_workflow(task_2_ex.id) self.assertEqual(states.RUNNING, wf_ex.state) self.assertIsNone(wf_ex.state_info) + # Join now should finally complete. + self.await_task_success(task_3_ex.id) + # Wait for the workflow to succeed. self.await_workflow_success(wf_ex.id) - wf_ex = db_api.get_workflow_execution(wf_ex.id) + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_execs = wf_ex.task_executions self.assertEqual(states.SUCCESS, wf_ex.state) self.assertIsNone(wf_ex.state_info) - self.assertEqual(3, len(wf_ex.task_executions)) + self.assertEqual(3, len(task_execs)) - task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1') - task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2') - task_3_ex = self._assert_single_item(wf_ex.task_executions, name='t3') + task_1_ex = self._assert_single_item(task_execs, name='t1') + task_2_ex = self._assert_single_item(task_execs, name='t2') + task_3_ex = self._assert_single_item(task_execs, name='t3') # Check action executions of task 1. self.assertEqual(states.SUCCESS, task_1_ex.state) @@ -969,7 +985,7 @@ class DirectWorkflowRerunTest(base.EngineTestCase): self.assertEqual(states.SUCCESS, task_3_ex.state) task_3_action_exs = db_api.get_action_executions( - task_execution_id=wf_ex.task_executions[2].id + task_execution_id=task_execs[2].id ) self.assertEqual(1, len(task_3_action_exs)) diff --git a/mistral/tests/unit/engine/test_join.py b/mistral/tests/unit/engine/test_join.py index e782dc5d..1fff346f 100644 --- a/mistral/tests/unit/engine/test_join.py +++ b/mistral/tests/unit/engine/test_join.py @@ -27,6 +27,30 @@ cfg.CONF.set_default('auth_enable', False, group='pecan') class JoinEngineTest(base.EngineTestCase): + def test_full_join_simple(self): + wf_text = """--- + version: '2.0' + + wf: + type: direct + + tasks: + join_task: + join: all + + task1: + on-success: join_task + + task2: + on-success: join_task + """ + + wf_service.create_workflows(wf_text) + + wf_ex = self.engine.start_workflow('wf', {}) + + self.await_workflow_success(wf_ex.id) + def test_full_join_without_errors(self): wf_text = """--- version: '2.0' @@ -354,7 +378,12 @@ class JoinEngineTest(base.EngineTestCase): result5 = task5.published['result5'] self.assertIsNotNone(result5) - self.assertEqual(2, result5.count('True')) + + # Depending on how many inbound tasks completed before 'join' + # task5 started it can get different inbound context with. + # But at least two inbound results should be accessible at task5 + # which logically corresponds to 'join' cardinality 2. + self.assertTrue(result5.count('True') >= 2) def test_discriminator(self): wf_text = """--- diff --git a/mistral/workflow/base.py b/mistral/workflow/base.py index 06d72d05..11ff0163 100644 --- a/mistral/workflow/base.py +++ b/mistral/workflow/base.py @@ -86,28 +86,25 @@ class WorkflowController(object): self.wf_spec = wf_spec @profiler.trace('workflow-controller-continue-workflow') - def continue_workflow(self): + def continue_workflow(self, task_ex=None): """Calculates a list of commands to continue the workflow. Given a workflow specification this method makes required analysis according to this workflow type rules and identifies a list of commands needed to continue the workflow. + :param task_ex: Task execution that caused workflow continuation. + Optional. If not specified, it means that no certain task caused + this operation (e.g. workflow has been just started or resumed + manually). :return: List of workflow commands (instances of mistral.workflow.commands.WorkflowCommand). """ - # TODO(rakhmerov): We now use this method for two cases: - # 1) to handle task completion - # 2) to resume a workflow after it's been paused - # Moving forward we need to introduce a separate method for - # resuming a workflow because it won't be operating with - # any concrete tasks that caused this operation. - if self._is_paused_or_completed(): return [] - return self._find_next_commands() + return self._find_next_commands(task_ex) def rerun_tasks(self, task_execs, reset=True): """Gets commands to rerun existing task executions. @@ -128,11 +125,21 @@ class WorkflowController(object): return cmds + @abc.abstractmethod + def is_task_start_allowed(self, task_ex): + """Determines if the given task is allowed to start. + + :param task_ex: Task execution. + :return: True if all preconditions are met and the given task + is allowed to start. + """ + raise NotImplementedError + @abc.abstractmethod def is_error_handled_for(self, task_ex): """Determines if error is handled for specific task. - :param task_ex: Task execution perform a check for. + :param task_ex: Task execution. :return: True if either there is no error at all or error is considered handled. """ @@ -162,7 +169,9 @@ class WorkflowController(object): """ raise NotImplementedError - def _get_task_inbound_context(self, task_spec): + def get_task_inbound_context(self, task_spec): + # TODO(rakhmerov): This method should also be able to work with task_ex + # to cover 'split' (aka 'merge') use case. upstream_task_execs = self._get_upstream_task_executions(task_spec) upstream_ctx = data_flow.evaluate_upstream_context(upstream_task_execs) @@ -190,7 +199,7 @@ class WorkflowController(object): raise NotImplementedError @abc.abstractmethod - def _find_next_commands(self): + def _find_next_commands(self, task_ex): """Finds commands that should run next. A concrete algorithm of finding such tasks depends on a concrete diff --git a/mistral/workflow/commands.py b/mistral/workflow/commands.py index 7b76431e..8ef92860 100644 --- a/mistral/workflow/commands.py +++ b/mistral/workflow/commands.py @@ -48,11 +48,13 @@ class RunTask(WorkflowCommand): super(RunTask, self).__init__(wf_ex, wf_spec, task_spec, ctx) self.wait = False + self.unique_key = None def is_waiting(self): - return (self.wait and - isinstance(self.task_spec, tasks.DirectWorkflowTaskSpec) and - self.task_spec.get_join()) + return self.wait + + def get_unique_key(self): + return self.unique_key def __repr__(self): return ( @@ -74,6 +76,7 @@ class RunExistingTask(WorkflowCommand): self.task_ex = task_ex self.reset = reset + self.unique_key = task_ex.unique_key class SetWorkflowState(WorkflowCommand): diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 3ef8990e..d02e94ff 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -65,16 +65,21 @@ class DirectWorkflowController(base.WorkflowController): self.wf_spec.get_tasks()[t_ex_candidate.name] ) - def _find_next_commands(self): - cmds = super(DirectWorkflowController, self)._find_next_commands() + def _find_next_commands(self, task_ex=None): + cmds = super(DirectWorkflowController, self)._find_next_commands( + task_ex + ) if not self.wf_ex.task_executions: return self._find_start_commands() - task_execs = [ - t_ex for t_ex in self.wf_ex.task_executions - if states.is_completed(t_ex.state) and not t_ex.processed - ] + if task_ex: + task_execs = [task_ex] + else: + task_execs = [ + t_ex for t_ex in self.wf_ex.task_executions + if states.is_completed(t_ex.state) and not t_ex.processed + ] for t_ex in task_execs: cmds.extend(self._find_next_commands_for_task(t_ex)) @@ -87,7 +92,7 @@ class DirectWorkflowController(base.WorkflowController): self.wf_ex, self.wf_spec, t_s, - self._get_task_inbound_context(t_s) + self.get_task_inbound_context(t_s) ) for t_s in self.wf_spec.find_start_tasks() ] @@ -115,26 +120,33 @@ class DirectWorkflowController(base.WorkflowController): self.wf_ex, self.wf_spec, t_s, - self._get_task_inbound_context(t_s), + data_flow.evaluate_task_outbound_context(task_ex), params ) - # NOTE(xylan): Decide whether or not a join task should run - # immediately. - if self._is_unsatisfied_join(cmd): - cmd.wait = True + self._configure_if_join(cmd) cmds.append(cmd) - # We need to remove all "join" tasks that have already started - # (or even completed) to prevent running "join" tasks more than - # once. - cmds = self._remove_started_joins(cmds) - LOG.debug("Found commands: %s" % cmds) return cmds + def _configure_if_join(self, cmd): + if not isinstance(cmd, commands.RunTask): + return + + if not cmd.task_spec.get_join(): + return + + cmd.unique_key = self._get_join_unique_key(cmd) + + if self._is_unsatisfied_join(cmd.task_spec): + cmd.wait = True + + def _get_join_unique_key(self, cmd): + return 'join-task-%s-%s' % (self.wf_ex.id, cmd.task_spec.get_name()) + # TODO(rakhmerov): Need to refactor this method to be able to pass tasks # whose contexts need to be merged. def evaluate_workflow_final_context(self): @@ -148,6 +160,14 @@ class DirectWorkflowController(base.WorkflowController): return ctx + def is_task_start_allowed(self, task_ex): + task_spec = self.wf_spec.get_tasks()[task_ex.name] + + return ( + not task_spec.get_join() or + not self._is_unsatisfied_join(task_spec) + ) + def is_error_handled_for(self, task_ex): return bool(self.wf_spec.get_on_error_clause(task_ex.name)) @@ -241,28 +261,10 @@ class DirectWorkflowController(base.WorkflowController): if not condition or expr.evaluate(condition, ctx) ] - def _remove_started_joins(self, cmds): - return list( - filter(lambda cmd: not self._is_started_join(cmd), cmds) - ) - - def _is_started_join(self, cmd): - if not (isinstance(cmd, commands.RunTask) and - cmd.task_spec.get_join()): - return False - - return wf_utils.find_task_execution_not_state( - self.wf_ex, - cmd.task_spec, - states.WAITING - ) - - def _is_unsatisfied_join(self, cmd): - if not isinstance(cmd, commands.RunTask): - return False - - task_spec = cmd.task_spec - + def _is_unsatisfied_join(self, task_spec): + # TODO(rakhmerov): We need to use task_ex instead of task_spec + # in order to cover a use case when there's more than one instance + # of the same 'join' task in a workflow. join_expr = task_spec.get_join() if not join_expr: diff --git a/mistral/workflow/reverse_workflow.py b/mistral/workflow/reverse_workflow.py index c94a1195..dfc79e97 100644 --- a/mistral/workflow/reverse_workflow.py +++ b/mistral/workflow/reverse_workflow.py @@ -40,13 +40,22 @@ class ReverseWorkflowController(base.WorkflowController): __workflow_type__ = "reverse" - def _find_next_commands(self): + def _find_next_commands(self, task_ex=None): """Finds all tasks with resolved dependencies. This method finds all tasks with resolved dependencies and returns them in the form of workflow commands. """ - cmds = super(ReverseWorkflowController, self)._find_next_commands() + cmds = super(ReverseWorkflowController, self)._find_next_commands( + task_ex + ) + + # TODO(rakhmerov): Adapt reverse workflow to non-locking model. + # 1. Task search must use task_ex parameter. + # 2. When a task has more than one dependency it's possible to + # get into 'phantom read' phenomena and create multiple instances + # of the same task. So 'unique_key' in conjunction with 'wait_flag' + # must be used to prevent this. task_specs = self._find_task_specs_with_satisfied_dependencies() @@ -55,7 +64,7 @@ class ReverseWorkflowController(base.WorkflowController): self.wf_ex, self.wf_spec, t_s, - self._get_task_inbound_context(t_s) + self.get_task_inbound_context(t_s) ) for t_s in task_specs ] @@ -99,6 +108,10 @@ class ReverseWorkflowController(base.WorkflowController): return data_flow.evaluate_task_outbound_context(task_execs[0]) + def is_task_start_allowed(self, task_ex): + # TODO(rakhmerov): Implement. + return True + def is_error_handled_for(self, task_ex): return task_ex.state != states.ERROR