diff --git a/taskflow/task.py b/taskflow/task.py index 3a0395f8..c34a13bf 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -18,6 +18,7 @@ import abc import collections import contextlib +import copy import logging import six @@ -133,6 +134,24 @@ class BaseTask(atom.Atom): This works the same as :meth:`.post_execute`, but for the revert phase. """ + def copy(self, retain_listeners=True): + """Clone/copy this task. + + :param retain_listeners: retain the attached listeners when cloning, + when false the listeners will be emptied, when + true the listeners will be copied and retained + + :rtype: task + :return: the copied task + """ + c = copy.copy(self) + c._events_listeners = c._events_listeners.copy() + c._events_listeners.clear() + if retain_listeners: + for event_name, listeners in six.iteritems(self._events_listeners): + c._events_listeners[event_name] = listeners[:] + return c + def update_progress(self, progress, **kwargs): """Update task progress and notify all registered listeners. diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index 3a963963..b80c0c6a 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -276,6 +276,29 @@ class TaskTest(test.TestCase): task = MyTask() self.assertRaises(ValueError, task.bind, 'update_progress', 2) + def test_copy_no_listeners(self): + handler1 = lambda: None + a_task = MyTask() + a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) + b_task = a_task.copy(retain_listeners=False) + self.assertEqual(len(list(a_task.listeners_iter())), 1) + self.assertEqual(len(list(b_task.listeners_iter())), 0) + + def test_copy_listeners(self): + handler1 = lambda: None + handler2 = lambda: None + a_task = MyTask() + a_task.bind(task.EVENT_UPDATE_PROGRESS, handler1) + b_task = a_task.copy() + self.assertEqual(len(list(b_task.listeners_iter())), 1) + self.assertTrue(a_task.unbind(task.EVENT_UPDATE_PROGRESS)) + self.assertEqual(len(list(a_task.listeners_iter())), 0) + self.assertEqual(len(list(b_task.listeners_iter())), 1) + b_task.bind(task.EVENT_UPDATE_PROGRESS, handler2) + listeners = dict(list(b_task.listeners_iter())) + self.assertEqual(len(listeners[task.EVENT_UPDATE_PROGRESS]), 2) + self.assertEqual(len(list(a_task.listeners_iter())), 0) + class FunctorTaskTest(test.TestCase):