diff --git a/taskflow/tests/unit/test_types.py b/taskflow/tests/unit/test_types.py index 7c05dd2c..61ce0056 100644 --- a/taskflow/tests/unit/test_types.py +++ b/taskflow/tests/unit/test_types.py @@ -113,6 +113,44 @@ class TreeTest(test.TestCase): root = tree.Node("josh") self.assertTrue(root.empty()) + def test_removal(self): + root = self._make_species() + self.assertIsNotNone(root.remove('reptile')) + self.assertRaises(ValueError, root.remove, 'reptile') + self.assertIsNone(root.find('reptile')) + + def test_removal_direct(self): + root = self._make_species() + self.assertRaises(ValueError, root.remove, 'human', + only_direct=True) + + def test_removal_self(self): + root = self._make_species() + n = root.find('horse') + self.assertIsNotNone(n.parent) + n.remove('horse', include_self=True) + self.assertIsNone(n.parent) + self.assertIsNone(root.find('horse')) + + def test_disassociate(self): + root = self._make_species() + n = root.find('horse') + self.assertIsNotNone(n.parent) + c = n.disassociate() + self.assertEqual(1, c) + self.assertIsNone(n.parent) + self.assertIsNone(root.find('horse')) + + def test_disassociate_many(self): + root = self._make_species() + n = root.find('horse') + n.parent.add(n) + n.parent.add(n) + c = n.disassociate() + self.assertEqual(3, c) + self.assertIsNone(n.parent) + self.assertIsNone(root.find('horse')) + def test_not_empty(self): root = self._make_species() self.assertFalse(root.empty()) diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py index 34d4ef11..97059d64 100644 --- a/taskflow/types/tree.py +++ b/taskflow/types/tree.py @@ -149,6 +149,44 @@ class Node(object): return n return None + def disassociate(self): + """Removes this node from its parent (if any). + + :returns: occurences of this node that were removed from its parent. + """ + occurrences = 0 + if self.parent is not None: + p = self.parent + self.parent = None + # Remove all instances of this node from its parent. + while True: + try: + p._children.remove(self) + except ValueError: + break + else: + occurrences += 1 + return occurrences + + def remove(self, item, only_direct=False, include_self=True): + """Removes a item from this nodes children. + + This will search not only this node but also any children nodes and + finally if nothing is found then a value error is raised instead of + the normally returned *removed* node object. + + :param item: item to lookup. + :param only_direct: only look at current node and its direct children. + :param include_self: include the current node during searching. + """ + node = self.find(item, only_direct=only_direct, + include_self=include_self) + if node is None: + raise ValueError("Item '%s' not found to remove" % item) + else: + node.disassociate() + return node + def __contains__(self, item): """Returns whether item exists in this node or this nodes children.