mistral/mistral/workflow/direct_workflow.py

554 lines
18 KiB
Python

# Copyright 2015 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from oslo_log import log as logging
from osprofiler import profiler
from mistral.db.v2 import api as db_api
from mistral import exceptions as exc
from mistral import expressions as expr
from mistral.workflow import base
from mistral.workflow import commands
from mistral.workflow import data_flow
from mistral.workflow import states
from mistral_lib import utils
LOG = logging.getLogger(__name__)
MAX_SEARCH_DEPTH = 5
class DirectWorkflowController(base.WorkflowController):
"""'Direct workflow' controller.
This handler implements the workflow pattern which is based on
direct transitions between tasks, i.e. after each task completion
a decision should be made which tasks should run next based on
result of task execution.
Note, that tasks can run in parallel. For example, if there's a workflow
consisting of three tasks 'A', 'B' and 'C' where 'A' starts first then
'B' and 'C' can start second if certain associated with transition
'A'->'B' and 'A'->'C' evaluate to true.
"""
__workflow_type__ = "direct"
def _get_upstream_task_executions(self, task_spec):
t_specs_names = [t_spec.get_name() for t_spec in
self.wf_spec.find_inbound_task_specs(task_spec)]
if not t_specs_names:
return []
if not task_spec.get_join():
return self._get_task_executions(
name=t_specs_names[0], # not a join, has just one parent
state={'in': (states.SUCCESS, states.ERROR, states.CANCELLED)},
processed=True
)
t_execs_candidates = self._get_task_executions(
name={'in': t_specs_names},
state={'in': (states.SUCCESS, states.ERROR, states.CANCELLED)},
)
t_execs = []
for t_ex in t_execs_candidates:
if task_spec.get_name() in [t[0] for t in t_ex.next_tasks]:
t_execs.append(t_ex)
return t_execs
def _find_next_commands(self, task_ex=None):
cmds = super(DirectWorkflowController, self)._find_next_commands(
task_ex
)
# Checking if task_ex is empty is a serious optimization here
# because 'self.wf_ex.task_executions' leads to initialization of
# the entire collection which in case of highly-parallel workflows
# may be very expensive.
if not task_ex and not self.wf_ex.task_executions:
return self._find_start_commands()
if task_ex:
task_execs = [task_ex]
else:
task_execs = [
t_ex for t_ex in self.wf_ex.task_executions
if states.is_completed(t_ex.state) and not t_ex.processed
]
for t_ex in task_execs:
cmds.extend(self._find_next_commands_for_task(t_ex))
return cmds
def _find_start_commands(self):
return [
commands.RunTask(
self.wf_ex,
self.wf_spec,
t_s,
self.get_task_inbound_context(t_s)
)
for t_s in self.wf_spec.find_start_tasks()
]
@profiler.trace(
'direct-wf-controller-find-next-commands-for-task',
hide_args=True
)
def _find_next_commands_for_task(self, task_ex):
"""Finds next commands based on the state of the given task.
:param task_ex: Task execution for which next commands need
to be found.
:return: List of workflow commands.
"""
cmds = []
ctx = data_flow.evaluate_task_outbound_context(task_ex)
for t_n, params, event_name in self._find_next_tasks(task_ex, ctx):
t_s = self.wf_spec.get_tasks()[t_n]
if not (t_s or t_n in commands.ENGINE_CMD_CLS):
raise exc.WorkflowException("Task '%s' not found." % t_n)
elif not t_s:
t_s = self.wf_spec.get_task(task_ex.name)
triggered_by = [
{
'task_id': task_ex.id,
'event': event_name
}
]
cmd = commands.create_command(
t_n,
self.wf_ex,
self.wf_spec,
t_s,
ctx,
params=params,
triggered_by=triggered_by,
handles_error=(event_name == 'on-error')
)
self._configure_if_join(cmd)
cmds.append(cmd)
LOG.debug("Found commands: %s", cmds)
return cmds
def _configure_if_join(self, cmd):
if not isinstance(cmd, (commands.RunTask, commands.RunExistingTask)):
return
if not cmd.task_spec.get_join():
return
cmd.unique_key = self._get_join_unique_key(cmd)
cmd.wait = True
def _get_join_unique_key(self, cmd):
return 'join-task-%s-%s' % (self.wf_ex.id, cmd.task_spec.get_name())
def rerun_tasks(self, task_execs, reset=True):
cmds = super(DirectWorkflowController, self).rerun_tasks(
task_execs,
reset
)
for cmd in cmds:
self._configure_if_join(cmd)
return cmds
# TODO(rakhmerov): Need to refactor this method to be able to pass tasks
# whose contexts need to be merged.
def evaluate_workflow_final_context(self):
ctx = {}
for batch in self._find_end_task_executions_as_batches():
for t_ex in batch:
ctx = utils.merge_dicts(
ctx,
data_flow.evaluate_task_outbound_context(t_ex)
)
return ctx
def get_logical_task_state(self, task_ex):
task_spec = self.wf_spec.get_tasks()[task_ex.name]
if not task_spec.get_join():
# A simple 'non-join' task does not have any preconditions
# based on state of other tasks so its logical state always
# equals to its real state.
return base.TaskLogicalState(task_ex.state, task_ex.state_info)
return self._get_join_logical_state(task_spec)
def find_indirectly_affected_task_executions(self, t_name):
all_joins = {task_spec.get_name()
for task_spec in self.wf_spec.get_tasks()
if task_spec.get_join()}
t_execs_cache = {
t_ex.name: t_ex for t_ex in self._get_task_executions(
fields=('id', 'name'),
name={'in': all_joins}
)
} if all_joins else {}
visited_task_names = set()
clauses = self.wf_spec.find_outbound_task_names(t_name)
res = set()
while clauses:
visited_task_names.add(t_name)
t_name = clauses.pop()
# Handle cycles.
if t_name in visited_task_names:
continue
# Encountered an engine command.
if not self.wf_spec.get_tasks()[t_name]:
continue
if t_name in all_joins:
if t_name in t_execs_cache:
res.add(t_execs_cache[t_name])
continue
clauses.update(self.wf_spec.find_outbound_task_names(t_name))
return res
def is_error_handled_for(self, task_ex):
# TODO(rakhmerov): The method works in a different way than
# all_errors_handled(). It doesn't evaluate expressions under
# "on-error" clause.
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
def all_errors_handled(self):
cnt = db_api.get_task_executions_count(
workflow_execution_id=self.wf_ex.id,
state=states.ERROR,
error_handled=False
)
return cnt == 0
def _find_end_task_executions_as_batches(self):
batches = db_api.get_completed_task_executions_as_batches(
workflow_execution_id=self.wf_ex.id,
has_next_tasks=False
)
for batch in batches:
yield batch
def may_complete_workflow(self, task_ex):
res = super(DirectWorkflowController, self).may_complete_workflow(
task_ex
)
return res and not task_ex.has_next_tasks
@profiler.trace('direct-wf-controller-find-next-tasks', hide_args=True)
def _find_next_tasks(self, task_ex, ctx):
t_n = task_ex.name
t_s = task_ex.state
ctx_view = data_flow.ContextView(
data_flow.get_current_task_dict(task_ex),
ctx,
data_flow.get_workflow_environment_dict(self.wf_ex),
self.wf_ex.context,
self.wf_ex.input
)
# [(task_name, params, 'on-success'|'on-error'|'on-complete'), ...]
result = []
if t_s == states.ERROR:
for name, cond, params in self.wf_spec.get_on_error_clause(t_n):
if not cond or expr.evaluate(cond, ctx_view):
params = expr.evaluate_recursively(params, ctx_view)
result.append((name, params, 'on-error'))
if t_s == states.SUCCESS:
for name, cond, params in self.wf_spec.get_on_success_clause(t_n):
if not cond or expr.evaluate(cond, ctx_view):
params = expr.evaluate_recursively(params, ctx_view)
result.append((name, params, 'on-success'))
if states.is_completed(t_s) and not states.is_cancelled(t_s):
for name, cond, params in self.wf_spec.get_on_complete_clause(t_n):
if not cond or expr.evaluate(cond, ctx_view):
params = expr.evaluate_recursively(params, ctx_view)
result.append((name, params, 'on-complete'))
return result
@profiler.trace(
'direct-wf-controller-get-join-logical-state',
hide_args=True
)
def _get_join_logical_state(self, task_spec):
"""Evaluates logical state of 'join' task.
:param task_spec: 'join' task specification.
:return: TaskLogicalState (state, state_info, cardinality,
triggered_by) where 'state' and 'state_info' describe the logical
state of the given 'join' task and 'cardinality' gives the
remaining number of unfulfilled preconditions. If logical state
is not WAITING then 'cardinality' should always be 0.
"""
# TODO(rakhmerov): We need to use task_ex instead of task_spec
# in order to cover a use case when there's more than one instance
# of the same 'join' task in a workflow.
join_expr = task_spec.get_join()
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return base.TaskLogicalState(states.RUNNING)
t_execs_cache = self._prepare_task_executions_cache(task_spec)
# List of tuples (task_name, task_ex, state, depth, event_name).
induced_states = []
for t_s in in_task_specs:
t_ex = t_execs_cache[t_s.get_name()]
tup = self._get_induced_join_state(
t_s,
t_ex,
task_spec,
t_execs_cache
)
induced_states.append(
(
t_s.get_name(),
t_ex,
tup[0],
tup[1],
tup[2]
)
)
def count(state):
cnt = 0
total_depth = 0
for s in induced_states:
if s[2] == state:
cnt += 1
total_depth += s[3]
return cnt, total_depth
errors_tuple = count(states.ERROR)
runnings_tuple = count(states.RUNNING)
total_count = len(induced_states)
def _blocked_message():
return (
'Blocked by tasks: %s' %
[s[0] for s in induced_states if s[2] == states.WAITING]
)
def _failed_message():
return (
'Failed by tasks: %s' %
[s[0] for s in induced_states if s[2] == states.ERROR]
)
def _triggered_by(state):
return [
{'task_id': s[1].id, 'event': s[4]}
for s in induced_states
if s[2] == state and s[1] is not None
]
# If "join" is configured as a number or 'one'.
if isinstance(join_expr, int) or join_expr == 'one':
spec_cardinality = 1 if join_expr == 'one' else join_expr
if runnings_tuple[0] >= spec_cardinality:
return base.TaskLogicalState(
states.RUNNING,
triggered_by=_triggered_by(states.RUNNING)
)
# E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING]
# No chance to get 3 RUNNING states.
if errors_tuple[0] > (total_count - spec_cardinality):
return base.TaskLogicalState(states.ERROR, _failed_message())
# Calculate how many tasks need to finish to trigger this 'join'.
cardinality = spec_cardinality - runnings_tuple[0]
return base.TaskLogicalState(
states.WAITING,
_blocked_message(),
cardinality=cardinality
)
if join_expr == 'all':
if total_count == runnings_tuple[0]:
return base.TaskLogicalState(
states.RUNNING,
triggered_by=_triggered_by(states.RUNNING)
)
if errors_tuple[0] > 0:
return base.TaskLogicalState(
states.ERROR,
_failed_message(),
triggered_by=_triggered_by(states.ERROR)
)
# Remaining cardinality is just a difference between all tasks and
# a number of those tasks that induce RUNNING state.
cardinality = total_count - runnings_tuple[1]
return base.TaskLogicalState(
states.WAITING,
_blocked_message(),
cardinality=cardinality
)
raise RuntimeError('Unexpected join expression: %s' % join_expr)
# TODO(rakhmerov): Method signature is incorrect given that
# we may have multiple task executions for a task. It should
# accept inbound task execution rather than a spec.
@profiler.trace(
'direct-wf-controller-get-induced-join-state',
hide_args=True
)
def _get_induced_join_state(self, in_task_spec, in_task_ex,
join_task_spec, t_execs_cache):
join_task_name = join_task_spec.get_name()
if not in_task_ex:
possible, depth = self._possible_route(
in_task_spec,
t_execs_cache
)
if possible:
return states.WAITING, depth, None
else:
return states.ERROR, depth, 'impossible route'
if not states.is_completed(in_task_ex.state):
return states.WAITING, 1, None
# [(task name, event name), ...]
next_tasks_tuples = in_task_ex.next_tasks or []
next_tasks_dict = {tup[0]: tup[1] for tup in next_tasks_tuples}
if join_task_name not in next_tasks_dict:
return states.ERROR, 1, "not triggered"
return states.RUNNING, 1, next_tasks_dict[join_task_name]
def _possible_route(self, task_spec, t_execs_cache, depth=1):
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return True, depth
for t_s in in_task_specs:
if t_s.get_name() not in t_execs_cache:
t_execs_cache.update(
self._prepare_task_executions_cache(task_spec)
)
t_ex = t_execs_cache.get(t_s.get_name())
if not t_ex:
possible, depth = self._possible_route(
t_s,
t_execs_cache,
depth + 1
)
if possible:
return True, depth
else:
t_name = task_spec.get_name()
if not states.is_completed(t_ex.state):
return True, depth
if t_name in [t[0] for t in t_ex.next_tasks]:
return True, depth
return False, depth
def _find_all_parent_task_names(self, task_spec, depth=1):
if depth == MAX_SEARCH_DEPTH:
return {task_spec.get_name()}
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return {task_spec.get_name()}
names = set()
for t_s in in_task_specs:
names.update(self._find_all_parent_task_names(t_s, depth + 1))
if depth > 1:
names.add(task_spec.get_name())
return names
def _prepare_task_executions_cache(self, task_spec):
names = self._find_all_parent_task_names(task_spec)
t_execs_cache = {
t_ex.name: t_ex for t_ex in self._get_task_executions(
fields=('id', 'name', 'state', 'next_tasks'),
name={'in': names}
)
} if names else {} # don't perform a db request if 'names' are empty
for name in names:
if name not in t_execs_cache:
t_execs_cache[name] = None
return t_execs_cache