diff --git a/mistral/tests/unit/engine/test_dataflow.py b/mistral/tests/unit/engine/test_dataflow.py index a002872e..0ce5a78b 100644 --- a/mistral/tests/unit/engine/test_dataflow.py +++ b/mistral/tests/unit/engine/test_dataflow.py @@ -371,6 +371,58 @@ class DataFlowEngineTest(engine_test_base.EngineTestCase): task4.published ) + def test_sequential_tasks_publishing_same_structured(self): + var_overwrite_wf = """--- + version: '2.0' + + wf: + type: direct + + tasks: + task1: + publish: + greeting: {"a": "b"} + on-success: + - task2 + + task2: + publish: + greeting: {} + on-success: + - task3 + + task3: + publish: + result: <% $.greeting %> + """ + + wf_service.create_workflows(var_overwrite_wf) + + # Start workflow. + wf_ex = self.engine.start_workflow( + 'wf', + {}, + env={'from': 'Neo'} + ) + + self.await_workflow_success(wf_ex.id) + + # Note: We need to reread execution to access related tasks. + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + self.assertEqual(states.SUCCESS, wf_ex.state) + + tasks = wf_ex.task_executions + + task1 = self._assert_single_item(tasks, name='task1') + task2 = self._assert_single_item(tasks, name='task2') + task3 = self._assert_single_item(tasks, name='task3') + + self.assertEqual(states.SUCCESS, task3.state) + self.assertDictEqual({'greeting': {'a': 'b'}}, task1.published) + self.assertDictEqual({'greeting': {}}, task2.published) + self.assertDictEqual({'result': {}}, task3.published) + def test_linear_dataflow_implicit_publish(self): linear_wf = """--- version: '2.0' diff --git a/mistral/utils/__init__.py b/mistral/utils/__init__.py index c13923f0..ee988390 100644 --- a/mistral/utils/__init__.py +++ b/mistral/utils/__init__.py @@ -148,6 +148,25 @@ def merge_dicts(left, right, overwrite=True): return left +def update_dict(left, right): + """Updates left dict with content from right dict + + :param left: Left dict. + :param right: Right dict. + :return: the updated left dictionary. + """ + + if left is None: + return right + + if right is None: + return left + + left.update(right) + + return left + + def get_file_list(directory): base_path = pkg.resource_filename( version.version_info.package, @@ -190,7 +209,7 @@ def iter_subclasses(cls, _seen=None): try: subs = cls.__subclasses__() - except TypeError: # fails only when cls is type + except TypeError: # fails only when cls is type subs = cls.__subclasses__(cls) for sub in subs: diff --git a/mistral/workflow/data_flow.py b/mistral/workflow/data_flow.py index eb006d9b..48fd088b 100644 --- a/mistral/workflow/data_flow.py +++ b/mistral/workflow/data_flow.py @@ -116,7 +116,7 @@ def evaluate_task_outbound_context(task_ex): in_context = (copy.deepcopy(dict(task_ex.in_context)) if task_ex.in_context is not None else {}) - return utils.merge_dicts(in_context, task_ex.published) + return utils.update_dict(in_context, task_ex.published) def evaluate_workflow_output(wf_spec, ctx):