From f57bad764a50896dd6b7189bfb6f30419ebc8e53 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Wed, 1 Jan 2014 10:08:36 -0800 Subject: [PATCH] Refactor task handler binding Instead of using a dictionary of handler mappings use a dict[event] -> list of handler mappings, this allows the same function/handler to be bound more than once (which can be useful to do, depending on handler specifications) and to allow for removal of these handlers find the first match and remove it using the reflection callback comparison routine. Change-Id: Ied31dd893502ec91a48f7d2e1ef2cd4553a07f89 --- taskflow/task.py | 55 +++++++++++++++++++++--------------- taskflow/utils/reflection.py | 3 +- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/taskflow/task.py b/taskflow/task.py index 32bde962..7d17101b 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -18,6 +18,7 @@ # under the License. import abc +import collections import logging import six @@ -131,8 +132,8 @@ class BaseTask(object): # can be useful in resuming older versions of tasks. Standard # major, minor version semantics apply. self.version = (1, 0) - # Map of events => callback functions to invoke on task events. - self._events_listeners = {} + # Map of events => lists of callbacks to invoke on task events. + self._events_listeners = collections.defaultdict(list) @property def name(self): @@ -181,43 +182,51 @@ class BaseTask(object): def _trigger(self, event, *args, **kwargs): """Execute all handlers for the given event type.""" - if event in self._events_listeners: - for handler in self._events_listeners[event]: - event_data = self._events_listeners[event][handler] - try: - handler(self, event_data, *args, **kwargs) - except Exception: - LOG.exception("Failed calling `%s` on event '%s'", - reflection.get_callable_name(handler), event) + for (handler, event_data) in self._events_listeners.get(event, []): + try: + handler(self, event_data, *args, **kwargs) + except Exception: + LOG.exception("Failed calling `%s` on event '%s'", + reflection.get_callable_name(handler), event) def bind(self, event, handler, **kwargs): """Attach a handler to an event for the task. :param event: event type - :param handler: function to execute each time event is triggered + :param handler: callback to execute each time event is triggered :param kwargs: optional named parameters that will be passed to the event handler :raises ValueError: if invalid event type passed """ if event not in self.TASK_EVENTS: - raise ValueError("Unknown task event %s" % event) - if event not in self._events_listeners: - self._events_listeners[event] = {} - self._events_listeners[event][handler] = kwargs + raise ValueError("Unknown task event '%s', can only bind" + " to events %s" % (event, self.TASK_EVENTS)) + assert six.callable(handler), "Handler must be callable" + self._events_listeners[event].append((handler, kwargs)) def unbind(self, event, handler=None): """Remove a previously-attached event handler from the task. If handler - function not passed, then unbind all event handlers. + function not passed, then unbind all event handlers for the provided + event. If multiple of the same handlers are bound, then the first + match is removed (and only the first match). :param event: event type - :param handler: previously attached to event function + :param handler: handler previously bound + + :rtype: boolean + :return: whether anything was removed """ - if event in self._events_listeners: - if not handler: - self._events_listeners[event] = {} - else: - if handler in self._events_listeners[event]: - self._events_listeners[event].pop(handler) + removed_any = False + if not handler: + removed_any = self._events_listeners.pop(event, removed_any) + else: + event_listeners = self._events_listeners.get(event, []) + for i, (handler2, _event_data) in enumerate(event_listeners): + if reflection.is_same_callback(handler, handler2): + event_listeners.pop(i) + removed_any = True + break + return bool(removed_any) class Task(BaseTask): diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py index 22900ab7..0674628f 100644 --- a/taskflow/utils/reflection.py +++ b/taskflow/utils/reflection.py @@ -17,9 +17,10 @@ # under the License. import inspect -import six import types +import six + def get_member_names(obj, exclude_hidden=True): """Get all the member names for a object."""