Towards non-locking model: adapt 'join' tasks to work w/o locks

* Added a scheduled job to check if a 'join' task is allowed to
  start in a separate transaction to prevent phantom reads.
  Architecturally this is done in a generic way with a thought
  in mind that we also need to adapt reverse workflow to work
  similarly. In order to avoid getting duplicates for 'join'
  tasks recently added DB API methods 'insert_or_ignore' is used.

TODO:
* Fix ReverseWorkflowController to work in non-locking model
* Fix 'with-items' so that it can work in non-locking model

Partially implements: blueprint mistral-non-locking-tx-model
Change-Id: Ia319965a65d7b3f09eaf28792104d7fd58e9c82e
This commit is contained in:
Renat Akhmerov 2016-08-05 19:35:51 +07:00
parent ca021e1a63
commit 9f236248c1
13 changed files with 333 additions and 143 deletions

View File

@ -292,9 +292,9 @@ def get_task_execution(id):
return IMPL.get_task_execution(id)
def load_task_execution(name):
def load_task_execution(id):
"""Unlike get_task_execution this method is allowed to return None."""
return IMPL.load_task_execution(name)
return IMPL.load_task_execution(id)
def get_task_executions(limit=None, marker=None, sort_keys=['created_at'],
@ -312,6 +312,10 @@ def create_task_execution(values):
return IMPL.create_task_execution(values)
def insert_or_ignore_task_execution(values):
return IMPL.insert_or_ignore_task_execution(values)
def update_task_execution(id, values):
return IMPL.update_task_execution(id, values)

View File

@ -319,9 +319,16 @@ def insert_or_ignore(model_cls, values, session=None):
replace_string=replace
)
res = session.execute(insert, model.to_dict())
return res.rowcount
# NOTE(rakhmerov): As it turned out the result proxy object
# returned by insert expression does not provide a valid
# count of updated rows in for all supported databases.
# For this reason we shouldn't return anything from this
# method. In order to check whether a new object was really
# inserted users should rely on different approaches. The
# simplest is just to insert an object with an explicitly
# set id and then check if object with such id exists in DB.
# Generated id must be unique to make it work.
session.execute(insert, model.to_dict())
# Workbook definitions.
@ -818,6 +825,10 @@ def create_task_execution(values, session=None):
return task_ex
def insert_or_ignore_task_execution(values):
insert_or_ignore(models.TaskExecution, values.copy())
@b.session_aware()
def update_task_execution(id, values, session=None):
task_ex = get_task_execution(id)
@ -869,9 +880,7 @@ def create_delayed_call(values, session=None):
def insert_or_ignore_delayed_call(values):
row_count = insert_or_ignore(models.DelayedCall, values.copy())
return row_count
insert_or_ignore(models.DelayedCall, values.copy())
@b.session_aware()

View File

@ -165,10 +165,12 @@ class TaskExecution(Execution):
sa.Index('%s_scope' % __tablename__, 'scope'),
sa.Index('%s_state' % __tablename__, 'state'),
sa.Index('%s_updated_at' % __tablename__, 'updated_at'),
sa.UniqueConstraint('unique_key')
)
# Main properties.
action_spec = sa.Column(st.JsonLongDictType())
unique_key = sa.Column(sa.String(80), nullable=True)
# Whether the task is fully processed (publishing and calculating commands
# after it). It allows to simplify workflow controller implementations
@ -323,7 +325,7 @@ class DelayedCall(mb.MistralModelBase):
target_method_name = sa.Column(sa.String(80), nullable=False)
method_arguments = sa.Column(st.JsonDictType())
serializers = sa.Column(st.JsonDictType())
unique_key = sa.Column(sa.String(50), nullable=True)
unique_key = sa.Column(sa.String(80), nullable=True)
auth_context = sa.Column(st.JsonDictType())
execution_time = sa.Column(sa.DateTime, nullable=False)
processing = sa.Column(sa.Boolean, default=False, nullable=False)

View File

@ -19,10 +19,13 @@ from oslo_log import log as logging
from osprofiler import profiler
import traceback as tb
from mistral.db.v2 import api as db_api
from mistral.engine import tasks
from mistral.engine import workflow_handler as wf_handler
from mistral import exceptions as exc
from mistral.services import scheduler
from mistral.workbook import parser as spec_parser
from mistral.workflow import base as wf_base
from mistral.workflow import commands as wf_cmds
from mistral.workflow import states
@ -31,6 +34,10 @@ from mistral.workflow import states
LOG = logging.getLogger(__name__)
_CHECK_TASK_START_ALLOWED_PATH = (
'mistral.engine.task_handler._check_task_start_allowed'
)
@profiler.trace('task-handler-run-task')
def run_task(wf_cmd):
@ -60,6 +67,9 @@ def run_task(wf_cmd):
return
if task.is_waiting():
_schedule_check_task_start_allowed(task.task_ex)
if task.is_completed():
wf_handler.schedule_on_task_complete(task.task_ex)
@ -127,6 +137,8 @@ def continue_task(task_ex):
)
try:
task.set_state(states.RUNNING, None)
task.run()
except exc.MistralException as e:
wf_ex = task_ex.workflow_execution
@ -194,7 +206,8 @@ def _build_task_from_command(cmd):
cmd.wf_spec,
spec_parser.get_task_spec(cmd.task_ex.spec),
cmd.ctx,
cmd.task_ex
task_ex=cmd.task_ex,
unique_key=cmd.task_ex.unique_key
)
if cmd.reset:
@ -203,7 +216,13 @@ def _build_task_from_command(cmd):
return task
if isinstance(cmd, wf_cmds.RunTask):
task = _create_task(cmd.wf_ex, cmd.wf_spec, cmd.task_spec, cmd.ctx)
task = _create_task(
cmd.wf_ex,
cmd.wf_spec,
cmd.task_spec,
cmd.ctx,
unique_key=cmd.unique_key
)
if cmd.is_waiting():
task.defer()
@ -213,8 +232,69 @@ def _build_task_from_command(cmd):
raise exc.MistralError('Unsupported workflow command: %s' % cmd)
def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None):
def _create_task(wf_ex, wf_spec, task_spec, ctx, task_ex=None,
unique_key=None):
if task_spec.get_with_items():
return tasks.WithItemsTask(wf_ex, wf_spec, task_spec, ctx, task_ex)
return tasks.WithItemsTask(
wf_ex,
wf_spec,
task_spec,
ctx,
task_ex,
unique_key
)
return tasks.RegularTask(wf_ex, wf_spec, task_spec, ctx, task_ex)
return tasks.RegularTask(
wf_ex,
wf_spec,
task_spec,
ctx,
task_ex,
unique_key
)
def _check_task_start_allowed(task_ex_id):
with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex_id)
wf_ctrl = wf_base.get_controller(
task_ex.workflow_execution,
spec_parser.get_workflow_spec_by_id(task_ex.workflow_id)
)
if wf_ctrl.is_task_start_allowed(task_ex):
continue_task(task_ex)
return
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
_schedule_check_task_start_allowed(task_ex, 1)
def _schedule_check_task_start_allowed(task_ex, delay=0):
"""Schedules task preconditions check.
This method provides transactional decoupling of task preconditions
check from events that can potentially satisfy those preconditions.
It's needed in non-locking model in order to avoid 'phantom read'
phenomena when reading state of multiple tasks to see if a task that
depends on them can start. Just starting a separate transaction
without using scheduler is not safe due to concurrency window that
we'll have in this case (time between transactions) whereas scheduler
is a special component that is designed to be resistant to failures.
:param task_ex: Task execution.
:param delay: Delay.
:return:
"""
key = 'th_c_t_s_a-%s' % task_ex.id
scheduler.schedule_call(
None,
_CHECK_TASK_START_ALLOWED_PATH,
delay,
unique_key=key,
task_ex_id=task_ex.id
)

View File

@ -47,15 +47,23 @@ class Task(object):
"""
@profiler.trace('task-create')
def __init__(self, wf_ex, wf_spec, task_spec, ctx, task_ex=None):
def __init__(self, wf_ex, wf_spec, task_spec, ctx, task_ex=None,
unique_key=None):
self.wf_ex = wf_ex
self.task_spec = task_spec
self.ctx = ctx
self.task_ex = task_ex
self.wf_spec = wf_spec
self.unique_key = unique_key
self.waiting = False
self.reset_flag = False
def is_completed(self):
return self.task_ex and states.is_completed(self.task_ex.state)
def is_waiting(self):
return self.waiting
@abc.abstractmethod
def on_action_complete(self, action_ex):
"""Handle action completion.
@ -160,7 +168,7 @@ class Task(object):
wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec)
# Calculate commands to process next.
cmds = wf_ctrl.continue_workflow()
cmds = wf_ctrl.continue_workflow(self.task_ex)
# Mark task as processed after all decisions have been made
# upon its completion.
@ -181,23 +189,39 @@ class Task(object):
p.after_task_complete(self.task_ex, self.task_spec)
def _create_task_execution(self, state=states.RUNNING):
self.task_ex = db_api.create_task_execution({
values = {
'id': utils.generate_unicode_uuid(),
'name': self.task_spec.get_name(),
'workflow_execution_id': self.wf_ex.id,
'workflow_name': self.wf_ex.workflow_name,
'workflow_id': self.wf_ex.workflow_id,
'state': state,
'spec': self.task_spec.to_dict(),
'unique_key': self.unique_key,
'in_context': self.ctx,
'published': {},
'runtime_context': {},
'project_id': self.wf_ex.project_id
})
}
db_api.insert_or_ignore_task_execution(values)
# Since 'insert_or_ignore' cannot return a valid count of updated
# rows the only reliable way to check if insert operation has created
# an object is try to load this object by just generated uuid.
task_ex = db_api.load_task_execution(values['id'])
if not task_ex:
return False
self.task_ex = task_ex
# Add to collection explicitly so that it's in a proper
# state within the current session.
self.wf_ex.task_executions.append(self.task_ex)
return True
def _get_action_defaults(self):
action_name = self.task_spec.get_action_name()
@ -226,9 +250,6 @@ class RegularTask(Task):
self.complete(state, state_info)
def is_completed(self):
return self.task_ex and states.is_completed(self.task_ex.state)
@profiler.trace('task-run')
def run(self):
if not self.task_ex:
@ -237,20 +258,12 @@ class RegularTask(Task):
self._run_existing()
def _run_new(self):
# NOTE(xylan): Need to think how to get rid of this weird judgment
# to keep it more consistent with the function name.
self.task_ex = wf_utils.find_task_execution_with_state(
self.wf_ex,
self.task_spec,
states.WAITING
)
if not self._create_task_execution():
# Task with the same unique key has already been created.
return
if self.task_ex:
self.set_state(states.RUNNING, None)
self.task_ex.in_context = self.ctx
else:
self._create_task_execution()
if self.waiting:
return
LOG.debug(
'Starting task [workflow=%s, task_spec=%s, init_state=%s]' %
@ -278,10 +291,18 @@ class RegularTask(Task):
self.set_state(states.RUNNING, None, processed=False)
self._update_inbound_context()
self._reset_actions()
self._schedule_actions()
def _update_inbound_context(self):
assert self.task_ex
wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec)
self.ctx = wf_ctrl.get_task_inbound_context(self.task_spec)
self.task_ex.in_context = self.ctx
def _reset_actions(self):
"""Resets task state.
@ -386,7 +407,9 @@ class WithItemsTask(RegularTask):
if with_items.is_completed(self.task_ex):
state = with_items.get_final_state(self.task_ex)
self.complete(state, state_info[state])
return
if (with_items.has_more_iterations(self.task_ex)

View File

@ -282,6 +282,12 @@ class Workflow(object):
# tasks are waiting and there are unhandled errors, then these
# tasks will not reach completion. In this case, mark the
# workflow complete.
# TODO(rakhmerov): In the new non-locking transactional model
# tasks in WAITING state should not be taken into account.
# I'm actually surprised why we did this so far. The point is
# that every task should complete sooner or later and get into
# one of the terminal state. We should not be leaving any tasks
# in a non-terminal state if we know exactly it's real state.
incomplete_tasks = wf_utils.find_incomplete_task_executions(self.wf_ex)
if any(not states.is_waiting(t.state) for t in incomplete_tasks):

View File

@ -46,13 +46,11 @@ class InsertOrIgnoreTest(test_base.DbTestCase):
)
def test_insert_or_ignore_without_conflicts(self):
row_count = db_api.insert_or_ignore(
db_api.insert_or_ignore(
db_models.DelayedCall,
DELAYED_CALL.copy()
)
self.assertEqual(1, row_count)
delayed_calls = db_api.get_delayed_calls()
self.assertEqual(1, len(delayed_calls))
@ -67,9 +65,7 @@ class InsertOrIgnoreTest(test_base.DbTestCase):
values['unique_key'] = 'key'
row_count = db_api.insert_or_ignore(db_models.DelayedCall, values)
self.assertEqual(1, row_count)
db_api.insert_or_ignore(db_models.DelayedCall, values)
delayed_calls = db_api.get_delayed_calls()
@ -85,9 +81,7 @@ class InsertOrIgnoreTest(test_base.DbTestCase):
values['unique_key'] = 'key'
row_count = db_api.insert_or_ignore(db_models.DelayedCall, values)
self.assertEqual(0, row_count)
db_api.insert_or_ignore(db_models.DelayedCall, values)
delayed_calls = db_api.get_delayed_calls()

View File

@ -321,18 +321,19 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
self.await_workflow_error(wf_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(states.ERROR, wf_ex.state)
self.assertIsNotNone(wf_ex.state_info)
self.assertEqual(3, len(wf_ex.task_executions))
self.assertEqual(3, len(task_execs))
self.assertDictEqual(env, wf_ex.params['env'])
self.assertDictEqual(env, wf_ex.context['__env'])
task_exs = wf_ex.task_executions
task_10_ex = self._assert_single_item(task_exs, name='t10')
task_21_ex = self._assert_single_item(task_exs, name='t21')
task_30_ex = self._assert_single_item(task_exs, name='t30')
task_10_ex = self._assert_single_item(task_execs, name='t10')
task_21_ex = self._assert_single_item(task_execs, name='t21')
task_30_ex = self._assert_single_item(task_execs, name='t30')
self.assertEqual(states.SUCCESS, task_10_ex.state)
self.assertEqual(states.ERROR, task_21_ex.state)
@ -347,29 +348,31 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
}
# Resume workflow and re-run failed task.
self.engine.rerun_workflow(task_21_ex.id, env=updated_env)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
wf_ex = self.engine.rerun_workflow(task_21_ex.id, env=updated_env)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertIsNone(wf_ex.state_info)
self.assertDictEqual(updated_env, wf_ex.params['env'])
self.assertDictEqual(updated_env, wf_ex.context['__env'])
# Await t30 success.
self.await_task_success(task_30_ex.id)
# Wait for the workflow to succeed.
self.await_workflow_success(wf_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(states.SUCCESS, wf_ex.state)
self.assertIsNone(wf_ex.state_info)
self.assertEqual(4, len(wf_ex.task_executions))
self.assertEqual(4, len(task_execs))
task_exs = wf_ex.task_executions
task_10_ex = self._assert_single_item(task_exs, name='t10')
task_21_ex = self._assert_single_item(task_exs, name='t21')
task_22_ex = self._assert_single_item(task_exs, name='t22')
task_30_ex = self._assert_single_item(task_exs, name='t30')
task_10_ex = self._assert_single_item(task_execs, name='t10')
task_21_ex = self._assert_single_item(task_execs, name='t21')
task_22_ex = self._assert_single_item(task_execs, name='t22')
task_30_ex = self._assert_single_item(task_execs, name='t30')
# Check action executions of task 10.
self.assertEqual(states.SUCCESS, task_10_ex.state)
@ -867,79 +870,92 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
# Run workflow and fail task.
wf_ex = self.engine.start_workflow('wb1.wf1', {})
wf_ex = db_api.get_workflow_execution(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
task_1_ex = self._assert_single_item(task_execs, name='t1')
task_2_ex = self._assert_single_item(task_execs, name='t2')
self.await_task_error(task_1_ex.id)
self.await_task_error(task_2_ex.id)
self.await_workflow_error(wf_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(states.ERROR, wf_ex.state)
self.assertIsNotNone(wf_ex.state_info)
self.assertEqual(2, len(wf_ex.task_executions))
self.assertEqual(2, len(task_execs))
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
task_1_ex = self._assert_single_item(task_execs, name='t1')
self.assertEqual(states.ERROR, task_1_ex.state)
self.assertIsNotNone(task_1_ex.state_info)
task_2_ex = self._assert_single_item(task_execs, name='t2')
self.assertEqual(states.ERROR, task_2_ex.state)
self.assertIsNotNone(task_2_ex.state_info)
# Resume workflow and re-run failed task.
self.engine.rerun_workflow(task_1_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
wf_ex = self.engine.rerun_workflow(task_1_ex.id)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertIsNone(wf_ex.state_info)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
# Wait for the task to succeed.
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
task_1_ex = self._assert_single_item(task_execs, name='t1')
self.await_task_success(task_1_ex.id)
self.await_workflow_error(wf_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(states.ERROR, wf_ex.state)
self.assertIsNotNone(wf_ex.state_info)
self.assertEqual(3, len(wf_ex.task_executions))
self.assertEqual(3, len(task_execs))
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
task_3_ex = self._assert_single_item(wf_ex.task_executions, name='t3')
task_1_ex = self._assert_single_item(task_execs, name='t1')
task_2_ex = self._assert_single_item(task_execs, name='t2')
task_3_ex = self._assert_single_item(task_execs, name='t3')
self.assertEqual(states.SUCCESS, task_1_ex.state)
self.assertEqual(states.ERROR, task_2_ex.state)
self.assertEqual(states.WAITING, task_3_ex.state)
# Resume workflow and re-run failed task.
self.engine.rerun_workflow(task_2_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
wf_ex = self.engine.rerun_workflow(task_2_ex.id)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertIsNone(wf_ex.state_info)
# Join now should finally complete.
self.await_task_success(task_3_ex.id)
# Wait for the workflow to succeed.
self.await_workflow_success(wf_ex.id)
wf_ex = db_api.get_workflow_execution(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(states.SUCCESS, wf_ex.state)
self.assertIsNone(wf_ex.state_info)
self.assertEqual(3, len(wf_ex.task_executions))
self.assertEqual(3, len(task_execs))
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
task_3_ex = self._assert_single_item(wf_ex.task_executions, name='t3')
task_1_ex = self._assert_single_item(task_execs, name='t1')
task_2_ex = self._assert_single_item(task_execs, name='t2')
task_3_ex = self._assert_single_item(task_execs, name='t3')
# Check action executions of task 1.
self.assertEqual(states.SUCCESS, task_1_ex.state)
@ -969,7 +985,7 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
self.assertEqual(states.SUCCESS, task_3_ex.state)
task_3_action_exs = db_api.get_action_executions(
task_execution_id=wf_ex.task_executions[2].id
task_execution_id=task_execs[2].id
)
self.assertEqual(1, len(task_3_action_exs))

View File

@ -27,6 +27,30 @@ cfg.CONF.set_default('auth_enable', False, group='pecan')
class JoinEngineTest(base.EngineTestCase):
def test_full_join_simple(self):
wf_text = """---
version: '2.0'
wf:
type: direct
tasks:
join_task:
join: all
task1:
on-success: join_task
task2:
on-success: join_task
"""
wf_service.create_workflows(wf_text)
wf_ex = self.engine.start_workflow('wf', {})
self.await_workflow_success(wf_ex.id)
def test_full_join_without_errors(self):
wf_text = """---
version: '2.0'
@ -354,7 +378,12 @@ class JoinEngineTest(base.EngineTestCase):
result5 = task5.published['result5']
self.assertIsNotNone(result5)
self.assertEqual(2, result5.count('True'))
# Depending on how many inbound tasks completed before 'join'
# task5 started it can get different inbound context with.
# But at least two inbound results should be accessible at task5
# which logically corresponds to 'join' cardinality 2.
self.assertTrue(result5.count('True') >= 2)
def test_discriminator(self):
wf_text = """---

View File

@ -86,28 +86,25 @@ class WorkflowController(object):
self.wf_spec = wf_spec
@profiler.trace('workflow-controller-continue-workflow')
def continue_workflow(self):
def continue_workflow(self, task_ex=None):
"""Calculates a list of commands to continue the workflow.
Given a workflow specification this method makes required analysis
according to this workflow type rules and identifies a list of
commands needed to continue the workflow.
:param task_ex: Task execution that caused workflow continuation.
Optional. If not specified, it means that no certain task caused
this operation (e.g. workflow has been just started or resumed
manually).
:return: List of workflow commands (instances of
mistral.workflow.commands.WorkflowCommand).
"""
# TODO(rakhmerov): We now use this method for two cases:
# 1) to handle task completion
# 2) to resume a workflow after it's been paused
# Moving forward we need to introduce a separate method for
# resuming a workflow because it won't be operating with
# any concrete tasks that caused this operation.
if self._is_paused_or_completed():
return []
return self._find_next_commands()
return self._find_next_commands(task_ex)
def rerun_tasks(self, task_execs, reset=True):
"""Gets commands to rerun existing task executions.
@ -128,11 +125,21 @@ class WorkflowController(object):
return cmds
@abc.abstractmethod
def is_task_start_allowed(self, task_ex):
"""Determines if the given task is allowed to start.
:param task_ex: Task execution.
:return: True if all preconditions are met and the given task
is allowed to start.
"""
raise NotImplementedError
@abc.abstractmethod
def is_error_handled_for(self, task_ex):
"""Determines if error is handled for specific task.
:param task_ex: Task execution perform a check for.
:param task_ex: Task execution.
:return: True if either there is no error at all or
error is considered handled.
"""
@ -162,7 +169,9 @@ class WorkflowController(object):
"""
raise NotImplementedError
def _get_task_inbound_context(self, task_spec):
def get_task_inbound_context(self, task_spec):
# TODO(rakhmerov): This method should also be able to work with task_ex
# to cover 'split' (aka 'merge') use case.
upstream_task_execs = self._get_upstream_task_executions(task_spec)
upstream_ctx = data_flow.evaluate_upstream_context(upstream_task_execs)
@ -190,7 +199,7 @@ class WorkflowController(object):
raise NotImplementedError
@abc.abstractmethod
def _find_next_commands(self):
def _find_next_commands(self, task_ex):
"""Finds commands that should run next.
A concrete algorithm of finding such tasks depends on a concrete

View File

@ -48,11 +48,13 @@ class RunTask(WorkflowCommand):
super(RunTask, self).__init__(wf_ex, wf_spec, task_spec, ctx)
self.wait = False
self.unique_key = None
def is_waiting(self):
return (self.wait and
isinstance(self.task_spec, tasks.DirectWorkflowTaskSpec) and
self.task_spec.get_join())
return self.wait
def get_unique_key(self):
return self.unique_key
def __repr__(self):
return (
@ -74,6 +76,7 @@ class RunExistingTask(WorkflowCommand):
self.task_ex = task_ex
self.reset = reset
self.unique_key = task_ex.unique_key
class SetWorkflowState(WorkflowCommand):

View File

@ -65,16 +65,21 @@ class DirectWorkflowController(base.WorkflowController):
self.wf_spec.get_tasks()[t_ex_candidate.name]
)
def _find_next_commands(self):
cmds = super(DirectWorkflowController, self)._find_next_commands()
def _find_next_commands(self, task_ex=None):
cmds = super(DirectWorkflowController, self)._find_next_commands(
task_ex
)
if not self.wf_ex.task_executions:
return self._find_start_commands()
task_execs = [
t_ex for t_ex in self.wf_ex.task_executions
if states.is_completed(t_ex.state) and not t_ex.processed
]
if task_ex:
task_execs = [task_ex]
else:
task_execs = [
t_ex for t_ex in self.wf_ex.task_executions
if states.is_completed(t_ex.state) and not t_ex.processed
]
for t_ex in task_execs:
cmds.extend(self._find_next_commands_for_task(t_ex))
@ -87,7 +92,7 @@ class DirectWorkflowController(base.WorkflowController):
self.wf_ex,
self.wf_spec,
t_s,
self._get_task_inbound_context(t_s)
self.get_task_inbound_context(t_s)
)
for t_s in self.wf_spec.find_start_tasks()
]
@ -115,26 +120,33 @@ class DirectWorkflowController(base.WorkflowController):
self.wf_ex,
self.wf_spec,
t_s,
self._get_task_inbound_context(t_s),
data_flow.evaluate_task_outbound_context(task_ex),
params
)
# NOTE(xylan): Decide whether or not a join task should run
# immediately.
if self._is_unsatisfied_join(cmd):
cmd.wait = True
self._configure_if_join(cmd)
cmds.append(cmd)
# We need to remove all "join" tasks that have already started
# (or even completed) to prevent running "join" tasks more than
# once.
cmds = self._remove_started_joins(cmds)
LOG.debug("Found commands: %s" % cmds)
return cmds
def _configure_if_join(self, cmd):
if not isinstance(cmd, commands.RunTask):
return
if not cmd.task_spec.get_join():
return
cmd.unique_key = self._get_join_unique_key(cmd)
if self._is_unsatisfied_join(cmd.task_spec):
cmd.wait = True
def _get_join_unique_key(self, cmd):
return 'join-task-%s-%s' % (self.wf_ex.id, cmd.task_spec.get_name())
# TODO(rakhmerov): Need to refactor this method to be able to pass tasks
# whose contexts need to be merged.
def evaluate_workflow_final_context(self):
@ -148,6 +160,14 @@ class DirectWorkflowController(base.WorkflowController):
return ctx
def is_task_start_allowed(self, task_ex):
task_spec = self.wf_spec.get_tasks()[task_ex.name]
return (
not task_spec.get_join() or
not self._is_unsatisfied_join(task_spec)
)
def is_error_handled_for(self, task_ex):
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
@ -241,28 +261,10 @@ class DirectWorkflowController(base.WorkflowController):
if not condition or expr.evaluate(condition, ctx)
]
def _remove_started_joins(self, cmds):
return list(
filter(lambda cmd: not self._is_started_join(cmd), cmds)
)
def _is_started_join(self, cmd):
if not (isinstance(cmd, commands.RunTask) and
cmd.task_spec.get_join()):
return False
return wf_utils.find_task_execution_not_state(
self.wf_ex,
cmd.task_spec,
states.WAITING
)
def _is_unsatisfied_join(self, cmd):
if not isinstance(cmd, commands.RunTask):
return False
task_spec = cmd.task_spec
def _is_unsatisfied_join(self, task_spec):
# TODO(rakhmerov): We need to use task_ex instead of task_spec
# in order to cover a use case when there's more than one instance
# of the same 'join' task in a workflow.
join_expr = task_spec.get_join()
if not join_expr:

View File

@ -40,13 +40,22 @@ class ReverseWorkflowController(base.WorkflowController):
__workflow_type__ = "reverse"
def _find_next_commands(self):
def _find_next_commands(self, task_ex=None):
"""Finds all tasks with resolved dependencies.
This method finds all tasks with resolved dependencies and
returns them in the form of workflow commands.
"""
cmds = super(ReverseWorkflowController, self)._find_next_commands()
cmds = super(ReverseWorkflowController, self)._find_next_commands(
task_ex
)
# TODO(rakhmerov): Adapt reverse workflow to non-locking model.
# 1. Task search must use task_ex parameter.
# 2. When a task has more than one dependency it's possible to
# get into 'phantom read' phenomena and create multiple instances
# of the same task. So 'unique_key' in conjunction with 'wait_flag'
# must be used to prevent this.
task_specs = self._find_task_specs_with_satisfied_dependencies()
@ -55,7 +64,7 @@ class ReverseWorkflowController(base.WorkflowController):
self.wf_ex,
self.wf_spec,
t_s,
self._get_task_inbound_context(t_s)
self.get_task_inbound_context(t_s)
)
for t_s in task_specs
]
@ -99,6 +108,10 @@ class ReverseWorkflowController(base.WorkflowController):
return data_flow.evaluate_task_outbound_context(task_execs[0])
def is_task_start_allowed(self, task_ex):
# TODO(rakhmerov): Implement.
return True
def is_error_handled_for(self, task_ex):
return task_ex.state != states.ERROR