Files
deb-mistral/mistral/workflow/direct_workflow.py
Renat Akhmerov d05dba2dd5 Renaming _find_completed_tasks to _find_successful_tasks
* Renamed method according to its semantics
* Added a couple of new utility method to find completed tasks
 and error tasks

Change-Id: Ib900e9d5f1ae032774304a4bddd3db008b134b2a
2015-02-13 19:13:51 +06:00

287 lines
9.0 KiB
Python

# Copyright 2014 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.engine1 import commands
from mistral import expressions as expr
from mistral.openstack.common import log as logging
from mistral import utils
from mistral.workflow import base
from mistral.workflow import data_flow
from mistral.workflow import states
from mistral.workflow import utils as wf_utils
LOG = logging.getLogger(__name__)
class DirectWorkflowHandler(base.WorkflowHandler):
"""'Direct workflow' handler.
This handler implements the workflow pattern which is based on
direct transitions between tasks, i.e. after each task completion
a decision should be made which tasks should run next based on
result of task execution.
Note, that tasks can run in parallel. For example, if there's a workflow
consisting of three tasks 'A', 'B' and 'C' where 'A' starts first then
'B' and 'C' can start second if certain associated with transition
'A'->'B' and 'A'->'C' evaluate to true.
"""
def start_workflow(self, **params):
self._set_execution_state(states.RUNNING)
return self._find_start_commands()
def get_upstream_tasks(self, task_spec):
# TODO(rakhmerov): Temporary solution, account conditions.
return self._find_inbound_task_specs(task_spec)
def _find_start_commands(self):
start_task_specs = []
for t_s in self.wf_spec.get_tasks():
if not self._has_inbound_transitions(t_s):
start_task_specs.append(t_s)
return [commands.RunTask(t_s) for t_s in start_task_specs]
def _has_inbound_transitions(self, task_spec):
for t_s in self.wf_spec.get_tasks():
if self._transition_exists(t_s.get_name(), task_spec.get_name()):
return True
return False
def _find_inbound_task_specs(self, task_spec):
return [
t_s for t_s in self.wf_spec.get_tasks()
if self._transition_exists(t_s.get_name(), task_spec.get_name())
]
def _find_outbound_task_specs(self, task_spec):
return [
t_s for t_s in self.wf_spec.get_tasks()
if self._transition_exists(task_spec.get_name(), t_s.get_name())
]
def _transition_exists(self, from_task_name, to_task_name):
t_names = set()
for tup in self.get_on_error_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_success_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_complete_clause(from_task_name):
t_names.add(tup[0])
return to_task_name in t_names
def _evaluate_workflow_final_context(self, cause_task_db):
ctx = {}
for t_db in self._find_end_tasks():
ctx = utils.merge_dicts(
ctx,
data_flow.evaluate_task_outbound_context(t_db)
)
return ctx
def _find_end_tasks(self):
return filter(
lambda t_db: not self._has_outbound_tasks(t_db),
wf_utils.find_successful_tasks(self.exec_db)
)
def _has_outbound_tasks(self, task_db):
t_specs = self._find_outbound_task_specs(
self.wf_spec.get_tasks()[task_db.name]
)
return any(
[wf_utils.find_db_task(self.exec_db, t_s) for t_s in t_specs]
)
def _find_next_commands(self, task_db, remove_unsatisfied_joins=True):
"""Finds commands that should run after completing given task.
Expression 'on_complete' is not mutually exclusive to 'on_success'
and 'on_error'.
:param task_db: Task DB model.
:param remove_unsatisfied_joins: True if incomplete "join"
tasks must be excluded from the list of commands.
:return: List of task specifications.
"""
cmds = []
t_name = task_db.name
t_state = task_db.state
ctx = data_flow.evaluate_task_outbound_context(task_db)
if states.is_completed(t_state):
on_complete = self.get_on_complete_clause(t_name)
if on_complete:
cmds += self._get_next_commands(on_complete, ctx)
if t_state == states.ERROR:
on_error = self.get_on_error_clause(t_name)
if on_error:
cmds += self._get_next_commands(on_error, ctx)
elif t_state == states.SUCCESS:
on_success = self.get_on_success_clause(t_name)
if on_success:
cmds += self._get_next_commands(on_success, ctx)
LOG.debug("Found commands: %s" % cmds)
# We need to remove all "join" tasks that have already started
# (or even completed) to prevent running "join" tasks more than
# once.
cmds = self._remove_started_joins(cmds)
if remove_unsatisfied_joins:
return self._remove_unsatisfied_joins(cmds)
else:
return cmds
def _is_error_handled(self, task_db):
return self.get_on_error_clause(task_db.name)
@staticmethod
def _remove_task_from_clause(on_clause, t_name):
return filter(lambda tup: tup[0] != t_name, on_clause)
def get_on_error_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_error()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_error(),
t_name
)
return result
def get_on_success_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_success()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_success(),
t_name
)
return result
def get_on_complete_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_complete()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_complete(),
t_name
)
return result
def _get_next_commands(self, cmd_conditions, ctx):
cmds = []
for t_name, condition in cmd_conditions:
if not condition or expr.evaluate(condition, ctx):
cmds.append(self._build_command(t_name))
return cmds
def _build_command(self, cmd_name):
cmd = commands.get_reserved_command(cmd_name)
return cmd or commands.RunTask(self.wf_spec.get_tasks()[cmd_name])
def _remove_started_joins(self, cmds):
return filter(lambda cmd: not self._is_started_join(cmd), cmds)
def _is_started_join(self, cmd):
if not isinstance(cmd, commands.RunTask):
return False
return (cmd.task_spec.get_join()
and wf_utils.find_db_task(self.exec_db, cmd.task_spec))
def _remove_unsatisfied_joins(self, cmds):
return filter(lambda cmd: not self._is_unsatisfied_join(cmd), cmds)
def _is_unsatisfied_join(self, cmd):
if not isinstance(cmd, commands.RunTask):
return False
task_spec = cmd.task_spec
join_expr = task_spec.get_join()
if not join_expr:
return False
in_task_specs = self._find_inbound_task_specs(task_spec)
if not in_task_specs:
return False
# We need to count a number of triggering inbound transitions.
num = len([1 for in_t_s in in_task_specs
if self._triggers_join(task_spec, in_t_s)])
# If "join" is configured as a number.
if isinstance(join_expr, int) and num < join_expr:
return True
if join_expr == 'all' and len(in_task_specs) > num:
return True
if join_expr == 'one' and num == 0:
return True
return False
def _triggers_join(self, join_task_spec, inbound_task_spec):
in_t_db = wf_utils.find_db_task(self.exec_db, inbound_task_spec)
if not in_t_db or not states.is_completed(in_t_db.state):
return False
def is_join_task(cmd):
return (isinstance(cmd, commands.RunTask)
and cmd.task_spec == join_task_spec)
return filter(
lambda cmd: is_join_task(cmd),
self._find_next_commands(in_t_db, False)
)