diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 584bd0e0..83a7433a 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -64,6 +64,19 @@ class Flow(linear_flow.Flow): lines.append(" State: %s" % (self.state)) return "\n".join(lines) + @decorators.locked + def remove(self, task_uuid): + remove_nodes = [] + for r in self._graph.nodes_iter(): + if r.uuid == task_uuid: + remove_nodes.append(r) + if not remove_nodes: + raise IndexError("No task found with uuid %s" % (task_uuid)) + else: + for r in remove_nodes: + self._graph.remove_node(r) + self._runners = [] + def _ordering(self): try: return self._connect() diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 2d711c1c..81228451 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -62,6 +62,7 @@ class Flow(base.Flow): self._left_off_at = 0 # All runners to run are collected here. self._runners = [] + self._connected = False @decorators.locked def add_many(self, tasks): @@ -76,7 +77,7 @@ class Flow(base.Flow): assert isinstance(task, collections.Callable) r = utils.Runner(task) r.runs_before = list(reversed(self._runners)) - self._associate_providers(r) + self._connected = False self._runners.append(r) return r.uuid @@ -108,9 +109,31 @@ class Flow(base.Flow): lines.append(" State: %s" % (self.state)) return "\n".join(lines) - def _ordering(self): + @decorators.locked + def remove(self, task_uuid): + removed = False + for (i, r) in enumerate(self._runners): + if r.uuid == task_uuid: + self._runners.pop(i) + self._connected = False + removed = True + break + if not removed: + raise IndexError("No task found with uuid %s" % (task_uuid)) + + def _connect(self): + if self._connected: + return self._runners + for r in self._runners: + r.providers = {} + for r in reversed(self._runners): + self._associate_providers(r) + self._connected = True return self._runners + def _ordering(self): + return self._connect() + @decorators.locked def run(self, context, *args, **kwargs): super(Flow, self).run(context, *args, **kwargs) @@ -241,6 +264,7 @@ class Flow(base.Flow): self.result_fetcher = None self._accumulator.reset() self._left_off_at = 0 + self._connected = False @decorators.locked def rollback(self, context, cause): diff --git a/taskflow/tests/unit/test_linear_flow.py b/taskflow/tests/unit/test_linear_flow.py index 57dc0498..c60a8e83 100644 --- a/taskflow/tests/unit/test_linear_flow.py +++ b/taskflow/tests/unit/test_linear_flow.py @@ -169,8 +169,8 @@ class LinearFlowTest(unittest2.TestCase): pass wf.add(task_a) - self.assertRaises(exc.InvalidStateException, - wf.add, task_b) + wf.add(task_b) + self.assertRaises(exc.InvalidStateException, wf.run, {}) def test_not_satisfied_inputs_no_previous(self): wf = lw.Flow("the-test-action") @@ -179,8 +179,8 @@ class LinearFlowTest(unittest2.TestCase): def task_a(context, c, *args, **kwargs): pass - self.assertRaises(exc.InvalidStateException, - wf.add, task_a) + wf.add(task_a) + self.assertRaises(exc.InvalidStateException, wf.run, {}) def test_flow_add_order(self): wf = lw.Flow("the-test-action") @@ -189,11 +189,12 @@ class LinearFlowTest(unittest2.TestCase): requires=set(), provides=['a', 'b'])) # This one should fail to add since it requires 'c' - self.assertRaises(exc.InvalidStateException, - wf.add, - utils.ProvidesRequiresTask('test-2', - requires=['c'], - provides=[])) + uuid = wf.add(utils.ProvidesRequiresTask('test-2', + requires=['c'], + provides=[])) + self.assertRaises(exc.InvalidStateException, wf.run, {}) + wf.remove(uuid) + wf.add(utils.ProvidesRequiresTask('test-2', requires=['a', 'b'], provides=['c', 'd'])) @@ -209,6 +210,8 @@ class LinearFlowTest(unittest2.TestCase): wf.add(utils.ProvidesRequiresTask('test-6', requires=['d'], provides=[])) + wf.reset() + wf.run({}) def test_interrupt_flow(self): wf = lw.Flow("the-int-action")