Merge "Raising exception if there aren't start tasks in direct workflow"

This commit is contained in:
Jenkins 2015-09-08 11:54:44 +00:00 committed by Gerrit Code Review
commit f450651551
14 changed files with 406 additions and 208 deletions

View File

@ -26,6 +26,7 @@ class SpecValidationController(rest.RestController):
def __init__(self, parser):
super(SpecValidationController, self).__init__()
self._parse_func = parser
@pecan.expose('json')

View File

@ -96,9 +96,11 @@ class WorkflowSpecValidationTestCase(base.BaseTest):
if not expect_error:
return self._spec_parser(dsl_yaml)
else:
return self.assertRaises(exc.DSLParsingException,
self._spec_parser,
dsl_yaml)
return self.assertRaises(
exc.DSLParsingException,
self._spec_parser,
dsl_yaml
)
class WorkbookSpecValidationTestCase(WorkflowSpecValidationTestCase):

View File

@ -22,12 +22,10 @@ from mistral.tests.unit.workbook.v2 import base
from mistral import utils
from mistral.workbook.v2 import tasks
LOG = logging.getLogger(__name__)
class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
def test_workflow_types(self):
tests = [
({'type': 'direct'}, False),
@ -62,7 +60,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
tasks.DirectWfTaskSpecList)
def test_direct_workflow_invalid_task(self):
overlay = {'test': {'type': 'direct', 'tasks': {}}}
overlay = {
'test': {
'type': 'direct',
'tasks': {}
}
}
requires = {'requires': ['echo', 'get']}
utils.merge_dicts(overlay['test']['tasks'], {'email': requires})
@ -71,6 +74,21 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
changes=overlay,
expect_error=True)
def test_direct_workflow_no_start_tasks(self):
overlay = {
'test': {
'type': 'direct',
'tasks': {
'task1': {'on-complete': 'task2'},
'task2': {'on-complete': 'task1'}
}
}
}
self._parse_dsl_spec(add_tasks=False,
changes=overlay,
expect_error=True)
def test_reverse_workflow(self):
overlay = {'test': {'type': 'reverse', 'tasks': {}}}
require = {'requires': ['echo', 'get']}
@ -142,9 +160,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_input, expect_error in tests:
overlay = {'test': wf_input}
self._parse_dsl_spec(add_tasks=True,
changes=overlay,
expect_error=expect_error)
self._parse_dsl_spec(
add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_outputs(self):
tests = [
@ -160,9 +181,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_output, expect_error in tests:
overlay = {'test': wf_output}
self._parse_dsl_spec(add_tasks=True,
changes=overlay,
expect_error=expect_error)
self._parse_dsl_spec(
add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_vars(self):
tests = [
@ -178,13 +202,18 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_vars, expect_error in tests:
overlay = {'test': wf_vars}
self._parse_dsl_spec(add_tasks=True,
changes=overlay,
expect_error=expect_error)
self._parse_dsl_spec(
add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_tasks_required(self):
exception = self._parse_dsl_spec(add_tasks=False,
expect_error=True)
exception = self._parse_dsl_spec(
add_tasks=False,
expect_error=True
)
self.assertIn("'tasks' is a required property", exception.message)
@ -197,9 +226,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_tasks, expect_error in tests:
overlay = {'test': wf_tasks}
self._parse_dsl_spec(add_tasks=False,
changes=overlay,
expect_error=expect_error)
self._parse_dsl_spec(
add_tasks=False,
changes=overlay,
expect_error=expect_error
)
def test_task_defaults(self):
tests = [
@ -289,9 +321,11 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
utils.merge_dicts(overlay['test']['task-defaults'], default)
self._parse_dsl_spec(add_tasks=True,
changes=overlay,
expect_error=expect_error)
self._parse_dsl_spec(
add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_invalid_item(self):
overlay = {'name': 'invalid'}

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
#
# Copyright 2013 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -19,6 +17,7 @@ from oslo_log import log as logging
from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models
from mistral import exceptions as exc
from mistral.tests import base
from mistral.workbook import parser as spec_parser
from mistral.workflow import direct_workflow as d_wf
@ -26,52 +25,24 @@ from mistral.workflow import states
LOG = logging.getLogger(__name__)
WB = """
---
version: '2.0'
name: my_wb
workflows:
wf:
type: direct
tasks:
task1:
action: std.echo output="Hey"
publish:
res1: <% $.task1 %>
on-complete:
- task2: <% $.res1 = 'Hey' %>
- task3: <% $.res1 = 'Not Hey' %>
task2:
action: std.echo output="Hi"
task3:
action: std.echo output="Hoy"
"""
class DirectWorkflowControllerTest(base.DbTestCase):
def setUp(self):
super(DirectWorkflowControllerTest, self).setUp()
wb_spec = spec_parser.get_workbook_spec_from_yaml(WB)
def _prepare_test(self, wf_text):
wf_spec = spec_parser.get_workflow_list_spec_from_yaml(wf_text)[0]
wf_ex = models.WorkflowExecution()
wf_ex.update({
'id': '1-2-3-4',
'spec': wb_spec.get_workflows().get('wf').to_dict(),
'spec': wf_spec.to_dict(),
'state': states.RUNNING
})
self.wf_ex = wf_ex
self.wb_spec = wb_spec
self.wf_spec = wf_spec
self.wf_ctrl = d_wf.DirectWorkflowController(wf_ex)
def _create_task_execution(self, name, state):
tasks_spec = self.wb_spec.get_workflows()['wf'].get_tasks()
tasks_spec = self.wf_spec.get_tasks()
task_ex = models.TaskExecution(
id=self.getUniqueString('id'),
@ -86,6 +57,30 @@ class DirectWorkflowControllerTest(base.DbTestCase):
@mock.patch.object(db_api, 'get_task_execution')
def test_continue_workflow(self, get_task_execution):
wf_text = """---
version: '2.0'
wf:
type: direct
tasks:
task1:
action: std.echo output="Hey"
publish:
res1: <% $.task1 %>
on-complete:
- task2: <% $.res1 = 'Hey' %>
- task3: <% $.res1 = 'Not Hey' %>
task2:
action: std.echo output="Hi"
task3:
action: std.echo output="Hoy"
"""
self._prepare_test(wf_text)
# Workflow execution is in initial step. No running tasks.
cmds = self.wf_ctrl.continue_workflow()
@ -142,3 +137,23 @@ class DirectWorkflowControllerTest(base.DbTestCase):
task2_ex.processed = True
self.assertEqual(0, len(cmds))
def test_continue_workflow_no_start_tasks(self):
wf_text = """---
version: '2.0'
wf:
description: >
Invalid workflow that doesn't have start tasks (tasks with
no inbound connections).
type: direct
tasks:
task1:
on-complete: task2
task2:
on-complete: task1
"""
self.assertRaises(exc.DSLParsingException, self._prepare_test, wf_text)

View File

@ -1,4 +1,4 @@
# Copyright 2013 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -44,19 +44,94 @@ ALL = (
PARAMS_PTRN = re.compile("([-_\w]+)=(%s)" % "|".join(ALL))
def instantiate_spec(spec_cls, data):
"""Instantiates specification accounting for specification hierarchies.
:param spec_cls: Specification concrete or base class. In case if base
class or the hierarchy is provided this method relies on attributes
_polymorphic_key and _polymorphic_value in order to find a concrete
class that needs to be instantiated.
:param data: Raw specification data as a dictionary.
"""
if issubclass(spec_cls, BaseSpecList):
# Ignore polymorphic search for specification lists because
# it doesn't make sense for them.
return spec_cls(data)
if not hasattr(spec_cls, '_polymorphic_key'):
spec = spec_cls(data)
spec.validate_semantics()
return spec
key = spec_cls._polymorphic_key
if not isinstance(key, tuple):
key_name = key
key_default = None
else:
key_name = key[0]
key_default = key[1]
for cls in utils.iter_subclasses(spec_cls):
if not hasattr(cls, '_polymorphic_value'):
raise exc.DSLParsingException(
"Class '%s' is expected to have attribute '_polymorphic_value'"
" because it's a part of specification hierarchy inherited "
"from class '%s'." % (cls, spec_cls)
)
if cls._polymorphic_value == data.get(key_name, key_default):
spec = cls(data)
spec.validate_semantics()
return cls(data)
raise exc.DSLParsingException(
'Failed to find a specification class to instantiate '
'[spec_cls=%s, data=%s]' % (spec_cls, data)
)
class BaseSpec(object):
"""Base class for all DSL specifications.
It represents a DSL entity such as workflow or task as a python object
providing more convenient API to analyse DSL than just working with raw
data in form of a dictionary. Specification classes also implement
all required validation logic by overriding instance method 'validate()'.
Note that the specification mechanism allows to have polymorphic entities
in DSL. For example, if we find it more convenient to have separate
specification classes for different types of workflow (i.e. 'direct' and
'reverse') we can do so. In this case, in order to instantiate them
correctly method 'instantiate_spec' must always be used where argument
'spec_cls' must be a root class of the specification hierarchy containing
class attribute '_polymorhpic_key' pointing to a key in raw data relying
on which we can find a concrete class. Concrete classes then must all have
attribute '_polymorhpic_value' corresponding to a value in a raw data.
Attribute '_polymorhpic_key' can be either a string or a tuple of size two
where the first value is a key name itself and the second value is a
default polymorphic value that must be used if raw data doesn't contain
a configured key at all. An example of this situation is when we don't
specify a workflow type in DSL. In this case, we assume it's 'direct'.
"""
# See http://json-schema.org
_schema = {
"type": "object"
'type': 'object'
}
_meta_schema = {
"type": "object"
'type': 'object'
}
_definitions = {}
_version = "1.0"
_version = '1.0'
@classmethod
def get_schema(cls, includes=['meta', 'definitions']):
@ -65,32 +140,60 @@ class BaseSpec(object):
schema['properties'] = utils.merge_dicts(
schema.get('properties', {}),
cls._meta_schema.get('properties', {}),
overwrite=False)
overwrite=False
)
if includes and 'meta' in includes:
schema['required'] = list(
set(schema.get('required', []) +
cls._meta_schema.get('required', [])))
cls._meta_schema.get('required', []))
)
if includes and 'definitions' in includes:
schema['definitions'] = utils.merge_dicts(
schema.get('definitions', {}),
cls._definitions,
overwrite=False)
overwrite=False
)
return schema
def __init__(self, data):
self._data = data
self.validate()
self.validate_schema()
def validate_schema(self):
"""Validates DSL entity schema that this specification represents.
By default, this method just validate schema of DSL entity that this
specification represents using "_schema" class attribute.
Additionally, child classes may implement additional logic to validate
more specific things like YAQL expressions in their fields.
Note that this method is called before construction of specification
fields and validation logic should only rely on raw data provided as
a dictionary accessible through '_data' instance field.
"""
def validate(self):
try:
jsonschema.validate(self._data, self.get_schema())
except jsonschema.ValidationError as e:
raise exc.InvalidModelException("Invalid DSL: %s" % e)
def validate_semantics(self):
"""Validates semantics of specification object.
Child classes may implement validation logic to check things like
integrity of corresponding data structure (e.g. task graph) or
other things that can't be expressed in JSON schema.
This method is called after specification has been built (i.e.
its initializer has finished it's work) so that validation logic
can rely on initialized specification fields.
"""
pass
def validate_yaql_expr(self, dsl_part):
if isinstance(dsl_part, six.string_types):
expr.validate(dsl_part)
@ -106,7 +209,7 @@ class BaseSpec(object):
def _spec_property(self, prop_name, spec_cls):
prop_val = self._data.get(prop_name)
return spec_cls(prop_val) if prop_val else None
return instantiate_spec(spec_cls, prop_val) if prop_val else None
def _group_spec(self, spec_cls, *prop_names):
if not prop_names:
@ -120,7 +223,7 @@ class BaseSpec(object):
if prop_val:
data[prop_name] = prop_val
return spec_cls(data)
return instantiate_spec(spec_cls, data)
def _inject_version(self, prop_names):
for prop_name in prop_names:
@ -139,8 +242,10 @@ class BaseSpec(object):
return prop_val
elif isinstance(prop_val, list):
result = {}
for t in prop_val:
result.update(t if isinstance(t, dict) else {t: ''})
return result
elif isinstance(prop_val, six.string_types):
return {prop_val: ''}
@ -172,6 +277,7 @@ class BaseSpec(object):
cmd = cmd_matcher.group()
params = {}
for k, v in re.findall(PARAMS_PTRN, cmd_str):
# Remove embracing quotes.
v = v.strip()
@ -218,7 +324,7 @@ class BaseListSpec(BaseSpec):
if k != 'version':
v['name'] = k
self._inject_version([k])
self.items.append(self.item_class(v))
self.items.append(instantiate_spec(self.item_class, v))
def validate(self):
super(BaseListSpec, self).validate()
@ -232,6 +338,12 @@ class BaseListSpec(BaseSpec):
def get_items(self):
return self.items
def __getitem__(self, idx):
return self.items[idx]
def __len__(self):
return len(self.items)
class BaseSpecList(object):
item_class = None
@ -245,7 +357,7 @@ class BaseSpecList(object):
if k != 'version':
v['name'] = k
v['version'] = self._version
self.items[k] = self.item_class(v)
self.items[k] = instantiate_spec(self.item_class, v)
def item_keys(self):
return self.items.keys()

View File

@ -17,6 +17,7 @@ import yaml
from yaml import error
from mistral import exceptions as exc
from mistral.workbook import base
from mistral.workbook.v2 import actions as actions_v2
from mistral.workbook.v2 import tasks as tasks_v2
from mistral.workbook.v2 import workbook as wb_v2
@ -61,7 +62,7 @@ def _get_spec_version(spec_dict):
def get_workbook_spec(spec_dict):
if _get_spec_version(spec_dict) == V2_0:
return wb_v2.WorkbookSpec(spec_dict)
return base.instantiate_spec(wb_v2.WorkbookSpec, spec_dict)
return None
@ -72,7 +73,7 @@ def get_workbook_spec_from_yaml(text):
def get_action_spec(spec_dict):
if _get_spec_version(spec_dict) == V2_0:
return actions_v2.ActionSpec(spec_dict)
return base.instantiate_spec(actions_v2.ActionSpec, spec_dict)
return None
@ -86,7 +87,7 @@ def get_action_spec_from_yaml(text, action_name):
def get_action_list_spec(spec_dict):
return actions_v2.ActionListSpec(spec_dict)
return base.instantiate_spec(actions_v2.ActionListSpec, spec_dict)
def get_action_list_spec_from_yaml(text):
@ -95,13 +96,13 @@ def get_action_list_spec_from_yaml(text):
def get_workflow_spec(spec_dict):
if _get_spec_version(spec_dict) == V2_0:
return wf_v2.WorkflowSpec(spec_dict)
return base.instantiate_spec(wf_v2.WorkflowSpec, spec_dict)
return None
def get_workflow_list_spec(spec_dict):
return wf_v2.WorkflowListSpec(spec_dict)
return base.instantiate_spec(wf_v2.WorkflowListSpec, spec_dict)
def get_workflow_spec_from_yaml(text):

View File

@ -49,8 +49,8 @@ class ActionSpec(base.BaseSpec):
utils.merge_dicts(self._base_input, _input)
def validate(self):
super(ActionSpec, self).validate()
def validate_schema(self):
super(ActionSpec, self).validate_schema()
# Validate YAQL expressions.
inline_params = self._parse_cmd_and_input(self._data.get('base'))[1]

View File

@ -55,8 +55,8 @@ class PoliciesSpec(base.BaseSpec):
self._pause_before = data.get('pause-before', False)
self._concurrency = data.get('concurrency', 0)
def validate(self):
super(PoliciesSpec, self).validate()
def validate_schema(self):
super(PoliciesSpec, self).validate_schema()
# Validate YAQL expressions.
self.validate_yaql_expr(self._data.get('wait-before', 0))

View File

@ -70,8 +70,8 @@ class RetrySpec(base.BaseSpec):
return retry
def validate(self):
super(RetrySpec, self).validate()
def validate_schema(self):
super(RetrySpec, self).validate_schema()
# Validate YAQL expressions.
self.validate_yaql_expr(self._data.get('count'))

View File

@ -72,8 +72,8 @@ class TaskDefaultsSpec(base.BaseSpec):
self._on_error = self._as_list_of_tuples("on-error")
self._requires = data.get('requires', [])
def validate(self):
super(TaskDefaultsSpec, self).validate()
def validate_schema(self):
super(TaskDefaultsSpec, self).validate_schema()
# Validate YAQL expressions.
self._validate_transitions('on-complete')

View File

@ -86,8 +86,8 @@ class TaskSpec(base.BaseSpec):
self._inject_type()
self._process_action_and_workflow()
def validate(self):
super(TaskSpec, self).validate()
def validate_schema(self):
super(TaskSpec, self).validate_schema()
action = self._data.get('action')
workflow = self._data.get('workflow')
@ -234,8 +234,8 @@ class DirectWorkflowTaskSpec(TaskSpec):
self._on_success = self._as_list_of_tuples('on-success')
self._on_error = self._as_list_of_tuples('on-error')
def validate(self):
super(DirectWorkflowTaskSpec, self).validate()
def validate_schema(self):
super(DirectWorkflowTaskSpec, self).validate_schema()
if 'join' in self._data:
join = self._data.get('join')

View File

@ -1,4 +1,4 @@
# Copyright 2014 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -24,35 +24,19 @@ from mistral.workbook.v2 import tasks
class WorkflowSpec(base.BaseSpec):
# See http://json-schema.org
_direct_task_schema = tasks.DirectWorkflowTaskSpec.get_schema(
includes=None)
_reverse_task_schema = tasks.ReverseWorkflowTaskSpec.get_schema(
includes=None)
_polymorphic_key = ('type', 'direct')
_task_defaults_schema = task_defaults.TaskDefaultsSpec.get_schema(
includes=None)
_schema = {
_meta_schema = {
"type": "object",
"properties": {
"type": types.WORKFLOW_TYPE,
"task-defaults": _task_defaults_schema,
"input": types.UNIQUE_STRING_OR_ONE_KEY_DICT_LIST,
"output": types.NONEMPTY_DICT,
"vars": types.NONEMPTY_DICT,
"tasks": {
"type": "object",
"minProperties": 1,
"patternProperties": {
"^\w+$": {
"anyOf": [
_direct_task_schema,
_reverse_task_schema
]
}
}
},
"vars": types.NONEMPTY_DICT
},
"required": ["tasks"],
"additionalProperties": False
@ -64,7 +48,7 @@ class WorkflowSpec(base.BaseSpec):
self._name = data['name']
self._description = data.get('description')
self._tags = data.get('tags', [])
self._type = data['type'] if 'type' in data else "direct"
self._type = data['type'] if 'type' in data else 'direct'
self._input = utils.get_input_dict(data.get('input', []))
self._output = data.get('output', {})
self._vars = data.get('vars', {})
@ -78,8 +62,8 @@ class WorkflowSpec(base.BaseSpec):
tasks.TaskSpecList.get_class(self._type)
)
def validate(self):
super(WorkflowSpec, self).validate()
def validate_schema(self):
super(WorkflowSpec, self).validate_schema()
if not self._data.get('tasks'):
raise exc.InvalidModelException(
@ -90,6 +74,10 @@ class WorkflowSpec(base.BaseSpec):
self.validate_yaql_expr(self._data.get('output', {}))
self.validate_yaql_expr(self._data.get('vars', {}))
def validate_semantics(self):
# Doesn't do anything by default.
pass
def get_name(self):
return self._name
@ -118,6 +106,133 @@ class WorkflowSpec(base.BaseSpec):
return self._tasks
class DirectWorkflowSpec(WorkflowSpec):
_polymorphic_value = 'direct'
_schema = {
"properties": {
"tasks": {
"type": "object",
"minProperties": 1,
"patternProperties": {
"^\w+$":
tasks.DirectWorkflowTaskSpec.get_schema(includes=None)
}
},
}
}
def validate_semantics(self):
# Check if there are start tasks.
if not self.find_start_tasks():
raise exc.DSLParsingException(
'Failed to find start tasks in direct workflow. '
'There must be at least one task without inbound transition.'
'[workflow_name=%s]' % self._name
)
def find_start_tasks(self):
return [
t_s for t_s in self.get_tasks()
if not self.has_inbound_transitions(t_s)
]
def find_inbound_task_specs(self, task_spec):
return [
t_s for t_s in self.get_tasks()
if self.transition_exists(t_s.get_name(), task_spec.get_name())
]
def find_outbound_task_specs(self, task_spec):
return [
t_s for t_s in self.get_tasks()
if self.transition_exists(task_spec.get_name(), t_s.get_name())
]
def has_inbound_transitions(self, task_spec):
return len(self.find_inbound_task_specs(task_spec)) > 0
def has_outbound_transitions(self, task_spec):
return len(self.find_outbound_task_specs(task_spec)) > 0
def transition_exists(self, from_task_name, to_task_name):
t_names = set()
for tup in self.get_on_error_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_success_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_complete_clause(from_task_name):
t_names.add(tup[0])
return to_task_name in t_names
def get_on_error_clause(self, t_name):
result = self.get_tasks()[t_name].get_on_error()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults:
result = self._remove_task_from_clause(
t_defaults.get_on_error(),
t_name
)
return result
def get_on_success_clause(self, t_name):
result = self.get_tasks()[t_name].get_on_success()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults:
result = self._remove_task_from_clause(
t_defaults.get_on_success(),
t_name
)
return result
def get_on_complete_clause(self, t_name):
result = self.get_tasks()[t_name].get_on_complete()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults:
result = self._remove_task_from_clause(
t_defaults.get_on_complete(),
t_name
)
return result
@staticmethod
def _remove_task_from_clause(on_clause, t_name):
return filter(lambda tup: tup[0] != t_name, on_clause)
class ReverseWorkflowSpec(WorkflowSpec):
_polymorphic_value = 'reverse'
_schema = {
"properties": {
"tasks": {
"type": "object",
"minProperties": 1,
"patternProperties": {
"^\w+$":
tasks.ReverseWorkflowTaskSpec.get_schema(includes=None)
}
},
}
}
class WorkflowSpecList(base.BaseSpecList):
item_class = WorkflowSpec

View File

@ -154,14 +154,13 @@ class WorkflowController(object):
if wf_type == wf_ctrl_cls.__workflow_type__:
return wf_ctrl_cls
msg = 'Failed to find a workflow controller [type=%s]' % wf_type
raise exc.NotFoundException(msg)
raise exc.NotFoundException(
'Failed to find a workflow controller [type=%s]' % wf_type
)
@staticmethod
def get_controller(wf_ex, wf_spec=None):
if not wf_spec:
wf_spec = spec_parser.get_workflow_spec(wf_ex['spec'])
ctrl_cls = WorkflowController._get_class(wf_spec.get_type())
return ctrl_cls(wf_ex)
return WorkflowController._get_class(wf_spec.get_type())(wf_ex)

View File

@ -1,4 +1,4 @@
# Copyright 2014 - Mirantis, Inc.
# Copyright 2015 - Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -47,7 +47,7 @@ class DirectWorkflowController(base.WorkflowController):
lambda t_e: self._is_upstream_task_execution(task_spec, t_e),
wf_utils.find_task_executions_by_specs(
self.wf_ex,
self._find_inbound_task_specs(task_spec)
self.wf_spec.find_inbound_task_specs(task_spec)
)
)
@ -80,19 +80,13 @@ class DirectWorkflowController(base.WorkflowController):
return cmds
def _find_start_commands(self):
t_specs = []
for t_s in self.wf_spec.get_tasks():
if not self._has_inbound_transitions(t_s):
t_specs.append(t_s)
return [
commands.RunTask(
self.wf_ex,
t_s,
self._get_task_inbound_context(t_s)
)
for t_s in t_specs
for t_s in self.wf_spec.find_start_tasks()
]
def _find_next_commands_for_task(self, task_ex):
@ -121,7 +115,7 @@ class DirectWorkflowController(base.WorkflowController):
)
# NOTE(xylan): Decide whether or not a join task should run
# immediately
# immediately.
if self._is_unsatisfied_join(cmd):
cmd.wait_flag = True
@ -136,35 +130,6 @@ class DirectWorkflowController(base.WorkflowController):
return cmds
def _has_inbound_transitions(self, task_spec):
return len(self._find_inbound_task_specs(task_spec)) > 0
def _find_inbound_task_specs(self, task_spec):
return [
t_s for t_s in self.wf_spec.get_tasks()
if self._transition_exists(t_s.get_name(), task_spec.get_name())
]
def _find_outbound_task_specs(self, task_spec):
return [
t_s for t_s in self.wf_spec.get_tasks()
if self._transition_exists(task_spec.get_name(), t_s.get_name())
]
def _transition_exists(self, from_task_name, to_task_name):
t_names = set()
for tup in self.get_on_error_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_success_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_complete_clause(from_task_name):
t_names.add(tup[0])
return to_task_name in t_names
# TODO(rakhmerov): Need to refactor this method to be able to pass tasks
# whose contexts need to be merged.
def evaluate_workflow_final_context(self):
@ -179,11 +144,11 @@ class DirectWorkflowController(base.WorkflowController):
return ctx
def is_error_handled_for(self, task_ex):
return bool(self.get_on_error_clause(task_ex.name))
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
def all_errors_handled(self):
for t_ex in wf_utils.find_error_task_executions(self.wf_ex):
if not self.get_on_error_clause(t_ex.name):
if not self.wf_spec.get_on_error_clause(t_ex.name):
return False
return True
@ -204,52 +169,6 @@ class DirectWorkflowController(base.WorkflowController):
if self.wf_spec.get_tasks()[t_name]
])
@staticmethod
def _remove_task_from_clause(on_clause, t_name):
return filter(lambda tup: tup[0] != t_name, on_clause)
def get_on_error_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_error()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_error(),
t_name
)
return result
def get_on_success_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_success()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_success(),
t_name
)
return result
def get_on_complete_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_complete()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_complete(),
t_name
)
return result
def _find_next_task_names(self, task_ex):
t_state = task_ex.state
t_name = task_ex.name
@ -260,19 +179,19 @@ class DirectWorkflowController(base.WorkflowController):
if states.is_completed(t_state):
t_names += self._find_next_task_names_for_clause(
self.get_on_complete_clause(t_name),
self.wf_spec.get_on_complete_clause(t_name),
ctx
)
if t_state == states.ERROR:
t_names += self._find_next_task_names_for_clause(
self.get_on_error_clause(t_name),
self.wf_spec.get_on_error_clause(t_name),
ctx
)
elif t_state == states.SUCCESS:
t_names += self._find_next_task_names_for_clause(
self.get_on_success_clause(t_name),
self.wf_spec.get_on_success_clause(t_name),
ctx
)
@ -323,7 +242,7 @@ class DirectWorkflowController(base.WorkflowController):
if not join_expr:
return False
in_task_specs = self._find_inbound_task_specs(task_spec)
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return False