diff --git a/doc/source/utils.rst b/doc/source/utils.rst index d8da6c0c..7878cbcb 100644 --- a/doc/source/utils.rst +++ b/doc/source/utils.rst @@ -23,6 +23,11 @@ Eventlet .. automodule:: taskflow.utils.eventlet_utils +Iterators +~~~~~~~~~ + +.. automodule:: taskflow.utils.iter_utils + Kazoo ~~~~~ diff --git a/taskflow/engines/action_engine/compiler.py b/taskflow/engines/action_engine/compiler.py index dc6c24e1..22b130a8 100644 --- a/taskflow/engines/action_engine/compiler.py +++ b/taskflow/engines/action_engine/compiler.py @@ -25,6 +25,7 @@ from taskflow import logging from taskflow import task from taskflow.types import graph as gr from taskflow.types import tree as tr +from taskflow.utils import iter_utils from taskflow.utils import misc LOG = logging.getLogger(__name__) @@ -232,8 +233,8 @@ class _FlowCompiler(object): @staticmethod def _occurence_detector(to_graph, from_graph): - return sum(1 for node in from_graph.nodes_iter() - if node in to_graph) + return iter_utils.count(node for node in from_graph.nodes_iter() + if node in to_graph) def _decompose_flow(self, flow, parent=None): """Decomposes a flow into a graph, tree node + decomposed subgraphs.""" diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index 1d3f5410..79044923 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -140,6 +140,245 @@ class TreeTest(test.TestCase): p.add(tree.Node("human")) return a + def test_pformat_species(self): + root = self._make_species() + expected = """ +animal +|__mammal +| |__horse +| |__primate +| |__monkey +| |__human +|__reptile +""" + self.assertEqual(expected.strip(), root.pformat()) + + def test_pformat_flat(self): + root = tree.Node("josh") + root.add(tree.Node("josh.1")) + expected = """ +josh +|__josh.1 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0].add(tree.Node("josh.1.1")) + expected = """ +josh +|__josh.1 + |__josh.1.1 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0][0].add(tree.Node("josh.1.1.1")) + expected = """ +josh +|__josh.1 + |__josh.1.1 + |__josh.1.1.1 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0][0][0].add(tree.Node("josh.1.1.1.1")) + expected = """ +josh +|__josh.1 + |__josh.1.1 + |__josh.1.1.1 + |__josh.1.1.1.1 +""" + self.assertEqual(expected.strip(), root.pformat()) + + def test_pformat_partial_species(self): + root = self._make_species() + + expected = """ +reptile +""" + self.assertEqual(expected.strip(), root[1].pformat()) + + expected = """ +mammal +|__horse +|__primate + |__monkey + |__human +""" + self.assertEqual(expected.strip(), root[0].pformat()) + + expected = """ +primate +|__monkey +|__human +""" + self.assertEqual(expected.strip(), root[0][1].pformat()) + + expected = """ +monkey +""" + self.assertEqual(expected.strip(), root[0][1][0].pformat()) + + def test_pformat(self): + + root = tree.Node("CEO") + + expected = """ +CEO +""" + + self.assertEqual(expected.strip(), root.pformat()) + + root.add(tree.Node("Infra")) + + expected = """ +CEO +|__Infra +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0].add(tree.Node("Infra.1")) + expected = """ +CEO +|__Infra + |__Infra.1 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root.add(tree.Node("Mail")) + expected = """ +CEO +|__Infra +| |__Infra.1 +|__Mail +""" + self.assertEqual(expected.strip(), root.pformat()) + + root.add(tree.Node("Search")) + expected = """ +CEO +|__Infra +| |__Infra.1 +|__Mail +|__Search +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[-1].add(tree.Node("Search.1")) + expected = """ +CEO +|__Infra +| |__Infra.1 +|__Mail +|__Search + |__Search.1 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[-1].add(tree.Node("Search.2")) + expected = """ +CEO +|__Infra +| |__Infra.1 +|__Mail +|__Search + |__Search.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0].add(tree.Node("Infra.2")) + expected = """ +CEO +|__Infra +| |__Infra.1 +| |__Infra.2 +|__Mail +|__Search + |__Search.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0].add(tree.Node("Infra.3")) + expected = """ +CEO +|__Infra +| |__Infra.1 +| |__Infra.2 +| |__Infra.3 +|__Mail +|__Search + |__Search.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[0][-1].add(tree.Node("Infra.3.1")) + expected = """ +CEO +|__Infra +| |__Infra.1 +| |__Infra.2 +| |__Infra.3 +| |__Infra.3.1 +|__Mail +|__Search + |__Search.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[-1][0].add(tree.Node("Search.1.1")) + expected = """ +CEO +|__Infra +| |__Infra.1 +| |__Infra.2 +| |__Infra.3 +| |__Infra.3.1 +|__Mail +|__Search + |__Search.1 + | |__Search.1.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[1].add(tree.Node("Mail.1")) + expected = """ +CEO +|__Infra +| |__Infra.1 +| |__Infra.2 +| |__Infra.3 +| |__Infra.3.1 +|__Mail +| |__Mail.1 +|__Search + |__Search.1 + | |__Search.1.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + + root[1][0].add(tree.Node("Mail.1.1")) + expected = """ +CEO +|__Infra +| |__Infra.1 +| |__Infra.2 +| |__Infra.3 +| |__Infra.3.1 +|__Mail +| |__Mail.1 +| |__Mail.1.1 +|__Search + |__Search.1 + | |__Search.1.1 + |__Search.2 +""" + self.assertEqual(expected.strip(), root.pformat()) + def test_path(self): root = self._make_species() human = root.find("human") diff --git a/taskflow/tests/unit/test_utils_iter_utils.py b/taskflow/tests/unit/test_utils_iter_utils.py new file mode 100644 index 00000000..82d470f3 --- /dev/null +++ b/taskflow/tests/unit/test_utils_iter_utils.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import string + +from six.moves import range as compat_range + +from taskflow import test +from taskflow.utils import iter_utils + + +def forever_it(): + i = 0 + while True: + yield i + i += 1 + + +class IterUtilsTest(test.TestCase): + def test_find_first_match(self): + it = forever_it() + self.assertEqual(100, iter_utils.find_first_match(it, + lambda v: v == 100)) + + def test_find_first_match_not_found(self): + it = iter(string.ascii_lowercase) + self.assertIsNone(iter_utils.find_first_match(it, + lambda v: v == '')) + + def test_count(self): + self.assertEqual(0, iter_utils.count([])) + self.assertEqual(1, iter_utils.count(['a'])) + self.assertEqual(10, iter_utils.count(compat_range(0, 10))) + self.assertEqual(1000, iter_utils.count(compat_range(0, 1000))) + self.assertEqual(0, iter_utils.count(compat_range(0))) + self.assertEqual(0, iter_utils.count(compat_range(-1))) + + def test_while_is_not(self): + it = iter(string.ascii_lowercase) + self.assertEqual(['a'], + list(iter_utils.while_is_not(it, 'a'))) + it = iter(string.ascii_lowercase) + self.assertEqual(['a', 'b'], + list(iter_utils.while_is_not(it, 'b'))) + self.assertEqual(list(string.ascii_lowercase[2:]), + list(iter_utils.while_is_not(it, 'zzz'))) + it = iter(string.ascii_lowercase) + self.assertEqual(list(string.ascii_lowercase), + list(iter_utils.while_is_not(it, ''))) diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py index 4faa8291..94c009e1 100644 --- a/taskflow/types/tree.py +++ b/taskflow/types/tree.py @@ -22,6 +22,7 @@ import os import six +from taskflow.utils import iter_utils from taskflow.utils import misc @@ -77,11 +78,25 @@ class _BFSIter(object): class Node(object): """A n-ary node class that can be used to create tree structures.""" - # Constants used when pretty formatting the node (and its children). + #: Default string prefix used in :py:meth:`.pformat`. STARTING_PREFIX = "" + + #: Default string used to create empty space used in :py:meth:`.pformat`. EMPTY_SPACE_SEP = " " + HORIZONTAL_CONN = "__" + """ + Default string used to horizontally connect a node to its + parent (used in :py:meth:`.pformat`.). + """ + VERTICAL_CONN = "|" + """ + Default string used to vertically connect a node to its + parent (used in :py:meth:`.pformat`). + """ + + #: Default line separator used in :py:meth:`.pformat`. LINE_SEP = os.linesep def __init__(self, item, **kwargs): @@ -124,18 +139,22 @@ class Node(object): yield node node = node.parent - def find(self, item, only_direct=False, include_self=True): - """Returns the node for an item if it exists in this node. + def find_first_match(self, matcher, only_direct=False, include_self=True): + """Finds the *first* node that matching callback returns true. - This will search not only this node but also any children nodes and - finally if nothing is found then None is returned instead of a node - object. + This will search not only this node but also any children nodes (in + depth first order, from right to left) and finally if nothing is + matched then ``None`` is returned instead of a node object. - :param item: item to lookup. - :param only_direct: only look at current node and its direct children. + :param matcher: callback that takes one positional argument (a node) + and returns true if it matches desired node or false + if not. + :param only_direct: only look at current node and its + direct children (implies that this does not + search using depth first). :param include_self: include the current node during searching. - :returns: the node for an item if it exists in this node + :returns: the node that matched (or ``None``) """ if only_direct: if include_self: @@ -144,10 +163,26 @@ class Node(object): it = self.reverse_iter() else: it = self.dfs_iter(include_self=include_self) - for n in it: - if n.item == item: - return n - return None + return iter_utils.find_first_match(it, matcher) + + def find(self, item, only_direct=False, include_self=True): + """Returns the *first* node for an item if it exists in this node. + + This will search not only this node but also any children nodes (in + depth first order, from right to left) and finally if nothing is + matched then ``None`` is returned instead of a node object. + + :param item: item to look for. + :param only_direct: only look at current node and its + direct children (implies that this does not + search using depth first). + :param include_self: include the current node during searching. + + :returns: the node that matched provided item (or ``None``) + """ + return self.find_first_match(lambda n: n.item == item, + only_direct=only_direct, + include_self=include_self) def disassociate(self): """Removes this node from its parent (if any). @@ -176,7 +211,9 @@ class Node(object): the normally returned *removed* node object. :param item: item to lookup. - :param only_direct: only look at current node and its direct children. + :param only_direct: only look at current node and its + direct children (implies that this does not + search using depth first). :param include_self: include the current node during searching. """ node = self.find(item, only_direct=only_direct, @@ -200,8 +237,11 @@ class Node(object): # NOTE(harlowja): 0 is the right most index, len - 1 is the left most return self._children[index] - def pformat(self, stringify_node=None): - """Recursively formats a node into a nice string representation. + def pformat(self, stringify_node=None, + linesep=LINE_SEP, vertical_conn=VERTICAL_CONN, + horizontal_conn=HORIZONTAL_CONN, empty_space=EMPTY_SPACE_SEP, + starting_prefix=STARTING_PREFIX): + """Formats this node + children into a nice string representation. **Example**:: @@ -220,33 +260,73 @@ class Node(object): |__Mobile |__Mail """ - def _inner_pformat(node, level, stringify_node): - if level == 0: - yield stringify_node(node) - prefix = self.STARTING_PREFIX - else: - yield self.HORIZONTAL_CONN + stringify_node(node) - prefix = self.EMPTY_SPACE_SEP * len(self.HORIZONTAL_CONN) - child_count = node.child_count() - for (i, child) in enumerate(node): - for (j, text) in enumerate(_inner_pformat(child, - level + 1, - stringify_node)): - if j == 0 or i + 1 < child_count: - text = prefix + self.VERTICAL_CONN + text - else: - text = prefix + self.EMPTY_SPACE_SEP + text - yield text if stringify_node is None: # Default to making a unicode string out of the nodes item... stringify_node = lambda node: six.text_type(node.item) - expected_lines = self.child_count(only_direct=False) - accumulator = six.StringIO() - for i, line in enumerate(_inner_pformat(self, 0, stringify_node)): - accumulator.write(line) - if i < expected_lines: - accumulator.write(self.LINE_SEP) - return accumulator.getvalue() + expected_lines = self.child_count(only_direct=False) + 1 + buff = six.StringIO() + conn = vertical_conn + horizontal_conn + stop_at_parent = self + for i, node in enumerate(self.dfs_iter(include_self=True), 1): + prefix = [] + connected_to_parent = False + last_node = node + # Walk through *most* of this nodes parents, and form the expected + # prefix that each parent should require, repeat this until we + # hit the root node (self) and use that as our nodes prefix + # string... + parent_node_it = iter_utils.while_is_not( + node.path_iter(include_self=True), stop_at_parent) + for j, parent_node in enumerate(parent_node_it): + if parent_node is stop_at_parent: + if j > 0: + if not connected_to_parent: + prefix.append(conn) + connected_to_parent = True + else: + # If the node was connected already then it must + # have had more than one parent, so we want to put + # the right final starting prefix on (which may be + # a empty space or another vertical connector)... + last_node = self._children[-1] + m = last_node.find_first_match(lambda n: n is node, + include_self=False, + only_direct=False) + if m is not None: + prefix.append(empty_space) + else: + prefix.append(vertical_conn) + elif parent_node is node: + # Skip ourself... (we only include ourself so that + # we can use the 'j' variable to determine if the only + # node requested is ourself in the first place); used + # in the first conditional here... + pass + else: + if not connected_to_parent: + prefix.append(conn) + spaces = len(horizontal_conn) + connected_to_parent = True + else: + # If we have already been connected to our parent + # then determine if this current node is the last + # node of its parent (and in that case just put + # on more spaces), otherwise put a vertical connector + # on and less spaces... + if parent_node[-1] is not last_node: + prefix.append(vertical_conn) + spaces = len(horizontal_conn) + else: + spaces = len(conn) + prefix.append(empty_space * spaces) + last_node = parent_node + prefix.append(starting_prefix) + for prefix_piece in reversed(prefix): + buff.write(prefix_piece) + buff.write(stringify_node(node)) + if i != expected_lines: + buff.write(linesep) + return buff.getvalue() def child_count(self, only_direct=True): """Returns how many children this node has. @@ -257,10 +337,7 @@ class Node(object): NOTE(harlowja): it does not account for the current node in this count. """ if not only_direct: - count = 0 - for _node in self.dfs_iter(): - count += 1 - return count + return iter_utils.count(self.dfs_iter()) return len(self._children) def __iter__(self): diff --git a/taskflow/utils/iter_utils.py b/taskflow/utils/iter_utils.py new file mode 100644 index 00000000..68810e8c --- /dev/null +++ b/taskflow/utils/iter_utils.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +def count(it): + """Returns how many values in the iterator (depletes the iterator).""" + return sum(1 for _value in it) + + +def find_first_match(it, matcher, not_found_value=None): + """Searches iterator for first value that matcher callback returns true.""" + for value in it: + if matcher(value): + return value + return not_found_value + + +def while_is_not(it, stop_value): + """Yields given values from iterator until stop value is passed. + + This uses the ``is`` operator to determine equivalency (and not the + ``==`` operator). + """ + for value in it: + yield value + if value is stop_value: + break