diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index 5f750d20..7c05dd2c 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -136,6 +136,16 @@ class TreeTest(test.TestCase): root.freeze() self.assertRaises(tree.FrozenNode, root.add, "bird") + def test_find(self): + root = self._make_species() + self.assertIsNone(root.find('monkey', only_direct=True)) + self.assertIsNotNone(root.find('monkey', only_direct=False)) + self.assertIsNotNone(root.find('animal', only_direct=True)) + self.assertIsNotNone(root.find('reptile', only_direct=True)) + self.assertIsNone(root.find('animal', include_self=False)) + self.assertIsNone(root.find('animal', + include_self=False, only_direct=True)) + def test_dfs_itr(self): root = self._make_species() things = list([n.item for n in root.dfs_iter(include_self=True)]) diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py index c4273149..8694c5b5 100644 --- a/taskflow/types/tree.py +++ b/taskflow/types/tree.py @@ -17,6 +17,7 @@ # under the License. import collections +import itertools import os import six @@ -118,7 +119,7 @@ class Node(object): yield node node = node.parent - def find(self, item): + def find(self, item, only_direct=False, include_self=True): """Returns the node for an item if it exists in this node. This will search not only this node but also any children nodes and @@ -126,9 +127,19 @@ class Node(object): object. :param item: item to lookup. + :param only_direct: only look at current node and its direct children. + :param include_self: include the current node during searching. + :returns: the node for an item if it exists in this node """ - for n in self.dfs_iter(include_self=True): + if only_direct: + if include_self: + it = itertools.chain([self], self.reverse_iter()) + else: + it = self.reverse_iter() + else: + it = self.dfs_iter(include_self=include_self) + for n in it: if n.item == item: return n return None