Raising exception if there aren't start tasks in direct workflow

* Now if direct workflow graph doesn't have start tasks (ones with no
  inbound transitions) then an exception gets raised.
* Refactoring workflow specifications: moving specification related method
  from direct workflow controller into workflow specification.
* Implemented the mechanism of polymorphic DSL entities. At this point
  there's a hierarchy of specification classes to represent different types
  of workflow.
* Specification validation logic is now explicitly split into two methods:
  validate_schema() and validate_semantics() where the second one is supposed
  to implement integrity checks and other more high-level rules that are
  impossible to define in JSON schema.
* Other minor refactoring and style changes.

Change-Id: I60937b77e39133e3b254fed574e6aec6aa402eb0
This commit is contained in:
Renat Akhmerov 2015-09-04 15:41:29 +06:00
parent 371ba27dcd
commit 49baf0311e
14 changed files with 406 additions and 208 deletions

View File

@ -26,6 +26,7 @@ class SpecValidationController(rest.RestController):
def __init__(self, parser): def __init__(self, parser):
super(SpecValidationController, self).__init__() super(SpecValidationController, self).__init__()
self._parse_func = parser self._parse_func = parser
@pecan.expose('json') @pecan.expose('json')

View File

@ -96,9 +96,11 @@ class WorkflowSpecValidationTestCase(base.BaseTest):
if not expect_error: if not expect_error:
return self._spec_parser(dsl_yaml) return self._spec_parser(dsl_yaml)
else: else:
return self.assertRaises(exc.DSLParsingException, return self.assertRaises(
self._spec_parser, exc.DSLParsingException,
dsl_yaml) self._spec_parser,
dsl_yaml
)
class WorkbookSpecValidationTestCase(WorkflowSpecValidationTestCase): class WorkbookSpecValidationTestCase(WorkflowSpecValidationTestCase):

View File

@ -22,12 +22,10 @@ from mistral.tests.unit.workbook.v2 import base
from mistral import utils from mistral import utils
from mistral.workbook.v2 import tasks from mistral.workbook.v2 import tasks
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
def test_workflow_types(self): def test_workflow_types(self):
tests = [ tests = [
({'type': 'direct'}, False), ({'type': 'direct'}, False),
@ -62,7 +60,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
tasks.DirectWfTaskSpecList) tasks.DirectWfTaskSpecList)
def test_direct_workflow_invalid_task(self): def test_direct_workflow_invalid_task(self):
overlay = {'test': {'type': 'direct', 'tasks': {}}} overlay = {
'test': {
'type': 'direct',
'tasks': {}
}
}
requires = {'requires': ['echo', 'get']} requires = {'requires': ['echo', 'get']}
utils.merge_dicts(overlay['test']['tasks'], {'email': requires}) utils.merge_dicts(overlay['test']['tasks'], {'email': requires})
@ -71,6 +74,21 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
changes=overlay, changes=overlay,
expect_error=True) expect_error=True)
def test_direct_workflow_no_start_tasks(self):
overlay = {
'test': {
'type': 'direct',
'tasks': {
'task1': {'on-complete': 'task2'},
'task2': {'on-complete': 'task1'}
}
}
}
self._parse_dsl_spec(add_tasks=False,
changes=overlay,
expect_error=True)
def test_reverse_workflow(self): def test_reverse_workflow(self):
overlay = {'test': {'type': 'reverse', 'tasks': {}}} overlay = {'test': {'type': 'reverse', 'tasks': {}}}
require = {'requires': ['echo', 'get']} require = {'requires': ['echo', 'get']}
@ -142,9 +160,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_input, expect_error in tests: for wf_input, expect_error in tests:
overlay = {'test': wf_input} overlay = {'test': wf_input}
self._parse_dsl_spec(add_tasks=True,
changes=overlay, self._parse_dsl_spec(
expect_error=expect_error) add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_outputs(self): def test_outputs(self):
tests = [ tests = [
@ -160,9 +181,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_output, expect_error in tests: for wf_output, expect_error in tests:
overlay = {'test': wf_output} overlay = {'test': wf_output}
self._parse_dsl_spec(add_tasks=True,
changes=overlay, self._parse_dsl_spec(
expect_error=expect_error) add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_vars(self): def test_vars(self):
tests = [ tests = [
@ -178,13 +202,18 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_vars, expect_error in tests: for wf_vars, expect_error in tests:
overlay = {'test': wf_vars} overlay = {'test': wf_vars}
self._parse_dsl_spec(add_tasks=True,
changes=overlay, self._parse_dsl_spec(
expect_error=expect_error) add_tasks=True,
changes=overlay,
expect_error=expect_error
)
def test_tasks_required(self): def test_tasks_required(self):
exception = self._parse_dsl_spec(add_tasks=False, exception = self._parse_dsl_spec(
expect_error=True) add_tasks=False,
expect_error=True
)
self.assertIn("'tasks' is a required property", exception.message) self.assertIn("'tasks' is a required property", exception.message)
@ -197,9 +226,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
for wf_tasks, expect_error in tests: for wf_tasks, expect_error in tests:
overlay = {'test': wf_tasks} overlay = {'test': wf_tasks}
self._parse_dsl_spec(add_tasks=False,
changes=overlay, self._parse_dsl_spec(
expect_error=expect_error) add_tasks=False,
changes=overlay,
expect_error=expect_error
)
def test_task_defaults(self): def test_task_defaults(self):
tests = [ tests = [
@ -289,9 +321,11 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase):
utils.merge_dicts(overlay['test']['task-defaults'], default) utils.merge_dicts(overlay['test']['task-defaults'], default)
self._parse_dsl_spec(add_tasks=True, self._parse_dsl_spec(
changes=overlay, add_tasks=True,
expect_error=expect_error) changes=overlay,
expect_error=expect_error
)
def test_invalid_item(self): def test_invalid_item(self):
overlay = {'name': 'invalid'} overlay = {'name': 'invalid'}

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*- # Copyright 2015 - Mirantis, Inc.
#
# Copyright 2013 - Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -19,6 +17,7 @@ from oslo_log import log as logging
from mistral.db.v2 import api as db_api from mistral.db.v2 import api as db_api
from mistral.db.v2.sqlalchemy import models from mistral.db.v2.sqlalchemy import models
from mistral import exceptions as exc
from mistral.tests import base from mistral.tests import base
from mistral.workbook import parser as spec_parser from mistral.workbook import parser as spec_parser
from mistral.workflow import direct_workflow as d_wf from mistral.workflow import direct_workflow as d_wf
@ -26,52 +25,24 @@ from mistral.workflow import states
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
WB = """
---
version: '2.0'
name: my_wb
workflows:
wf:
type: direct
tasks:
task1:
action: std.echo output="Hey"
publish:
res1: <% $.task1 %>
on-complete:
- task2: <% $.res1 = 'Hey' %>
- task3: <% $.res1 = 'Not Hey' %>
task2:
action: std.echo output="Hi"
task3:
action: std.echo output="Hoy"
"""
class DirectWorkflowControllerTest(base.DbTestCase): class DirectWorkflowControllerTest(base.DbTestCase):
def setUp(self): def _prepare_test(self, wf_text):
super(DirectWorkflowControllerTest, self).setUp() wf_spec = spec_parser.get_workflow_list_spec_from_yaml(wf_text)[0]
wb_spec = spec_parser.get_workbook_spec_from_yaml(WB)
wf_ex = models.WorkflowExecution() wf_ex = models.WorkflowExecution()
wf_ex.update({ wf_ex.update({
'id': '1-2-3-4', 'id': '1-2-3-4',
'spec': wb_spec.get_workflows().get('wf').to_dict(), 'spec': wf_spec.to_dict(),
'state': states.RUNNING 'state': states.RUNNING
}) })
self.wf_ex = wf_ex self.wf_ex = wf_ex
self.wb_spec = wb_spec self.wf_spec = wf_spec
self.wf_ctrl = d_wf.DirectWorkflowController(wf_ex) self.wf_ctrl = d_wf.DirectWorkflowController(wf_ex)
def _create_task_execution(self, name, state): def _create_task_execution(self, name, state):
tasks_spec = self.wb_spec.get_workflows()['wf'].get_tasks() tasks_spec = self.wf_spec.get_tasks()
task_ex = models.TaskExecution( task_ex = models.TaskExecution(
id=self.getUniqueString('id'), id=self.getUniqueString('id'),
@ -86,6 +57,30 @@ class DirectWorkflowControllerTest(base.DbTestCase):
@mock.patch.object(db_api, 'get_task_execution') @mock.patch.object(db_api, 'get_task_execution')
def test_continue_workflow(self, get_task_execution): def test_continue_workflow(self, get_task_execution):
wf_text = """---
version: '2.0'
wf:
type: direct
tasks:
task1:
action: std.echo output="Hey"
publish:
res1: <% $.task1 %>
on-complete:
- task2: <% $.res1 = 'Hey' %>
- task3: <% $.res1 = 'Not Hey' %>
task2:
action: std.echo output="Hi"
task3:
action: std.echo output="Hoy"
"""
self._prepare_test(wf_text)
# Workflow execution is in initial step. No running tasks. # Workflow execution is in initial step. No running tasks.
cmds = self.wf_ctrl.continue_workflow() cmds = self.wf_ctrl.continue_workflow()
@ -142,3 +137,23 @@ class DirectWorkflowControllerTest(base.DbTestCase):
task2_ex.processed = True task2_ex.processed = True
self.assertEqual(0, len(cmds)) self.assertEqual(0, len(cmds))
def test_continue_workflow_no_start_tasks(self):
wf_text = """---
version: '2.0'
wf:
description: >
Invalid workflow that doesn't have start tasks (tasks with
no inbound connections).
type: direct
tasks:
task1:
on-complete: task2
task2:
on-complete: task1
"""
self.assertRaises(exc.DSLParsingException, self._prepare_test, wf_text)

View File

@ -1,4 +1,4 @@
# Copyright 2013 - Mirantis, Inc. # Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc. # Copyright 2015 - StackStorm, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -44,19 +44,94 @@ ALL = (
PARAMS_PTRN = re.compile("([-_\w]+)=(%s)" % "|".join(ALL)) PARAMS_PTRN = re.compile("([-_\w]+)=(%s)" % "|".join(ALL))
def instantiate_spec(spec_cls, data):
"""Instantiates specification accounting for specification hierarchies.
:param spec_cls: Specification concrete or base class. In case if base
class or the hierarchy is provided this method relies on attributes
_polymorphic_key and _polymorphic_value in order to find a concrete
class that needs to be instantiated.
:param data: Raw specification data as a dictionary.
"""
if issubclass(spec_cls, BaseSpecList):
# Ignore polymorphic search for specification lists because
# it doesn't make sense for them.
return spec_cls(data)
if not hasattr(spec_cls, '_polymorphic_key'):
spec = spec_cls(data)
spec.validate_semantics()
return spec
key = spec_cls._polymorphic_key
if not isinstance(key, tuple):
key_name = key
key_default = None
else:
key_name = key[0]
key_default = key[1]
for cls in utils.iter_subclasses(spec_cls):
if not hasattr(cls, '_polymorphic_value'):
raise exc.DSLParsingException(
"Class '%s' is expected to have attribute '_polymorphic_value'"
" because it's a part of specification hierarchy inherited "
"from class '%s'." % (cls, spec_cls)
)
if cls._polymorphic_value == data.get(key_name, key_default):
spec = cls(data)
spec.validate_semantics()
return cls(data)
raise exc.DSLParsingException(
'Failed to find a specification class to instantiate '
'[spec_cls=%s, data=%s]' % (spec_cls, data)
)
class BaseSpec(object): class BaseSpec(object):
"""Base class for all DSL specifications.
It represents a DSL entity such as workflow or task as a python object
providing more convenient API to analyse DSL than just working with raw
data in form of a dictionary. Specification classes also implement
all required validation logic by overriding instance method 'validate()'.
Note that the specification mechanism allows to have polymorphic entities
in DSL. For example, if we find it more convenient to have separate
specification classes for different types of workflow (i.e. 'direct' and
'reverse') we can do so. In this case, in order to instantiate them
correctly method 'instantiate_spec' must always be used where argument
'spec_cls' must be a root class of the specification hierarchy containing
class attribute '_polymorhpic_key' pointing to a key in raw data relying
on which we can find a concrete class. Concrete classes then must all have
attribute '_polymorhpic_value' corresponding to a value in a raw data.
Attribute '_polymorhpic_key' can be either a string or a tuple of size two
where the first value is a key name itself and the second value is a
default polymorphic value that must be used if raw data doesn't contain
a configured key at all. An example of this situation is when we don't
specify a workflow type in DSL. In this case, we assume it's 'direct'.
"""
# See http://json-schema.org # See http://json-schema.org
_schema = { _schema = {
"type": "object" 'type': 'object'
} }
_meta_schema = { _meta_schema = {
"type": "object" 'type': 'object'
} }
_definitions = {} _definitions = {}
_version = "1.0" _version = '1.0'
@classmethod @classmethod
def get_schema(cls, includes=['meta', 'definitions']): def get_schema(cls, includes=['meta', 'definitions']):
@ -65,32 +140,60 @@ class BaseSpec(object):
schema['properties'] = utils.merge_dicts( schema['properties'] = utils.merge_dicts(
schema.get('properties', {}), schema.get('properties', {}),
cls._meta_schema.get('properties', {}), cls._meta_schema.get('properties', {}),
overwrite=False) overwrite=False
)
if includes and 'meta' in includes: if includes and 'meta' in includes:
schema['required'] = list( schema['required'] = list(
set(schema.get('required', []) + set(schema.get('required', []) +
cls._meta_schema.get('required', []))) cls._meta_schema.get('required', []))
)
if includes and 'definitions' in includes: if includes and 'definitions' in includes:
schema['definitions'] = utils.merge_dicts( schema['definitions'] = utils.merge_dicts(
schema.get('definitions', {}), schema.get('definitions', {}),
cls._definitions, cls._definitions,
overwrite=False) overwrite=False
)
return schema return schema
def __init__(self, data): def __init__(self, data):
self._data = data self._data = data
self.validate() self.validate_schema()
def validate_schema(self):
"""Validates DSL entity schema that this specification represents.
By default, this method just validate schema of DSL entity that this
specification represents using "_schema" class attribute.
Additionally, child classes may implement additional logic to validate
more specific things like YAQL expressions in their fields.
Note that this method is called before construction of specification
fields and validation logic should only rely on raw data provided as
a dictionary accessible through '_data' instance field.
"""
def validate(self):
try: try:
jsonschema.validate(self._data, self.get_schema()) jsonschema.validate(self._data, self.get_schema())
except jsonschema.ValidationError as e: except jsonschema.ValidationError as e:
raise exc.InvalidModelException("Invalid DSL: %s" % e) raise exc.InvalidModelException("Invalid DSL: %s" % e)
def validate_semantics(self):
"""Validates semantics of specification object.
Child classes may implement validation logic to check things like
integrity of corresponding data structure (e.g. task graph) or
other things that can't be expressed in JSON schema.
This method is called after specification has been built (i.e.
its initializer has finished it's work) so that validation logic
can rely on initialized specification fields.
"""
pass
def validate_yaql_expr(self, dsl_part): def validate_yaql_expr(self, dsl_part):
if isinstance(dsl_part, six.string_types): if isinstance(dsl_part, six.string_types):
expr.validate(dsl_part) expr.validate(dsl_part)
@ -106,7 +209,7 @@ class BaseSpec(object):
def _spec_property(self, prop_name, spec_cls): def _spec_property(self, prop_name, spec_cls):
prop_val = self._data.get(prop_name) prop_val = self._data.get(prop_name)
return spec_cls(prop_val) if prop_val else None return instantiate_spec(spec_cls, prop_val) if prop_val else None
def _group_spec(self, spec_cls, *prop_names): def _group_spec(self, spec_cls, *prop_names):
if not prop_names: if not prop_names:
@ -120,7 +223,7 @@ class BaseSpec(object):
if prop_val: if prop_val:
data[prop_name] = prop_val data[prop_name] = prop_val
return spec_cls(data) return instantiate_spec(spec_cls, data)
def _inject_version(self, prop_names): def _inject_version(self, prop_names):
for prop_name in prop_names: for prop_name in prop_names:
@ -139,8 +242,10 @@ class BaseSpec(object):
return prop_val return prop_val
elif isinstance(prop_val, list): elif isinstance(prop_val, list):
result = {} result = {}
for t in prop_val: for t in prop_val:
result.update(t if isinstance(t, dict) else {t: ''}) result.update(t if isinstance(t, dict) else {t: ''})
return result return result
elif isinstance(prop_val, six.string_types): elif isinstance(prop_val, six.string_types):
return {prop_val: ''} return {prop_val: ''}
@ -172,6 +277,7 @@ class BaseSpec(object):
cmd = cmd_matcher.group() cmd = cmd_matcher.group()
params = {} params = {}
for k, v in re.findall(PARAMS_PTRN, cmd_str): for k, v in re.findall(PARAMS_PTRN, cmd_str):
# Remove embracing quotes. # Remove embracing quotes.
v = v.strip() v = v.strip()
@ -218,7 +324,7 @@ class BaseListSpec(BaseSpec):
if k != 'version': if k != 'version':
v['name'] = k v['name'] = k
self._inject_version([k]) self._inject_version([k])
self.items.append(self.item_class(v)) self.items.append(instantiate_spec(self.item_class, v))
def validate(self): def validate(self):
super(BaseListSpec, self).validate() super(BaseListSpec, self).validate()
@ -232,6 +338,12 @@ class BaseListSpec(BaseSpec):
def get_items(self): def get_items(self):
return self.items return self.items
def __getitem__(self, idx):
return self.items[idx]
def __len__(self):
return len(self.items)
class BaseSpecList(object): class BaseSpecList(object):
item_class = None item_class = None
@ -245,7 +357,7 @@ class BaseSpecList(object):
if k != 'version': if k != 'version':
v['name'] = k v['name'] = k
v['version'] = self._version v['version'] = self._version
self.items[k] = self.item_class(v) self.items[k] = instantiate_spec(self.item_class, v)
def item_keys(self): def item_keys(self):
return self.items.keys() return self.items.keys()

View File

@ -17,6 +17,7 @@ import yaml
from yaml import error from yaml import error
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.workbook import base
from mistral.workbook.v2 import actions as actions_v2 from mistral.workbook.v2 import actions as actions_v2
from mistral.workbook.v2 import tasks as tasks_v2 from mistral.workbook.v2 import tasks as tasks_v2
from mistral.workbook.v2 import workbook as wb_v2 from mistral.workbook.v2 import workbook as wb_v2
@ -61,7 +62,7 @@ def _get_spec_version(spec_dict):
def get_workbook_spec(spec_dict): def get_workbook_spec(spec_dict):
if _get_spec_version(spec_dict) == V2_0: if _get_spec_version(spec_dict) == V2_0:
return wb_v2.WorkbookSpec(spec_dict) return base.instantiate_spec(wb_v2.WorkbookSpec, spec_dict)
return None return None
@ -72,7 +73,7 @@ def get_workbook_spec_from_yaml(text):
def get_action_spec(spec_dict): def get_action_spec(spec_dict):
if _get_spec_version(spec_dict) == V2_0: if _get_spec_version(spec_dict) == V2_0:
return actions_v2.ActionSpec(spec_dict) return base.instantiate_spec(actions_v2.ActionSpec, spec_dict)
return None return None
@ -86,7 +87,7 @@ def get_action_spec_from_yaml(text, action_name):
def get_action_list_spec(spec_dict): def get_action_list_spec(spec_dict):
return actions_v2.ActionListSpec(spec_dict) return base.instantiate_spec(actions_v2.ActionListSpec, spec_dict)
def get_action_list_spec_from_yaml(text): def get_action_list_spec_from_yaml(text):
@ -95,13 +96,13 @@ def get_action_list_spec_from_yaml(text):
def get_workflow_spec(spec_dict): def get_workflow_spec(spec_dict):
if _get_spec_version(spec_dict) == V2_0: if _get_spec_version(spec_dict) == V2_0:
return wf_v2.WorkflowSpec(spec_dict) return base.instantiate_spec(wf_v2.WorkflowSpec, spec_dict)
return None return None
def get_workflow_list_spec(spec_dict): def get_workflow_list_spec(spec_dict):
return wf_v2.WorkflowListSpec(spec_dict) return base.instantiate_spec(wf_v2.WorkflowListSpec, spec_dict)
def get_workflow_spec_from_yaml(text): def get_workflow_spec_from_yaml(text):

View File

@ -49,8 +49,8 @@ class ActionSpec(base.BaseSpec):
utils.merge_dicts(self._base_input, _input) utils.merge_dicts(self._base_input, _input)
def validate(self): def validate_schema(self):
super(ActionSpec, self).validate() super(ActionSpec, self).validate_schema()
# Validate YAQL expressions. # Validate YAQL expressions.
inline_params = self._parse_cmd_and_input(self._data.get('base'))[1] inline_params = self._parse_cmd_and_input(self._data.get('base'))[1]

View File

@ -55,8 +55,8 @@ class PoliciesSpec(base.BaseSpec):
self._pause_before = data.get('pause-before', False) self._pause_before = data.get('pause-before', False)
self._concurrency = data.get('concurrency', 0) self._concurrency = data.get('concurrency', 0)
def validate(self): def validate_schema(self):
super(PoliciesSpec, self).validate() super(PoliciesSpec, self).validate_schema()
# Validate YAQL expressions. # Validate YAQL expressions.
self.validate_yaql_expr(self._data.get('wait-before', 0)) self.validate_yaql_expr(self._data.get('wait-before', 0))

View File

@ -70,8 +70,8 @@ class RetrySpec(base.BaseSpec):
return retry return retry
def validate(self): def validate_schema(self):
super(RetrySpec, self).validate() super(RetrySpec, self).validate_schema()
# Validate YAQL expressions. # Validate YAQL expressions.
self.validate_yaql_expr(self._data.get('count')) self.validate_yaql_expr(self._data.get('count'))

View File

@ -72,8 +72,8 @@ class TaskDefaultsSpec(base.BaseSpec):
self._on_error = self._as_list_of_tuples("on-error") self._on_error = self._as_list_of_tuples("on-error")
self._requires = data.get('requires', []) self._requires = data.get('requires', [])
def validate(self): def validate_schema(self):
super(TaskDefaultsSpec, self).validate() super(TaskDefaultsSpec, self).validate_schema()
# Validate YAQL expressions. # Validate YAQL expressions.
self._validate_transitions('on-complete') self._validate_transitions('on-complete')

View File

@ -86,8 +86,8 @@ class TaskSpec(base.BaseSpec):
self._inject_type() self._inject_type()
self._process_action_and_workflow() self._process_action_and_workflow()
def validate(self): def validate_schema(self):
super(TaskSpec, self).validate() super(TaskSpec, self).validate_schema()
action = self._data.get('action') action = self._data.get('action')
workflow = self._data.get('workflow') workflow = self._data.get('workflow')
@ -234,8 +234,8 @@ class DirectWorkflowTaskSpec(TaskSpec):
self._on_success = self._as_list_of_tuples('on-success') self._on_success = self._as_list_of_tuples('on-success')
self._on_error = self._as_list_of_tuples('on-error') self._on_error = self._as_list_of_tuples('on-error')
def validate(self): def validate_schema(self):
super(DirectWorkflowTaskSpec, self).validate() super(DirectWorkflowTaskSpec, self).validate_schema()
if 'join' in self._data: if 'join' in self._data:
join = self._data.get('join') join = self._data.get('join')

View File

@ -1,4 +1,4 @@
# Copyright 2014 - Mirantis, Inc. # Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc. # Copyright 2015 - StackStorm, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -24,35 +24,19 @@ from mistral.workbook.v2 import tasks
class WorkflowSpec(base.BaseSpec): class WorkflowSpec(base.BaseSpec):
# See http://json-schema.org # See http://json-schema.org
_direct_task_schema = tasks.DirectWorkflowTaskSpec.get_schema( _polymorphic_key = ('type', 'direct')
includes=None)
_reverse_task_schema = tasks.ReverseWorkflowTaskSpec.get_schema(
includes=None)
_task_defaults_schema = task_defaults.TaskDefaultsSpec.get_schema( _task_defaults_schema = task_defaults.TaskDefaultsSpec.get_schema(
includes=None) includes=None)
_schema = { _meta_schema = {
"type": "object", "type": "object",
"properties": { "properties": {
"type": types.WORKFLOW_TYPE, "type": types.WORKFLOW_TYPE,
"task-defaults": _task_defaults_schema, "task-defaults": _task_defaults_schema,
"input": types.UNIQUE_STRING_OR_ONE_KEY_DICT_LIST, "input": types.UNIQUE_STRING_OR_ONE_KEY_DICT_LIST,
"output": types.NONEMPTY_DICT, "output": types.NONEMPTY_DICT,
"vars": types.NONEMPTY_DICT, "vars": types.NONEMPTY_DICT
"tasks": {
"type": "object",
"minProperties": 1,
"patternProperties": {
"^\w+$": {
"anyOf": [
_direct_task_schema,
_reverse_task_schema
]
}
}
},
}, },
"required": ["tasks"], "required": ["tasks"],
"additionalProperties": False "additionalProperties": False
@ -64,7 +48,7 @@ class WorkflowSpec(base.BaseSpec):
self._name = data['name'] self._name = data['name']
self._description = data.get('description') self._description = data.get('description')
self._tags = data.get('tags', []) self._tags = data.get('tags', [])
self._type = data['type'] if 'type' in data else "direct" self._type = data['type'] if 'type' in data else 'direct'
self._input = utils.get_input_dict(data.get('input', [])) self._input = utils.get_input_dict(data.get('input', []))
self._output = data.get('output', {}) self._output = data.get('output', {})
self._vars = data.get('vars', {}) self._vars = data.get('vars', {})
@ -78,8 +62,8 @@ class WorkflowSpec(base.BaseSpec):
tasks.TaskSpecList.get_class(self._type) tasks.TaskSpecList.get_class(self._type)
) )
def validate(self): def validate_schema(self):
super(WorkflowSpec, self).validate() super(WorkflowSpec, self).validate_schema()
if not self._data.get('tasks'): if not self._data.get('tasks'):
raise exc.InvalidModelException( raise exc.InvalidModelException(
@ -90,6 +74,10 @@ class WorkflowSpec(base.BaseSpec):
self.validate_yaql_expr(self._data.get('output', {})) self.validate_yaql_expr(self._data.get('output', {}))
self.validate_yaql_expr(self._data.get('vars', {})) self.validate_yaql_expr(self._data.get('vars', {}))
def validate_semantics(self):
# Doesn't do anything by default.
pass
def get_name(self): def get_name(self):
return self._name return self._name
@ -118,6 +106,133 @@ class WorkflowSpec(base.BaseSpec):
return self._tasks return self._tasks
class DirectWorkflowSpec(WorkflowSpec):
_polymorphic_value = 'direct'
_schema = {
"properties": {
"tasks": {
"type": "object",
"minProperties": 1,
"patternProperties": {
"^\w+$":
tasks.DirectWorkflowTaskSpec.get_schema(includes=None)
}
},
}
}
def validate_semantics(self):
# Check if there are start tasks.
if not self.find_start_tasks():
raise exc.DSLParsingException(
'Failed to find start tasks in direct workflow. '
'There must be at least one task without inbound transition.'
'[workflow_name=%s]' % self._name
)
def find_start_tasks(self):
return [
t_s for t_s in self.get_tasks()
if not self.has_inbound_transitions(t_s)
]
def find_inbound_task_specs(self, task_spec):
return [
t_s for t_s in self.get_tasks()
if self.transition_exists(t_s.get_name(), task_spec.get_name())
]
def find_outbound_task_specs(self, task_spec):
return [
t_s for t_s in self.get_tasks()
if self.transition_exists(task_spec.get_name(), t_s.get_name())
]
def has_inbound_transitions(self, task_spec):
return len(self.find_inbound_task_specs(task_spec)) > 0
def has_outbound_transitions(self, task_spec):
return len(self.find_outbound_task_specs(task_spec)) > 0
def transition_exists(self, from_task_name, to_task_name):
t_names = set()
for tup in self.get_on_error_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_success_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_complete_clause(from_task_name):
t_names.add(tup[0])
return to_task_name in t_names
def get_on_error_clause(self, t_name):
result = self.get_tasks()[t_name].get_on_error()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults:
result = self._remove_task_from_clause(
t_defaults.get_on_error(),
t_name
)
return result
def get_on_success_clause(self, t_name):
result = self.get_tasks()[t_name].get_on_success()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults:
result = self._remove_task_from_clause(
t_defaults.get_on_success(),
t_name
)
return result
def get_on_complete_clause(self, t_name):
result = self.get_tasks()[t_name].get_on_complete()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults:
result = self._remove_task_from_clause(
t_defaults.get_on_complete(),
t_name
)
return result
@staticmethod
def _remove_task_from_clause(on_clause, t_name):
return filter(lambda tup: tup[0] != t_name, on_clause)
class ReverseWorkflowSpec(WorkflowSpec):
_polymorphic_value = 'reverse'
_schema = {
"properties": {
"tasks": {
"type": "object",
"minProperties": 1,
"patternProperties": {
"^\w+$":
tasks.ReverseWorkflowTaskSpec.get_schema(includes=None)
}
},
}
}
class WorkflowSpecList(base.BaseSpecList): class WorkflowSpecList(base.BaseSpecList):
item_class = WorkflowSpec item_class = WorkflowSpec

View File

@ -154,14 +154,13 @@ class WorkflowController(object):
if wf_type == wf_ctrl_cls.__workflow_type__: if wf_type == wf_ctrl_cls.__workflow_type__:
return wf_ctrl_cls return wf_ctrl_cls
msg = 'Failed to find a workflow controller [type=%s]' % wf_type raise exc.NotFoundException(
raise exc.NotFoundException(msg) 'Failed to find a workflow controller [type=%s]' % wf_type
)
@staticmethod @staticmethod
def get_controller(wf_ex, wf_spec=None): def get_controller(wf_ex, wf_spec=None):
if not wf_spec: if not wf_spec:
wf_spec = spec_parser.get_workflow_spec(wf_ex['spec']) wf_spec = spec_parser.get_workflow_spec(wf_ex['spec'])
ctrl_cls = WorkflowController._get_class(wf_spec.get_type()) return WorkflowController._get_class(wf_spec.get_type())(wf_ex)
return ctrl_cls(wf_ex)

View File

@ -1,4 +1,4 @@
# Copyright 2014 - Mirantis, Inc. # Copyright 2015 - Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,7 +47,7 @@ class DirectWorkflowController(base.WorkflowController):
lambda t_e: self._is_upstream_task_execution(task_spec, t_e), lambda t_e: self._is_upstream_task_execution(task_spec, t_e),
wf_utils.find_task_executions_by_specs( wf_utils.find_task_executions_by_specs(
self.wf_ex, self.wf_ex,
self._find_inbound_task_specs(task_spec) self.wf_spec.find_inbound_task_specs(task_spec)
) )
) )
@ -80,19 +80,13 @@ class DirectWorkflowController(base.WorkflowController):
return cmds return cmds
def _find_start_commands(self): def _find_start_commands(self):
t_specs = []
for t_s in self.wf_spec.get_tasks():
if not self._has_inbound_transitions(t_s):
t_specs.append(t_s)
return [ return [
commands.RunTask( commands.RunTask(
self.wf_ex, self.wf_ex,
t_s, t_s,
self._get_task_inbound_context(t_s) self._get_task_inbound_context(t_s)
) )
for t_s in t_specs for t_s in self.wf_spec.find_start_tasks()
] ]
def _find_next_commands_for_task(self, task_ex): def _find_next_commands_for_task(self, task_ex):
@ -121,7 +115,7 @@ class DirectWorkflowController(base.WorkflowController):
) )
# NOTE(xylan): Decide whether or not a join task should run # NOTE(xylan): Decide whether or not a join task should run
# immediately # immediately.
if self._is_unsatisfied_join(cmd): if self._is_unsatisfied_join(cmd):
cmd.wait_flag = True cmd.wait_flag = True
@ -136,35 +130,6 @@ class DirectWorkflowController(base.WorkflowController):
return cmds return cmds
def _has_inbound_transitions(self, task_spec):
return len(self._find_inbound_task_specs(task_spec)) > 0
def _find_inbound_task_specs(self, task_spec):
return [
t_s for t_s in self.wf_spec.get_tasks()
if self._transition_exists(t_s.get_name(), task_spec.get_name())
]
def _find_outbound_task_specs(self, task_spec):
return [
t_s for t_s in self.wf_spec.get_tasks()
if self._transition_exists(task_spec.get_name(), t_s.get_name())
]
def _transition_exists(self, from_task_name, to_task_name):
t_names = set()
for tup in self.get_on_error_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_success_clause(from_task_name):
t_names.add(tup[0])
for tup in self.get_on_complete_clause(from_task_name):
t_names.add(tup[0])
return to_task_name in t_names
# TODO(rakhmerov): Need to refactor this method to be able to pass tasks # TODO(rakhmerov): Need to refactor this method to be able to pass tasks
# whose contexts need to be merged. # whose contexts need to be merged.
def evaluate_workflow_final_context(self): def evaluate_workflow_final_context(self):
@ -179,11 +144,11 @@ class DirectWorkflowController(base.WorkflowController):
return ctx return ctx
def is_error_handled_for(self, task_ex): def is_error_handled_for(self, task_ex):
return bool(self.get_on_error_clause(task_ex.name)) return bool(self.wf_spec.get_on_error_clause(task_ex.name))
def all_errors_handled(self): def all_errors_handled(self):
for t_ex in wf_utils.find_error_task_executions(self.wf_ex): for t_ex in wf_utils.find_error_task_executions(self.wf_ex):
if not self.get_on_error_clause(t_ex.name): if not self.wf_spec.get_on_error_clause(t_ex.name):
return False return False
return True return True
@ -204,52 +169,6 @@ class DirectWorkflowController(base.WorkflowController):
if self.wf_spec.get_tasks()[t_name] if self.wf_spec.get_tasks()[t_name]
]) ])
@staticmethod
def _remove_task_from_clause(on_clause, t_name):
return filter(lambda tup: tup[0] != t_name, on_clause)
def get_on_error_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_error()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_error(),
t_name
)
return result
def get_on_success_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_success()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_success(),
t_name
)
return result
def get_on_complete_clause(self, t_name):
result = self.wf_spec.get_tasks()[t_name].get_on_complete()
if not result:
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
result = self._remove_task_from_clause(
task_defaults.get_on_complete(),
t_name
)
return result
def _find_next_task_names(self, task_ex): def _find_next_task_names(self, task_ex):
t_state = task_ex.state t_state = task_ex.state
t_name = task_ex.name t_name = task_ex.name
@ -260,19 +179,19 @@ class DirectWorkflowController(base.WorkflowController):
if states.is_completed(t_state): if states.is_completed(t_state):
t_names += self._find_next_task_names_for_clause( t_names += self._find_next_task_names_for_clause(
self.get_on_complete_clause(t_name), self.wf_spec.get_on_complete_clause(t_name),
ctx ctx
) )
if t_state == states.ERROR: if t_state == states.ERROR:
t_names += self._find_next_task_names_for_clause( t_names += self._find_next_task_names_for_clause(
self.get_on_error_clause(t_name), self.wf_spec.get_on_error_clause(t_name),
ctx ctx
) )
elif t_state == states.SUCCESS: elif t_state == states.SUCCESS:
t_names += self._find_next_task_names_for_clause( t_names += self._find_next_task_names_for_clause(
self.get_on_success_clause(t_name), self.wf_spec.get_on_success_clause(t_name),
ctx ctx
) )
@ -323,7 +242,7 @@ class DirectWorkflowController(base.WorkflowController):
if not join_expr: if not join_expr:
return False return False
in_task_specs = self._find_inbound_task_specs(task_spec) in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs: if not in_task_specs:
return False return False