diff --git a/mistral/dsl.py b/mistral/dsl.py index 82e1f4c4..390fa0d5 100644 --- a/mistral/dsl.py +++ b/mistral/dsl.py @@ -63,7 +63,10 @@ class Parser(object): return tasks def get_task(self, task_name): - return self.get_tasks().get(task_name, {}) + task = self.get_tasks().get(task_name, {}) + if task: + task['name'] = task_name + return task def get_task_dsl_property(self, task_name, property_name): task_dsl = self.get_task(task_name) @@ -71,15 +74,21 @@ class Parser(object): def get_task_on_error(self, task_name): task = self.get_task_dsl_property(task_name, "on-error") - return task if isinstance(task, dict) else {task: ''} + if task: + return task if isinstance(task, dict) else {task: ''} + return None def get_task_on_success(self, task_name): task = self.get_task_dsl_property(task_name, "on-success") - return task if isinstance(task, dict) else {task: ''} + if task: + return task if isinstance(task, dict) else {task: ''} + return None def get_task_on_finish(self, task_name): task = self.get_task_dsl_property(task_name, "on-finish") - return task if isinstance(task, dict) else {task: ''} + if task: + return task if isinstance(task, dict) else {task: ''} + return None def get_task_input(self, task_name): return self.get_task_dsl_property(task_name, "input") diff --git a/mistral/engine/abstract_engine.py b/mistral/engine/abstract_engine.py index 3ce90fe0..e0966a13 100644 --- a/mistral/engine/abstract_engine.py +++ b/mistral/engine/abstract_engine.py @@ -53,7 +53,7 @@ class AbstractEngine(object): finally: db_api.end_tx() - cls._run_tasks(workflow.find_tasks_to_start(tasks)) + cls._run_tasks(workflow.find_resolved_tasks(tasks)) return execution @@ -61,16 +61,21 @@ class AbstractEngine(object): def convey_task_result(cls, workbook_name, execution_id, task_id, state, result): db_api.start_tx() - + wb_dsl = cls._get_wb_dsl(workbook_name) #TODO(rakhmerov): validate state transition # Update task state. task = db_api.task_update(workbook_name, execution_id, task_id, {"state": state, "result": result}) execution = db_api.execution_get(workbook_name, execution_id) + cls._create_next_tasks(task, + wb_dsl, + workbook_name, + execution_id) # Determine what tasks need to be started. tasks = db_api.tasks_get(workbook_name, execution_id) + # TODO(nmakhotkin) merge result into context try: new_exec_state = cls._determine_execution_state(execution, tasks) @@ -93,7 +98,7 @@ class AbstractEngine(object): return task if tasks: - cls._run_tasks(workflow.find_tasks_to_start(tasks)) + cls._run_tasks(workflow.find_resolved_tasks(tasks)) return task @@ -130,6 +135,14 @@ class AbstractEngine(object): "state": states.RUNNING }) + @classmethod + def _create_next_tasks(cls, task, wb_dsl, + workbook_name, execution_id): + dsl_tasks = workflow.find_tasks_after_completion(task, wb_dsl) + tasks = cls._create_tasks(dsl_tasks, wb_dsl, + workbook_name, execution_id) + return workflow.find_resolved_tasks(tasks) + @classmethod def _create_tasks(cls, dsl_tasks, wb_dsl, workbook_name, execution_id): tasks = [] @@ -158,7 +171,7 @@ class AbstractEngine(object): if workflow.is_error(tasks): return states.ERROR - if workflow.is_success(tasks): + if workflow.is_success(tasks) or workflow.is_finished(tasks): return states.SUCCESS return execution['state'] diff --git a/mistral/engine/workflow.py b/mistral/engine/workflow.py index e7735017..687d2fe6 100644 --- a/mistral/engine/workflow.py +++ b/mistral/engine/workflow.py @@ -18,6 +18,10 @@ import networkx as nx from networkx.algorithms import traversal from mistral.engine import states +from mistral.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) def find_workflow_tasks(wb_dsl, task_name): @@ -38,11 +42,60 @@ def find_workflow_tasks(wb_dsl, task_name): return tasks -def find_tasks_to_start(tasks): +def find_resolved_tasks(tasks): # We need to analyse graph and see which tasks are ready to start return _get_resolved_tasks(tasks) +def _get_checked_tasks(target_tasks): + checked_tasks = [] + for t in target_tasks: + #TODO(nmakhotkin): see and evaluate YAQL with data from context + checked_tasks.append(t) + return checked_tasks + + +def _get_tasks_to_schedule(target_tasks, wb_dsl): + tasks_to_schedule = _get_checked_tasks(target_tasks) + return [wb_dsl.get_task(t_name) for t_name in tasks_to_schedule] + + +def find_tasks_after_completion(task, wb_dsl): + """Determine tasks which should be scheduled after completing + given task. Expression 'on_finish' is not mutually exclusive to + 'on_success' and 'on_error'. + + :param task: Task object + :param wb_dsl: DSL Parser + :return: list of DSL tasks. + """ + state = task['state'] + found_tasks = [] + LOG.debug("Recieved task %s: %s" % (task['name'], state)) + + if state == states.ERROR: + tasks_on_error = wb_dsl.get_task_on_error(task['name']) + if tasks_on_error: + found_tasks = _get_tasks_to_schedule(tasks_on_error, wb_dsl) + + elif state == states.SUCCESS: + tasks_on_success = wb_dsl.get_task_on_success(task['name']) + if tasks_on_success: + found_tasks = _get_tasks_to_schedule(tasks_on_success, wb_dsl) + + if states.is_finished(state): + tasks_on_finish = wb_dsl.get_task_on_finish(task['name']) + if tasks_on_finish: + found_tasks += _get_tasks_to_schedule(tasks_on_finish, wb_dsl) + + LOG.debug("Found tasks: %s" % found_tasks) + workflow_tasks = [] + for t in found_tasks: + workflow_tasks += find_workflow_tasks(wb_dsl, t['name']) + LOG.debug("Workflow tasks to schedule: %s" % workflow_tasks) + return workflow_tasks + + def is_finished(tasks): return all(states.is_finished(task['state']) for task in tasks) diff --git a/mistral/tests/resources/test_rest.yaml b/mistral/tests/resources/test_rest.yaml index 8cf553a0..b485179c 100644 --- a/mistral/tests/resources/test_rest.yaml +++ b/mistral/tests/resources/test_rest.yaml @@ -82,7 +82,7 @@ Workflow: image_id: 1234 flavor_id: 2 - test: + test_subsequent: action: MyRest:backup-vm parameters: server_id: @@ -90,6 +90,8 @@ Workflow: attach-volumes on-error: backup-vms: $.status != 'OK' + on-finish: + create-vms events: create-vms: diff --git a/mistral/tests/unit/engine/local/test_engine.py b/mistral/tests/unit/engine/local/test_engine.py index 44e1b0fc..d9ed67f2 100644 --- a/mistral/tests/unit/engine/local/test_engine.py +++ b/mistral/tests/unit/engine/local/test_engine.py @@ -114,3 +114,86 @@ class TestLocalEngine(test_base.DbTestCase): execution = db_api.execution_get(WB_NAME, execution['id']) self.assertEqual(execution['state'], states.SUCCESS) self.assertEqual(task['state'], states.SUCCESS) + + @mock.patch.object(db_api, "workbook_get", + mock.MagicMock(return_value={ + 'definition': get_cfg("test_rest.yaml") + })) + @mock.patch.object(actions.RestAction, "run", + mock.MagicMock(return_value={'state': states.SUCCESS})) + def test_engine_tasks_on_success_finish(self): + execution = ENGINE.start_workflow_execution(WB_NAME, + "test_subsequent") + tasks = db_api.tasks_get(WB_NAME, execution['id']) + self.assertEqual(len(tasks), 1) + execution = db_api.execution_get(WB_NAME, execution['id']) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[0]['id'], + states.SUCCESS, None) + + tasks = db_api.tasks_get(WB_NAME, execution['id']) + self.assertEqual(len(tasks), 4) + attach_volumes = [t for t in tasks if t['name'] == 'attach-volumes'][0] + self.assertIn(attach_volumes, tasks) + self.assertEqual(tasks[0]['state'], states.SUCCESS) + self.assertEqual(tasks[1]['state'], states.IDLE) + self.assertEqual(tasks[2]['state'], states.RUNNING) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[2]['id'], + states.SUCCESS, None) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[3]['id'], + states.SUCCESS, None) + + tasks = db_api.tasks_get(WB_NAME, execution['id']) + self.assertEqual(tasks[2]['state'], states.SUCCESS) + self.assertEqual(tasks[1]['state'], states.RUNNING) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[1]['id'], + states.SUCCESS, None) + + tasks = db_api.tasks_get(WB_NAME, execution['id']) + execution = db_api.execution_get(WB_NAME, execution['id']) + self.assertEqual(tasks[1]['state'], states.SUCCESS) + self.assertEqual(execution['state'], states.SUCCESS) + + @mock.patch.object(db_api, "workbook_get", + mock.MagicMock(return_value={ + 'definition': get_cfg("test_rest.yaml") + })) + @mock.patch.object(actions.RestAction, "run", + mock.MagicMock(return_value={'state': states.SUCCESS})) + def test_engine_tasks_on_error_finish(self): + execution = ENGINE.start_workflow_execution(WB_NAME, + "test_subsequent") + tasks = db_api.tasks_get(WB_NAME, execution['id']) + execution = db_api.execution_get(WB_NAME, execution['id']) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[0]['id'], + states.ERROR, None) + + tasks = db_api.tasks_get(WB_NAME, execution['id']) + self.assertEqual(len(tasks), 4) + backup_vms = [t for t in tasks if t['name'] == 'backup-vms'][0] + self.assertIn(backup_vms, tasks) + self.assertEqual(tasks[0]['state'], states.ERROR) + self.assertEqual(tasks[1]['state'], states.IDLE) + self.assertEqual(tasks[2]['state'], states.RUNNING) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[2]['id'], + states.SUCCESS, None) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[3]['id'], + states.SUCCESS, None) + + tasks = db_api.tasks_get(WB_NAME, execution['id']) + self.assertEqual(tasks[2]['state'], states.SUCCESS) + self.assertEqual(tasks[1]['state'], states.RUNNING) + ENGINE.convey_task_result(WB_NAME, execution['id'], + tasks[1]['id'], + states.SUCCESS, None) + + tasks = db_api.tasks_get(WB_NAME, execution['id']) + execution = db_api.execution_get(WB_NAME, execution['id']) + self.assertEqual(tasks[1]['state'], states.SUCCESS) + self.assertEqual(execution['state'], states.SUCCESS) diff --git a/mistral/tests/unit/engine/test_workflow.py b/mistral/tests/unit/engine/test_workflow.py index 965981d5..a45676ee 100644 --- a/mistral/tests/unit/engine/test_workflow.py +++ b/mistral/tests/unit/engine/test_workflow.py @@ -54,5 +54,5 @@ class WorkflowTest(base.DbTestCase): self.assertEqual(tasks[1]['name'], 'create-vms') def test_tasks_to_start(self): - tasks_to_start = workflow.find_tasks_to_start(TASKS) + tasks_to_start = workflow.find_resolved_tasks(TASKS) self.assertEqual(len(tasks_to_start), 2) diff --git a/mistral/tests/unit/test_parser.py b/mistral/tests/unit/test_parser.py index 582463ef..85a32485 100644 --- a/mistral/tests/unit/test_parser.py +++ b/mistral/tests/unit/test_parser.py @@ -60,9 +60,9 @@ class DSLParserTest(unittest2.TestCase): self.assertEqual(task, {}) def test_task_property(self): - on_success = self.dsl.get_task_on_success("test") + on_success = self.dsl.get_task_on_success("test_subsequent") self.assertEqual(on_success, {"attach-volumes": ''}) - on_error = self.dsl.get_task_on_error("test") + on_error = self.dsl.get_task_on_error("test_subsequent") self.assertEqual(on_error, {"backup-vms": "$.status != 'OK'"}) def test_actions(self):