diff --git a/taskflow/tests/unit/test_flattening.py b/taskflow/tests/unit/test_flattening.py index 4e948d9d..a530ceed 100644 --- a/taskflow/tests/unit/test_flattening.py +++ b/taskflow/tests/unit/test_flattening.py @@ -39,6 +39,22 @@ def _make_many(amount): class FlattenTest(test.TestCase): + def test_flatten_task(self): + task = t_utils.DummyTask(name='a') + g = f_utils.flatten(task) + + self.assertEqual(list(g.nodes()), [task]) + self.assertEqual(list(g.edges()), []) + + def test_flatten_retry(self): + r = retry.AlwaysRevert('r1') + msg_regex = "^Retry controller .* is used not as a flow parameter" + self.assertRaisesRegexp(TypeError, msg_regex, f_utils.flatten, r) + + def test_flatten_wrong_object(self): + msg_regex = '^Unknown type requested to flatten' + self.assertRaisesRegexp(TypeError, msg_regex, f_utils.flatten, 42) + def test_linear_flatten(self): a, b, c, d = _make_many(4) flo = lf.Flow("test") @@ -53,6 +69,8 @@ class FlattenTest(test.TestCase): order = nx.topological_sort(g) self.assertEqual([a, b, c, d], order) self.assertTrue(g.has_edge(c, d)) + self.assertEqual(g.get_edge_data(c, d), {'invariant': True}) + self.assertEqual([d], list(g_utils.get_no_successors(g))) self.assertEqual([a], list(g_utils.get_no_predecessors(g))) @@ -86,8 +104,9 @@ class FlattenTest(test.TestCase): self.assertEqual(4, len(g)) lb = g.subgraph([a, b]) - self.assertTrue(lb.has_edge(a, b)) self.assertFalse(lb.has_edge(b, a)) + self.assertTrue(lb.has_edge(a, b)) + self.assertEqual(g.get_edge_data(a, b), {'invariant': True}) ub = g.subgraph([c, d]) self.assertEqual(0, ub.number_of_edges()) @@ -109,8 +128,9 @@ class FlattenTest(test.TestCase): for n in [a, b]: self.assertFalse(g.has_edge(n, c)) self.assertFalse(g.has_edge(n, d)) - self.assertTrue(g.has_edge(c, d)) self.assertFalse(g.has_edge(d, c)) + self.assertTrue(g.has_edge(c, d)) + self.assertEqual(g.get_edge_data(c, d), {'invariant': True}) ub = g.subgraph([a, b]) self.assertEqual(0, ub.number_of_edges()) @@ -135,9 +155,12 @@ class FlattenTest(test.TestCase): flo2.add(e, f, g) flo.add(flo2) - g = f_utils.flatten(flo) - self.assertEqual(7, len(g)) - self.assertEqual(2, g.number_of_edges()) + graph = f_utils.flatten(flo) + self.assertEqual(7, len(graph)) + self.assertItemsEqual(graph.edges(data=True), [ + (e, f, {'invariant': True}), + (f, g, {'invariant': True}) + ]) def test_graph_flatten_nested_graph(self): a, b, c, d, e, f, g = _make_many(7) @@ -162,11 +185,62 @@ class FlattenTest(test.TestCase): g = f_utils.flatten(flo) self.assertEqual(4, len(g)) - self.assertEqual(3, g.number_of_edges()) - self.assertEqual(set([a]), - set(g_utils.get_no_predecessors(g))) - self.assertEqual(set([d]), - set(g_utils.get_no_successors(g))) + self.assertItemsEqual(g.edges(data=True), [ + (a, b, {'manual': True}), + (b, c, {'manual': True}), + (c, d, {'manual': True}), + ]) + self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([d], g_utils.get_no_successors(g)) + + def test_graph_flatten_dependencies(self): + a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[]) + b = t_utils.ProvidesRequiresTask('b', provides=[], requires=['x']) + flo = gf.Flow("test").add(a, b) + + g = f_utils.flatten(flo) + self.assertEqual(2, len(g)) + self.assertItemsEqual(g.edges(data=True), [ + (a, b, {'reasons': set(['x'])}) + ]) + self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([b], g_utils.get_no_successors(g)) + + def test_graph_flatten_nested_requires(self): + a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[]) + b = t_utils.ProvidesRequiresTask('b', provides=[], requires=[]) + c = t_utils.ProvidesRequiresTask('c', provides=[], requires=['x']) + flo = gf.Flow("test").add( + a, + lf.Flow("test2").add(b, c) + ) + + g = f_utils.flatten(flo) + self.assertEqual(3, len(g)) + self.assertItemsEqual(g.edges(data=True), [ + (a, b, {'reasons': set(['x'])}), + (b, c, {'invariant': True}) + ]) + self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([c], g_utils.get_no_successors(g)) + + def test_graph_flatten_nested_provides(self): + a = t_utils.ProvidesRequiresTask('a', provides=[], requires=['x']) + b = t_utils.ProvidesRequiresTask('b', provides=['x'], requires=[]) + c = t_utils.ProvidesRequiresTask('c', provides=[], requires=[]) + flo = gf.Flow("test").add( + a, + lf.Flow("test2").add(b, c) + ) + + g = f_utils.flatten(flo) + self.assertEqual(3, len(g)) + self.assertItemsEqual(g.edges(data=True), [ + (b, c, {'invariant': True}), + (c, a, {'reasons': set(['x'])}) + ]) + self.assertItemsEqual([b], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([a], g_utils.get_no_successors(g)) def test_flatten_checks_for_dups(self): flo = gf.Flow("test").add( @@ -208,12 +282,14 @@ class FlattenTest(test.TestCase): c2 = retry.AlwaysRevert("c2") flo = lf.Flow("test", c1).add(lf.Flow("test2", c2)) g = f_utils.flatten(flo) + self.assertEqual(2, len(g)) - self.assertEqual(1, g.number_of_edges()) - self.assertEqual(set([c1]), - set(g_utils.get_no_predecessors(g))) - self.assertEqual(set([c2]), - set(g_utils.get_no_successors(g))) + self.assertItemsEqual(g.edges(data=True), [ + (c1, c2, {'retry': True}) + ]) + self.assertIs(c1, g.node[c2]['retry']) + self.assertItemsEqual([c1], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([c2], g_utils.get_no_successors(g)) def test_flatten_retry_in_linear_flow_with_tasks(self): c = retry.AlwaysRevert("c") @@ -221,13 +297,15 @@ class FlattenTest(test.TestCase): flo = lf.Flow("test", c).add(a, b) g = f_utils.flatten(flo) self.assertEqual(3, len(g)) - self.assertEqual(2, g.number_of_edges()) - self.assertEqual(set([c]), - set(g_utils.get_no_predecessors(g))) - self.assertEqual(set([b]), - set(g_utils.get_no_successors(g))) - self.assertEqual(c, g.node[a]['retry']) - self.assertEqual(c, g.node[b]['retry']) + self.assertItemsEqual(g.edges(data=True), [ + (a, b, {'invariant': True}), + (c, a, {'retry': True}) + ]) + + self.assertItemsEqual([c], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([b], g_utils.get_no_successors(g)) + self.assertIs(c, g.node[a]['retry']) + self.assertIs(c, g.node[b]['retry']) def test_flatten_retry_in_unordered_flow_with_tasks(self): c = retry.AlwaysRevert("c") @@ -235,28 +313,34 @@ class FlattenTest(test.TestCase): flo = uf.Flow("test", c).add(a, b) g = f_utils.flatten(flo) self.assertEqual(3, len(g)) - self.assertEqual(2, g.number_of_edges()) - self.assertEqual(set([c]), - set(g_utils.get_no_predecessors(g))) - self.assertEqual(set([a, b]), - set(g_utils.get_no_successors(g))) - self.assertEqual(c, g.node[a]['retry']) - self.assertEqual(c, g.node[b]['retry']) + self.assertItemsEqual(g.edges(data=True), [ + (c, a, {'retry': True}), + (c, b, {'retry': True}) + ]) + + self.assertItemsEqual([c], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([a, b], g_utils.get_no_successors(g)) + self.assertIs(c, g.node[a]['retry']) + self.assertIs(c, g.node[b]['retry']) def test_flatten_retry_in_graph_flow_with_tasks(self): - c = retry.AlwaysRevert("cp") - a, b, d = _make_many(3) - flo = gf.Flow("test", c).add(a, b, d).link(b, d) + r = retry.AlwaysRevert("cp") + a, b, c = _make_many(3) + flo = gf.Flow("test", r).add(a, b, c).link(b, c) g = f_utils.flatten(flo) self.assertEqual(4, len(g)) - self.assertEqual(3, g.number_of_edges()) - self.assertEqual(set([c]), - set(g_utils.get_no_predecessors(g))) - self.assertEqual(set([a, d]), - set(g_utils.get_no_successors(g))) - self.assertEqual(c, g.node[a]['retry']) - self.assertEqual(c, g.node[b]['retry']) - self.assertEqual(c, g.node[d]['retry']) + + self.assertItemsEqual(g.edges(data=True), [ + (r, a, {'retry': True}), + (r, b, {'retry': True}), + (b, c, {'manual': True}) + ]) + + self.assertItemsEqual([r], g_utils.get_no_predecessors(g)) + self.assertItemsEqual([a, c], g_utils.get_no_successors(g)) + self.assertIs(r, g.node[a]['retry']) + self.assertIs(r, g.node[b]['retry']) + self.assertIs(r, g.node[c]['retry']) def test_flatten_retries_hierarchy(self): c1 = retry.AlwaysRevert("cp1") @@ -268,13 +352,19 @@ class FlattenTest(test.TestCase): d) g = f_utils.flatten(flo) self.assertEqual(6, len(g)) - self.assertEqual(5, g.number_of_edges()) - self.assertEqual(c1, g.node[a]['retry']) - self.assertEqual(c1, g.node[d]['retry']) - self.assertEqual(c2, g.node[b]['retry']) - self.assertEqual(c2, g.node[c]['retry']) - self.assertEqual(c1, g.node[c2]['retry']) - self.assertEqual(None, g.node[c1].get('retry')) + self.assertItemsEqual(g.edges(data=True), [ + (c1, a, {'retry': True}), + (a, c2, {'invariant': True}), + (c2, b, {'retry': True}), + (b, c, {'invariant': True}), + (c, d, {'invariant': True}), + ]) + self.assertIs(c1, g.node[a]['retry']) + self.assertIs(c1, g.node[d]['retry']) + self.assertIs(c2, g.node[b]['retry']) + self.assertIs(c2, g.node[c]['retry']) + self.assertIs(c1, g.node[c2]['retry']) + self.assertIs(None, g.node[c1].get('retry')) def test_flatten_retry_subflows_hierarchy(self): c1 = retry.AlwaysRevert("cp1") @@ -285,9 +375,14 @@ class FlattenTest(test.TestCase): d) g = f_utils.flatten(flo) self.assertEqual(5, len(g)) - self.assertEqual(4, g.number_of_edges()) - self.assertEqual(c1, g.node[a]['retry']) - self.assertEqual(c1, g.node[d]['retry']) - self.assertEqual(c1, g.node[b]['retry']) - self.assertEqual(c1, g.node[c]['retry']) - self.assertEqual(None, g.node[c1].get('retry')) + self.assertItemsEqual(g.edges(data=True), [ + (c1, a, {'retry': True}), + (a, b, {'invariant': True}), + (b, c, {'invariant': True}), + (c, d, {'invariant': True}), + ]) + self.assertIs(c1, g.node[a]['retry']) + self.assertIs(c1, g.node[d]['retry']) + self.assertIs(c1, g.node[b]['retry']) + self.assertIs(c1, g.node[c]['retry']) + self.assertIs(None, g.node[c1].get('retry')) diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py index 36c5149e..0d1d2bbf 100644 --- a/taskflow/utils/flow_utils.py +++ b/taskflow/utils/flow_utils.py @@ -15,7 +15,6 @@ # under the License. import logging -import threading import networkx as nx @@ -24,17 +23,14 @@ from taskflow import flow from taskflow import retry from taskflow import task from taskflow.utils import graph_utils as gu -from taskflow.utils import lock_utils as lu from taskflow.utils import misc LOG = logging.getLogger(__name__) -# Use the 'flatten' attribute as the need to add an edge here, which is useful -# for doing later analysis of the edges (to determine why the edges were -# created). -FLATTEN_EDGE_DATA = { - 'flatten': True, + +RETRY_EDGE_DATA = { + 'retry': True, } @@ -44,19 +40,12 @@ class Flattener(object): self._graph = None self._history = set() self._freeze = bool(freeze) - self._lock = threading.Lock() - self._edge_data = FLATTEN_EDGE_DATA.copy() - def _add_new_edges(self, graph, nodes_from, nodes_to, edge_attrs=None): + def _add_new_edges(self, graph, nodes_from, nodes_to, edge_attrs): """Adds new edges from nodes to other nodes in the specified graph, with the following edge attributes (defaulting to the class provided edge_data if None), if the edge does not already exist. """ - if edge_attrs is None: - edge_attrs = self._edge_data - else: - edge_attrs = edge_attrs.copy() - edge_attrs.update(self._edge_data) for u in nodes_from: for v in nodes_to: if not graph.has_edge(u, v): @@ -88,18 +77,16 @@ class Flattener(object): def _connect_retry(self, retry, graph): graph.add_node(retry) - # All graph nodes that has not predecessors should be depended on its - # retry - for n in gu.get_no_predecessors(graph): - if n != retry: - # modified that the same copy isn't modified. - graph.add_edge(retry, n, FLATTEN_EDGE_DATA.copy()) + + # All graph nodes that have no predecessors should depend on its retry + nodes_to = [n for n in gu.get_no_predecessors(graph) if n != retry] + self._add_new_edges(graph, [retry], nodes_to, RETRY_EDGE_DATA) # Add link to retry for each node of subgraph that hasn't # a parent retry for n in graph.nodes_iter(): if n != retry and 'retry' not in graph.node[n]: - graph.add_node(n, {'retry': retry}) + graph.node[n]['retry'] = retry def _flatten_task(self, task): """Flattens a individual task.""" @@ -116,7 +103,7 @@ class Flattener(object): subgraph = self._flatten(item) subgraph_map[item] = subgraph graph = gu.merge_graphs([graph, subgraph]) - # Reconnect all node edges to there corresponding subgraphs. + # Reconnect all node edges to their corresponding subgraphs. for (u, v, u_v_attrs) in flow.iter_links(): # Connect the ones with no predecessors in v to the ones with no # successors in u (thus maintaining the edge dependency). @@ -162,7 +149,6 @@ class Flattener(object): "found: %s" % (dup_names)) self._history.clear() - @lu.locked def flatten(self): """Flattens a item (a task or flow) into a single execution graph.""" if self._graph is not None: