diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index c9683cb4..426a236b 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -65,18 +65,19 @@ class Flow(linear_flow.Flow): return "; ".join(lines) @decorators.locked - def remove(self, task_uuid): - remove_nodes = [] + def remove(self, uuid): + runner = None for r in self._graph.nodes_iter(): - if r.uuid == task_uuid: - remove_nodes.append(r) - if not remove_nodes: - raise IndexError("No task found with uuid %s" % (task_uuid)) + if r.uuid == uuid: + runner = r + break + if not runner: + raise ValueError("No runner found with uuid %s" % (uuid)) else: - for r in remove_nodes: - self._graph.remove_node(r) - self._runners = [] - self._leftoff_at = None + # Ensure that we reset out internal state after said removal + self._graph.remove_node(runner) + self._runners = [] + self._leftoff_at = None def _ordering(self): try: diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 47c9b185..9a9df158 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -103,17 +103,26 @@ class Flow(base.Flow): return "; ".join(lines) @decorators.locked - def remove(self, task_uuid): - removed = False + def remove(self, uuid): + index_removed = -1 for (i, r) in enumerate(self._runners): - if r.uuid == task_uuid: - self._runners.pop(i) - self._connected = False - self._leftoff_at = None - removed = True + if r.uuid == uuid: + index_removed = i break - if not removed: - raise IndexError("No task found with uuid %s" % (task_uuid)) + if index_removed == -1: + raise ValueError("No runner found with uuid %s" % (uuid)) + else: + # Ensure that we reset out internal state after said removal. + removed = self._runners.pop(index_removed) + self._connected = False + self._leftoff_at = None + # Go and remove it from any runner after the removed runner since + # those runners may have had an attachment to it. + for r in self._runners[index_removed:]: + try: + r.runs_before.remove(removed) + except (IndexError, ValueError): + pass def _connect(self): if self._connected: