diff --git a/taskflow/flow.py b/taskflow/flow.py index 56786d4d..4d93edf4 100644 --- a/taskflow/flow.py +++ b/taskflow/flow.py @@ -98,6 +98,15 @@ class Flow(object): * ``meta`` is link metadata, a dictionary. """ + @abc.abstractmethod + def iter_nodes(self): + """Iterate over nodes of the flow. + + Iterates over 2-tuples ``(A, meta)``, where + * ``A`` is a child (atom or subflow) of current flow; + * ``meta`` is link metadata, a dictionary. + """ + def __str__(self): return "%s: %s(len=%d)" % (reflection.get_class_name(self), self.name, len(self)) diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 37da34a6..c0745e1e 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -266,12 +266,16 @@ class Flow(flow.Flow): return self._get_subgraph().number_of_nodes() def __iter__(self): - for n in self._get_subgraph().topological_sort(): + for n, _n_data in self.iter_nodes(): yield n def iter_links(self): - for (u, v, e_data) in self._get_subgraph().edges_iter(data=True): - yield (u, v, e_data) + return self._get_subgraph().edges_iter(data=True) + + def iter_nodes(self): + g = self._get_subgraph() + for n in g.topological_sort(): + yield n, g.node[n] @property def requires(self): diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 3067076c..f581ce45 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -60,3 +60,7 @@ class Flow(flow.Flow): def iter_links(self): for src, dst in zip(self._children[:-1], self._children[1:]): yield (src, dst, _LINK_METADATA.copy()) + + def iter_nodes(self): + for n in self._children: + yield (n, {}) diff --git a/taskflow/patterns/unordered_flow.py b/taskflow/patterns/unordered_flow.py index 52bd286e..036ca2f9 100644 --- a/taskflow/patterns/unordered_flow.py +++ b/taskflow/patterns/unordered_flow.py @@ -48,6 +48,10 @@ class Flow(flow.Flow): # between each other due to invariants retained during construction. return iter(()) + def iter_nodes(self): + for n in self._children: + yield (n, {}) + @property def requires(self): requires = set() diff --git a/taskflow/tests/unit/patterns/test_graph_flow.py b/taskflow/tests/unit/patterns/test_graph_flow.py index 588361c9..1d876d56 100644 --- a/taskflow/tests/unit/patterns/test_graph_flow.py +++ b/taskflow/tests/unit/patterns/test_graph_flow.py @@ -212,6 +212,31 @@ class GraphFlowTest(test.TestCase): f = gf.Flow('test').add(task1, task2) self.assertRaises(exc.DependencyFailure, f.add, task3) + def test_iter_nodes(self): + task1 = _task('task1', provides=['a'], requires=['c']) + task2 = _task('task2', provides=['b'], requires=['a']) + task3 = _task('task3', provides=['c']) + f1 = gf.Flow('nested') + f1.add(task3) + tasks = set([task1, task2, f1]) + f = gf.Flow('test').add(task1, task2, f1) + for (n, data) in f.iter_nodes(): + self.assertTrue(n in tasks) + self.assertDictEqual({}, data) + + def test_iter_links(self): + task1 = _task('task1') + task2 = _task('task2') + task3 = _task('task3') + f1 = gf.Flow('nested') + f1.add(task3) + tasks = set([task1, task2, f1]) + f = gf.Flow('test').add(task1, task2, f1) + for (u, v, data) in f.iter_links(): + self.assertTrue(u in tasks) + self.assertTrue(v in tasks) + self.assertDictEqual({}, data) + class TargetedGraphFlowTest(test.TestCase): diff --git a/taskflow/tests/unit/patterns/test_linear_flow.py b/taskflow/tests/unit/patterns/test_linear_flow.py index 48f8f8de..fa39e173 100644 --- a/taskflow/tests/unit/patterns/test_linear_flow.py +++ b/taskflow/tests/unit/patterns/test_linear_flow.py @@ -118,3 +118,24 @@ class LinearFlowTest(test.TestCase): self.assertEqual(f.requires, set(['a'])) self.assertEqual(f.provides, set(['b'])) + + def test_iter_nodes(self): + task1 = _task(name='task1') + task2 = _task(name='task2') + task3 = _task(name='task3') + f = lf.Flow('test').add(task1, task2, task3) + tasks = set([task1, task2, task3]) + for (node, data) in f.iter_nodes(): + self.assertTrue(node in tasks) + self.assertDictEqual({}, data) + + def test_iter_links(self): + task1 = _task(name='task1') + task2 = _task(name='task2') + task3 = _task(name='task3') + f = lf.Flow('test').add(task1, task2, task3) + tasks = set([task1, task2, task3]) + for (u, v, data) in f.iter_links(): + self.assertTrue(u in tasks) + self.assertTrue(v in tasks) + self.assertDictEqual(lf._LINK_METADATA, data) diff --git a/taskflow/tests/unit/patterns/test_unordered_flow.py b/taskflow/tests/unit/patterns/test_unordered_flow.py index 195516b6..eeb3bb2b 100644 --- a/taskflow/tests/unit/patterns/test_unordered_flow.py +++ b/taskflow/tests/unit/patterns/test_unordered_flow.py @@ -108,3 +108,21 @@ class UnorderedFlowTest(test.TestCase): self.assertEqual(ret.name, 'test_retry') self.assertEqual(f.requires, set([])) self.assertEqual(f.provides, set(['b', 'a'])) + + def test_iter_nodes(self): + task1 = _task(name='task1', provides=['a', 'b']) + task2 = _task(name='task2', provides=['a', 'c']) + tasks = set([task1, task2]) + f = uf.Flow('test') + f.add(task2, task1) + for (node, data) in f.iter_nodes(): + self.assertTrue(node in tasks) + self.assertDictEqual({}, data) + + def test_iter_links(self): + task1 = _task(name='task1', provides=['a', 'b']) + task2 = _task(name='task2', provides=['a', 'c']) + f = uf.Flow('test') + f.add(task2, task1) + for (u, v, data) in f.iter_links(): + raise AssertionError('links iterator should be empty')