Adding validation of workflow graph

* To avoid circular dependency between mistral.workbook.parser
 and mistral.workflow.commands reserved task names were moved to
 workbook.tasks module.

Closes-Bug: #1494635

Change-Id: I21337a9b616f55c3bbdc957538b25d36246b3903
This commit is contained in:
Nikolay Mahotkin 2015-09-14 18:01:19 +03:00 committed by Renat Akhmerov
parent 529638fde6
commit 002949a804
7 changed files with 120 additions and 81 deletions

View File

@ -91,26 +91,6 @@ class DirectWorkflowEngineTest(base.EngineTestCase):
self.assertTrue(wf_ex.state, states.ERROR) self.assertTrue(wf_ex.state, states.ERROR)
def test_direct_workflow_wrong_task_name(self):
wf_text = """
version: '2.0'
wf:
tasks:
task1:
action: std.echo output="Echo"
on-success:
- wrong_name
"""
wf_service.create_workflows(wf_text)
wf_ex = self.engine.start_workflow('wf', {})
self._await(lambda: self.is_execution_error(wf_ex.id))
wf_ex = db_api.get_workflow_execution(wf_ex.id)
self.assertIn("'wrong_name' not found", wf_ex.state_info)
def test_direct_workflow_change_state_after_success(self): def test_direct_workflow_change_state_after_success(self):
wf_text = """ wf_text = """
version: '2.0' version: '2.0'
@ -342,3 +322,25 @@ class DirectWorkflowEngineTest(base.EngineTestCase):
wf_ex = self.engine.start_workflow('wf', {}) wf_ex = self.engine.start_workflow('wf', {})
self._await(lambda: self.is_execution_success(wf_ex.id)) self._await(lambda: self.is_execution_success(wf_ex.id))
def test_inconsistent_task_names(self):
wf_text = """
version: '2.0'
wf:
tasks:
task1:
action: std.noop
on-success: task3
task2:
action: std.noop
"""
exception = self.assertRaises(
exc.InvalidModelException,
wf_service.create_workflows,
wf_text
)
self.assertIn("Task 'task3' not found", exception.message)

View File

@ -137,28 +137,6 @@ class ReverseWorkflowEngineTest(base.EngineTestCase):
self.assertDictEqual({'result2': 'a & b'}, task2_ex.published) self.assertDictEqual({'result2': 'a & b'}, task2_ex.published)
def test_reverse_workflow_wrong_task_name(self):
wf_text = """---
version: '2.0'
wf_wrong_task:
type: reverse
tasks:
task2:
requires: [wrong_name]
"""
wf_service.create_workflows(wf_text)
exception = self.assertRaises(
exc.WorkflowException,
self.engine.start_workflow,
"wf_wrong_task",
{},
task_name='task2'
)
self.assertIn("wrong_name", exception.message)
def test_one_line_requires_syntax(self): def test_one_line_requires_syntax(self):
wf_input = {'param1': 'a', 'param2': 'b'} wf_input = {'param1': 'a', 'param2': 'b'}
@ -176,3 +154,27 @@ class ReverseWorkflowEngineTest(base.EngineTestCase):
self._assert_single_item(tasks, name='task4', state=states.SUCCESS) self._assert_single_item(tasks, name='task4', state=states.SUCCESS)
self._assert_single_item(tasks, name='task3', state=states.SUCCESS) self._assert_single_item(tasks, name='task3', state=states.SUCCESS)
def test_inconsistent_task_names(self):
wf_text = """
version: '2.0'
wf:
type: reverse
tasks:
task2:
action: std.noop
task3:
action: std.noop
requires: [task1]
"""
exception = self.assertRaises(
exc.InvalidModelException,
wf_service.create_workflows,
wf_text
)
self.assertIn("Task 'task1' not found", exception.message)

View File

@ -326,8 +326,8 @@ class BaseListSpec(BaseSpec):
self._inject_version([k]) self._inject_version([k])
self.items.append(instantiate_spec(self.item_class, v)) self.items.append(instantiate_spec(self.item_class, v))
def validate(self): def validate_schema(self):
super(BaseListSpec, self).validate() super(BaseListSpec, self).validate_schema()
if len(self._data.keys()) < 2: if len(self._data.keys()) < 2:
raise exc.InvalidModelException( raise exc.InvalidModelException(

View File

@ -29,6 +29,12 @@ from mistral.workbook.v2 import policies
WITH_ITEMS_PTRN = re.compile( WITH_ITEMS_PTRN = re.compile(
"\s*([\w\d_\-]+)\s*in\s*(\[.+\]|%s)" % expr.INLINE_YAQL_REGEXP "\s*([\w\d_\-]+)\s*in\s*(\[.+\]|%s)" % expr.INLINE_YAQL_REGEXP
) )
RESERVED_TASK_NAMES = [
'noop',
'fail',
'succeed',
'pause'
]
class TaskSpec(base.BaseSpec): class TaskSpec(base.BaseSpec):

View File

@ -81,6 +81,20 @@ class WorkflowSpec(base.BaseSpec):
# Doesn't do anything by default. # Doesn't do anything by default.
pass pass
def _validate_task_link(self, task_name, allow_engine_cmds=True):
valid_task = self._task_exists(task_name)
if allow_engine_cmds:
valid_task |= task_name in tasks.RESERVED_TASK_NAMES
if not valid_task:
raise exc.InvalidModelException(
"Task '%s' not found." % task_name
)
def _task_exists(self, task_name):
return self.get_tasks()[task_name] is not None
def get_name(self): def get_name(self):
return self._name return self._name
@ -134,6 +148,15 @@ class DirectWorkflowSpec(WorkflowSpec):
'[workflow_name=%s]' % self._name '[workflow_name=%s]' % self._name
) )
self._check_workflow_integrity()
def _check_workflow_integrity(self):
for t_s in self.get_tasks():
out_task_names = self.find_outbound_task_names(t_s.get_name())
for out_t_name in out_task_names:
self._validate_task_link(out_t_name)
def find_start_tasks(self): def find_start_tasks(self):
return [ return [
t_s for t_s in self.get_tasks() t_s for t_s in self.get_tasks()
@ -158,18 +181,23 @@ class DirectWorkflowSpec(WorkflowSpec):
def has_outbound_transitions(self, task_spec): def has_outbound_transitions(self, task_spec):
return len(self.find_outbound_task_specs(task_spec)) > 0 return len(self.find_outbound_task_specs(task_spec)) > 0
def transition_exists(self, from_task_name, to_task_name): def find_outbound_task_names(self, task_name):
t_names = set() t_names = set()
for tup in self.get_on_error_clause(from_task_name): for tup in self.get_on_error_clause(task_name):
t_names.add(tup[0]) t_names.add(tup[0])
for tup in self.get_on_success_clause(from_task_name): for tup in self.get_on_success_clause(task_name):
t_names.add(tup[0]) t_names.add(tup[0])
for tup in self.get_on_complete_clause(from_task_name): for tup in self.get_on_complete_clause(task_name):
t_names.add(tup[0]) t_names.add(tup[0])
return t_names
def transition_exists(self, from_task_name, to_task_name):
t_names = self.find_outbound_task_names(from_task_name)
return to_task_name in t_names return to_task_name in t_names
def get_on_error_clause(self, t_name): def get_on_error_clause(self, t_name):
@ -235,6 +263,26 @@ class ReverseWorkflowSpec(WorkflowSpec):
} }
} }
def validate_semantics(self):
self._check_workflow_integrity()
def _check_workflow_integrity(self):
for t_s in self.get_tasks():
for req in self.get_task_requires(t_s):
self._validate_task_link(req, allow_engine_cmds=False)
def get_task_requires(self, task_spec):
requires = set(task_spec.get_requires())
defaults = self.get_task_defaults()
if defaults:
requires |= set(defaults.get_requires())
requires.discard(task_spec.get_name())
return list(requires)
class WorkflowSpecList(base.BaseSpecList): class WorkflowSpecList(base.BaseSpecList):
item_class = WorkflowSpec item_class = WorkflowSpec

View File

@ -131,12 +131,14 @@ class PauseWorkflow(SetWorkflowState):
return "Pause [workflow=%s]" % self.wf_ex.name return "Pause [workflow=%s]" % self.wf_ex.name
RESERVED_CMDS = { RESERVED_CMDS = dict(zip(
'noop': Noop, tasks.RESERVED_TASK_NAMES, [
'fail': FailWorkflow, Noop,
'succeed': SucceedWorkflow, FailWorkflow,
'pause': PauseWorkflow SucceedWorkflow,
} PauseWorkflow
]
))
def get_command_class(cmd_name): def get_command_class(cmd_name):

View File

@ -73,7 +73,7 @@ class ReverseWorkflowController(base.WorkflowController):
def _get_upstream_task_executions(self, task_spec): def _get_upstream_task_executions(self, task_spec):
t_specs = [ t_specs = [
self.wf_spec.get_tasks()[t_name] self.wf_spec.get_tasks()[t_name]
for t_name in self._get_task_requires(task_spec) for t_name in self.wf_spec.get_task_requires(task_spec)
or [] or []
] ]
@ -120,22 +120,11 @@ class ReverseWorkflowController(base.WorkflowController):
if self._is_satisfied_task(t_s) if self._is_satisfied_task(t_s)
] ]
def _task_exists(self, task_name):
return self.wf_spec.get_tasks()[task_name] is not None
def _is_satisfied_task(self, task_spec): def _is_satisfied_task(self, task_spec):
task_requires = self._get_task_requires(task_spec)
for req in task_requires:
if not self._task_exists(req):
raise exc.WorkflowException(
"Task '%s' not found." % req
)
if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec): if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec):
return False return False
if not self._get_task_requires(task_spec): if not self.wf_spec.get_task_requires(task_spec):
return True return True
success_t_names = set() success_t_names = set()
@ -144,7 +133,9 @@ class ReverseWorkflowController(base.WorkflowController):
if t_ex.state == states.SUCCESS: if t_ex.state == states.SUCCESS:
success_t_names.add(t_ex.name) success_t_names.add(t_ex.name)
return not (set(self._get_task_requires(task_spec)) - success_t_names) return not (
set(self.wf_spec.get_task_requires(task_spec)) - success_t_names
)
def _build_graph(self, tasks_spec): def _build_graph(self, tasks_spec):
graph = nx.DiGraph() graph = nx.DiGraph()
@ -161,7 +152,7 @@ class ReverseWorkflowController(base.WorkflowController):
return graph return graph
def _get_dependency_tasks(self, tasks_spec, task_spec): def _get_dependency_tasks(self, tasks_spec, task_spec):
dep_task_names = self._get_task_requires(task_spec) dep_task_names = self.wf_spec.get_task_requires(task_spec)
if len(dep_task_names) == 0: if len(dep_task_names) == 0:
return [] return []
@ -174,15 +165,3 @@ class ReverseWorkflowController(base.WorkflowController):
dep_t_specs.add(t_spec) dep_t_specs.add(t_spec)
return dep_t_specs return dep_t_specs
def _get_task_requires(self, task_spec):
requires = set(task_spec.get_requires())
task_defaults = self.wf_spec.get_task_defaults()
if task_defaults:
requires |= set(task_defaults.get_requires())
requires.discard(task_spec.get_name())
return list(requires)