Refactor task policies

* The purpose of this patch is to improve encapsulation of task
  execution state management. We already have the class Task
  (engine.tasks.Task) that represents an engine task and it is
  supposed to be responsible for everything related to managing
  persistent state of the corresponding task execution object.
  However, we break this encapsulation in many places and various
  modules manipulate with task execution state directly. This fact
  leads to what is called "spagetty code" because important
  things are often spread out across the system and it's hard to
  maintain. It also leads to lots of duplications. So this patch
  refactors policies so that they manipulate with a task execution
  through an instance of Task which hides low level aspects.

Change-Id: Ie728bf950c4244db3fec0f3dadd5e195ad42081d
This commit is contained in:
Renat Akhmerov 2020-05-28 17:37:12 +07:00
parent c4c11c9b85
commit ddf9577785
12 changed files with 330 additions and 279 deletions

View File

@ -20,7 +20,8 @@ import jsonschema
import six
from mistral import exceptions as exc
from mistral.workflow import data_flow
from mistral import utils
from mistral_lib.utils import inspect_utils
@ -154,43 +155,21 @@ class TaskPolicy(object):
"""
_schema = {}
def before_task_start(self, task_ex, task_spec):
def before_task_start(self, task):
"""Called right before task start.
:param task_ex: DB model for task that is about to start.
:param task_spec: Task specification.
:param task: Engine task. Instance of engine.tasks.Task.
"""
wf_ex = task_ex.workflow_execution
ctx_view = data_flow.ContextView(
task_ex.in_context,
data_flow.get_current_task_dict(task_ex),
data_flow.get_workflow_environment_dict(wf_ex),
wf_ex.context,
wf_ex.input
)
data_flow.evaluate_object_fields(self, ctx_view)
utils.evaluate_object_fields(self, task.get_expression_context())
self._validate()
def after_task_complete(self, task_ex, task_spec):
def after_task_complete(self, task):
"""Called right after task completes.
:param task_ex: Completed task DB model.
:param task_spec: Completed task specification.
:param task: Engine task. Instance of engine.tasks.Task.
"""
wf_ex = task_ex.workflow_execution
ctx_view = data_flow.ContextView(
task_ex.in_context,
data_flow.get_current_task_dict(task_ex),
data_flow.get_workflow_environment_dict(wf_ex),
wf_ex.context,
wf_ex.input
)
data_flow.evaluate_object_fields(self, ctx_view)
utils.evaluate_object_fields(self, task.get_expression_context())
self._validate()

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from mistral.db import utils as db_utils
from mistral.db.v2 import api as db_api
from mistral.engine import base
@ -24,7 +26,6 @@ from mistral.utils import wf_trace
from mistral.workflow import data_flow
from mistral.workflow import states
import six
_CONTINUE_TASK_PATH = 'mistral.engine.policies._continue_task'
_COMPLETE_TASK_PATH = 'mistral.engine.policies._complete_task'
@ -33,14 +34,6 @@ _FAIL_IF_INCOMPLETE_TASK_PATH = (
)
def _log_task_delay(task_ex, delay_sec, state=states.RUNNING_DELAYED):
wf_trace.info(
task_ex,
"Task '%s' [%s -> %s, delay = %s sec]" %
(task_ex.name, task_ex.state, state, delay_sec)
)
def build_policies(policies_spec, wf_spec):
task_defaults = wf_spec.get_task_defaults()
wf_policies = task_defaults.get_policies() if task_defaults else None
@ -184,51 +177,37 @@ class WaitBeforePolicy(base.TaskPolicy):
def __init__(self, delay):
self.delay = delay
def before_task_start(self, task_ex, task_spec):
super(WaitBeforePolicy, self).before_task_start(task_ex, task_spec)
def before_task_start(self, task):
super(WaitBeforePolicy, self).before_task_start(task)
# No need to wait for a task if delay is 0
if self.delay == 0:
return
context_key = 'wait_before_policy'
ctx_key = 'wait_before_policy'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
policy_ctx = task.get_policy_context(ctx_key)
task_ex.runtime_context = runtime_context
policy_context = runtime_context[context_key]
if policy_context.get('skip'):
if policy_ctx.get('skip'):
# Unset state 'RUNNING_DELAYED'.
wf_trace.info(
task_ex,
"Task '%s' [%s -> %s]"
% (task_ex.name, states.RUNNING_DELAYED, states.RUNNING)
)
task_ex.state = states.RUNNING
task.set_state(states.RUNNING, None)
return
if task_ex.state != states.IDLE:
policy_context.update({'skip': True})
if task.get_state() != states.IDLE:
policy_ctx.update({'skip': True})
_log_task_delay(task_ex, self.delay)
task_ex.state = states.RUNNING_DELAYED
task.set_state(
states.RUNNING_DELAYED,
"Delayed by 'wait-before' policy [delay=%s]" % self.delay
)
sched = sched_base.get_system_scheduler()
job = sched_base.SchedulerJob(
run_after=self.delay,
func_name=_CONTINUE_TASK_PATH,
func_args={
'task_ex_id': task_ex.id
}
func_args={'task_ex_id': task.get_id()}
)
sched.schedule(job)
@ -247,41 +226,30 @@ class WaitAfterPolicy(base.TaskPolicy):
def __init__(self, delay):
self.delay = delay
def after_task_complete(self, task_ex, task_spec):
super(WaitAfterPolicy, self).after_task_complete(task_ex, task_spec)
def after_task_complete(self, task):
super(WaitAfterPolicy, self).after_task_complete(task)
# No need to postpone a task if delay is 0
if self.delay == 0:
return
context_key = 'wait_after_policy'
ctx_key = 'wait_after_policy'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
policy_ctx = task.get_policy_context(ctx_key)
task_ex.runtime_context = runtime_context
policy_context = runtime_context[context_key]
if policy_context.get('skip'):
if policy_ctx.get('skip'):
# Skip, already processed.
return
policy_context.update({'skip': True})
policy_ctx.update({'skip': True})
_log_task_delay(task_ex, self.delay)
end_state = task.get_state()
end_state_info = task.get_state_info()
end_state = task_ex.state
end_state_info = task_ex.state_info
# TODO(rakhmerov): Policies probably need to have tasks.Task
# interface in order to manage task state safely.
# Set task state to 'RUNNING_DELAYED'.
task_ex.state = states.RUNNING_DELAYED
task_ex.state_info = (
'Suspended by wait-after policy for %s seconds' % self.delay
task.set_state(
states.RUNNING_DELAYED,
"Delayed by 'wait-after' policy [delay=%s]" % self.delay
)
# Schedule to change task state to RUNNING again.
@ -291,7 +259,7 @@ class WaitAfterPolicy(base.TaskPolicy):
run_after=self.delay,
func_name=_COMPLETE_TASK_PATH,
func_args={
'task_ex_id': task_ex.id,
'task_ex_id': task.get_id(),
'state': end_state,
'state_info': end_state_info
}
@ -320,7 +288,7 @@ class RetryPolicy(base.TaskPolicy):
self._break_on_clause = break_on
self._continue_on_clause = continue_on
def after_task_complete(self, task_ex, task_spec):
def after_task_complete(self, task):
"""Possible Cases:
1. state = SUCCESS
@ -336,7 +304,7 @@ class RetryPolicy(base.TaskPolicy):
3. retry:count = 5, current:count = 4, state = ERROR
Iterations complete therefore state = #{state}, current:count = 4.
"""
super(RetryPolicy, self).after_task_complete(task_ex, task_spec)
super(RetryPolicy, self).after_task_complete(task)
# There is nothing to repeat
if self.count == 0:
@ -347,53 +315,42 @@ class RetryPolicy(base.TaskPolicy):
# then the retry_no in the runtime_context of the task_ex will not
# be updated accurately. To be exact, the retry_no will be one
# iteration behind.
ex = task_ex.executions # noqa
ex = task.task_ex.executions # noqa
context_key = 'retry_task_policy'
ctx_key = 'retry_task_policy'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
wf_ex = task_ex.workflow_execution
ctx_view = data_flow.ContextView(
data_flow.get_current_task_dict(task_ex),
data_flow.evaluate_task_outbound_context(task_ex),
wf_ex.context,
wf_ex.input
expr_ctx = task.get_expression_context(
ctx=data_flow.evaluate_task_outbound_context(task.task_ex)
)
continue_on_evaluation = expressions.evaluate(
self._continue_on_clause,
ctx_view
expr_ctx
)
break_on_evaluation = expressions.evaluate(
self._break_on_clause,
ctx_view
expr_ctx
)
task_ex.runtime_context = runtime_context
state = task_ex.state
state = task.get_state()
if not states.is_completed(state) or states.is_cancelled(state):
return
policy_context = runtime_context[context_key]
policy_ctx = task.get_policy_context(ctx_key)
retry_no = 0
if 'retry_no' in policy_context:
retry_no = policy_context['retry_no']
del policy_context['retry_no']
if 'retry_no' in policy_ctx:
retry_no = policy_ctx['retry_no']
del policy_ctx['retry_no']
retries_remain = retry_no < self.count
stop_continue_flag = (
task_ex.state == states.SUCCESS and
task.get_state() == states.SUCCESS and
not self._continue_on_clause
)
@ -403,55 +360,49 @@ class RetryPolicy(base.TaskPolicy):
)
break_triggered = (
task_ex.state == states.ERROR and
task.get_state() == states.ERROR and
break_on_evaluation
)
if not retries_remain or break_triggered or stop_continue_flag:
return
data_flow.invalidate_task_execution_result(task_ex)
task.invalidate_result()
policy_context['retry_no'] = retry_no + 1
runtime_context[context_key] = policy_context
policy_ctx['retry_no'] = retry_no + 1
task.touch_runtime_context()
# NOTE(vgvoleg): join tasks in direct workflows can't be
# retried as-is, because these tasks can't start without
# a correct logical state.
if hasattr(task_spec, "get_join") and task_spec.get_join():
if hasattr(task.task_spec, "get_join") and task.task_spec.get_join():
from mistral.engine import task_handler as t_h
_log_task_delay(task_ex, self.delay, states.WAITING)
task.set_state(
states.WAITING,
"Delayed by 'retry' policy [delay=%s]" % self.delay
)
task_ex.state = states.WAITING
t_h._schedule_refresh_task_state(task_ex.id, self.delay)
t_h._schedule_refresh_task_state(task.get_id(), self.delay)
return
_log_task_delay(task_ex, self.delay)
task_ex.state = states.RUNNING_DELAYED
task.set_state(
states.RUNNING_DELAYED,
"Delayed by 'retry' policy [delay=%s]" % self.delay
)
sched = sched_base.get_system_scheduler()
job = sched_base.SchedulerJob(
run_after=self.delay,
func_name=_CONTINUE_TASK_PATH,
func_args={'task_ex_id': task_ex.id}
func_args={'task_ex_id': task.get_id()}
)
sched.schedule(job)
@staticmethod
def refresh_runtime_context(task_ex):
runtime_context = task_ex.runtime_context or {}
retry_task_policy = runtime_context.get('retry_task_policy')
if retry_task_policy:
retry_task_policy['retry_no'] = 0
task_ex.runtime_context['retry_task_policy'] = retry_task_policy
class TimeoutPolicy(base.TaskPolicy):
_schema = {
@ -466,8 +417,8 @@ class TimeoutPolicy(base.TaskPolicy):
def __init__(self, timeout_sec):
self.delay = timeout_sec
def before_task_start(self, task_ex, task_spec):
super(TimeoutPolicy, self).before_task_start(task_ex, task_spec)
def before_task_start(self, task):
super(TimeoutPolicy, self).before_task_start(task)
# No timeout if delay is 0
if self.delay == 0:
@ -479,7 +430,7 @@ class TimeoutPolicy(base.TaskPolicy):
run_after=self.delay,
func_name=_FAIL_IF_INCOMPLETE_TASK_PATH,
func_args={
'task_ex_id': task_ex.id,
'task_ex_id': task.get_id(),
'timeout': self.delay
}
)
@ -487,9 +438,9 @@ class TimeoutPolicy(base.TaskPolicy):
sched.schedule(job)
wf_trace.info(
task_ex,
task.task_ex,
"Timeout check scheduled [task=%s, timeout(s)=%s]." %
(task_ex.id, self.delay)
(task.get_id(), self.delay)
)
@ -503,20 +454,25 @@ class PauseBeforePolicy(base.TaskPolicy):
def __init__(self, expression):
self.expr = expression
def before_task_start(self, task_ex, task_spec):
super(PauseBeforePolicy, self).before_task_start(task_ex, task_spec)
def before_task_start(self, task):
super(PauseBeforePolicy, self).before_task_start(task)
if not self.expr:
return
wf_trace.info(
task_ex,
task.task_ex,
"Workflow paused before task '%s' [%s -> %s]" %
(task_ex.name, task_ex.workflow_execution.state, states.PAUSED)
(
task.get_name(),
task.wf_ex.state,
states.PAUSED
)
)
task_ex.state = states.IDLE
wf_handler.pause_workflow(task_ex.workflow_execution)
task.set_state(states.IDLE, "Set by 'pause-before' policy")
wf_handler.pause_workflow(task.wf_ex)
class ConcurrencyPolicy(base.TaskPolicy):
@ -532,8 +488,8 @@ class ConcurrencyPolicy(base.TaskPolicy):
def __init__(self, concurrency):
self.concurrency = concurrency
def before_task_start(self, task_ex, task_spec):
super(ConcurrencyPolicy, self).before_task_start(task_ex, task_spec)
def before_task_start(self, task):
super(ConcurrencyPolicy, self).before_task_start(task)
if self.concurrency == 0:
return
@ -542,15 +498,9 @@ class ConcurrencyPolicy(base.TaskPolicy):
# property value and setting a variable into task runtime context.
# This variable is then used to define how many action executions
# may be started in parallel.
context_key = 'concurrency'
ctx_key = 'concurrency'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
runtime_context[context_key] = self.concurrency
task_ex.runtime_context = runtime_context
task.set_runtime_context_value(ctx_key, self.concurrency)
class FailOnPolicy(base.TaskPolicy):
@ -563,18 +513,17 @@ class FailOnPolicy(base.TaskPolicy):
def __init__(self, fail_on):
self.fail_on = fail_on
def before_task_start(self, task_ex, task_spec):
def before_task_start(self, task):
pass
def after_task_complete(self, task_ex, task_spec):
if task_ex.state != states.SUCCESS:
def after_task_complete(self, task):
super(FailOnPolicy, self).after_task_complete(task)
if task.get_state() != states.SUCCESS:
return
super(FailOnPolicy, self).after_task_complete(task_ex, task_spec)
if self.fail_on:
task_ex.state = states.ERROR
task_ex.state_info = 'Failed by fail-on policy'
task.set_state(states.ERROR, "Failed by 'fail-on' policy")
@db_utils.retry_on_db_error

View File

@ -83,7 +83,7 @@ def run_task(wf_cmd):
def mark_task_running(task_ex, wf_spec):
task = _build_task_from_execution(wf_spec, task_ex)
task = build_task_from_execution(wf_spec, task_ex)
old_task_state = task_ex.state
@ -209,7 +209,7 @@ def force_fail_task(task_ex, msg, task=None):
task_ex.workflow_execution_id
)
task = _build_task_from_execution(wf_spec, task_ex)
task = build_task_from_execution(wf_spec, task_ex)
old_task_state = task_ex.state
task.set_state(states.ERROR, msg)
@ -228,7 +228,7 @@ def continue_task(task_ex):
task_ex.workflow_execution_id
)
task = _build_task_from_execution(wf_spec, task_ex)
task = build_task_from_execution(wf_spec, task_ex)
try:
task.set_state(states.RUNNING, None)
@ -257,7 +257,7 @@ def complete_task(task_ex, state, state_info):
task_ex.workflow_execution_id
)
task = _build_task_from_execution(wf_spec, task_ex)
task = build_task_from_execution(wf_spec, task_ex)
try:
task.complete(state, state_info)
@ -324,7 +324,7 @@ def _check_affected_tasks(task):
)
def _build_task_from_execution(wf_spec, task_ex):
def build_task_from_execution(wf_spec, task_ex):
return _create_task(
task_ex.workflow_execution,
wf_spec,

View File

@ -90,34 +90,46 @@ class Task(object):
if not filtered_publishers:
return
def _convert_to_notification_data():
return {
"id": self.task_ex.id,
"name": self.task_ex.name,
"workflow_execution_id": self.task_ex.workflow_execution_id,
"workflow_name": self.task_ex.workflow_name,
"workflow_namespace": self.task_ex.workflow_namespace,
"workflow_id": self.task_ex.workflow_id,
"state": self.task_ex.state,
"state_info": self.task_ex.state_info,
"type": self.task_ex.type,
"project_id": self.task_ex.project_id,
"created_at": utils.datetime_to_str(self.task_ex.created_at),
"updated_at": utils.datetime_to_str(self.task_ex.updated_at),
"started_at": utils.datetime_to_str(self.task_ex.started_at),
"finished_at": utils.datetime_to_str(self.task_ex.finished_at)
}
data = {
"id": self.task_ex.id,
"name": self.task_ex.name,
"workflow_execution_id": self.task_ex.workflow_execution_id,
"workflow_name": self.task_ex.workflow_name,
"workflow_namespace": self.task_ex.workflow_namespace,
"workflow_id": self.task_ex.workflow_id,
"state": self.task_ex.state,
"state_info": self.task_ex.state_info,
"type": self.task_ex.type,
"project_id": self.task_ex.project_id,
"created_at": utils.datetime_to_str(self.task_ex.created_at),
"updated_at": utils.datetime_to_str(self.task_ex.updated_at),
"started_at": utils.datetime_to_str(self.task_ex.started_at),
"finished_at": utils.datetime_to_str(self.task_ex.finished_at)
}
def _send_notification():
notifier.notify(
self.task_ex.id,
_convert_to_notification_data(),
data,
event,
self.task_ex.updated_at,
filtered_publishers
)
post_tx_queue.register_operation(_send_notification)
def get_id(self):
return self.task_ex.id if self.task_ex else None
def get_name(self):
return self.task_ex.name if self.task_ex else None
def get_state(self):
return self.task_ex.state if self.task_ex else None
def get_state_info(self):
return self.task_ex.state_info if self.task_ex else None
def is_completed(self):
return self.task_ex and states.is_completed(self.task_ex.state)
@ -130,6 +142,85 @@ class Task(object):
def is_state_changed(self):
return self.state_changed
def get_expression_context(self, ctx=None):
assert self.task_ex
return data_flow.ContextView(
data_flow.get_current_task_dict(self.task_ex),
data_flow.get_workflow_environment_dict(self.wf_ex),
ctx or {},
self.task_ex.in_context,
self.wf_ex.context,
self.wf_ex.input,
)
def evaluate(self, data, ctx=None):
"""Evaluates data against the task context.
The data is evaluated against the standard task context that includes
all standard components like workflow environment, workflow input etc.
However, if needed, the method also takes an additional context that
can be virtually merged with the standard one.
:param data: Data (a string, dict or a list) that possibly contain
YAQL/Jinja expressions to evaluate.
:param ctx: Additional context.
:return:
"""
return expr.evaluate_recursively(
data,
self.get_expression_context(ctx)
)
def set_runtime_context_value(self, key, value):
assert self.task_ex
if not self.task_ex.runtime_context:
self.task_ex.runtime_context = {}
self.task_ex.runtime_context[key] = value
def touch_runtime_context(self):
"""Must be called after any update of the runtime context.
The ORM framework can't trace updates happening within
deep structures like dictionaries. So every time we update
something inside "runtime_context" of the task execution
we need to call this method that does a fake update of the
field.
"""
runtime_ctx = self.task_ex.runtime_context
if runtime_ctx:
random_key = next(iter(runtime_ctx.keys()))
runtime_ctx[random_key] = runtime_ctx[random_key]
def cleanup_runtime_context(self):
runtime_context = self.task_ex.runtime_context
if runtime_context:
runtime_context.clear()
def get_policy_context(self, key):
assert self.task_ex
if not self.task_ex.runtime_context:
self.task_ex.runtime_context = {}
if key not in self.task_ex.runtime_context:
self.task_ex.runtime_context.update({key: {}})
return self.task_ex.runtime_context[key]
def invalidate_result(self):
if not self.task_ex:
return
for ex in self.task_ex.executions:
ex.accepted = False
@abc.abstractmethod
def on_action_complete(self, action_ex):
"""Handle action completion.
@ -396,29 +487,24 @@ class Task(object):
policies_spec = self.task_spec.get_policies()
for p in policies.build_policies(policies_spec, self.wf_spec):
p.before_task_start(self.task_ex, self.task_spec)
p.before_task_start(self)
def _after_task_complete(self):
policies_spec = self.task_spec.get_policies()
for p in policies.build_policies(policies_spec, self.wf_spec):
p.after_task_complete(self.task_ex, self.task_spec)
p.after_task_complete(self)
@profiler.trace('task-create-task-execution')
def _create_task_execution(self, state=states.RUNNING, state_info=None):
task_id = utils.generate_unicode_uuid()
task_name = self.task_spec.get_name()
task_type = self.task_spec.get_type()
task_tags = self.task_spec.get_tags()
values = {
'id': task_id,
'name': task_name,
'id': utils.generate_unicode_uuid(),
'name': self.task_spec.get_name(),
'workflow_execution_id': self.wf_ex.id,
'workflow_name': self.wf_ex.workflow_name,
'workflow_namespace': self.wf_ex.workflow_namespace,
'workflow_id': self.wf_ex.workflow_id,
'tags': task_tags,
'tags': self.task_spec.get_tags(),
'state': state,
'state_info': state_info,
'spec': self.task_spec.to_dict(),
@ -427,7 +513,7 @@ class Task(object):
'published': {},
'runtime_context': {},
'project_id': self.wf_ex.project_id,
'type': task_type
'type': self.task_spec.get_type()
}
if self.triggered_by:
@ -647,7 +733,7 @@ class RegularTask(Task):
input_spec = self.task_spec.get_input()
input_dict = (
self._evaluate_expression(input_spec, ctx) if input_spec else {}
self.evaluate(input_spec, ctx) if input_spec else {}
)
if not isinstance(input_dict, dict):
@ -663,18 +749,6 @@ class RegularTask(Task):
overwrite=False
)
def _evaluate_expression(self, expression, ctx=None):
ctx_view = data_flow.ContextView(
data_flow.get_current_task_dict(self.task_ex),
data_flow.get_workflow_environment_dict(self.wf_ex),
ctx or {},
self.task_ex.in_context,
self.wf_ex.context,
self.wf_ex.input,
)
return expr.evaluate_recursively(expression, ctx_view)
def _build_action(self):
action_name = self.task_spec.get_action_name()
wf_name = self.task_spec.get_workflow_name()
@ -682,13 +756,13 @@ class RegularTask(Task):
# For dynamic workflow evaluation we regenerate the action.
if wf_name:
return actions.WorkflowAction(
wf_name=self._evaluate_expression(wf_name),
wf_name=self.evaluate(wf_name),
task_ex=self.task_ex
)
# For dynamic action evaluation we just regenerate the name.
if action_name:
action_name = self._evaluate_expression(action_name)
action_name = self.evaluate(action_name)
if not action_name:
action_name = 'std.noop'
@ -701,9 +775,12 @@ class RegularTask(Task):
)
if action_def.spec:
return actions.AdHocAction(action_def, task_ex=self.task_ex,
task_ctx=self.ctx,
wf_ctx=self.wf_ex.context)
return actions.AdHocAction(
action_def,
task_ex=self.task_ex,
task_ctx=self.ctx,
wf_ctx=self.wf_ex.context
)
return actions.PythonAction(action_def, task_ex=self.task_ex)
@ -833,7 +910,7 @@ class WithItemsTask(RegularTask):
:return: Evaluated 'with-items' expression values.
"""
exp_res = self._evaluate_expression(self.task_spec.get_with_items())
exp_res = self.evaluate(self.task_spec.get_with_items())
# Expression result may contain iterables instead of lists in the
# dictionary values. So we need to convert them into lists and
@ -977,6 +1054,7 @@ class WithItemsTask(RegularTask):
indices += list(six.moves.range(max(candidates) + 1, count))
else:
i = self._get_next_start_index()
indices = list(six.moves.range(i, count))
return indices[:capacity]

View File

@ -202,6 +202,7 @@ def pause_workflow(wf_ex, msg=None):
# If any subworkflows failed to pause for temporary reason, this
# allows pause to be executed again on the main workflow.
wf = workflows.Workflow(wf_ex=wf_ex)
wf.pause(msg=msg)
@ -209,9 +210,14 @@ def rerun_workflow(wf_ex, task_ex, reset=True, env=None):
if wf_ex.state == states.PAUSED:
return wf_ex.get_clone()
# To break cyclic dependency.
from mistral.engine import task_handler
wf = workflows.Workflow(wf_ex=wf_ex)
wf.rerun(task_ex, reset=reset, env=env)
task = task_handler.build_task_from_execution(wf.wf_spec, task_ex)
wf.rerun(task, reset=reset, env=env)
_schedule_check_and_fix_integrity(
wf_ex,
@ -242,6 +248,7 @@ def resume_workflow(wf_ex, env=None):
# Resume current workflow here so to trigger continue workflow only
# after all other subworkflows are placed back in running state.
wf = workflows.Workflow(wf_ex=wf_ex)
wf.resume(env=env)

View File

@ -88,26 +88,25 @@ class Workflow(object):
if not filtered_publishers:
return
def _convert_to_notification_data():
return {
"id": self.wf_ex.id,
"name": self.wf_ex.name,
"workflow_name": self.wf_ex.workflow_name,
"workflow_namespace": self.wf_ex.workflow_namespace,
"workflow_id": self.wf_ex.workflow_id,
"state": self.wf_ex.state,
"state_info": self.wf_ex.state_info,
"project_id": self.wf_ex.project_id,
"task_execution_id": self.wf_ex.task_execution_id,
"root_execution_id": self.wf_ex.root_execution_id,
"created_at": utils.datetime_to_str(self.wf_ex.created_at),
"updated_at": utils.datetime_to_str(self.wf_ex.updated_at)
}
data = {
"id": self.wf_ex.id,
"name": self.wf_ex.name,
"workflow_name": self.wf_ex.workflow_name,
"workflow_namespace": self.wf_ex.workflow_namespace,
"workflow_id": self.wf_ex.workflow_id,
"state": self.wf_ex.state,
"state_info": self.wf_ex.state_info,
"project_id": self.wf_ex.project_id,
"task_execution_id": self.wf_ex.task_execution_id,
"root_execution_id": self.wf_ex.root_execution_id,
"created_at": utils.datetime_to_str(self.wf_ex.created_at),
"updated_at": utils.datetime_to_str(self.wf_ex.updated_at)
}
def _send_notification():
notifier.notify(
self.wf_ex.id,
_convert_to_notification_data(),
data,
event,
self.wf_ex.updated_at,
filtered_publishers
@ -248,10 +247,11 @@ class Workflow(object):
self.wf_spec.__class__.__name__
)
def rerun(self, task_ex, reset=True, env=None):
def rerun(self, task, reset=True, env=None):
"""Rerun workflow from the given task.
:param task_ex: Task execution that the workflow needs to rerun from.
:param task: An engine task associated with the task the workflow
needs to rerun from.
:param reset: If True, reset task state including deleting its action
executions.
:param env: Environment.
@ -266,13 +266,10 @@ class Workflow(object):
wf_ctrl = wf_base.get_controller(self.wf_ex)
# Calculate commands to process next.
cmds = wf_ctrl.rerun_tasks([task_ex], reset=reset)
cmds = wf_ctrl.rerun_tasks([task.task_ex], reset=reset)
if cmds:
# Import the task_handler module here to avoid circular reference.
from mistral.engine import policies
policies.RetryPolicy.refresh_runtime_context(task_ex)
task.cleanup_runtime_context()
self._continue_workflow(cmds)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from eventlet import timeout
from unittest import mock
from oslo_config import cfg
@ -22,6 +21,7 @@ from mistral.actions import std_actions
from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models
from mistral.engine import policies
from mistral.engine import tasks
from mistral import exceptions as exc
from mistral.lang import parser as spec_parser
from mistral.rpc import clients as rpc
@ -345,6 +345,7 @@ class PoliciesTest(base.EngineTestCase):
)
self.assertEqual(4, len(arr))
p = self._assert_single_item(arr, delay=2)
self.assertIsInstance(p, policies.WaitBeforePolicy)
@ -384,8 +385,10 @@ class PoliciesTest(base.EngineTestCase):
policy.delay = "<% $.int_var %>"
engine_task = tasks.RegularTask(wf_ex, None, None, {}, task_ex)
# Validation is ok.
policy.before_task_start(task_ex, None)
policy.before_task_start(engine_task)
policy.delay = "some_string"
@ -393,8 +396,7 @@ class PoliciesTest(base.EngineTestCase):
exception = self.assertRaises(
exc.InvalidModelException,
policy.before_task_start,
task_ex,
None
engine_task,
)
self.assertIn("Invalid data type in TaskPolicy", str(exception))
@ -597,6 +599,7 @@ class PoliciesTest(base.EngineTestCase):
pass
else:
self.fail("Shouldn't happen")
self.await_task_success(task_ex.id)
def test_wait_after_policy_negative_number(self):
@ -652,6 +655,7 @@ class PoliciesTest(base.EngineTestCase):
pass
else:
self.fail("Shouldn't happen")
self.await_task_success(task_ex.id)
def test_wait_after_policy_from_var_negative_number(self):
@ -1257,10 +1261,15 @@ class PoliciesTest(base.EngineTestCase):
with db_api.transaction():
# Note: We need to reread execution to access related tasks.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
tasks = wf_ex.task_executions
self._assert_single_item(tasks, name="task2", state=states.ERROR)
self._assert_single_item(tasks, name="join_task", state=states.ERROR)
task_execs = wf_ex.task_executions
self._assert_single_item(task_execs, name="task2", state=states.ERROR)
self._assert_single_item(
task_execs,
name="join_task",
state=states.ERROR
)
def test_retry_join_task_after_idle_task(self):
retry_wb = """---
@ -1293,10 +1302,15 @@ class PoliciesTest(base.EngineTestCase):
with db_api.transaction():
# Note: We need to reread execution to access related tasks.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
tasks = wf_ex.task_executions
self._assert_single_item(tasks, name="task2", state=states.ERROR)
self._assert_single_item(tasks, name="join_task", state=states.ERROR)
task_execs = wf_ex.task_executions
self._assert_single_item(task_execs, name="task2", state=states.ERROR)
self._assert_single_item(
task_execs,
name="join_task",
state=states.ERROR
)
@mock.patch.object(
std_actions.EchoAction,
@ -1327,6 +1341,7 @@ class PoliciesTest(base.EngineTestCase):
"""
wf_service.create_workflows(retry_wf)
wf_ex = self.engine.start_workflow('wf1')
self.await_workflow_success(wf_ex.id)
@ -1364,6 +1379,7 @@ class PoliciesTest(base.EngineTestCase):
"""
wf_service.create_workflows(retry_wf)
wf_ex = self.engine.start_workflow('repeated_retry')
self.await_workflow_running(wf_ex.id)
@ -1372,15 +1388,18 @@ class PoliciesTest(base.EngineTestCase):
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_ex = wf_ex.task_executions[0]
self.await_task_running(task_ex.id)
first_action_ex = task_ex.executions[0]
self.await_action_state(first_action_ex.id, states.RUNNING)
complete_action_params = (
first_action_ex.id,
ml_actions.Result(error="mock")
)
rpc.get_engine_client().on_action_complete(*complete_action_params)
for _ in range(2):
@ -1391,6 +1410,7 @@ class PoliciesTest(base.EngineTestCase):
)
self.await_task_running(task_ex.id)
with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex.id)
action_exs = task_ex.executions
@ -1572,6 +1592,7 @@ class PoliciesTest(base.EngineTestCase):
def test_retry_with_input(self):
wf_text = """---
version: '2.0'
wf1:
tasks:
task1:
@ -1580,16 +1601,19 @@ class PoliciesTest(base.EngineTestCase):
- task2
publish:
success: task4
task2:
action: std.noop
on-success:
- task4: <% $.success = 'task4' %>
- task5: <% $.success = 'task5' %>
task4:
on-complete:
- task5
publish:
param: data
task5:
action: std.echo
input:
@ -1605,7 +1629,9 @@ class PoliciesTest(base.EngineTestCase):
delay: 1
"""
wf_service.create_workflows(wf_text)
wf_ex = self.engine.start_workflow('wf1')
self.await_workflow_success(wf_ex.id)
def test_action_timeout(self):
@ -1619,6 +1645,7 @@ class PoliciesTest(base.EngineTestCase):
"""
wf_service.create_workflows(wf_text)
wf_ex = self.engine.start_workflow('wf1')
with db_api.transaction():
@ -1626,10 +1653,9 @@ class PoliciesTest(base.EngineTestCase):
task_ex = wf_ex.task_executions[0]
action_ex = task_ex.action_executions[0]
with timeout.Timeout(8):
self.await_workflow_error(wf_ex.id)
self.await_task_error(task_ex.id)
self.await_action_error(action_ex.id)
self.await_workflow_error(wf_ex.id, delay=8)
self.await_task_error(task_ex.id, delay=8)
self.await_action_error(action_ex.id, delay=8)
def test_pause_before_policy(self):
wb_service.create_workbook_v2(PAUSE_BEFORE_WB)
@ -1881,10 +1907,11 @@ class PoliciesTest(base.EngineTestCase):
def test_retry_policy_break_on_with_dict(self, run_method):
run_method.return_value = types.Result(error={'key-1': 15})
wf_retry_break_on_with_dictionary = """---
wb_text = """---
version: '2.0'
name: wb
workflows:
wf1:
tasks:
@ -1896,7 +1923,7 @@ class PoliciesTest(base.EngineTestCase):
break-on: <% task().result['key-1'] = 15 %>
"""
wb_service.create_workbook_v2(wf_retry_break_on_with_dictionary)
wb_service.create_workbook_v2(wb_text)
# Start workflow.
wf_ex = self.engine.start_workflow('wb.wf1')
@ -1932,6 +1959,7 @@ class PoliciesTest(base.EngineTestCase):
# Start workflow.
wf_ex = self.engine.start_workflow('wb.wf1')
self.await_workflow_error(wf_ex.id)
with db_api.transaction():
@ -1957,6 +1985,7 @@ class PoliciesTest(base.EngineTestCase):
# Start workflow.
wf_ex = self.engine.start_workflow('wb.wf1')
self.await_workflow_success(wf_ex.id)
with db_api.transaction():
@ -1983,6 +2012,7 @@ class PoliciesTest(base.EngineTestCase):
# Start workflow.
wf_ex = self.engine.start_workflow('wb.wf1')
self.await_workflow_error(wf_ex.id)
with db_api.transaction():
@ -2016,6 +2046,7 @@ class PoliciesTest(base.EngineTestCase):
# Start workflow.
wf_ex = self.engine.start_workflow('wb.wf1')
self.await_workflow_success(wf_ex.id)
with db_api.transaction():
@ -2054,6 +2085,7 @@ class PoliciesTest(base.EngineTestCase):
# Start workflow.
wf_ex = self.engine.start_workflow('wb.wf1')
self.await_workflow_success(wf_ex.id)
with db_api.transaction():

View File

@ -609,7 +609,7 @@ class WithItemsEngineTest(base.EngineTestCase):
self.assertIn(task1_ex.published['result'], ['Guy'])
def test_with_items_concurrency_1(self):
wf_with_concurrency_1 = """---
wf_text = """---
version: "2.0"
wf:
@ -623,7 +623,7 @@ class WithItemsEngineTest(base.EngineTestCase):
concurrency: 1
"""
wf_service.create_workflows(wf_with_concurrency_1)
wf_service.create_workflows(wf_text)
# Start workflow.
wf_ex = self.engine.start_workflow('wf')

View File

@ -42,7 +42,6 @@ def log_event(ctx, ex_id, data, event, timestamp, **kwargs):
class NotifyEventsTest(base.NotifierTestCase):
def setUp(self):
super(NotifyEventsTest, self).setUp()
@ -57,12 +56,14 @@ class NotifyEventsTest(base.NotifierTestCase):
self.publishers['noop'].publish.reset_mock()
del EVENT_LOGS[:]
cfg.CONF.set_default('type', 'local', group='notifier')
def tearDown(self):
cfg.CONF.set_default('notify', None, group='notifier')
super(NotifyEventsTest, self).tearDown()
cfg.CONF.set_default('notify', None, group='notifier')
def test_notify_all_explicit(self):
wf_def = """
version: '2.0'
@ -93,6 +94,7 @@ class NotifyEventsTest(base.NotifierTestCase):
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_exs = wf_ex.task_executions
self.assertEqual(states.SUCCESS, wf_ex.state)

View File

@ -24,7 +24,11 @@ import threading
from oslo_concurrency import processutils
from oslo_serialization import jsonutils
from mistral_lib.utils import inspect_utils
from mistral import exceptions as exc
from mistral import expressions as expr
# Thread local storage.
_th_loc_storage = threading.local()
@ -152,3 +156,21 @@ def from_json_str(json_str):
return None
return jsonutils.loads(json_str)
def evaluate_object_fields(obj, ctx):
"""Evaluates all expressions recursively contained in the object fields.
Some of the given object fields may be strings or data structures that
contain YAQL/Jinja expressions. The method evaluates them and updates
the corresponding object fields with the evaluated values.
:param obj: The object to inspect.
:param ctx: Expression context.
"""
fields = inspect_utils.get_public_fields(obj)
evaluated_fields = expr.evaluate_recursively(fields, ctx)
for k, v in evaluated_fields.items():
setattr(obj, k, v)

View File

@ -19,6 +19,7 @@ from oslo_log import log as logging
from mistral.db.v2.sqlalchemy import models
cfg.CONF.import_opt('workflow_trace_log_name', 'mistral.config')
WF_TRACE = logging.getLogger(cfg.CONF.workflow_trace_log_name)

View File

@ -24,7 +24,6 @@ from mistral import expressions as expr
from mistral.lang import parser as spec_parser
from mistral.workflow import states
from mistral_lib import utils
from mistral_lib.utils import inspect_utils
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
@ -157,16 +156,10 @@ def _extract_execution_result(ex):
return ex.output['result']
def invalidate_task_execution_result(task_ex):
for ex in task_ex.executions:
ex.accepted = False
def get_task_execution_result(task_ex):
execs = task_ex.executions
execs.sort(
key=lambda x: x.runtime_context.get('index')
)
execs.sort(key=lambda x: x.runtime_context.get('index'))
results = [
_extract_execution_result(ex)
@ -327,15 +320,6 @@ def add_workflow_variables_to_context(wf_ex, wf_spec):
utils.merge_dicts(wf_ex.context, wf_vars)
def evaluate_object_fields(obj, context):
fields = inspect_utils.get_public_fields(obj)
evaluated_fields = expr.evaluate_recursively(fields, context)
for k, v in evaluated_fields.items():
setattr(obj, k, v)
def get_workflow_environment_dict(wf_ex):
if not wf_ex:
return {}