From ade8bb35fa934982d33b086eab1787ea15d0b214 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Wed, 28 May 2014 18:21:11 -0700 Subject: [PATCH] Add a tree type A tree module will be very useful for tracking tree structures in taskflow. So to encourage development and usage of such structures add a type module and helper classes that can be used perform tree operations on tree structures. Change-Id: I63c0653d051aeb4d1ea8a55f0e25fc25ff9e37f1 --- taskflow/tests/unit/test_types.py | 115 +++++++++++++++++++ taskflow/types/tree.py | 182 ++++++++++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 taskflow/tests/unit/test_types.py create mode 100644 taskflow/types/tree.py diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py new file mode 100644 index 00000000..5e5b074d --- /dev/null +++ b/taskflow/tests/unit/test_types.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import networkx as nx + +from taskflow.types import graph +from taskflow.types import tree + +from taskflow import test + + +class GraphTest(test.TestCase): + def test_no_successors_no_predecessors(self): + g = graph.DiGraph() + g.add_node("a") + g.add_node("b") + g.add_node("c") + g.add_edge("b", "c") + self.assertEqual(set(['a', 'b']), + set(g.no_predecessors_iter())) + self.assertEqual(set(['a', 'c']), + set(g.no_successors_iter())) + + def test_directed(self): + g = graph.DiGraph() + g.add_node("a") + g.add_node("b") + g.add_edge("a", "b") + self.assertTrue(g.is_directed_acyclic()) + g.add_edge("b", "a") + self.assertFalse(g.is_directed_acyclic()) + + def test_frozen(self): + g = graph.DiGraph() + self.assertFalse(g.frozen) + g.add_node("b") + g.freeze() + self.assertRaises(nx.NetworkXError, g.add_node, "c") + + +class TreeTest(test.TestCase): + def _make_species(self): + # This is the following tree: + # + # animal + # |__mammal + # | |__horse + # | |__primate + # | |__monkey + # | |__human + # |__reptile + a = tree.Node("animal") + m = tree.Node("mammal") + r = tree.Node("reptile") + a.add(m) + a.add(r) + m.add(tree.Node("horse")) + p = tree.Node("primate") + m.add(p) + p.add(tree.Node("monkey")) + p.add(tree.Node("human")) + return a + + def test_path(self): + root = self._make_species() + human = root.find("human") + self.assertIsNotNone(human) + p = list([n.item for n in human.path_iter()]) + self.assertEqual(['human', 'primate', 'mammal', 'animal'], p) + + def test_empty(self): + root = tree.Node("josh") + self.assertTrue(root.empty()) + + def test_not_empty(self): + root = self._make_species() + self.assertFalse(root.empty()) + + def test_node_count(self): + root = self._make_species() + self.assertEqual(7, 1 + root.child_count(only_direct=False)) + + def test_index(self): + root = self._make_species() + self.assertEqual(0, root.index("mammal")) + self.assertEqual(1, root.index("reptile")) + + def test_contains(self): + root = self._make_species() + self.assertIn("monkey", root) + self.assertNotIn("bird", root) + + def test_freeze(self): + root = self._make_species() + root.freeze() + self.assertRaises(tree.FrozenNode, root.add, "bird") + + def test_dfs_itr(self): + root = self._make_species() + things = list([n.item for n in root.dfs_iter(include_self=True)]) + self.assertEqual(set(['animal', 'reptile', 'mammal', 'horse', + 'primate', 'monkey', 'human']), set(things)) diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py new file mode 100644 index 00000000..41369b04 --- /dev/null +++ b/taskflow/types/tree.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import six + + +class FrozenNode(Exception): + """Exception raised when a frozen node is modified.""" + + +class _DFSIter(object): + """Depth first iterator (non-recursive) over the child nodes.""" + + def __init__(self, root, include_self=False): + self.root = root + self.include_self = bool(include_self) + + def __iter__(self): + stack = [] + if self.include_self: + stack.append(self.root) + else: + for child_node in self.root: + stack.append(child_node) + while stack: + node = stack.pop() + # Visit the node. + yield node + # Traverse the left & right subtree. + for child_node in reversed(list(node)): + stack.append(child_node) + + +class Node(object): + """A n-ary node class that can be used to create tree structures.""" + + def __init__(self, item, **kwargs): + self.item = item + self.parent = None + self.metadata = dict(kwargs) + self._children = [] + self._frozen = False + + def _frozen_add(self, child): + raise FrozenNode("Frozen node(s) can't be modified") + + def freeze(self): + if not self._frozen: + for n in self: + n.freeze() + self.add = self._frozen_add + self._frozen = True + + def add(self, child): + child.parent = self + self._children.append(child) + + def empty(self): + """Returns if the node is a leaf node.""" + return self.child_count() == 0 + + def path_iter(self, include_self=True): + """Yields back the path from this node to the root node.""" + if include_self: + node = self + else: + node = self.parent + while node is not None: + yield node + node = node.parent + + def find(self, item): + """Returns the node for an item if it exists in this node. + + This will search not only this node but also any children nodes and + finally if nothing is found then None is returned instead of a node + object. + """ + for n in self.dfs_iter(include_self=True): + if n.item == item: + return n + return None + + def __contains__(self, item): + """Returns if this item exists in this node or this nodes children.""" + return self.find(item) is not None + + def __getitem__(self, index): + # NOTE(harlowja): 0 is the right most index, len - 1 is the left most + return self._children[index] + + def pformat(self): + """Recursively formats a node into a nice string representation. + + Example Input: + yahoo = tt.Node("CEO") + yahoo.add(tt.Node("Infra")) + yahoo[0].add(tt.Node("Boss")) + yahoo[0][0].add(tt.Node("Me")) + yahoo.add(tt.Node("Mobile")) + yahoo.add(tt.Node("Mail")) + + Example Output: + CEO + |__Infra + | |__Boss + | |__Me + |__Mobile + |__Mail + """ + def _inner_pformat(node, level): + if level == 0: + yield six.text_type(node.item) + prefix = "" + else: + yield "__%s" % six.text_type(node.item) + prefix = " " * 2 + children = list(node) + for (i, child) in enumerate(children): + for (j, text) in enumerate(_inner_pformat(child, level + 1)): + if j == 0 or i + 1 < len(children): + text = prefix + "|" + text + else: + text = prefix + " " + text + yield text + expected_lines = self.child_count(only_direct=False) + accumulator = six.StringIO() + for i, line in enumerate(_inner_pformat(self, 0)): + accumulator.write(line) + if i < expected_lines: + accumulator.write('\n') + return accumulator.getvalue() + + def child_count(self, only_direct=True): + """Returns how many children this node has. + + This can be either only the direct children of this node or inclusive + of all children nodes of this node (children of children and so-on). + + NOTE(harlowja): it does not account for the current node in this count. + """ + if not only_direct: + count = 0 + for _node in self.dfs_iter(): + count += 1 + return count + return len(self._children) + + def __iter__(self): + """Iterates over the direct children of this node (right->left).""" + for c in self._children: + yield c + + def index(self, item): + """Finds the child index of a given item, searchs in added order.""" + index_at = None + for (i, child) in enumerate(self._children): + if child.item == item: + index_at = i + break + if index_at is None: + raise ValueError("%s is not contained in any child" % (item)) + return index_at + + def dfs_iter(self, include_self=False): + """Depth first iteration (non-recursive) over the child nodes.""" + return _DFSIter(self, include_self=include_self)