diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index d991b684..92a7973d 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -511,6 +511,27 @@ CEO self.assertEqual(['mammal', 'reptile', 'horse', 'primate', 'monkey', 'human'], things) + def test_to_diagraph(self): + root = self._make_species() + g = root.to_digraph() + self.assertEqual(root.child_count(only_direct=False) + 1, len(g)) + for node in root.dfs_iter(include_self=True): + self.assertIn(node.item, g) + self.assertEqual([], g.predecessors('animal')) + self.assertEqual(['animal'], g.predecessors('reptile')) + self.assertEqual(['primate'], g.predecessors('human')) + self.assertEqual(['mammal'], g.predecessors('primate')) + self.assertEqual(['animal'], g.predecessors('mammal')) + self.assertEqual(['mammal', 'reptile'], g.successors('animal')) + + def test_to_digraph_retains_metadata(self): + root = tree.Node("chickens", alive=True) + dead_chicken = tree.Node("chicken.1", alive=False) + root.add(dead_chicken) + g = root.to_digraph() + self.assertEqual(g.node['chickens'], {'alive': True}) + self.assertEqual(g.node['chicken.1'], {'alive': False}) + class OrderedSetTest(test.TestCase): diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py index e0f61670..b8df57ad 100644 --- a/taskflow/types/tree.py +++ b/taskflow/types/tree.py @@ -22,6 +22,7 @@ import os import six +from taskflow.types import graph from taskflow.utils import iter_utils from taskflow.utils import misc @@ -390,3 +391,20 @@ class Node(object): return _BFSIter(self, include_self=include_self, right_to_left=right_to_left) + + def to_digraph(self): + """Converts this node + its children into a ordered directed graph. + + The graph returned will have the same structure as the + this node and its children (and tree node metadata will be translated + into graph node metadata). + + :returns: a directed graph + :rtype: :py:class:`taskflow.types.graph.OrderedDiGraph` + """ + g = graph.OrderedDiGraph() + for node in self.bfs_iter(include_self=True, right_to_left=True): + g.add_node(node.item, attr_dict=node.metadata) + if node is not self: + g.add_edge(node.parent.item, node.item) + return g