Optimize the logic that check if 'join' task is allowed to start
* Moved DB related lookup functions from workflow/utils to a separate module lookup_utils * Optimized data access pattern where we calculate 'join' states induced by upstream tasks, applied caching for some lookup operations * Added caching in direct workflow specification for inbound and outbound task specs, if workflow size is large calculating them may be expensive * Added an adaptive delay between calls that refresh 'join' task state based on a number of unfulfilled preconditions returned by workflow controller Change-Id: I383fa52f2f05877df7522048020cc7ff280324a2
This commit is contained in:
parent
a0f6c7ae3f
commit
9d06a61fe4
@ -73,7 +73,7 @@ def run_task(wf_cmd):
|
||||
return
|
||||
|
||||
if task.is_waiting() and (task.is_created() or task.is_state_changed()):
|
||||
_schedule_refresh_task_state(task.task_ex)
|
||||
_schedule_refresh_task_state(task.task_ex, 1)
|
||||
|
||||
|
||||
@profiler.trace('task-handler-on-action-complete')
|
||||
@ -251,15 +251,27 @@ def _refresh_task_state(task_ex_id):
|
||||
wf_spec
|
||||
)
|
||||
|
||||
state, state_info = wf_ctrl.get_logical_task_state(task_ex)
|
||||
state, state_info, cardinality = wf_ctrl.get_logical_task_state(
|
||||
task_ex
|
||||
)
|
||||
|
||||
if state == states.RUNNING:
|
||||
continue_task(task_ex)
|
||||
elif state == states.ERROR:
|
||||
fail_task(task_ex, state_info)
|
||||
elif state == states.WAITING:
|
||||
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
|
||||
_schedule_refresh_task_state(task_ex, 1)
|
||||
# Let's assume that a task takes 0.01 sec in average to complete
|
||||
# and based on this assumption calculate a time of the next check.
|
||||
# The estimation is very rough, of course, but this delay will be
|
||||
# decreasing as task preconditions will be completing which will
|
||||
# give a decent asymptotic approximation.
|
||||
# For example, if a 'join' task has 100 inbound incomplete tasks
|
||||
# then the next 'refresh_task_state' call will happen in 10
|
||||
# seconds. For 500 tasks it will be 50 seconds. The larger the
|
||||
# workflow is, the more beneficial this mechanism will be.
|
||||
delay = int(cardinality * 0.01)
|
||||
|
||||
_schedule_refresh_task_state(task_ex, max(1, delay))
|
||||
else:
|
||||
# Must never get here.
|
||||
raise RuntimeError(
|
||||
|
@ -34,6 +34,7 @@ from mistral.workbook import parser as spec_parser
|
||||
from mistral.workflow import base as wf_base
|
||||
from mistral.workflow import commands
|
||||
from mistral.workflow import data_flow
|
||||
from mistral.workflow import lookup_utils
|
||||
from mistral.workflow import states
|
||||
from mistral.workflow import utils as wf_utils
|
||||
|
||||
@ -158,6 +159,11 @@ class Workflow(object):
|
||||
|
||||
assert self.wf_ex
|
||||
|
||||
# Since some lookup utils functions may use cache for completed tasks
|
||||
# we need to clean caches to make sure that stale objects can't be
|
||||
# retrieved.
|
||||
lookup_utils.clean_caches()
|
||||
|
||||
wf_service.update_workflow_execution_env(self.wf_ex, env)
|
||||
|
||||
self.set_state(states.RUNNING, recursive=True)
|
||||
@ -429,7 +435,7 @@ def _build_fail_info_message(wf_ctrl, wf_ex):
|
||||
failed_tasks = sorted(
|
||||
filter(
|
||||
lambda t: not wf_ctrl.is_error_handled_for(t),
|
||||
wf_utils.find_error_task_executions(wf_ex)
|
||||
lookup_utils.find_error_task_executions(wf_ex.id)
|
||||
),
|
||||
key=lambda t: t.name
|
||||
)
|
||||
@ -468,7 +474,7 @@ def _build_fail_info_message(wf_ctrl, wf_ex):
|
||||
def _build_cancel_info_message(wf_ctrl, wf_ex):
|
||||
# Try to find where cancel is exactly.
|
||||
cancelled_tasks = sorted(
|
||||
wf_utils.find_cancelled_task_executions(wf_ex),
|
||||
lookup_utils.find_cancelled_task_executions(wf_ex.id),
|
||||
key=lambda t: t.name
|
||||
)
|
||||
|
||||
|
@ -36,6 +36,7 @@ from mistral.tests.unit import config as test_config
|
||||
from mistral.utils import inspect_utils as i_utils
|
||||
from mistral import version
|
||||
from mistral.workbook import parser as spec_parser
|
||||
from mistral.workflow import lookup_utils
|
||||
|
||||
RESOURCES_PATH = 'tests/resources/'
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -244,6 +245,8 @@ class DbTestCase(BaseTest):
|
||||
action_manager.sync_db()
|
||||
|
||||
def _clean_db(self):
|
||||
lookup_utils.clean_caches()
|
||||
|
||||
contexts = [
|
||||
get_context(default=False),
|
||||
get_context(default=True)
|
||||
|
@ -13,13 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from oslo_serialization import jsonutils
|
||||
from stevedore import extension
|
||||
import yaql
|
||||
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral import utils
|
||||
from mistral.workflow import utils as wf_utils
|
||||
from oslo_serialization import jsonutils
|
||||
from stevedore import extension
|
||||
|
||||
|
||||
ROOT_CONTEXT = None
|
||||
@ -87,8 +86,6 @@ def task_(context, task_name):
|
||||
# Importing data_flow in order to break cycle dependency between modules.
|
||||
from mistral.workflow import data_flow
|
||||
|
||||
wf_ex = db_api.get_workflow_execution(context['__execution']['id'])
|
||||
|
||||
# This section may not exist in a context if it's calculated not in
|
||||
# task scope.
|
||||
cur_task = context['__task_execution']
|
||||
@ -96,7 +93,10 @@ def task_(context, task_name):
|
||||
if cur_task and cur_task['name'] == task_name:
|
||||
task_ex = db_api.get_task_execution(cur_task['id'])
|
||||
else:
|
||||
task_execs = wf_utils.find_task_executions_by_name(wf_ex, task_name)
|
||||
task_execs = db_api.get_task_executions(
|
||||
workflow_execution_id=context['__execution']['id'],
|
||||
name=task_name
|
||||
)
|
||||
|
||||
# TODO(rakhmerov): Account for multiple executions (i.e. in case of
|
||||
# cycles).
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
from oslo_utils import uuidutils
|
||||
import six
|
||||
import threading
|
||||
|
||||
from mistral import exceptions as exc
|
||||
from mistral import utils
|
||||
@ -150,6 +151,18 @@ class DirectWorkflowSpec(WorkflowSpec):
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, data):
|
||||
super(DirectWorkflowSpec, self).__init__(data)
|
||||
|
||||
# Init simple dictionary based caches for inbound and
|
||||
# outbound task specifications. In fact, we don't need
|
||||
# any special cache implementations here because these
|
||||
# structures can't grow indefinitely.
|
||||
self.inbound_tasks_cache_lock = threading.RLock()
|
||||
self.inbound_tasks_cache = {}
|
||||
self.outbound_tasks_cache_lock = threading.RLock()
|
||||
self.outbound_tasks_cache = {}
|
||||
|
||||
def validate_semantics(self):
|
||||
super(DirectWorkflowSpec, self).validate_semantics()
|
||||
|
||||
@ -211,17 +224,43 @@ class DirectWorkflowSpec(WorkflowSpec):
|
||||
]
|
||||
|
||||
def find_inbound_task_specs(self, task_spec):
|
||||
return [
|
||||
task_name = task_spec.get_name()
|
||||
|
||||
with self.inbound_tasks_cache_lock:
|
||||
specs = self.inbound_tasks_cache.get(task_name)
|
||||
|
||||
if specs is not None:
|
||||
return specs
|
||||
|
||||
specs = [
|
||||
t_s for t_s in self.get_tasks()
|
||||
if self.transition_exists(t_s.get_name(), task_spec.get_name())
|
||||
if self.transition_exists(t_s.get_name(), task_name)
|
||||
]
|
||||
|
||||
with self.inbound_tasks_cache_lock:
|
||||
self.inbound_tasks_cache[task_name] = specs
|
||||
|
||||
return specs
|
||||
|
||||
def find_outbound_task_specs(self, task_spec):
|
||||
return [
|
||||
task_name = task_spec.get_name()
|
||||
|
||||
with self.outbound_tasks_cache_lock:
|
||||
specs = self.outbound_tasks_cache.get(task_name)
|
||||
|
||||
if specs is not None:
|
||||
return specs
|
||||
|
||||
specs = [
|
||||
t_s for t_s in self.get_tasks()
|
||||
if self.transition_exists(task_spec.get_name(), t_s.get_name())
|
||||
if self.transition_exists(task_name, t_s.get_name())
|
||||
]
|
||||
|
||||
with self.outbound_tasks_cache_lock:
|
||||
self.outbound_tasks_cache[task_name] = specs
|
||||
|
||||
return specs
|
||||
|
||||
def has_inbound_transitions(self, task_spec):
|
||||
return len(self.find_inbound_task_specs(task_spec)) > 0
|
||||
|
||||
|
@ -26,13 +26,14 @@ from mistral import utils as u
|
||||
from mistral.workbook import parser as spec_parser
|
||||
from mistral.workflow import commands
|
||||
from mistral.workflow import data_flow
|
||||
from mistral.workflow import lookup_utils
|
||||
from mistral.workflow import states
|
||||
from mistral.workflow import utils as wf_utils
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@profiler.trace('wf-controller-get-controller')
|
||||
def get_controller(wf_ex, wf_spec=None):
|
||||
"""Gets a workflow controller instance by given workflow execution object.
|
||||
|
||||
@ -130,8 +131,13 @@ class WorkflowController(object):
|
||||
"""Determines a logical state of the given task.
|
||||
|
||||
:param task_ex: Task execution.
|
||||
:return: Tuple (state, state_info) which the given task should have
|
||||
according to workflow rules and current states of other tasks.
|
||||
:return: Tuple (state, state_info, cardinality) where 'state' and
|
||||
'state_info' are the corresponding values which the given
|
||||
task should have according to workflow rules and current
|
||||
states of other tasks. 'cardinality' gives the estimation on
|
||||
the number of preconditions that are not yet met in case if
|
||||
state is WAITING. This number can be used to estimate how
|
||||
frequently we can refresh the state of this task.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -159,7 +165,9 @@ class WorkflowController(object):
|
||||
|
||||
:return: True if there is one or more tasks in cancelled state.
|
||||
"""
|
||||
return len(wf_utils.find_cancelled_task_executions(self.wf_ex)) > 0
|
||||
t_execs = lookup_utils.find_cancelled_task_executions(self.wf_ex.id)
|
||||
|
||||
return len(t_execs) > 0
|
||||
|
||||
@abc.abstractmethod
|
||||
def evaluate_workflow_final_context(self):
|
||||
@ -214,8 +222,8 @@ class WorkflowController(object):
|
||||
return []
|
||||
|
||||
# Add all tasks in IDLE state.
|
||||
idle_tasks = wf_utils.find_task_executions_with_state(
|
||||
self.wf_ex,
|
||||
idle_tasks = lookup_utils.find_task_executions_with_state(
|
||||
self.wf_ex.id,
|
||||
states.IDLE
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from oslo_log import log as logging
|
||||
from osprofiler import profiler
|
||||
|
||||
from mistral import exceptions as exc
|
||||
from mistral import expressions as expr
|
||||
@ -20,8 +21,8 @@ from mistral import utils
|
||||
from mistral.workflow import base
|
||||
from mistral.workflow import commands
|
||||
from mistral.workflow import data_flow
|
||||
from mistral.workflow import lookup_utils
|
||||
from mistral.workflow import states
|
||||
from mistral.workflow import utils as wf_utils
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -46,8 +47,8 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
return list(
|
||||
filter(
|
||||
lambda t_e: self._is_upstream_task_execution(task_spec, t_e),
|
||||
wf_utils.find_task_executions_by_specs(
|
||||
self.wf_ex,
|
||||
lookup_utils.find_task_executions_by_specs(
|
||||
self.wf_ex.id,
|
||||
self.wf_spec.find_inbound_task_specs(task_spec)
|
||||
)
|
||||
)
|
||||
@ -60,7 +61,7 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
if not t_spec.get_join():
|
||||
return not t_ex_candidate.processed
|
||||
|
||||
induced_state = self._get_induced_join_state(
|
||||
induced_state, _ = self._get_induced_join_state(
|
||||
self.wf_spec.get_tasks()[t_ex_candidate.name],
|
||||
t_spec
|
||||
)
|
||||
@ -173,7 +174,7 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
# A simple 'non-join' task does not have any preconditions
|
||||
# based on state of other tasks so its logical state always
|
||||
# equals to its real state.
|
||||
return task_ex.state, task_ex.state_info
|
||||
return task_ex.state, task_ex.state_info, 0
|
||||
|
||||
return self._get_join_logical_state(task_spec)
|
||||
|
||||
@ -181,8 +182,7 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
|
||||
|
||||
def all_errors_handled(self):
|
||||
for t_ex in wf_utils.find_error_task_executions(self.wf_ex):
|
||||
|
||||
for t_ex in lookup_utils.find_error_task_executions(self.wf_ex.id):
|
||||
tasks_on_error = self._find_next_tasks_for_clause(
|
||||
self.wf_spec.get_on_error_clause(t_ex.name),
|
||||
data_flow.evaluate_task_outbound_context(t_ex)
|
||||
@ -197,7 +197,7 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
return list(
|
||||
filter(
|
||||
lambda t_ex: not self._has_outbound_tasks(t_ex),
|
||||
wf_utils.find_successful_task_executions(self.wf_ex)
|
||||
lookup_utils.find_successful_task_executions(self.wf_ex.id)
|
||||
)
|
||||
)
|
||||
|
||||
@ -270,64 +270,94 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
if not condition or expr.evaluate(condition, ctx)
|
||||
]
|
||||
|
||||
@profiler.trace('direct-wf-controller-get-join-logical-state')
|
||||
def _get_join_logical_state(self, task_spec):
|
||||
"""Evaluates logical state of 'join' task.
|
||||
|
||||
:param task_spec: 'join' task specification.
|
||||
:return: Tuple (state, state_info, spec_cardinality) where 'state' and
|
||||
'state_info' describe the logical state of the given 'join'
|
||||
task and 'spec_cardinality' gives the remaining number of
|
||||
unfulfilled preconditions. If logical state is not WAITING then
|
||||
'spec_cardinality' should always be 0.
|
||||
"""
|
||||
|
||||
# TODO(rakhmerov): We need to use task_ex instead of task_spec
|
||||
# in order to cover a use case when there's more than one instance
|
||||
# of the same 'join' task in a workflow.
|
||||
# TODO(rakhmerov): In some cases this method will be expensive because
|
||||
# it uses a multistep recursive search. We need to optimize it moving
|
||||
# forward (e.g. with Workflow Execution Graph).
|
||||
|
||||
join_expr = task_spec.get_join()
|
||||
|
||||
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
|
||||
|
||||
if not in_task_specs:
|
||||
return states.RUNNING
|
||||
return states.RUNNING, None, 0
|
||||
|
||||
# List of tuples (task_name, state).
|
||||
# List of tuples (task_name, (state, depth)).
|
||||
induced_states = [
|
||||
(t_s.get_name(), self._get_induced_join_state(t_s, task_spec))
|
||||
for t_s in in_task_specs
|
||||
]
|
||||
|
||||
def count(state):
|
||||
return len(list(filter(lambda s: s[1] == state, induced_states)))
|
||||
cnt = 0
|
||||
total_depth = 0
|
||||
|
||||
error_count = count(states.ERROR)
|
||||
running_count = count(states.RUNNING)
|
||||
for s in induced_states:
|
||||
if s[1][0] == state:
|
||||
cnt += 1
|
||||
total_depth += s[1][1]
|
||||
|
||||
return cnt, total_depth
|
||||
|
||||
errors_tuples = count(states.ERROR)
|
||||
runnings_tuple = count(states.RUNNING)
|
||||
total_count = len(induced_states)
|
||||
|
||||
def _blocked_message():
|
||||
return (
|
||||
'Blocked by tasks: %s' %
|
||||
[s[0] for s in induced_states if s[1] == states.WAITING]
|
||||
[s[0] for s in induced_states if s[1][0] == states.WAITING]
|
||||
)
|
||||
|
||||
def _failed_message():
|
||||
return (
|
||||
'Failed by tasks: %s' %
|
||||
[s[0] for s in induced_states if s[1] == states.ERROR]
|
||||
[s[0] for s in induced_states if s[1][0] == states.ERROR]
|
||||
)
|
||||
|
||||
# If "join" is configured as a number or 'one'.
|
||||
if isinstance(join_expr, int) or join_expr == 'one':
|
||||
cardinality = 1 if join_expr == 'one' else join_expr
|
||||
spec_cardinality = 1 if join_expr == 'one' else join_expr
|
||||
|
||||
if running_count >= cardinality:
|
||||
return states.RUNNING, None
|
||||
if runnings_tuple[0] >= spec_cardinality:
|
||||
return states.RUNNING, None, 0
|
||||
|
||||
# E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING]
|
||||
# No chance to get 3 RUNNING states.
|
||||
if error_count > (total_count - cardinality):
|
||||
return states.ERROR, _failed_message()
|
||||
if errors_tuples[0] > (total_count - spec_cardinality):
|
||||
return states.ERROR, _failed_message(), 0
|
||||
|
||||
return states.WAITING, _blocked_message()
|
||||
# Calculate how many tasks need to finish to trigger this 'join'.
|
||||
cardinality = spec_cardinality - runnings_tuple[0]
|
||||
|
||||
return states.WAITING, _blocked_message(), cardinality
|
||||
|
||||
if join_expr == 'all':
|
||||
if total_count == running_count:
|
||||
return states.RUNNING, None
|
||||
if total_count == runnings_tuple[0]:
|
||||
return states.RUNNING, None, 0
|
||||
|
||||
if error_count > 0:
|
||||
return states.ERROR, _failed_message()
|
||||
if errors_tuples[0] > 0:
|
||||
return states.ERROR, _failed_message(), 0
|
||||
|
||||
return states.WAITING, _blocked_message()
|
||||
# Remaining cardinality is just a difference between all tasks and
|
||||
# a number of those tasks that induce RUNNING state.
|
||||
cardinality = total_count - runnings_tuple[1]
|
||||
|
||||
return states.WAITING, _blocked_message(), cardinality
|
||||
|
||||
raise RuntimeError('Unexpected join expression: %s' % join_expr)
|
||||
|
||||
@ -337,51 +367,54 @@ class DirectWorkflowController(base.WorkflowController):
|
||||
def _get_induced_join_state(self, inbound_task_spec, join_task_spec):
|
||||
join_task_name = join_task_spec.get_name()
|
||||
|
||||
in_task_ex = self._find_task_execution_by_spec(inbound_task_spec)
|
||||
in_task_ex = self._find_task_execution_by_name(
|
||||
inbound_task_spec.get_name()
|
||||
)
|
||||
|
||||
if not in_task_ex:
|
||||
if self._possible_route(inbound_task_spec):
|
||||
return states.WAITING
|
||||
possible, depth = self._possible_route(inbound_task_spec)
|
||||
|
||||
if possible:
|
||||
return states.WAITING, depth
|
||||
else:
|
||||
return states.ERROR
|
||||
return states.ERROR, depth
|
||||
|
||||
if not states.is_completed(in_task_ex.state):
|
||||
return states.WAITING
|
||||
return states.WAITING, 1
|
||||
|
||||
if join_task_name not in self._find_next_task_names(in_task_ex):
|
||||
return states.ERROR
|
||||
return states.ERROR, 1
|
||||
|
||||
return states.RUNNING
|
||||
return states.RUNNING, 1
|
||||
|
||||
def _find_task_execution_by_spec(self, task_spec):
|
||||
in_t_execs = wf_utils.find_task_executions_by_spec(
|
||||
self.wf_ex,
|
||||
task_spec
|
||||
def _find_task_execution_by_name(self, t_name):
|
||||
# Note: in case of 'join' completion check it's better to initialize
|
||||
# the entire task_executions collection to avoid too many DB queries.
|
||||
t_execs = lookup_utils.find_task_executions_by_name(
|
||||
self.wf_ex.id,
|
||||
t_name
|
||||
)
|
||||
|
||||
# TODO(rakhmerov): Temporary hack. See the previous comment.
|
||||
return in_t_execs[-1] if in_t_execs else None
|
||||
return t_execs[-1] if t_execs else None
|
||||
|
||||
def _possible_route(self, task_spec):
|
||||
# TODO(rakhmerov): In some cases this method will be expensive because
|
||||
# it uses a multistep recursive search with DB queries.
|
||||
# It will be optimized with Workflow Execution Graph moving forward.
|
||||
def _possible_route(self, task_spec, depth=1):
|
||||
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
|
||||
|
||||
if not in_task_specs:
|
||||
return True
|
||||
return True, depth
|
||||
|
||||
for t_s in in_task_specs:
|
||||
t_ex = self._find_task_execution_by_spec(t_s)
|
||||
t_ex = self._find_task_execution_by_name(t_s.get_name())
|
||||
|
||||
if not t_ex:
|
||||
if self._possible_route(t_s):
|
||||
return True
|
||||
if self._possible_route(t_s, depth + 1):
|
||||
return True, depth
|
||||
else:
|
||||
t_name = task_spec.get_name()
|
||||
|
||||
if (not states.is_completed(t_ex.state) or
|
||||
t_name in self._find_next_task_names(t_ex)):
|
||||
return True
|
||||
return True, depth
|
||||
|
||||
return False
|
||||
return False, depth
|
||||
|
109
mistral/workflow/lookup_utils.py
Normal file
109
mistral/workflow/lookup_utils.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright 2015 - Mirantis, Inc.
|
||||
# Copyright 2015 - StackStorm, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
The intention of the module is providing various DB related lookup functions
|
||||
for more convenient usage withing the workflow engine.
|
||||
|
||||
Some of the functions may provide caching capabilities.
|
||||
|
||||
WARNING: Oftentimes, persistent objects returned by the methods in this
|
||||
module won't be attached to the current DB SQLAlchemy session because
|
||||
they are returned from the cache and therefore they need to be used
|
||||
carefully without trying to do any lazy loading etc.
|
||||
These objects are also not suitable for re-attaching them to a session
|
||||
in order to update their persistent DB state.
|
||||
Mostly, they are useful for doing any kind of fast lookups with in order
|
||||
to make some decision based on their state.
|
||||
"""
|
||||
|
||||
import cachetools
|
||||
import threading
|
||||
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral.workflow import states
|
||||
|
||||
_TASK_EXECUTIONS_CACHE_LOCK = threading.RLock()
|
||||
_TASK_EXECUTIONS_CACHE = cachetools.LRUCache(maxsize=20000)
|
||||
|
||||
|
||||
def find_task_executions_by_name(wf_ex_id, task_name):
|
||||
"""Finds task executions by workflow execution id and task name.
|
||||
|
||||
:param wf_ex_id: Workflow execution id.
|
||||
:param task_name: Task name.
|
||||
:return: Task executions (possibly a cached value).
|
||||
"""
|
||||
cache_key = (wf_ex_id, task_name)
|
||||
|
||||
with _TASK_EXECUTIONS_CACHE_LOCK:
|
||||
t_execs = _TASK_EXECUTIONS_CACHE.get(cache_key)
|
||||
|
||||
if t_execs:
|
||||
return t_execs
|
||||
|
||||
t_execs = db_api.get_task_executions(
|
||||
workflow_execution_id=wf_ex_id,
|
||||
name=task_name
|
||||
)
|
||||
|
||||
# We can cache only finished tasks because they won't change.
|
||||
all_finished = (
|
||||
t_execs and
|
||||
all([states.is_completed(t_ex.state) for t_ex in t_execs])
|
||||
)
|
||||
|
||||
if all_finished:
|
||||
with _TASK_EXECUTIONS_CACHE_LOCK:
|
||||
_TASK_EXECUTIONS_CACHE[cache_key] = t_execs
|
||||
|
||||
return t_execs
|
||||
|
||||
|
||||
def find_task_executions_by_spec(wf_ex_id, task_spec):
|
||||
return find_task_executions_by_name(wf_ex_id, task_spec.get_name())
|
||||
|
||||
|
||||
def find_task_executions_by_specs(wf_ex_id, task_specs):
|
||||
res = []
|
||||
|
||||
for t_s in task_specs:
|
||||
res = res + find_task_executions_by_spec(wf_ex_id, t_s)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def find_task_executions_with_state(wf_ex_id, state):
|
||||
return db_api.get_task_executions(
|
||||
workflow_execution_id=wf_ex_id,
|
||||
state=state
|
||||
)
|
||||
|
||||
|
||||
def find_successful_task_executions(wf_ex_id):
|
||||
return find_task_executions_with_state(wf_ex_id, states.SUCCESS)
|
||||
|
||||
|
||||
def find_error_task_executions(wf_ex_id):
|
||||
return find_task_executions_with_state(wf_ex_id, states.ERROR)
|
||||
|
||||
|
||||
def find_cancelled_task_executions(wf_ex_id):
|
||||
return find_task_executions_with_state(wf_ex_id, states.CANCELLED)
|
||||
|
||||
|
||||
def clean_caches():
|
||||
with _TASK_EXECUTIONS_CACHE_LOCK:
|
||||
_TASK_EXECUTIONS_CACHE.clear()
|
@ -19,8 +19,8 @@ from mistral import exceptions as exc
|
||||
from mistral.workflow import base
|
||||
from mistral.workflow import commands
|
||||
from mistral.workflow import data_flow
|
||||
from mistral.workflow import lookup_utils
|
||||
from mistral.workflow import states
|
||||
from mistral.workflow import utils as wf_utils
|
||||
|
||||
|
||||
class ReverseWorkflowController(base.WorkflowController):
|
||||
@ -92,13 +92,16 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
return list(
|
||||
filter(
|
||||
lambda t_e: t_e.state == states.SUCCESS,
|
||||
wf_utils.find_task_executions_by_specs(self.wf_ex, t_specs)
|
||||
lookup_utils.find_task_executions_by_specs(
|
||||
self.wf_ex.id,
|
||||
t_specs
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def evaluate_workflow_final_context(self):
|
||||
task_execs = wf_utils.find_task_executions_by_spec(
|
||||
self.wf_ex,
|
||||
task_execs = lookup_utils.find_task_executions_by_spec(
|
||||
self.wf_ex.id,
|
||||
self._get_target_task_specification()
|
||||
)
|
||||
|
||||
@ -110,13 +113,15 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
|
||||
def get_logical_task_state(self, task_ex):
|
||||
# TODO(rakhmerov): Implement.
|
||||
return task_ex.state, task_ex.state_info
|
||||
return task_ex.state, task_ex.state_info, 0
|
||||
|
||||
def is_error_handled_for(self, task_ex):
|
||||
return task_ex.state != states.ERROR
|
||||
|
||||
def all_errors_handled(self):
|
||||
return len(wf_utils.find_error_task_executions(self.wf_ex)) == 0
|
||||
task_execs = lookup_utils.find_error_task_executions(self.wf_ex.id)
|
||||
|
||||
return len(task_execs) == 0
|
||||
|
||||
def _find_task_specs_with_satisfied_dependencies(self):
|
||||
"""Given a target task name finds tasks with no dependencies.
|
||||
@ -139,7 +144,8 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
]
|
||||
|
||||
def _is_satisfied_task(self, task_spec):
|
||||
if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec):
|
||||
if lookup_utils.find_task_executions_by_spec(
|
||||
self.wf_ex.id, task_spec):
|
||||
return False
|
||||
|
||||
if not self.wf_spec.get_task_requires(task_spec):
|
||||
|
@ -14,9 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral.utils import serializers
|
||||
from mistral.workflow import states
|
||||
|
||||
|
||||
class Result(object):
|
||||
@ -72,46 +70,3 @@ class ResultSerializer(serializers.Serializer):
|
||||
entity['error'],
|
||||
entity.get('cancel', False)
|
||||
)
|
||||
|
||||
|
||||
def find_task_executions_by_name(wf_ex, task_name):
|
||||
return db_api.get_task_executions(
|
||||
workflow_execution_id=wf_ex.id,
|
||||
name=task_name
|
||||
)
|
||||
|
||||
|
||||
def find_task_executions_by_spec(wf_ex, task_spec):
|
||||
return find_task_executions_by_name(wf_ex, task_spec.get_name())
|
||||
|
||||
|
||||
def find_task_executions_by_specs(wf_ex, task_specs):
|
||||
res = []
|
||||
|
||||
for t_s in task_specs:
|
||||
res = res + find_task_executions_by_spec(wf_ex, t_s)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def find_task_executions_with_state(wf_ex, state):
|
||||
return db_api.get_task_executions(
|
||||
workflow_execution_id=wf_ex.id,
|
||||
state=state
|
||||
)
|
||||
|
||||
|
||||
def find_running_task_executions(wf_ex):
|
||||
return find_task_executions_with_state(wf_ex, states.RUNNING)
|
||||
|
||||
|
||||
def find_successful_task_executions(wf_ex):
|
||||
return find_task_executions_with_state(wf_ex, states.SUCCESS)
|
||||
|
||||
|
||||
def find_error_task_executions(wf_ex):
|
||||
return find_task_executions_with_state(wf_ex, states.ERROR)
|
||||
|
||||
|
||||
def find_cancelled_task_executions(wf_ex):
|
||||
return find_task_executions_with_state(wf_ex, states.CANCELLED)
|
||||
|
Loading…
Reference in New Issue
Block a user