diff --git a/mistral/tests/unit/engine/test_direct_workflow.py b/mistral/tests/unit/engine/test_direct_workflow.py index 6394c85b..f321b4f8 100644 --- a/mistral/tests/unit/engine/test_direct_workflow.py +++ b/mistral/tests/unit/engine/test_direct_workflow.py @@ -91,26 +91,6 @@ class DirectWorkflowEngineTest(base.EngineTestCase): self.assertTrue(wf_ex.state, states.ERROR) - def test_direct_workflow_wrong_task_name(self): - wf_text = """ - version: '2.0' - - wf: - tasks: - task1: - action: std.echo output="Echo" - on-success: - - wrong_name - """ - - wf_service.create_workflows(wf_text) - wf_ex = self.engine.start_workflow('wf', {}) - - self._await(lambda: self.is_execution_error(wf_ex.id)) - wf_ex = db_api.get_workflow_execution(wf_ex.id) - - self.assertIn("'wrong_name' not found", wf_ex.state_info) - def test_direct_workflow_change_state_after_success(self): wf_text = """ version: '2.0' @@ -342,3 +322,25 @@ class DirectWorkflowEngineTest(base.EngineTestCase): wf_ex = self.engine.start_workflow('wf', {}) self._await(lambda: self.is_execution_success(wf_ex.id)) + + def test_inconsistent_task_names(self): + wf_text = """ + version: '2.0' + + wf: + tasks: + task1: + action: std.noop + on-success: task3 + + task2: + action: std.noop + """ + + exception = self.assertRaises( + exc.InvalidModelException, + wf_service.create_workflows, + wf_text + ) + + self.assertIn("Task 'task3' not found", exception.message) diff --git a/mistral/tests/unit/engine/test_reverse_workflow.py b/mistral/tests/unit/engine/test_reverse_workflow.py index a354ea21..1c52cc14 100644 --- a/mistral/tests/unit/engine/test_reverse_workflow.py +++ b/mistral/tests/unit/engine/test_reverse_workflow.py @@ -137,28 +137,6 @@ class ReverseWorkflowEngineTest(base.EngineTestCase): self.assertDictEqual({'result2': 'a & b'}, task2_ex.published) - def test_reverse_workflow_wrong_task_name(self): - wf_text = """--- - version: '2.0' - wf_wrong_task: - type: reverse - tasks: - task2: - requires: [wrong_name] - """ - - wf_service.create_workflows(wf_text) - - exception = self.assertRaises( - exc.WorkflowException, - self.engine.start_workflow, - "wf_wrong_task", - {}, - task_name='task2' - ) - - self.assertIn("wrong_name", exception.message) - def test_one_line_requires_syntax(self): wf_input = {'param1': 'a', 'param2': 'b'} @@ -176,3 +154,27 @@ class ReverseWorkflowEngineTest(base.EngineTestCase): self._assert_single_item(tasks, name='task4', state=states.SUCCESS) self._assert_single_item(tasks, name='task3', state=states.SUCCESS) + + def test_inconsistent_task_names(self): + wf_text = """ + version: '2.0' + + wf: + type: reverse + + tasks: + task2: + action: std.noop + + task3: + action: std.noop + requires: [task1] + """ + + exception = self.assertRaises( + exc.InvalidModelException, + wf_service.create_workflows, + wf_text + ) + + self.assertIn("Task 'task1' not found", exception.message) diff --git a/mistral/workbook/base.py b/mistral/workbook/base.py index b90f8be8..470fce80 100644 --- a/mistral/workbook/base.py +++ b/mistral/workbook/base.py @@ -326,8 +326,8 @@ class BaseListSpec(BaseSpec): self._inject_version([k]) self.items.append(instantiate_spec(self.item_class, v)) - def validate(self): - super(BaseListSpec, self).validate() + def validate_schema(self): + super(BaseListSpec, self).validate_schema() if len(self._data.keys()) < 2: raise exc.InvalidModelException( diff --git a/mistral/workbook/v2/tasks.py b/mistral/workbook/v2/tasks.py index c213efd8..3a89610d 100644 --- a/mistral/workbook/v2/tasks.py +++ b/mistral/workbook/v2/tasks.py @@ -29,6 +29,12 @@ from mistral.workbook.v2 import policies WITH_ITEMS_PTRN = re.compile( "\s*([\w\d_\-]+)\s*in\s*(\[.+\]|%s)" % expr.INLINE_YAQL_REGEXP ) +RESERVED_TASK_NAMES = [ + 'noop', + 'fail', + 'succeed', + 'pause' +] class TaskSpec(base.BaseSpec): diff --git a/mistral/workbook/v2/workflows.py b/mistral/workbook/v2/workflows.py index ddb26374..01fbc810 100644 --- a/mistral/workbook/v2/workflows.py +++ b/mistral/workbook/v2/workflows.py @@ -81,6 +81,20 @@ class WorkflowSpec(base.BaseSpec): # Doesn't do anything by default. pass + def _validate_task_link(self, task_name, allow_engine_cmds=True): + valid_task = self._task_exists(task_name) + + if allow_engine_cmds: + valid_task |= task_name in tasks.RESERVED_TASK_NAMES + + if not valid_task: + raise exc.InvalidModelException( + "Task '%s' not found." % task_name + ) + + def _task_exists(self, task_name): + return self.get_tasks()[task_name] is not None + def get_name(self): return self._name @@ -134,6 +148,15 @@ class DirectWorkflowSpec(WorkflowSpec): '[workflow_name=%s]' % self._name ) + self._check_workflow_integrity() + + def _check_workflow_integrity(self): + for t_s in self.get_tasks(): + out_task_names = self.find_outbound_task_names(t_s.get_name()) + + for out_t_name in out_task_names: + self._validate_task_link(out_t_name) + def find_start_tasks(self): return [ t_s for t_s in self.get_tasks() @@ -158,18 +181,23 @@ class DirectWorkflowSpec(WorkflowSpec): def has_outbound_transitions(self, task_spec): return len(self.find_outbound_task_specs(task_spec)) > 0 - def transition_exists(self, from_task_name, to_task_name): + def find_outbound_task_names(self, task_name): t_names = set() - for tup in self.get_on_error_clause(from_task_name): + for tup in self.get_on_error_clause(task_name): t_names.add(tup[0]) - for tup in self.get_on_success_clause(from_task_name): + for tup in self.get_on_success_clause(task_name): t_names.add(tup[0]) - for tup in self.get_on_complete_clause(from_task_name): + for tup in self.get_on_complete_clause(task_name): t_names.add(tup[0]) + return t_names + + def transition_exists(self, from_task_name, to_task_name): + t_names = self.find_outbound_task_names(from_task_name) + return to_task_name in t_names def get_on_error_clause(self, t_name): @@ -235,6 +263,26 @@ class ReverseWorkflowSpec(WorkflowSpec): } } + def validate_semantics(self): + self._check_workflow_integrity() + + def _check_workflow_integrity(self): + for t_s in self.get_tasks(): + for req in self.get_task_requires(t_s): + self._validate_task_link(req, allow_engine_cmds=False) + + def get_task_requires(self, task_spec): + requires = set(task_spec.get_requires()) + + defaults = self.get_task_defaults() + + if defaults: + requires |= set(defaults.get_requires()) + + requires.discard(task_spec.get_name()) + + return list(requires) + class WorkflowSpecList(base.BaseSpecList): item_class = WorkflowSpec diff --git a/mistral/workflow/commands.py b/mistral/workflow/commands.py index f6786b93..5fa8f1e0 100644 --- a/mistral/workflow/commands.py +++ b/mistral/workflow/commands.py @@ -131,12 +131,14 @@ class PauseWorkflow(SetWorkflowState): return "Pause [workflow=%s]" % self.wf_ex.name -RESERVED_CMDS = { - 'noop': Noop, - 'fail': FailWorkflow, - 'succeed': SucceedWorkflow, - 'pause': PauseWorkflow -} +RESERVED_CMDS = dict(zip( + tasks.RESERVED_TASK_NAMES, [ + Noop, + FailWorkflow, + SucceedWorkflow, + PauseWorkflow + ] +)) def get_command_class(cmd_name): diff --git a/mistral/workflow/reverse_workflow.py b/mistral/workflow/reverse_workflow.py index d1881412..f422dca5 100644 --- a/mistral/workflow/reverse_workflow.py +++ b/mistral/workflow/reverse_workflow.py @@ -73,7 +73,7 @@ class ReverseWorkflowController(base.WorkflowController): def _get_upstream_task_executions(self, task_spec): t_specs = [ self.wf_spec.get_tasks()[t_name] - for t_name in self._get_task_requires(task_spec) + for t_name in self.wf_spec.get_task_requires(task_spec) or [] ] @@ -120,22 +120,11 @@ class ReverseWorkflowController(base.WorkflowController): if self._is_satisfied_task(t_s) ] - def _task_exists(self, task_name): - return self.wf_spec.get_tasks()[task_name] is not None - def _is_satisfied_task(self, task_spec): - task_requires = self._get_task_requires(task_spec) - - for req in task_requires: - if not self._task_exists(req): - raise exc.WorkflowException( - "Task '%s' not found." % req - ) - if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec): return False - if not self._get_task_requires(task_spec): + if not self.wf_spec.get_task_requires(task_spec): return True success_t_names = set() @@ -144,7 +133,9 @@ class ReverseWorkflowController(base.WorkflowController): if t_ex.state == states.SUCCESS: success_t_names.add(t_ex.name) - return not (set(self._get_task_requires(task_spec)) - success_t_names) + return not ( + set(self.wf_spec.get_task_requires(task_spec)) - success_t_names + ) def _build_graph(self, tasks_spec): graph = nx.DiGraph() @@ -161,7 +152,7 @@ class ReverseWorkflowController(base.WorkflowController): return graph def _get_dependency_tasks(self, tasks_spec, task_spec): - dep_task_names = self._get_task_requires(task_spec) + dep_task_names = self.wf_spec.get_task_requires(task_spec) if len(dep_task_names) == 0: return [] @@ -174,15 +165,3 @@ class ReverseWorkflowController(base.WorkflowController): dep_t_specs.add(t_spec) return dep_t_specs - - def _get_task_requires(self, task_spec): - requires = set(task_spec.get_requires()) - - task_defaults = self.wf_spec.get_task_defaults() - - if task_defaults: - requires |= set(task_defaults.get_requires()) - - requires.discard(task_spec.get_name()) - - return list(requires)