deb-mistral/mistral/engine/policies.py
Renat Akhmerov 633eb0fe6d Add proper error handling for task continuation
* In case if task needs to be continued, e.g. in case of 'wait-before'
  policy which inserts a delay into normal task execution flow (between
  creation of task policy and scheduling actions), possible exceptions
  also need to be handled properly (move task and worklfow into ERROR).
  This patch adds error handling and the test to check this.
* Other minor changes related to addressing a few TODO's across engine
  code.

Change-Id: I525f193a149e3b0341aa8d0ffa0858ded96ba94f
2016-07-08 15:08:51 +07:00

471 lines
13 KiB
Python

# Copyright 2014 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.db.v2 import api as db_api
from mistral.engine import base
from mistral import expressions
from mistral.services import scheduler
from mistral.utils import wf_trace
from mistral.workflow import data_flow
from mistral.workflow import states
import six
_CONTINUE_TASK_PATH = 'mistral.engine.policies._continue_task'
_COMPLETE_TASK_PATH = 'mistral.engine.policies._complete_task'
def _log_task_delay(task_ex, delay_sec):
wf_trace.info(
task_ex,
"Task '%s' [%s -> %s, delay = %s sec]" %
(task_ex.name, task_ex.state, states.RUNNING_DELAYED, delay_sec)
)
def build_policies(policies_spec, wf_spec):
task_defaults = wf_spec.get_task_defaults()
wf_policies = task_defaults.get_policies() if task_defaults else None
if not (policies_spec or wf_policies):
return []
return construct_policies_list(policies_spec, wf_policies)
def get_policy_factories():
return [
build_pause_before_policy,
build_wait_before_policy,
build_wait_after_policy,
build_retry_policy,
build_timeout_policy,
build_concurrency_policy
]
def construct_policies_list(policies_spec, wf_policies):
policies = []
for factory in get_policy_factories():
policy = factory(policies_spec)
if wf_policies and not policy:
policy = factory(wf_policies)
if policy:
policies.append(policy)
return policies
def build_wait_before_policy(policies_spec):
if not policies_spec:
return None
wait_before = policies_spec.get_wait_before()
if (isinstance(wait_before, six.string_types) or wait_before > 0):
return WaitBeforePolicy(wait_before)
else:
return None
def build_wait_after_policy(policies_spec):
if not policies_spec:
return None
wait_after = policies_spec.get_wait_after()
if (isinstance(wait_after, six.string_types) or wait_after > 0):
return WaitAfterPolicy(wait_after)
else:
return None
def build_retry_policy(policies_spec):
if not policies_spec:
return None
retry = policies_spec.get_retry()
if not retry:
return None
return RetryPolicy(
retry.get_count(),
retry.get_delay(),
retry.get_break_on(),
retry.get_continue_on()
)
def build_timeout_policy(policies_spec):
if not policies_spec:
return None
timeout_policy = policies_spec.get_timeout()
if (isinstance(timeout_policy, six.string_types) or timeout_policy > 0):
return TimeoutPolicy(timeout_policy)
else:
return None
def build_pause_before_policy(policies_spec):
if not policies_spec:
return None
pause_before_policy = policies_spec.get_pause_before()
return (PauseBeforePolicy(pause_before_policy)
if pause_before_policy else None)
def build_concurrency_policy(policies_spec):
if not policies_spec:
return None
concurrency_policy = policies_spec.get_concurrency()
return (ConcurrencyPolicy(concurrency_policy)
if concurrency_policy else None)
def _ensure_context_has_key(runtime_context, key):
if not runtime_context:
runtime_context = {}
if key not in runtime_context:
runtime_context.update({key: {}})
return runtime_context
class WaitBeforePolicy(base.TaskPolicy):
_schema = {
"properties": {
"delay": {"type": "integer"}
}
}
def __init__(self, delay):
self.delay = delay
def before_task_start(self, task_ex, task_spec):
super(WaitBeforePolicy, self).before_task_start(task_ex, task_spec)
context_key = 'wait_before_policy'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
task_ex.runtime_context = runtime_context
policy_context = runtime_context[context_key]
if policy_context.get('skip'):
# Unset state 'RUNNING_DELAYED'.
wf_trace.info(
task_ex,
"Task '%s' [%s -> %s]"
% (task_ex.name, states.RUNNING_DELAYED, states.RUNNING)
)
task_ex.state = states.RUNNING
return
if task_ex.state != states.IDLE:
policy_context.update({'skip': True})
_log_task_delay(task_ex, self.delay)
task_ex.state = states.RUNNING_DELAYED
scheduler.schedule_call(
None,
_CONTINUE_TASK_PATH,
self.delay,
task_ex_id=task_ex.id,
)
class WaitAfterPolicy(base.TaskPolicy):
_schema = {
"properties": {
"delay": {"type": "integer"}
}
}
def __init__(self, delay):
self.delay = delay
def after_task_complete(self, task_ex, task_spec):
super(WaitAfterPolicy, self).after_task_complete(task_ex, task_spec)
context_key = 'wait_after_policy'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
task_ex.runtime_context = runtime_context
policy_context = runtime_context[context_key]
if policy_context.get('skip'):
# Skip, already processed.
return
policy_context.update({'skip': True})
_log_task_delay(task_ex, self.delay)
end_state = task_ex.state
end_state_info = task_ex.state_info
# TODO(rakhmerov): Policies probably needs to have tasks.Task
# interface in order to change manage task state safely.
# Set task state to 'DELAYED'.
task_ex.state = states.RUNNING_DELAYED
task_ex.state_info = (
'Suspended by wait-after policy for %s seconds' % self.delay
)
# Schedule to change task state to RUNNING again.
scheduler.schedule_call(
None,
_COMPLETE_TASK_PATH,
self.delay,
task_ex_id=task_ex.id,
state=end_state,
state_info=end_state_info
)
class RetryPolicy(base.TaskPolicy):
_schema = {
"properties": {
"delay": {"type": "integer"},
"count": {"type": "integer"}
}
}
def __init__(self, count, delay, break_on, continue_on):
self.count = count
self.delay = delay
self.break_on = break_on
self._continue_on_clause = continue_on
def after_task_complete(self, task_ex, task_spec):
"""Possible Cases:
1. state = SUCCESS
if continue_on is not specified,
no need to move to next iteration;
if current:count achieve retry:count then policy
breaks the loop (regardless on continue-on condition);
otherwise - check continue_on condition and if
it is True - schedule the next iteration,
otherwise policy breaks the loop.
2. retry:count = 5, current:count = 2, state = ERROR,
state = IDLE/DELAYED, current:count = 3
3. retry:count = 5, current:count = 4, state = ERROR
Iterations complete therefore state = #{state}, current:count = 4.
"""
super(RetryPolicy, self).after_task_complete(task_ex, task_spec)
# TODO(m4dcoder): If the task_ex.executions collection is not called,
# then the retry_no in the runtime_context of the task_ex will not
# be updated accurately. To be exact, the retry_no will be one
# iteration behind. task_ex.executions was originally called in
# get_task_execution_result but it was refactored to use
# db_api.get_action_executions to support session-less use cases.
action_ex = task_ex.executions # noqa
context_key = 'retry_task_policy'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
continue_on_evaluation = expressions.evaluate(
self._continue_on_clause,
data_flow.evaluate_task_outbound_context(task_ex)
)
task_ex.runtime_context = runtime_context
state = task_ex.state
if not states.is_completed(state):
return
policy_context = runtime_context[context_key]
retry_no = 0
if 'retry_no' in policy_context:
retry_no = policy_context['retry_no']
del policy_context['retry_no']
retries_remain = retry_no + 1 < self.count
stop_continue_flag = (task_ex.state == states.SUCCESS and
not self._continue_on_clause)
stop_continue_flag = (stop_continue_flag or
(self._continue_on_clause and
not continue_on_evaluation))
break_triggered = task_ex.state == states.ERROR and self.break_on
if not retries_remain or break_triggered or stop_continue_flag:
return
_log_task_delay(task_ex, self.delay)
data_flow.invalidate_task_execution_result(task_ex)
task_ex.state = states.RUNNING_DELAYED
policy_context['retry_no'] = retry_no + 1
runtime_context[context_key] = policy_context
scheduler.schedule_call(
None,
_CONTINUE_TASK_PATH,
self.delay,
task_ex_id=task_ex.id,
)
class TimeoutPolicy(base.TaskPolicy):
_schema = {
"properties": {
"delay": {"type": "integer"}
}
}
def __init__(self, timeout_sec):
self.delay = timeout_sec
def before_task_start(self, task_ex, task_spec):
super(TimeoutPolicy, self).before_task_start(task_ex, task_spec)
scheduler.schedule_call(
None,
'mistral.engine.policies._fail_task_if_incomplete',
self.delay,
task_ex_id=task_ex.id,
timeout=self.delay
)
wf_trace.info(
task_ex,
"Timeout check scheduled [task=%s, timeout(s)=%s]." %
(task_ex.id, self.delay)
)
class PauseBeforePolicy(base.TaskPolicy):
_schema = {
"properties": {
"expr": {"type": "boolean"}
}
}
def __init__(self, expression):
self.expr = expression
def before_task_start(self, task_ex, task_spec):
super(PauseBeforePolicy, self).before_task_start(task_ex, task_spec)
if not self.expr:
return
wf_trace.info(
task_ex,
"Workflow paused before task '%s' [%s -> %s]" %
(task_ex.name, task_ex.workflow_execution.state, states.PAUSED)
)
task_ex.workflow_execution.state = states.PAUSED
task_ex.state = states.IDLE
class ConcurrencyPolicy(base.TaskPolicy):
_schema = {
"properties": {
"concurrency": {"type": "integer"},
}
}
def __init__(self, concurrency):
self.concurrency = concurrency
def before_task_start(self, task_ex, task_spec):
super(ConcurrencyPolicy, self).before_task_start(task_ex, task_spec)
# This policy doesn't do anything except validating "concurrency"
# property value and setting a variable into task runtime context.
# This variable is then used to define how many action executions
# may be started in parallel.
context_key = 'concurrency'
runtime_context = _ensure_context_has_key(
task_ex.runtime_context,
context_key
)
runtime_context[context_key] = self.concurrency
task_ex.runtime_context = runtime_context
def _continue_task(task_ex_id):
from mistral.engine import task_handler
with db_api.transaction():
task_handler.continue_task(db_api.get_task_execution(task_ex_id))
def _complete_task(task_ex_id, state, state_info):
from mistral.engine import task_handler
with db_api.transaction():
task_handler.complete_task(
db_api.get_task_execution(task_ex_id),
state,
state_info
)
def _fail_task_if_incomplete(task_ex_id, timeout):
from mistral.engine import task_handler
with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex_id)
if not states.is_completed(task_ex.state):
msg = 'Task timed out [timeout(s)=%s].' % timeout
task_handler.complete_task(
db_api.get_task_execution(task_ex_id),
states.ERROR,
msg
)