Merge "Add a tree type"
This commit is contained in:
115
taskflow/tests/unit/test_types.py
Normal file
115
taskflow/tests/unit/test_types.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from taskflow.types import graph
|
||||
from taskflow.types import tree
|
||||
|
||||
from taskflow import test
|
||||
|
||||
|
||||
class GraphTest(test.TestCase):
|
||||
def test_no_successors_no_predecessors(self):
|
||||
g = graph.DiGraph()
|
||||
g.add_node("a")
|
||||
g.add_node("b")
|
||||
g.add_node("c")
|
||||
g.add_edge("b", "c")
|
||||
self.assertEqual(set(['a', 'b']),
|
||||
set(g.no_predecessors_iter()))
|
||||
self.assertEqual(set(['a', 'c']),
|
||||
set(g.no_successors_iter()))
|
||||
|
||||
def test_directed(self):
|
||||
g = graph.DiGraph()
|
||||
g.add_node("a")
|
||||
g.add_node("b")
|
||||
g.add_edge("a", "b")
|
||||
self.assertTrue(g.is_directed_acyclic())
|
||||
g.add_edge("b", "a")
|
||||
self.assertFalse(g.is_directed_acyclic())
|
||||
|
||||
def test_frozen(self):
|
||||
g = graph.DiGraph()
|
||||
self.assertFalse(g.frozen)
|
||||
g.add_node("b")
|
||||
g.freeze()
|
||||
self.assertRaises(nx.NetworkXError, g.add_node, "c")
|
||||
|
||||
|
||||
class TreeTest(test.TestCase):
|
||||
def _make_species(self):
|
||||
# This is the following tree:
|
||||
#
|
||||
# animal
|
||||
# |__mammal
|
||||
# | |__horse
|
||||
# | |__primate
|
||||
# | |__monkey
|
||||
# | |__human
|
||||
# |__reptile
|
||||
a = tree.Node("animal")
|
||||
m = tree.Node("mammal")
|
||||
r = tree.Node("reptile")
|
||||
a.add(m)
|
||||
a.add(r)
|
||||
m.add(tree.Node("horse"))
|
||||
p = tree.Node("primate")
|
||||
m.add(p)
|
||||
p.add(tree.Node("monkey"))
|
||||
p.add(tree.Node("human"))
|
||||
return a
|
||||
|
||||
def test_path(self):
|
||||
root = self._make_species()
|
||||
human = root.find("human")
|
||||
self.assertIsNotNone(human)
|
||||
p = list([n.item for n in human.path_iter()])
|
||||
self.assertEqual(['human', 'primate', 'mammal', 'animal'], p)
|
||||
|
||||
def test_empty(self):
|
||||
root = tree.Node("josh")
|
||||
self.assertTrue(root.empty())
|
||||
|
||||
def test_not_empty(self):
|
||||
root = self._make_species()
|
||||
self.assertFalse(root.empty())
|
||||
|
||||
def test_node_count(self):
|
||||
root = self._make_species()
|
||||
self.assertEqual(7, 1 + root.child_count(only_direct=False))
|
||||
|
||||
def test_index(self):
|
||||
root = self._make_species()
|
||||
self.assertEqual(0, root.index("mammal"))
|
||||
self.assertEqual(1, root.index("reptile"))
|
||||
|
||||
def test_contains(self):
|
||||
root = self._make_species()
|
||||
self.assertIn("monkey", root)
|
||||
self.assertNotIn("bird", root)
|
||||
|
||||
def test_freeze(self):
|
||||
root = self._make_species()
|
||||
root.freeze()
|
||||
self.assertRaises(tree.FrozenNode, root.add, "bird")
|
||||
|
||||
def test_dfs_itr(self):
|
||||
root = self._make_species()
|
||||
things = list([n.item for n in root.dfs_iter(include_self=True)])
|
||||
self.assertEqual(set(['animal', 'reptile', 'mammal', 'horse',
|
||||
'primate', 'monkey', 'human']), set(things))
|
||||
182
taskflow/types/tree.py
Normal file
182
taskflow/types/tree.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class FrozenNode(Exception):
|
||||
"""Exception raised when a frozen node is modified."""
|
||||
|
||||
|
||||
class _DFSIter(object):
|
||||
"""Depth first iterator (non-recursive) over the child nodes."""
|
||||
|
||||
def __init__(self, root, include_self=False):
|
||||
self.root = root
|
||||
self.include_self = bool(include_self)
|
||||
|
||||
def __iter__(self):
|
||||
stack = []
|
||||
if self.include_self:
|
||||
stack.append(self.root)
|
||||
else:
|
||||
for child_node in self.root:
|
||||
stack.append(child_node)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
# Visit the node.
|
||||
yield node
|
||||
# Traverse the left & right subtree.
|
||||
for child_node in reversed(list(node)):
|
||||
stack.append(child_node)
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""A n-ary node class that can be used to create tree structures."""
|
||||
|
||||
def __init__(self, item, **kwargs):
|
||||
self.item = item
|
||||
self.parent = None
|
||||
self.metadata = dict(kwargs)
|
||||
self._children = []
|
||||
self._frozen = False
|
||||
|
||||
def _frozen_add(self, child):
|
||||
raise FrozenNode("Frozen node(s) can't be modified")
|
||||
|
||||
def freeze(self):
|
||||
if not self._frozen:
|
||||
for n in self:
|
||||
n.freeze()
|
||||
self.add = self._frozen_add
|
||||
self._frozen = True
|
||||
|
||||
def add(self, child):
|
||||
child.parent = self
|
||||
self._children.append(child)
|
||||
|
||||
def empty(self):
|
||||
"""Returns if the node is a leaf node."""
|
||||
return self.child_count() == 0
|
||||
|
||||
def path_iter(self, include_self=True):
|
||||
"""Yields back the path from this node to the root node."""
|
||||
if include_self:
|
||||
node = self
|
||||
else:
|
||||
node = self.parent
|
||||
while node is not None:
|
||||
yield node
|
||||
node = node.parent
|
||||
|
||||
def find(self, item):
|
||||
"""Returns the node for an item if it exists in this node.
|
||||
|
||||
This will search not only this node but also any children nodes and
|
||||
finally if nothing is found then None is returned instead of a node
|
||||
object.
|
||||
"""
|
||||
for n in self.dfs_iter(include_self=True):
|
||||
if n.item == item:
|
||||
return n
|
||||
return None
|
||||
|
||||
def __contains__(self, item):
|
||||
"""Returns if this item exists in this node or this nodes children."""
|
||||
return self.find(item) is not None
|
||||
|
||||
def __getitem__(self, index):
|
||||
# NOTE(harlowja): 0 is the right most index, len - 1 is the left most
|
||||
return self._children[index]
|
||||
|
||||
def pformat(self):
|
||||
"""Recursively formats a node into a nice string representation.
|
||||
|
||||
Example Input:
|
||||
yahoo = tt.Node("CEO")
|
||||
yahoo.add(tt.Node("Infra"))
|
||||
yahoo[0].add(tt.Node("Boss"))
|
||||
yahoo[0][0].add(tt.Node("Me"))
|
||||
yahoo.add(tt.Node("Mobile"))
|
||||
yahoo.add(tt.Node("Mail"))
|
||||
|
||||
Example Output:
|
||||
CEO
|
||||
|__Infra
|
||||
| |__Boss
|
||||
| |__Me
|
||||
|__Mobile
|
||||
|__Mail
|
||||
"""
|
||||
def _inner_pformat(node, level):
|
||||
if level == 0:
|
||||
yield six.text_type(node.item)
|
||||
prefix = ""
|
||||
else:
|
||||
yield "__%s" % six.text_type(node.item)
|
||||
prefix = " " * 2
|
||||
children = list(node)
|
||||
for (i, child) in enumerate(children):
|
||||
for (j, text) in enumerate(_inner_pformat(child, level + 1)):
|
||||
if j == 0 or i + 1 < len(children):
|
||||
text = prefix + "|" + text
|
||||
else:
|
||||
text = prefix + " " + text
|
||||
yield text
|
||||
expected_lines = self.child_count(only_direct=False)
|
||||
accumulator = six.StringIO()
|
||||
for i, line in enumerate(_inner_pformat(self, 0)):
|
||||
accumulator.write(line)
|
||||
if i < expected_lines:
|
||||
accumulator.write('\n')
|
||||
return accumulator.getvalue()
|
||||
|
||||
def child_count(self, only_direct=True):
|
||||
"""Returns how many children this node has.
|
||||
|
||||
This can be either only the direct children of this node or inclusive
|
||||
of all children nodes of this node (children of children and so-on).
|
||||
|
||||
NOTE(harlowja): it does not account for the current node in this count.
|
||||
"""
|
||||
if not only_direct:
|
||||
count = 0
|
||||
for _node in self.dfs_iter():
|
||||
count += 1
|
||||
return count
|
||||
return len(self._children)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates over the direct children of this node (right->left)."""
|
||||
for c in self._children:
|
||||
yield c
|
||||
|
||||
def index(self, item):
|
||||
"""Finds the child index of a given item, searchs in added order."""
|
||||
index_at = None
|
||||
for (i, child) in enumerate(self._children):
|
||||
if child.item == item:
|
||||
index_at = i
|
||||
break
|
||||
if index_at is None:
|
||||
raise ValueError("%s is not contained in any child" % (item))
|
||||
return index_at
|
||||
|
||||
def dfs_iter(self, include_self=False):
|
||||
"""Depth first iteration (non-recursive) over the child nodes."""
|
||||
return _DFSIter(self, include_self=include_self)
|
||||
Reference in New Issue
Block a user