diff --git a/taskflow/engines/action_engine/graph_analyzer.py b/taskflow/engines/action_engine/graph_analyzer.py index a0fcde35..dcede332 100644 --- a/taskflow/engines/action_engine/graph_analyzer.py +++ b/taskflow/engines/action_engine/graph_analyzer.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. +from networkx.algorithms import traversal import six from taskflow import retry as r @@ -124,22 +125,8 @@ class GraphAnalyzer(object): """Iterates a subgraph connected to current retry controller, including nested retry controllers and its nodes. """ - visited_nodes = set() - retries_scope = set() - retries_scope.add(retry) - - nodes = self._graph.successors(retry) - while nodes: - next_nodes = [] - for node in nodes: - if node not in visited_nodes: - visited_nodes.add(node) - if self.find_atom_retry(node) in retries_scope: - yield node - if isinstance(node, r.Retry): - retries_scope.add(node) - next_nodes += self._graph.successors(node) - nodes = next_nodes + for _src, dst in traversal.dfs_edges(self._graph, retry): + yield dst def iterate_retries(self, state=None): """Iterates retry controllers of a graph with given state or all diff --git a/taskflow/tests/unit/test_flattening.py b/taskflow/tests/unit/test_flattening.py index a530ceed..0efc2d25 100644 --- a/taskflow/tests/unit/test_flattening.py +++ b/taskflow/tests/unit/test_flattening.py @@ -137,6 +137,22 @@ class FlattenTest(test.TestCase): lb = g.subgraph([c, d]) self.assertEqual(1, lb.number_of_edges()) + def test_unordered_nested_in_linear_flatten(self): + a, b, c, d = _make_many(4) + flo = lf.Flow('lt').add( + a, + uf.Flow('ut').add(b, c), + d) + + g = f_utils.flatten(flo) + self.assertEqual(4, len(g)) + self.assertItemsEqual(g.edges(), [ + (a, b), + (a, c), + (b, d), + (c, d) + ]) + def test_graph_flatten(self): a, b, c, d = _make_many(4) flo = gf.Flow("test") @@ -218,10 +234,10 @@ class FlattenTest(test.TestCase): g = f_utils.flatten(flo) self.assertEqual(3, len(g)) self.assertItemsEqual(g.edges(data=True), [ - (a, b, {'reasons': set(['x'])}), + (a, c, {'reasons': set(['x'])}), (b, c, {'invariant': True}) ]) - self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([a, b], g_utils.get_no_predecessors(g)) self.assertItemsEqual([c], g_utils.get_no_successors(g)) def test_graph_flatten_nested_provides(self): @@ -237,10 +253,10 @@ class FlattenTest(test.TestCase): self.assertEqual(3, len(g)) self.assertItemsEqual(g.edges(data=True), [ (b, c, {'invariant': True}), - (c, a, {'reasons': set(['x'])}) + (b, a, {'reasons': set(['x'])}) ]) self.assertItemsEqual([b], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a], g_utils.get_no_successors(g)) + self.assertItemsEqual([a, c], g_utils.get_no_successors(g)) def test_flatten_checks_for_dups(self): flo = gf.Flow("test").add( diff --git a/taskflow/tests/unit/test_retries.py b/taskflow/tests/unit/test_retries.py index 1c75ed61..d09944f5 100644 --- a/taskflow/tests/unit/test_retries.py +++ b/taskflow/tests/unit/test_retries.py @@ -493,7 +493,7 @@ class RetryTest(utils.EngineTestBase): self.assertEqual(self.values, expected) def test_for_each_with_set(self): - collection = ([3, 2, 5]) + collection = set([3, 2, 5]) retry1 = retry.ForEach(collection, 'r1', provides='x') flow = lf.Flow('flow-1', retry1).add(utils.FailingTaskWithOneArg('t1')) engine = self._make_engine(flow) @@ -636,6 +636,53 @@ class RetryTest(utils.EngineTestBase): self.assertEqual(r.history[0][1], {}) self.assertEqual(isinstance(r.history[0][0], misc.Failure), True) + def test_nested_provides_graph_reverts_correctly(self): + flow = gf.Flow("test").add( + utils.SaveOrderTask('a', requires=['x']), + lf.Flow("test2", retry=retry.Times(2)).add( + utils.SaveOrderTask('b', provides='x'), + utils.FailingTask('c'))) + engine = self._make_engine(flow) + engine.compile() + engine.storage.save('test2_retry', 1) + engine.storage.save('b', 11) + engine.storage.save('a', 10) + self.assertRaisesRegexp(RuntimeError, '^Woot', engine.run) + self.assertItemsEqual(self.values[:3], [ + 'a reverted(10)', + 'c reverted(Failure: RuntimeError: Woot!)', + 'b reverted(11)', + ]) + # Task 'a' was or was not executed again, both cases are ok. + self.assertIsSuperAndSubsequence(self.values[3:], [ + 'b', + 'c reverted(Failure: RuntimeError: Woot!)', + 'b reverted(5)' + ]) + self.assertEqual(engine.storage.get_flow_state(), st.REVERTED) + + def test_nested_provides_graph_retried_correctly(self): + flow = gf.Flow("test").add( + utils.SaveOrderTask('a', requires=['x']), + lf.Flow("test2", retry=retry.Times(2)).add( + utils.SaveOrderTask('b', provides='x'), + utils.SaveOrderTask('c'))) + engine = self._make_engine(flow) + engine.compile() + engine.storage.save('test2_retry', 1) + engine.storage.save('b', 11) + # pretend that 'c' failed + fail = misc.Failure.from_exception(RuntimeError('Woot!')) + engine.storage.save('c', fail, st.FAILURE) + + engine.run() + self.assertItemsEqual(self.values[:2], [ + 'c reverted(Failure: RuntimeError: Woot!)', + 'b reverted(11)', + ]) + self.assertItemsEqual(self.values[2:], ['b', 'c', 'a']) + self.assertEqual(engine.storage.get_flow_state(), st.SUCCESS) + class RetryParallelExecutionTest(utils.EngineTestBase): diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py index 0d1d2bbf..c6d86ec5 100644 --- a/taskflow/utils/flow_utils.py +++ b/taskflow/utils/flow_utils.py @@ -46,6 +46,7 @@ class Flattener(object): with the following edge attributes (defaulting to the class provided edge_data if None), if the edge does not already exist. """ + nodes_to = list(nodes_to) for u in nodes_from: for v in nodes_to: if not graph.has_edge(u, v): @@ -103,14 +104,25 @@ class Flattener(object): subgraph = self._flatten(item) subgraph_map[item] = subgraph graph = gu.merge_graphs([graph, subgraph]) + # Reconnect all node edges to their corresponding subgraphs. - for (u, v, u_v_attrs) in flow.iter_links(): - # Connect the ones with no predecessors in v to the ones with no - # successors in u (thus maintaining the edge dependency). - self._add_new_edges(graph, - list(gu.get_no_successors(subgraph_map[u])), - list(gu.get_no_predecessors(subgraph_map[v])), - edge_attrs=u_v_attrs) + for (u, v, attrs) in flow.iter_links(): + if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')): + # Connect nodes with no predecessors in v to nodes with + # no successors in u (thus maintaining the edge dependency). + self._add_new_edges(graph, + gu.get_no_successors(subgraph_map[u]), + gu.get_no_predecessors(subgraph_map[v]), + edge_attrs=attrs) + else: + # This is dependency-only edge, connect corresponding + # providers and consumers. + for provider in subgraph_map[u]: + for consumer in subgraph_map[v]: + reasons = provider.provides & consumer.requires + if reasons: + graph.add_edge(provider, consumer, reasons=reasons) + if flow.retry is not None: self._connect_retry(flow.retry, graph) return graph