194 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			194 lines
		
	
	
		
			6.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # vim: tabstop=4 shiftwidth=4 softtabstop=4
 | |
| 
 | |
| #    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
 | |
| #
 | |
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | |
| #    not use this file except in compliance with the License. You may obtain
 | |
| #    a copy of the License at
 | |
| #
 | |
| #         http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #    Unless required by applicable law or agreed to in writing, software
 | |
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | |
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | |
| #    License for the specific language governing permissions and limitations
 | |
| #    under the License.
 | |
| 
 | |
| import six
 | |
| 
 | |
| 
 | |
| class FrozenNode(Exception):
 | |
|     """Exception raised when a frozen node is modified."""
 | |
| 
 | |
|     def __init__(self):
 | |
|         super(FrozenNode, self).__init__("Frozen node(s) can't be modified")
 | |
| 
 | |
| 
 | |
| class _DFSIter(object):
 | |
|     """Depth first iterator (non-recursive) over the child nodes."""
 | |
| 
 | |
|     def __init__(self, root, include_self=False):
 | |
|         self.root = root
 | |
|         self.include_self = bool(include_self)
 | |
| 
 | |
|     def __iter__(self):
 | |
|         stack = []
 | |
|         if self.include_self:
 | |
|             stack.append(self.root)
 | |
|         else:
 | |
|             for child_node in self.root:
 | |
|                 stack.append(child_node)
 | |
|         while stack:
 | |
|             node = stack.pop()
 | |
|             # Visit the node.
 | |
|             yield node
 | |
|             # Traverse the left & right subtree.
 | |
|             for child_node in reversed(list(node)):
 | |
|                 stack.append(child_node)
 | |
| 
 | |
| 
 | |
| class Node(object):
 | |
|     """A n-ary node class that can be used to create tree structures."""
 | |
| 
 | |
|     def __init__(self, item, **kwargs):
 | |
|         self.item = item
 | |
|         self.parent = None
 | |
|         self.metadata = dict(kwargs)
 | |
|         self.frozen = False
 | |
|         self._children = []
 | |
| 
 | |
|     def freeze(self):
 | |
|         if not self.frozen:
 | |
|             # This will DFS until all children are frozen as well, only
 | |
|             # after that works do we freeze ourselves (this makes it so
 | |
|             # that we don't become frozen if a child node fails to perform
 | |
|             # the freeze operation).
 | |
|             for n in self:
 | |
|                 n.freeze()
 | |
|             self.frozen = True
 | |
| 
 | |
|     def add(self, child):
 | |
|         if self.frozen:
 | |
|             raise FrozenNode()
 | |
|         child.parent = self
 | |
|         self._children.append(child)
 | |
| 
 | |
|     def empty(self):
 | |
|         """Returns if the node is a leaf node."""
 | |
|         return self.child_count() == 0
 | |
| 
 | |
|     def path_iter(self, include_self=True):
 | |
|         """Yields back the path from this node to the root node."""
 | |
|         if include_self:
 | |
|             node = self
 | |
|         else:
 | |
|             node = self.parent
 | |
|         while node is not None:
 | |
|             yield node
 | |
|             node = node.parent
 | |
| 
 | |
|     def find(self, item):
 | |
|         """Returns the node for an item if it exists in this node.
 | |
| 
 | |
|         This will search not only this node but also any children nodes and
 | |
|         finally if nothing is found then None is returned instead of a node
 | |
|         object.
 | |
|         """
 | |
|         for n in self.dfs_iter(include_self=True):
 | |
|             if n.item == item:
 | |
|                 return n
 | |
|         return None
 | |
| 
 | |
|     def __contains__(self, item):
 | |
|         """Returns if this item exists in this node or this nodes children."""
 | |
|         return self.find(item) is not None
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         # NOTE(harlowja): 0 is the right most index, len - 1 is the left most
 | |
|         return self._children[index]
 | |
| 
 | |
|     def pformat(self):
 | |
|         """Recursively formats a node into a nice string representation.
 | |
| 
 | |
|         **Example**::
 | |
| 
 | |
|             >>> from taskflow.types import tree
 | |
|             >>> yahoo = tree.Node("CEO")
 | |
|             >>> yahoo.add(tree.Node("Infra"))
 | |
|             >>> yahoo[0].add(tree.Node("Boss"))
 | |
|             >>> yahoo[0][0].add(tree.Node("Me"))
 | |
|             >>> yahoo.add(tree.Node("Mobile"))
 | |
|             >>> yahoo.add(tree.Node("Mail"))
 | |
|             >>> print(yahoo.pformat())
 | |
|             CEO
 | |
|             |__Infra
 | |
|             |  |__Boss
 | |
|             |     |__Me
 | |
|             |__Mobile
 | |
|             |__Mail
 | |
|         """
 | |
|         def _inner_pformat(node, level):
 | |
|             if level == 0:
 | |
|                 yield six.text_type(node.item)
 | |
|                 prefix = ""
 | |
|             else:
 | |
|                 yield "__%s" % six.text_type(node.item)
 | |
|                 prefix = " " * 2
 | |
|             children = list(node)
 | |
|             for (i, child) in enumerate(children):
 | |
|                 for (j, text) in enumerate(_inner_pformat(child, level + 1)):
 | |
|                     if j == 0 or i + 1 < len(children):
 | |
|                         text = prefix + "|" + text
 | |
|                     else:
 | |
|                         text = prefix + " " + text
 | |
|                     yield text
 | |
|         expected_lines = self.child_count(only_direct=False)
 | |
|         accumulator = six.StringIO()
 | |
|         for i, line in enumerate(_inner_pformat(self, 0)):
 | |
|             accumulator.write(line)
 | |
|             if i < expected_lines:
 | |
|                 accumulator.write('\n')
 | |
|         return accumulator.getvalue()
 | |
| 
 | |
|     def child_count(self, only_direct=True):
 | |
|         """Returns how many children this node has.
 | |
| 
 | |
|         This can be either only the direct children of this node or inclusive
 | |
|         of all children nodes of this node (children of children and so-on).
 | |
| 
 | |
|         NOTE(harlowja): it does not account for the current node in this count.
 | |
|         """
 | |
|         if not only_direct:
 | |
|             count = 0
 | |
|             for _node in self.dfs_iter():
 | |
|                 count += 1
 | |
|             return count
 | |
|         return len(self._children)
 | |
| 
 | |
|     def __iter__(self):
 | |
|         """Iterates over the direct children of this node (right->left)."""
 | |
|         for c in self._children:
 | |
|             yield c
 | |
| 
 | |
|     def reverse_iter(self):
 | |
|         """Iterates over the direct children of this node (left->right)."""
 | |
|         for c in reversed(self._children):
 | |
|             yield c
 | |
| 
 | |
|     def index(self, item):
 | |
|         """Finds the child index of a given item, searches in added order."""
 | |
|         index_at = None
 | |
|         for (i, child) in enumerate(self._children):
 | |
|             if child.item == item:
 | |
|                 index_at = i
 | |
|                 break
 | |
|         if index_at is None:
 | |
|             raise ValueError("%s is not contained in any child" % (item))
 | |
|         return index_at
 | |
| 
 | |
|     def dfs_iter(self, include_self=False):
 | |
|         """Depth first iteration (non-recursive) over the child nodes."""
 | |
|         return _DFSIter(self, include_self=include_self)
 | 
