diff --git a/taskflow/task.py b/taskflow/task.py index 32bde962..7d17101b 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -18,6 +18,7 @@ # under the License. import abc +import collections import logging import six @@ -131,8 +132,8 @@ class BaseTask(object): # can be useful in resuming older versions of tasks. Standard # major, minor version semantics apply. self.version = (1, 0) - # Map of events => callback functions to invoke on task events. - self._events_listeners = {} + # Map of events => lists of callbacks to invoke on task events. + self._events_listeners = collections.defaultdict(list) @property def name(self): @@ -181,43 +182,51 @@ class BaseTask(object): def _trigger(self, event, *args, **kwargs): """Execute all handlers for the given event type.""" - if event in self._events_listeners: - for handler in self._events_listeners[event]: - event_data = self._events_listeners[event][handler] - try: - handler(self, event_data, *args, **kwargs) - except Exception: - LOG.exception("Failed calling `%s` on event '%s'", - reflection.get_callable_name(handler), event) + for (handler, event_data) in self._events_listeners.get(event, []): + try: + handler(self, event_data, *args, **kwargs) + except Exception: + LOG.exception("Failed calling `%s` on event '%s'", + reflection.get_callable_name(handler), event) def bind(self, event, handler, **kwargs): """Attach a handler to an event for the task. :param event: event type - :param handler: function to execute each time event is triggered + :param handler: callback to execute each time event is triggered :param kwargs: optional named parameters that will be passed to the event handler :raises ValueError: if invalid event type passed """ if event not in self.TASK_EVENTS: - raise ValueError("Unknown task event %s" % event) - if event not in self._events_listeners: - self._events_listeners[event] = {} - self._events_listeners[event][handler] = kwargs + raise ValueError("Unknown task event '%s', can only bind" + " to events %s" % (event, self.TASK_EVENTS)) + assert six.callable(handler), "Handler must be callable" + self._events_listeners[event].append((handler, kwargs)) def unbind(self, event, handler=None): """Remove a previously-attached event handler from the task. If handler - function not passed, then unbind all event handlers. + function not passed, then unbind all event handlers for the provided + event. If multiple of the same handlers are bound, then the first + match is removed (and only the first match). :param event: event type - :param handler: previously attached to event function + :param handler: handler previously bound + + :rtype: boolean + :return: whether anything was removed """ - if event in self._events_listeners: - if not handler: - self._events_listeners[event] = {} - else: - if handler in self._events_listeners[event]: - self._events_listeners[event].pop(handler) + removed_any = False + if not handler: + removed_any = self._events_listeners.pop(event, removed_any) + else: + event_listeners = self._events_listeners.get(event, []) + for i, (handler2, _event_data) in enumerate(event_listeners): + if reflection.is_same_callback(handler, handler2): + event_listeners.pop(i) + removed_any = True + break + return bool(removed_any) class Task(BaseTask): diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py index 22900ab7..0674628f 100644 --- a/taskflow/utils/reflection.py +++ b/taskflow/utils/reflection.py @@ -17,9 +17,10 @@ # under the License. import inspect -import six import types +import six + def get_member_names(obj, exclude_hidden=True): """Get all the member names for a object."""