diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index ae0d0ee5a..a9eb93198 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -48,28 +48,28 @@ class DirectWorkflowController(base.WorkflowController): def _get_upstream_task_executions(self, task_spec): t_specs_names = [t_spec.get_name() for t_spec in self.wf_spec.find_inbound_task_specs(task_spec)] - t_execs = self._get_task_executions(name={'in': t_specs_names}) - return [t_ex for t_ex in t_execs - if self._is_upstream_task_execution(task_spec, t_ex)] + if not t_specs_names: + return [] - def _is_upstream_task_execution(self, t_spec, t_ex_candidate): - if not states.is_completed(t_ex_candidate.state): - return False + if not task_spec.get_join(): + return self._get_task_executions( + name=t_specs_names[0], # not a join, has just one parent + state={'in': (states.SUCCESS, states.ERROR, states.CANCELLED)}, + processed=True + ) - if not t_spec.get_join(): - return t_ex_candidate.processed - - t_execs_cache = self._prepare_task_executions_cache(t_spec) - - induced_state, _, _ = self._get_induced_join_state( - self.wf_spec.get_tasks()[t_ex_candidate.name], - t_ex_candidate, - t_spec, - t_execs_cache + t_execs_candidates = self._get_task_executions( + name={'in': t_specs_names}, + state={'in': (states.SUCCESS, states.ERROR, states.CANCELLED)}, ) - return induced_state == states.RUNNING + t_execs = [] + for t_ex in t_execs_candidates: + if task_spec.get_name() in [t[0] for t in t_ex.next_tasks]: + t_execs.append(t_ex) + + return t_execs def _find_next_commands(self, task_ex=None): cmds = super(DirectWorkflowController, self)._find_next_commands(