Adding validation of workflow graph
* To avoid circular dependency between mistral.workbook.parser and mistral.workflow.commands reserved task names were moved to workbook.tasks module. Closes-Bug: #1494635 Change-Id: I21337a9b616f55c3bbdc957538b25d36246b3903
This commit is contained in:
parent
529638fde6
commit
002949a804
@ -91,26 +91,6 @@ class DirectWorkflowEngineTest(base.EngineTestCase):
|
||||
|
||||
self.assertTrue(wf_ex.state, states.ERROR)
|
||||
|
||||
def test_direct_workflow_wrong_task_name(self):
|
||||
wf_text = """
|
||||
version: '2.0'
|
||||
|
||||
wf:
|
||||
tasks:
|
||||
task1:
|
||||
action: std.echo output="Echo"
|
||||
on-success:
|
||||
- wrong_name
|
||||
"""
|
||||
|
||||
wf_service.create_workflows(wf_text)
|
||||
wf_ex = self.engine.start_workflow('wf', {})
|
||||
|
||||
self._await(lambda: self.is_execution_error(wf_ex.id))
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertIn("'wrong_name' not found", wf_ex.state_info)
|
||||
|
||||
def test_direct_workflow_change_state_after_success(self):
|
||||
wf_text = """
|
||||
version: '2.0'
|
||||
@ -342,3 +322,25 @@ class DirectWorkflowEngineTest(base.EngineTestCase):
|
||||
wf_ex = self.engine.start_workflow('wf', {})
|
||||
|
||||
self._await(lambda: self.is_execution_success(wf_ex.id))
|
||||
|
||||
def test_inconsistent_task_names(self):
|
||||
wf_text = """
|
||||
version: '2.0'
|
||||
|
||||
wf:
|
||||
tasks:
|
||||
task1:
|
||||
action: std.noop
|
||||
on-success: task3
|
||||
|
||||
task2:
|
||||
action: std.noop
|
||||
"""
|
||||
|
||||
exception = self.assertRaises(
|
||||
exc.InvalidModelException,
|
||||
wf_service.create_workflows,
|
||||
wf_text
|
||||
)
|
||||
|
||||
self.assertIn("Task 'task3' not found", exception.message)
|
||||
|
@ -137,28 +137,6 @@ class ReverseWorkflowEngineTest(base.EngineTestCase):
|
||||
|
||||
self.assertDictEqual({'result2': 'a & b'}, task2_ex.published)
|
||||
|
||||
def test_reverse_workflow_wrong_task_name(self):
|
||||
wf_text = """---
|
||||
version: '2.0'
|
||||
wf_wrong_task:
|
||||
type: reverse
|
||||
tasks:
|
||||
task2:
|
||||
requires: [wrong_name]
|
||||
"""
|
||||
|
||||
wf_service.create_workflows(wf_text)
|
||||
|
||||
exception = self.assertRaises(
|
||||
exc.WorkflowException,
|
||||
self.engine.start_workflow,
|
||||
"wf_wrong_task",
|
||||
{},
|
||||
task_name='task2'
|
||||
)
|
||||
|
||||
self.assertIn("wrong_name", exception.message)
|
||||
|
||||
def test_one_line_requires_syntax(self):
|
||||
wf_input = {'param1': 'a', 'param2': 'b'}
|
||||
|
||||
@ -176,3 +154,27 @@ class ReverseWorkflowEngineTest(base.EngineTestCase):
|
||||
|
||||
self._assert_single_item(tasks, name='task4', state=states.SUCCESS)
|
||||
self._assert_single_item(tasks, name='task3', state=states.SUCCESS)
|
||||
|
||||
def test_inconsistent_task_names(self):
|
||||
wf_text = """
|
||||
version: '2.0'
|
||||
|
||||
wf:
|
||||
type: reverse
|
||||
|
||||
tasks:
|
||||
task2:
|
||||
action: std.noop
|
||||
|
||||
task3:
|
||||
action: std.noop
|
||||
requires: [task1]
|
||||
"""
|
||||
|
||||
exception = self.assertRaises(
|
||||
exc.InvalidModelException,
|
||||
wf_service.create_workflows,
|
||||
wf_text
|
||||
)
|
||||
|
||||
self.assertIn("Task 'task1' not found", exception.message)
|
||||
|
@ -326,8 +326,8 @@ class BaseListSpec(BaseSpec):
|
||||
self._inject_version([k])
|
||||
self.items.append(instantiate_spec(self.item_class, v))
|
||||
|
||||
def validate(self):
|
||||
super(BaseListSpec, self).validate()
|
||||
def validate_schema(self):
|
||||
super(BaseListSpec, self).validate_schema()
|
||||
|
||||
if len(self._data.keys()) < 2:
|
||||
raise exc.InvalidModelException(
|
||||
|
@ -29,6 +29,12 @@ from mistral.workbook.v2 import policies
|
||||
WITH_ITEMS_PTRN = re.compile(
|
||||
"\s*([\w\d_\-]+)\s*in\s*(\[.+\]|%s)" % expr.INLINE_YAQL_REGEXP
|
||||
)
|
||||
RESERVED_TASK_NAMES = [
|
||||
'noop',
|
||||
'fail',
|
||||
'succeed',
|
||||
'pause'
|
||||
]
|
||||
|
||||
|
||||
class TaskSpec(base.BaseSpec):
|
||||
|
@ -81,6 +81,20 @@ class WorkflowSpec(base.BaseSpec):
|
||||
# Doesn't do anything by default.
|
||||
pass
|
||||
|
||||
def _validate_task_link(self, task_name, allow_engine_cmds=True):
|
||||
valid_task = self._task_exists(task_name)
|
||||
|
||||
if allow_engine_cmds:
|
||||
valid_task |= task_name in tasks.RESERVED_TASK_NAMES
|
||||
|
||||
if not valid_task:
|
||||
raise exc.InvalidModelException(
|
||||
"Task '%s' not found." % task_name
|
||||
)
|
||||
|
||||
def _task_exists(self, task_name):
|
||||
return self.get_tasks()[task_name] is not None
|
||||
|
||||
def get_name(self):
|
||||
return self._name
|
||||
|
||||
@ -134,6 +148,15 @@ class DirectWorkflowSpec(WorkflowSpec):
|
||||
'[workflow_name=%s]' % self._name
|
||||
)
|
||||
|
||||
self._check_workflow_integrity()
|
||||
|
||||
def _check_workflow_integrity(self):
|
||||
for t_s in self.get_tasks():
|
||||
out_task_names = self.find_outbound_task_names(t_s.get_name())
|
||||
|
||||
for out_t_name in out_task_names:
|
||||
self._validate_task_link(out_t_name)
|
||||
|
||||
def find_start_tasks(self):
|
||||
return [
|
||||
t_s for t_s in self.get_tasks()
|
||||
@ -158,18 +181,23 @@ class DirectWorkflowSpec(WorkflowSpec):
|
||||
def has_outbound_transitions(self, task_spec):
|
||||
return len(self.find_outbound_task_specs(task_spec)) > 0
|
||||
|
||||
def transition_exists(self, from_task_name, to_task_name):
|
||||
def find_outbound_task_names(self, task_name):
|
||||
t_names = set()
|
||||
|
||||
for tup in self.get_on_error_clause(from_task_name):
|
||||
for tup in self.get_on_error_clause(task_name):
|
||||
t_names.add(tup[0])
|
||||
|
||||
for tup in self.get_on_success_clause(from_task_name):
|
||||
for tup in self.get_on_success_clause(task_name):
|
||||
t_names.add(tup[0])
|
||||
|
||||
for tup in self.get_on_complete_clause(from_task_name):
|
||||
for tup in self.get_on_complete_clause(task_name):
|
||||
t_names.add(tup[0])
|
||||
|
||||
return t_names
|
||||
|
||||
def transition_exists(self, from_task_name, to_task_name):
|
||||
t_names = self.find_outbound_task_names(from_task_name)
|
||||
|
||||
return to_task_name in t_names
|
||||
|
||||
def get_on_error_clause(self, t_name):
|
||||
@ -235,6 +263,26 @@ class ReverseWorkflowSpec(WorkflowSpec):
|
||||
}
|
||||
}
|
||||
|
||||
def validate_semantics(self):
|
||||
self._check_workflow_integrity()
|
||||
|
||||
def _check_workflow_integrity(self):
|
||||
for t_s in self.get_tasks():
|
||||
for req in self.get_task_requires(t_s):
|
||||
self._validate_task_link(req, allow_engine_cmds=False)
|
||||
|
||||
def get_task_requires(self, task_spec):
|
||||
requires = set(task_spec.get_requires())
|
||||
|
||||
defaults = self.get_task_defaults()
|
||||
|
||||
if defaults:
|
||||
requires |= set(defaults.get_requires())
|
||||
|
||||
requires.discard(task_spec.get_name())
|
||||
|
||||
return list(requires)
|
||||
|
||||
|
||||
class WorkflowSpecList(base.BaseSpecList):
|
||||
item_class = WorkflowSpec
|
||||
|
@ -131,12 +131,14 @@ class PauseWorkflow(SetWorkflowState):
|
||||
return "Pause [workflow=%s]" % self.wf_ex.name
|
||||
|
||||
|
||||
RESERVED_CMDS = {
|
||||
'noop': Noop,
|
||||
'fail': FailWorkflow,
|
||||
'succeed': SucceedWorkflow,
|
||||
'pause': PauseWorkflow
|
||||
}
|
||||
RESERVED_CMDS = dict(zip(
|
||||
tasks.RESERVED_TASK_NAMES, [
|
||||
Noop,
|
||||
FailWorkflow,
|
||||
SucceedWorkflow,
|
||||
PauseWorkflow
|
||||
]
|
||||
))
|
||||
|
||||
|
||||
def get_command_class(cmd_name):
|
||||
|
@ -73,7 +73,7 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
def _get_upstream_task_executions(self, task_spec):
|
||||
t_specs = [
|
||||
self.wf_spec.get_tasks()[t_name]
|
||||
for t_name in self._get_task_requires(task_spec)
|
||||
for t_name in self.wf_spec.get_task_requires(task_spec)
|
||||
or []
|
||||
]
|
||||
|
||||
@ -120,22 +120,11 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
if self._is_satisfied_task(t_s)
|
||||
]
|
||||
|
||||
def _task_exists(self, task_name):
|
||||
return self.wf_spec.get_tasks()[task_name] is not None
|
||||
|
||||
def _is_satisfied_task(self, task_spec):
|
||||
task_requires = self._get_task_requires(task_spec)
|
||||
|
||||
for req in task_requires:
|
||||
if not self._task_exists(req):
|
||||
raise exc.WorkflowException(
|
||||
"Task '%s' not found." % req
|
||||
)
|
||||
|
||||
if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec):
|
||||
return False
|
||||
|
||||
if not self._get_task_requires(task_spec):
|
||||
if not self.wf_spec.get_task_requires(task_spec):
|
||||
return True
|
||||
|
||||
success_t_names = set()
|
||||
@ -144,7 +133,9 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
if t_ex.state == states.SUCCESS:
|
||||
success_t_names.add(t_ex.name)
|
||||
|
||||
return not (set(self._get_task_requires(task_spec)) - success_t_names)
|
||||
return not (
|
||||
set(self.wf_spec.get_task_requires(task_spec)) - success_t_names
|
||||
)
|
||||
|
||||
def _build_graph(self, tasks_spec):
|
||||
graph = nx.DiGraph()
|
||||
@ -161,7 +152,7 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
return graph
|
||||
|
||||
def _get_dependency_tasks(self, tasks_spec, task_spec):
|
||||
dep_task_names = self._get_task_requires(task_spec)
|
||||
dep_task_names = self.wf_spec.get_task_requires(task_spec)
|
||||
|
||||
if len(dep_task_names) == 0:
|
||||
return []
|
||||
@ -174,15 +165,3 @@ class ReverseWorkflowController(base.WorkflowController):
|
||||
dep_t_specs.add(t_spec)
|
||||
|
||||
return dep_t_specs
|
||||
|
||||
def _get_task_requires(self, task_spec):
|
||||
requires = set(task_spec.get_requires())
|
||||
|
||||
task_defaults = self.wf_spec.get_task_defaults()
|
||||
|
||||
if task_defaults:
|
||||
requires |= set(task_defaults.get_requires())
|
||||
|
||||
requires.discard(task_spec.get_name())
|
||||
|
||||
return list(requires)
|
||||
|
Loading…
Reference in New Issue
Block a user