diff --git a/mistral/api/controllers/v2/validation.py b/mistral/api/controllers/v2/validation.py index 380aa0990..e97b9752b 100644 --- a/mistral/api/controllers/v2/validation.py +++ b/mistral/api/controllers/v2/validation.py @@ -26,6 +26,7 @@ class SpecValidationController(rest.RestController): def __init__(self, parser): super(SpecValidationController, self).__init__() + self._parse_func = parser @pecan.expose('json') diff --git a/mistral/tests/unit/workbook/v2/base.py b/mistral/tests/unit/workbook/v2/base.py index 652227977..937ba04c1 100644 --- a/mistral/tests/unit/workbook/v2/base.py +++ b/mistral/tests/unit/workbook/v2/base.py @@ -96,9 +96,11 @@ class WorkflowSpecValidationTestCase(base.BaseTest): if not expect_error: return self._spec_parser(dsl_yaml) else: - return self.assertRaises(exc.DSLParsingException, - self._spec_parser, - dsl_yaml) + return self.assertRaises( + exc.DSLParsingException, + self._spec_parser, + dsl_yaml + ) class WorkbookSpecValidationTestCase(WorkflowSpecValidationTestCase): diff --git a/mistral/tests/unit/workbook/v2/test_workflows.py b/mistral/tests/unit/workbook/v2/test_workflows.py index e3d03b017..994f0d95c 100644 --- a/mistral/tests/unit/workbook/v2/test_workflows.py +++ b/mistral/tests/unit/workbook/v2/test_workflows.py @@ -22,12 +22,10 @@ from mistral.tests.unit.workbook.v2 import base from mistral import utils from mistral.workbook.v2 import tasks - LOG = logging.getLogger(__name__) class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): - def test_workflow_types(self): tests = [ ({'type': 'direct'}, False), @@ -62,7 +60,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): tasks.DirectWfTaskSpecList) def test_direct_workflow_invalid_task(self): - overlay = {'test': {'type': 'direct', 'tasks': {}}} + overlay = { + 'test': { + 'type': 'direct', + 'tasks': {} + } + } requires = {'requires': ['echo', 'get']} utils.merge_dicts(overlay['test']['tasks'], {'email': requires}) @@ -71,6 +74,21 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): changes=overlay, expect_error=True) + def test_direct_workflow_no_start_tasks(self): + overlay = { + 'test': { + 'type': 'direct', + 'tasks': { + 'task1': {'on-complete': 'task2'}, + 'task2': {'on-complete': 'task1'} + } + } + } + + self._parse_dsl_spec(add_tasks=False, + changes=overlay, + expect_error=True) + def test_reverse_workflow(self): overlay = {'test': {'type': 'reverse', 'tasks': {}}} require = {'requires': ['echo', 'get']} @@ -142,9 +160,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): for wf_input, expect_error in tests: overlay = {'test': wf_input} - self._parse_dsl_spec(add_tasks=True, - changes=overlay, - expect_error=expect_error) + + self._parse_dsl_spec( + add_tasks=True, + changes=overlay, + expect_error=expect_error + ) def test_outputs(self): tests = [ @@ -160,9 +181,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): for wf_output, expect_error in tests: overlay = {'test': wf_output} - self._parse_dsl_spec(add_tasks=True, - changes=overlay, - expect_error=expect_error) + + self._parse_dsl_spec( + add_tasks=True, + changes=overlay, + expect_error=expect_error + ) def test_vars(self): tests = [ @@ -178,13 +202,18 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): for wf_vars, expect_error in tests: overlay = {'test': wf_vars} - self._parse_dsl_spec(add_tasks=True, - changes=overlay, - expect_error=expect_error) + + self._parse_dsl_spec( + add_tasks=True, + changes=overlay, + expect_error=expect_error + ) def test_tasks_required(self): - exception = self._parse_dsl_spec(add_tasks=False, - expect_error=True) + exception = self._parse_dsl_spec( + add_tasks=False, + expect_error=True + ) self.assertIn("'tasks' is a required property", exception.message) @@ -197,9 +226,12 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): for wf_tasks, expect_error in tests: overlay = {'test': wf_tasks} - self._parse_dsl_spec(add_tasks=False, - changes=overlay, - expect_error=expect_error) + + self._parse_dsl_spec( + add_tasks=False, + changes=overlay, + expect_error=expect_error + ) def test_task_defaults(self): tests = [ @@ -289,9 +321,11 @@ class WorkflowSpecValidation(base.WorkflowSpecValidationTestCase): utils.merge_dicts(overlay['test']['task-defaults'], default) - self._parse_dsl_spec(add_tasks=True, - changes=overlay, - expect_error=expect_error) + self._parse_dsl_spec( + add_tasks=True, + changes=overlay, + expect_error=expect_error + ) def test_invalid_item(self): overlay = {'name': 'invalid'} diff --git a/mistral/tests/unit/workflow/test_direct_workflow.py b/mistral/tests/unit/workflow/test_direct_workflow.py index 4f26f3fef..cfd608e9f 100644 --- a/mistral/tests/unit/workflow/test_direct_workflow.py +++ b/mistral/tests/unit/workflow/test_direct_workflow.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2013 - Mirantis, Inc. +# Copyright 2015 - Mirantis, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +17,7 @@ from oslo_log import log as logging from mistral.db.v2 import api as db_api from mistral.db.v2.sqlalchemy import models +from mistral import exceptions as exc from mistral.tests import base from mistral.workbook import parser as spec_parser from mistral.workflow import direct_workflow as d_wf @@ -26,52 +25,24 @@ from mistral.workflow import states LOG = logging.getLogger(__name__) -WB = """ ---- -version: '2.0' - -name: my_wb - -workflows: - wf: - type: direct - - tasks: - task1: - action: std.echo output="Hey" - publish: - res1: <% $.task1 %> - on-complete: - - task2: <% $.res1 = 'Hey' %> - - task3: <% $.res1 = 'Not Hey' %> - - task2: - action: std.echo output="Hi" - - task3: - action: std.echo output="Hoy" -""" - class DirectWorkflowControllerTest(base.DbTestCase): - def setUp(self): - super(DirectWorkflowControllerTest, self).setUp() - - wb_spec = spec_parser.get_workbook_spec_from_yaml(WB) + def _prepare_test(self, wf_text): + wf_spec = spec_parser.get_workflow_list_spec_from_yaml(wf_text)[0] wf_ex = models.WorkflowExecution() wf_ex.update({ 'id': '1-2-3-4', - 'spec': wb_spec.get_workflows().get('wf').to_dict(), + 'spec': wf_spec.to_dict(), 'state': states.RUNNING }) self.wf_ex = wf_ex - self.wb_spec = wb_spec + self.wf_spec = wf_spec self.wf_ctrl = d_wf.DirectWorkflowController(wf_ex) def _create_task_execution(self, name, state): - tasks_spec = self.wb_spec.get_workflows()['wf'].get_tasks() + tasks_spec = self.wf_spec.get_tasks() task_ex = models.TaskExecution( id=self.getUniqueString('id'), @@ -86,6 +57,30 @@ class DirectWorkflowControllerTest(base.DbTestCase): @mock.patch.object(db_api, 'get_task_execution') def test_continue_workflow(self, get_task_execution): + wf_text = """--- + version: '2.0' + + wf: + type: direct + + tasks: + task1: + action: std.echo output="Hey" + publish: + res1: <% $.task1 %> + on-complete: + - task2: <% $.res1 = 'Hey' %> + - task3: <% $.res1 = 'Not Hey' %> + + task2: + action: std.echo output="Hi" + + task3: + action: std.echo output="Hoy" + """ + + self._prepare_test(wf_text) + # Workflow execution is in initial step. No running tasks. cmds = self.wf_ctrl.continue_workflow() @@ -142,3 +137,23 @@ class DirectWorkflowControllerTest(base.DbTestCase): task2_ex.processed = True self.assertEqual(0, len(cmds)) + + def test_continue_workflow_no_start_tasks(self): + wf_text = """--- + version: '2.0' + + wf: + description: > + Invalid workflow that doesn't have start tasks (tasks with + no inbound connections). + type: direct + + tasks: + task1: + on-complete: task2 + + task2: + on-complete: task1 + """ + + self.assertRaises(exc.DSLParsingException, self._prepare_test, wf_text) diff --git a/mistral/workbook/base.py b/mistral/workbook/base.py index 651ed6849..9a0fe6ab0 100644 --- a/mistral/workbook/base.py +++ b/mistral/workbook/base.py @@ -1,4 +1,4 @@ -# Copyright 2013 - Mirantis, Inc. +# Copyright 2015 - Mirantis, Inc. # Copyright 2015 - StackStorm, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -44,19 +44,94 @@ ALL = ( PARAMS_PTRN = re.compile("([-_\w]+)=(%s)" % "|".join(ALL)) +def instantiate_spec(spec_cls, data): + """Instantiates specification accounting for specification hierarchies. + + :param spec_cls: Specification concrete or base class. In case if base + class or the hierarchy is provided this method relies on attributes + _polymorphic_key and _polymorphic_value in order to find a concrete + class that needs to be instantiated. + :param data: Raw specification data as a dictionary. + """ + + if issubclass(spec_cls, BaseSpecList): + # Ignore polymorphic search for specification lists because + # it doesn't make sense for them. + return spec_cls(data) + + if not hasattr(spec_cls, '_polymorphic_key'): + spec = spec_cls(data) + + spec.validate_semantics() + + return spec + + key = spec_cls._polymorphic_key + + if not isinstance(key, tuple): + key_name = key + key_default = None + else: + key_name = key[0] + key_default = key[1] + + for cls in utils.iter_subclasses(spec_cls): + if not hasattr(cls, '_polymorphic_value'): + raise exc.DSLParsingException( + "Class '%s' is expected to have attribute '_polymorphic_value'" + " because it's a part of specification hierarchy inherited " + "from class '%s'." % (cls, spec_cls) + ) + + if cls._polymorphic_value == data.get(key_name, key_default): + spec = cls(data) + + spec.validate_semantics() + + return cls(data) + + raise exc.DSLParsingException( + 'Failed to find a specification class to instantiate ' + '[spec_cls=%s, data=%s]' % (spec_cls, data) + ) + + class BaseSpec(object): + """Base class for all DSL specifications. + + It represents a DSL entity such as workflow or task as a python object + providing more convenient API to analyse DSL than just working with raw + data in form of a dictionary. Specification classes also implement + all required validation logic by overriding instance method 'validate()'. + + Note that the specification mechanism allows to have polymorphic entities + in DSL. For example, if we find it more convenient to have separate + specification classes for different types of workflow (i.e. 'direct' and + 'reverse') we can do so. In this case, in order to instantiate them + correctly method 'instantiate_spec' must always be used where argument + 'spec_cls' must be a root class of the specification hierarchy containing + class attribute '_polymorhpic_key' pointing to a key in raw data relying + on which we can find a concrete class. Concrete classes then must all have + attribute '_polymorhpic_value' corresponding to a value in a raw data. + Attribute '_polymorhpic_key' can be either a string or a tuple of size two + where the first value is a key name itself and the second value is a + default polymorphic value that must be used if raw data doesn't contain + a configured key at all. An example of this situation is when we don't + specify a workflow type in DSL. In this case, we assume it's 'direct'. + """ + # See http://json-schema.org _schema = { - "type": "object" + 'type': 'object' } _meta_schema = { - "type": "object" + 'type': 'object' } _definitions = {} - _version = "1.0" + _version = '1.0' @classmethod def get_schema(cls, includes=['meta', 'definitions']): @@ -65,32 +140,60 @@ class BaseSpec(object): schema['properties'] = utils.merge_dicts( schema.get('properties', {}), cls._meta_schema.get('properties', {}), - overwrite=False) + overwrite=False + ) if includes and 'meta' in includes: schema['required'] = list( set(schema.get('required', []) + - cls._meta_schema.get('required', []))) + cls._meta_schema.get('required', [])) + ) if includes and 'definitions' in includes: schema['definitions'] = utils.merge_dicts( schema.get('definitions', {}), cls._definitions, - overwrite=False) + overwrite=False + ) return schema def __init__(self, data): self._data = data - self.validate() + self.validate_schema() + + def validate_schema(self): + """Validates DSL entity schema that this specification represents. + + By default, this method just validate schema of DSL entity that this + specification represents using "_schema" class attribute. + Additionally, child classes may implement additional logic to validate + more specific things like YAQL expressions in their fields. + + Note that this method is called before construction of specification + fields and validation logic should only rely on raw data provided as + a dictionary accessible through '_data' instance field. + """ - def validate(self): try: jsonschema.validate(self._data, self.get_schema()) except jsonschema.ValidationError as e: raise exc.InvalidModelException("Invalid DSL: %s" % e) + def validate_semantics(self): + """Validates semantics of specification object. + + Child classes may implement validation logic to check things like + integrity of corresponding data structure (e.g. task graph) or + other things that can't be expressed in JSON schema. + + This method is called after specification has been built (i.e. + its initializer has finished it's work) so that validation logic + can rely on initialized specification fields. + """ + pass + def validate_yaql_expr(self, dsl_part): if isinstance(dsl_part, six.string_types): expr.validate(dsl_part) @@ -106,7 +209,7 @@ class BaseSpec(object): def _spec_property(self, prop_name, spec_cls): prop_val = self._data.get(prop_name) - return spec_cls(prop_val) if prop_val else None + return instantiate_spec(spec_cls, prop_val) if prop_val else None def _group_spec(self, spec_cls, *prop_names): if not prop_names: @@ -120,7 +223,7 @@ class BaseSpec(object): if prop_val: data[prop_name] = prop_val - return spec_cls(data) + return instantiate_spec(spec_cls, data) def _inject_version(self, prop_names): for prop_name in prop_names: @@ -139,8 +242,10 @@ class BaseSpec(object): return prop_val elif isinstance(prop_val, list): result = {} + for t in prop_val: result.update(t if isinstance(t, dict) else {t: ''}) + return result elif isinstance(prop_val, six.string_types): return {prop_val: ''} @@ -172,6 +277,7 @@ class BaseSpec(object): cmd = cmd_matcher.group() params = {} + for k, v in re.findall(PARAMS_PTRN, cmd_str): # Remove embracing quotes. v = v.strip() @@ -218,7 +324,7 @@ class BaseListSpec(BaseSpec): if k != 'version': v['name'] = k self._inject_version([k]) - self.items.append(self.item_class(v)) + self.items.append(instantiate_spec(self.item_class, v)) def validate(self): super(BaseListSpec, self).validate() @@ -232,6 +338,12 @@ class BaseListSpec(BaseSpec): def get_items(self): return self.items + def __getitem__(self, idx): + return self.items[idx] + + def __len__(self): + return len(self.items) + class BaseSpecList(object): item_class = None @@ -245,7 +357,7 @@ class BaseSpecList(object): if k != 'version': v['name'] = k v['version'] = self._version - self.items[k] = self.item_class(v) + self.items[k] = instantiate_spec(self.item_class, v) def item_keys(self): return self.items.keys() diff --git a/mistral/workbook/parser.py b/mistral/workbook/parser.py index 2898eb1dc..983686f24 100644 --- a/mistral/workbook/parser.py +++ b/mistral/workbook/parser.py @@ -17,6 +17,7 @@ import yaml from yaml import error from mistral import exceptions as exc +from mistral.workbook import base from mistral.workbook.v2 import actions as actions_v2 from mistral.workbook.v2 import tasks as tasks_v2 from mistral.workbook.v2 import workbook as wb_v2 @@ -61,7 +62,7 @@ def _get_spec_version(spec_dict): def get_workbook_spec(spec_dict): if _get_spec_version(spec_dict) == V2_0: - return wb_v2.WorkbookSpec(spec_dict) + return base.instantiate_spec(wb_v2.WorkbookSpec, spec_dict) return None @@ -72,7 +73,7 @@ def get_workbook_spec_from_yaml(text): def get_action_spec(spec_dict): if _get_spec_version(spec_dict) == V2_0: - return actions_v2.ActionSpec(spec_dict) + return base.instantiate_spec(actions_v2.ActionSpec, spec_dict) return None @@ -86,7 +87,7 @@ def get_action_spec_from_yaml(text, action_name): def get_action_list_spec(spec_dict): - return actions_v2.ActionListSpec(spec_dict) + return base.instantiate_spec(actions_v2.ActionListSpec, spec_dict) def get_action_list_spec_from_yaml(text): @@ -95,13 +96,13 @@ def get_action_list_spec_from_yaml(text): def get_workflow_spec(spec_dict): if _get_spec_version(spec_dict) == V2_0: - return wf_v2.WorkflowSpec(spec_dict) + return base.instantiate_spec(wf_v2.WorkflowSpec, spec_dict) return None def get_workflow_list_spec(spec_dict): - return wf_v2.WorkflowListSpec(spec_dict) + return base.instantiate_spec(wf_v2.WorkflowListSpec, spec_dict) def get_workflow_spec_from_yaml(text): diff --git a/mistral/workbook/v2/actions.py b/mistral/workbook/v2/actions.py index 856789688..b97acb695 100644 --- a/mistral/workbook/v2/actions.py +++ b/mistral/workbook/v2/actions.py @@ -49,8 +49,8 @@ class ActionSpec(base.BaseSpec): utils.merge_dicts(self._base_input, _input) - def validate(self): - super(ActionSpec, self).validate() + def validate_schema(self): + super(ActionSpec, self).validate_schema() # Validate YAQL expressions. inline_params = self._parse_cmd_and_input(self._data.get('base'))[1] diff --git a/mistral/workbook/v2/policies.py b/mistral/workbook/v2/policies.py index 66716af89..e817e43eb 100644 --- a/mistral/workbook/v2/policies.py +++ b/mistral/workbook/v2/policies.py @@ -55,8 +55,8 @@ class PoliciesSpec(base.BaseSpec): self._pause_before = data.get('pause-before', False) self._concurrency = data.get('concurrency', 0) - def validate(self): - super(PoliciesSpec, self).validate() + def validate_schema(self): + super(PoliciesSpec, self).validate_schema() # Validate YAQL expressions. self.validate_yaql_expr(self._data.get('wait-before', 0)) diff --git a/mistral/workbook/v2/retry_policy.py b/mistral/workbook/v2/retry_policy.py index aa363a9a3..aa52f073b 100644 --- a/mistral/workbook/v2/retry_policy.py +++ b/mistral/workbook/v2/retry_policy.py @@ -70,8 +70,8 @@ class RetrySpec(base.BaseSpec): return retry - def validate(self): - super(RetrySpec, self).validate() + def validate_schema(self): + super(RetrySpec, self).validate_schema() # Validate YAQL expressions. self.validate_yaql_expr(self._data.get('count')) diff --git a/mistral/workbook/v2/task_defaults.py b/mistral/workbook/v2/task_defaults.py index 58a053881..86e06329a 100644 --- a/mistral/workbook/v2/task_defaults.py +++ b/mistral/workbook/v2/task_defaults.py @@ -72,8 +72,8 @@ class TaskDefaultsSpec(base.BaseSpec): self._on_error = self._as_list_of_tuples("on-error") self._requires = data.get('requires', []) - def validate(self): - super(TaskDefaultsSpec, self).validate() + def validate_schema(self): + super(TaskDefaultsSpec, self).validate_schema() # Validate YAQL expressions. self._validate_transitions('on-complete') diff --git a/mistral/workbook/v2/tasks.py b/mistral/workbook/v2/tasks.py index e7930619e..5ce81d01b 100644 --- a/mistral/workbook/v2/tasks.py +++ b/mistral/workbook/v2/tasks.py @@ -86,8 +86,8 @@ class TaskSpec(base.BaseSpec): self._inject_type() self._process_action_and_workflow() - def validate(self): - super(TaskSpec, self).validate() + def validate_schema(self): + super(TaskSpec, self).validate_schema() action = self._data.get('action') workflow = self._data.get('workflow') @@ -234,8 +234,8 @@ class DirectWorkflowTaskSpec(TaskSpec): self._on_success = self._as_list_of_tuples('on-success') self._on_error = self._as_list_of_tuples('on-error') - def validate(self): - super(DirectWorkflowTaskSpec, self).validate() + def validate_schema(self): + super(DirectWorkflowTaskSpec, self).validate_schema() if 'join' in self._data: join = self._data.get('join') diff --git a/mistral/workbook/v2/workflows.py b/mistral/workbook/v2/workflows.py index 6e2ff8a65..9bcfce471 100644 --- a/mistral/workbook/v2/workflows.py +++ b/mistral/workbook/v2/workflows.py @@ -1,4 +1,4 @@ -# Copyright 2014 - Mirantis, Inc. +# Copyright 2015 - Mirantis, Inc. # Copyright 2015 - StackStorm, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,35 +24,19 @@ from mistral.workbook.v2 import tasks class WorkflowSpec(base.BaseSpec): # See http://json-schema.org - _direct_task_schema = tasks.DirectWorkflowTaskSpec.get_schema( - includes=None) - - _reverse_task_schema = tasks.ReverseWorkflowTaskSpec.get_schema( - includes=None) + _polymorphic_key = ('type', 'direct') _task_defaults_schema = task_defaults.TaskDefaultsSpec.get_schema( includes=None) - _schema = { + _meta_schema = { "type": "object", "properties": { "type": types.WORKFLOW_TYPE, "task-defaults": _task_defaults_schema, "input": types.UNIQUE_STRING_OR_ONE_KEY_DICT_LIST, "output": types.NONEMPTY_DICT, - "vars": types.NONEMPTY_DICT, - "tasks": { - "type": "object", - "minProperties": 1, - "patternProperties": { - "^\w+$": { - "anyOf": [ - _direct_task_schema, - _reverse_task_schema - ] - } - } - }, + "vars": types.NONEMPTY_DICT }, "required": ["tasks"], "additionalProperties": False @@ -64,7 +48,7 @@ class WorkflowSpec(base.BaseSpec): self._name = data['name'] self._description = data.get('description') self._tags = data.get('tags', []) - self._type = data['type'] if 'type' in data else "direct" + self._type = data['type'] if 'type' in data else 'direct' self._input = utils.get_input_dict(data.get('input', [])) self._output = data.get('output', {}) self._vars = data.get('vars', {}) @@ -78,8 +62,8 @@ class WorkflowSpec(base.BaseSpec): tasks.TaskSpecList.get_class(self._type) ) - def validate(self): - super(WorkflowSpec, self).validate() + def validate_schema(self): + super(WorkflowSpec, self).validate_schema() if not self._data.get('tasks'): raise exc.InvalidModelException( @@ -90,6 +74,10 @@ class WorkflowSpec(base.BaseSpec): self.validate_yaql_expr(self._data.get('output', {})) self.validate_yaql_expr(self._data.get('vars', {})) + def validate_semantics(self): + # Doesn't do anything by default. + pass + def get_name(self): return self._name @@ -118,6 +106,133 @@ class WorkflowSpec(base.BaseSpec): return self._tasks +class DirectWorkflowSpec(WorkflowSpec): + _polymorphic_value = 'direct' + + _schema = { + "properties": { + "tasks": { + "type": "object", + "minProperties": 1, + "patternProperties": { + "^\w+$": + tasks.DirectWorkflowTaskSpec.get_schema(includes=None) + } + }, + } + } + + def validate_semantics(self): + # Check if there are start tasks. + if not self.find_start_tasks(): + raise exc.DSLParsingException( + 'Failed to find start tasks in direct workflow. ' + 'There must be at least one task without inbound transition.' + '[workflow_name=%s]' % self._name + ) + + def find_start_tasks(self): + return [ + t_s for t_s in self.get_tasks() + if not self.has_inbound_transitions(t_s) + ] + + def find_inbound_task_specs(self, task_spec): + return [ + t_s for t_s in self.get_tasks() + if self.transition_exists(t_s.get_name(), task_spec.get_name()) + ] + + def find_outbound_task_specs(self, task_spec): + return [ + t_s for t_s in self.get_tasks() + if self.transition_exists(task_spec.get_name(), t_s.get_name()) + ] + + def has_inbound_transitions(self, task_spec): + return len(self.find_inbound_task_specs(task_spec)) > 0 + + def has_outbound_transitions(self, task_spec): + return len(self.find_outbound_task_specs(task_spec)) > 0 + + def transition_exists(self, from_task_name, to_task_name): + t_names = set() + + for tup in self.get_on_error_clause(from_task_name): + t_names.add(tup[0]) + + for tup in self.get_on_success_clause(from_task_name): + t_names.add(tup[0]) + + for tup in self.get_on_complete_clause(from_task_name): + t_names.add(tup[0]) + + return to_task_name in t_names + + def get_on_error_clause(self, t_name): + result = self.get_tasks()[t_name].get_on_error() + + if not result: + t_defaults = self.get_task_defaults() + + if t_defaults: + result = self._remove_task_from_clause( + t_defaults.get_on_error(), + t_name + ) + + return result + + def get_on_success_clause(self, t_name): + result = self.get_tasks()[t_name].get_on_success() + + if not result: + t_defaults = self.get_task_defaults() + + if t_defaults: + result = self._remove_task_from_clause( + t_defaults.get_on_success(), + t_name + ) + + return result + + def get_on_complete_clause(self, t_name): + result = self.get_tasks()[t_name].get_on_complete() + + if not result: + t_defaults = self.get_task_defaults() + + if t_defaults: + result = self._remove_task_from_clause( + t_defaults.get_on_complete(), + t_name + ) + + return result + + @staticmethod + def _remove_task_from_clause(on_clause, t_name): + return filter(lambda tup: tup[0] != t_name, on_clause) + + +class ReverseWorkflowSpec(WorkflowSpec): + _polymorphic_value = 'reverse' + + _schema = { + "properties": { + "tasks": { + "type": "object", + "minProperties": 1, + "patternProperties": { + "^\w+$": + tasks.ReverseWorkflowTaskSpec.get_schema(includes=None) + } + }, + } + } + + class WorkflowSpecList(base.BaseSpecList): item_class = WorkflowSpec diff --git a/mistral/workflow/base.py b/mistral/workflow/base.py index 330a98d69..4e7e96f6c 100644 --- a/mistral/workflow/base.py +++ b/mistral/workflow/base.py @@ -154,14 +154,13 @@ class WorkflowController(object): if wf_type == wf_ctrl_cls.__workflow_type__: return wf_ctrl_cls - msg = 'Failed to find a workflow controller [type=%s]' % wf_type - raise exc.NotFoundException(msg) + raise exc.NotFoundException( + 'Failed to find a workflow controller [type=%s]' % wf_type + ) @staticmethod def get_controller(wf_ex, wf_spec=None): if not wf_spec: wf_spec = spec_parser.get_workflow_spec(wf_ex['spec']) - ctrl_cls = WorkflowController._get_class(wf_spec.get_type()) - - return ctrl_cls(wf_ex) + return WorkflowController._get_class(wf_spec.get_type())(wf_ex) diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 191ad11c9..f1219bd1e 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -1,4 +1,4 @@ -# Copyright 2014 - Mirantis, Inc. +# Copyright 2015 - Mirantis, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ class DirectWorkflowController(base.WorkflowController): lambda t_e: self._is_upstream_task_execution(task_spec, t_e), wf_utils.find_task_executions_by_specs( self.wf_ex, - self._find_inbound_task_specs(task_spec) + self.wf_spec.find_inbound_task_specs(task_spec) ) ) @@ -80,19 +80,13 @@ class DirectWorkflowController(base.WorkflowController): return cmds def _find_start_commands(self): - t_specs = [] - - for t_s in self.wf_spec.get_tasks(): - if not self._has_inbound_transitions(t_s): - t_specs.append(t_s) - return [ commands.RunTask( self.wf_ex, t_s, self._get_task_inbound_context(t_s) ) - for t_s in t_specs + for t_s in self.wf_spec.find_start_tasks() ] def _find_next_commands_for_task(self, task_ex): @@ -121,7 +115,7 @@ class DirectWorkflowController(base.WorkflowController): ) # NOTE(xylan): Decide whether or not a join task should run - # immediately + # immediately. if self._is_unsatisfied_join(cmd): cmd.wait_flag = True @@ -136,35 +130,6 @@ class DirectWorkflowController(base.WorkflowController): return cmds - def _has_inbound_transitions(self, task_spec): - return len(self._find_inbound_task_specs(task_spec)) > 0 - - def _find_inbound_task_specs(self, task_spec): - return [ - t_s for t_s in self.wf_spec.get_tasks() - if self._transition_exists(t_s.get_name(), task_spec.get_name()) - ] - - def _find_outbound_task_specs(self, task_spec): - return [ - t_s for t_s in self.wf_spec.get_tasks() - if self._transition_exists(task_spec.get_name(), t_s.get_name()) - ] - - def _transition_exists(self, from_task_name, to_task_name): - t_names = set() - - for tup in self.get_on_error_clause(from_task_name): - t_names.add(tup[0]) - - for tup in self.get_on_success_clause(from_task_name): - t_names.add(tup[0]) - - for tup in self.get_on_complete_clause(from_task_name): - t_names.add(tup[0]) - - return to_task_name in t_names - # TODO(rakhmerov): Need to refactor this method to be able to pass tasks # whose contexts need to be merged. def evaluate_workflow_final_context(self): @@ -179,11 +144,11 @@ class DirectWorkflowController(base.WorkflowController): return ctx def is_error_handled_for(self, task_ex): - return bool(self.get_on_error_clause(task_ex.name)) + return bool(self.wf_spec.get_on_error_clause(task_ex.name)) def all_errors_handled(self): for t_ex in wf_utils.find_error_task_executions(self.wf_ex): - if not self.get_on_error_clause(t_ex.name): + if not self.wf_spec.get_on_error_clause(t_ex.name): return False return True @@ -204,52 +169,6 @@ class DirectWorkflowController(base.WorkflowController): if self.wf_spec.get_tasks()[t_name] ]) - @staticmethod - def _remove_task_from_clause(on_clause, t_name): - return filter(lambda tup: tup[0] != t_name, on_clause) - - def get_on_error_clause(self, t_name): - result = self.wf_spec.get_tasks()[t_name].get_on_error() - - if not result: - task_defaults = self.wf_spec.get_task_defaults() - - if task_defaults: - result = self._remove_task_from_clause( - task_defaults.get_on_error(), - t_name - ) - - return result - - def get_on_success_clause(self, t_name): - result = self.wf_spec.get_tasks()[t_name].get_on_success() - - if not result: - task_defaults = self.wf_spec.get_task_defaults() - - if task_defaults: - result = self._remove_task_from_clause( - task_defaults.get_on_success(), - t_name - ) - - return result - - def get_on_complete_clause(self, t_name): - result = self.wf_spec.get_tasks()[t_name].get_on_complete() - - if not result: - task_defaults = self.wf_spec.get_task_defaults() - - if task_defaults: - result = self._remove_task_from_clause( - task_defaults.get_on_complete(), - t_name - ) - - return result - def _find_next_task_names(self, task_ex): t_state = task_ex.state t_name = task_ex.name @@ -260,19 +179,19 @@ class DirectWorkflowController(base.WorkflowController): if states.is_completed(t_state): t_names += self._find_next_task_names_for_clause( - self.get_on_complete_clause(t_name), + self.wf_spec.get_on_complete_clause(t_name), ctx ) if t_state == states.ERROR: t_names += self._find_next_task_names_for_clause( - self.get_on_error_clause(t_name), + self.wf_spec.get_on_error_clause(t_name), ctx ) elif t_state == states.SUCCESS: t_names += self._find_next_task_names_for_clause( - self.get_on_success_clause(t_name), + self.wf_spec.get_on_success_clause(t_name), ctx ) @@ -323,7 +242,7 @@ class DirectWorkflowController(base.WorkflowController): if not join_expr: return False - in_task_specs = self._find_inbound_task_specs(task_spec) + in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec) if not in_task_specs: return False