diff --git a/taskflow/tests/unit/test_notifier.py b/taskflow/tests/unit/test_notifier.py index ab536077a..d9d40001e 100644 --- a/taskflow/tests/unit/test_notifier.py +++ b/taskflow/tests/unit/test_notifier.py @@ -91,11 +91,16 @@ class NotifierTest(test.TestCase): call_counts[registered_state].append((state, details)) notifier = nt.Notifier() - notifier.register(states.SUCCESS, - functools.partial(call_me_on, states.SUCCESS)) - notifier.register(nt.Notifier.ANY, - functools.partial(call_me_on, - nt.Notifier.ANY)) + + call_me_on_success = functools.partial(call_me_on, states.SUCCESS) + notifier.register(states.SUCCESS, call_me_on_success) + self.assertTrue(notifier.is_registered(states.SUCCESS, + call_me_on_success)) + + call_me_on_any = functools.partial(call_me_on, nt.Notifier.ANY) + notifier.register(nt.Notifier.ANY, call_me_on_any) + self.assertTrue(notifier.is_registered(nt.Notifier.ANY, + call_me_on_any)) self.assertEqual(2, len(notifier)) notifier.notify(states.SUCCESS, {}) @@ -107,3 +112,62 @@ class NotifierTest(test.TestCase): self.assertEqual(2, len(call_counts[nt.Notifier.ANY])) self.assertEqual(1, len(call_counts[states.SUCCESS])) self.assertEqual(2, len(call_counts)) + + def test_details_filter(self): + call_counts = collections.defaultdict(list) + + def call_me_on(registered_state, state, details): + call_counts[registered_state].append((state, details)) + + def when_red(details): + return details.get('color') == 'red' + + notifier = nt.Notifier() + + call_me_on_success = functools.partial(call_me_on, states.SUCCESS) + notifier.register(states.SUCCESS, call_me_on_success, + details_filter=when_red) + self.assertEqual(1, len(notifier)) + self.assertTrue(notifier.is_registered( + states.SUCCESS, call_me_on_success, details_filter=when_red)) + + notifier.notify(states.SUCCESS, {}) + self.assertEqual(0, len(call_counts[states.SUCCESS])) + notifier.notify(states.SUCCESS, {'color': 'red'}) + self.assertEqual(1, len(call_counts[states.SUCCESS])) + notifier.notify(states.SUCCESS, {'color': 'green'}) + self.assertEqual(1, len(call_counts[states.SUCCESS])) + + def test_different_details_filter(self): + call_counts = collections.defaultdict(list) + + def call_me_on(registered_state, state, details): + call_counts[registered_state].append((state, details)) + + def when_red(details): + return details.get('color') == 'red' + + def when_blue(details): + return details.get('color') == 'blue' + + notifier = nt.Notifier() + + call_me_on_success = functools.partial(call_me_on, states.SUCCESS) + notifier.register(states.SUCCESS, call_me_on_success, + details_filter=when_red) + notifier.register(states.SUCCESS, call_me_on_success, + details_filter=when_blue) + self.assertEqual(2, len(notifier)) + self.assertTrue(notifier.is_registered( + states.SUCCESS, call_me_on_success, details_filter=when_blue)) + self.assertTrue(notifier.is_registered( + states.SUCCESS, call_me_on_success, details_filter=when_red)) + + notifier.notify(states.SUCCESS, {}) + self.assertEqual(0, len(call_counts[states.SUCCESS])) + notifier.notify(states.SUCCESS, {'color': 'red'}) + self.assertEqual(1, len(call_counts[states.SUCCESS])) + notifier.notify(states.SUCCESS, {'color': 'blue'}) + self.assertEqual(2, len(call_counts[states.SUCCESS])) + notifier.notify(states.SUCCESS, {'color': 'green'}) + self.assertEqual(2, len(call_counts[states.SUCCESS])) diff --git a/taskflow/types/notifier.py b/taskflow/types/notifier.py index b8ce8c5fd..8e58302fa 100644 --- a/taskflow/types/notifier.py +++ b/taskflow/types/notifier.py @@ -15,7 +15,6 @@ # under the License. import collections -import copy import logging import six @@ -25,6 +24,56 @@ from taskflow.utils import reflection LOG = logging.getLogger(__name__) +class _Listener(object): + """Internal helper that represents a notification listener/target.""" + + def __init__(self, callback, args=None, kwargs=None, details_filter=None): + self._callback = callback + self._details_filter = details_filter + if not args: + self._args = () + else: + self._args = args[:] + if not kwargs: + self._kwargs = {} + else: + self._kwargs = kwargs.copy() + + def __call__(self, event_type, details): + if self._details_filter is not None: + if not self._details_filter(details): + return + kwargs = self._kwargs.copy() + kwargs['details'] = details + self._callback(event_type, *self._args, **kwargs) + + def __repr__(self): + repr_msg = "%s object at 0x%x calling into '%r'" % ( + reflection.get_class_name(self), id(self), self._callback) + if self._details_filter is not None: + repr_msg += " using details filter '%r'" % self._details_filter + return "<%s>" % repr_msg + + def is_equivalent(self, callback, details_filter=None): + if not reflection.is_same_callback(self._callback, callback): + return False + if details_filter is not None: + if self._details_filter is None: + return False + else: + return reflection.is_same_callback(self._details_filter, + details_filter) + else: + return self._details_filter is None + + def __eq__(self, other): + if isinstance(other, _Listener): + return self.is_equivalent(other._callback, + details_filter=other._details_filter) + else: + return NotImplemented + + class Notifier(object): """A notification helper class. @@ -34,7 +83,7 @@ class Notifier(object): notification occurs. """ - #: Keys that can not be used in callbacks arguments + #: Keys that can *not* be used in callbacks arguments RESERVED_KEYS = ('details',) #: Kleene star constant that is used to recieve all notifications @@ -46,15 +95,14 @@ class Notifier(object): def __len__(self): """Returns how many callbacks are registered.""" count = 0 - for (_event_type, callbacks) in six.iteritems(self._listeners): - count += len(callbacks) + for (_event_type, listeners) in six.iteritems(self._listeners): + count += len(listeners) return count - def is_registered(self, event_type, callback): + def is_registered(self, event_type, callback, details_filter=None): """Check if a callback is registered.""" - listeners = list(self._listeners.get(event_type, [])) - for (cb, _args, _kwargs) in listeners: - if reflection.is_same_callback(cb, callback): + for listener in self._listeners.get(event_type, []): + if listener.is_equivalent(callback, details_filter=details_filter): return True return False @@ -72,52 +120,55 @@ class Notifier(object): :param details: addition event details """ listeners = list(self._listeners.get(self.ANY, [])) - for i in self._listeners[event_type]: - if i not in listeners: - listeners.append(i) + for listener in self._listeners[event_type]: + if listener not in listeners: + listeners.append(listener) if not listeners: return - for (callback, args, kwargs) in listeners: - if args is None: - args = [] - if kwargs is None: - kwargs = {} - kwargs['details'] = details + for listener in listeners: try: - callback(event_type, *args, **kwargs) + listener(event_type, details) except Exception: - LOG.warn("Failure calling callback %s to notify about event" - " %s, details: %s", callback, event_type, + LOG.warn("Failure calling listener %s to notify about event" + " %s, details: %s", listener, event_type, details, exc_info=True) - def register(self, event_type, callback, args=None, kwargs=None): + def register(self, event_type, callback, + args=None, kwargs=None, details_filter=None): """Register a callback to be called when event of a given type occurs. Callback will be called with provided ``args`` and ``kwargs`` and when event type occurs (or on any event if ``event_type`` equals to :attr:`.ANY`). It will also get additional keyword argument, ``details``, that will hold event details provided to the - :meth:`.notify` method. + :meth:`.notify` method (if a details filter callback is provided then + the target callback will *only* be triggered if the details filter + callback returns a truthy value). """ if not six.callable(callback): - raise ValueError("Notification callback must be callable") - if self.is_registered(event_type, callback): - raise ValueError("Notification callback already registered") + raise ValueError("Event callback must be callable") + if details_filter is not None: + if not six.callable(details_filter): + raise ValueError("Details filter must be callable") + if self.is_registered(event_type, callback, + details_filter=details_filter): + raise ValueError("Event callback already registered with" + " equivalent details filter") if kwargs: for k in self.RESERVED_KEYS: if k in kwargs: - raise KeyError(("Reserved key '%s' not allowed in " - "kwargs") % k) - kwargs = copy.copy(kwargs) - if args: - args = copy.copy(args) - self._listeners[event_type].append((callback, args, kwargs)) + raise KeyError("Reserved key '%s' not allowed in " + "kwargs" % k) + self._listeners[event_type].append( + _Listener(callback, + args=args, kwargs=kwargs, + details_filter=details_filter)) - def deregister(self, event_type, callback): + def deregister(self, event_type, callback, details_filter=None): """Remove a single callback from listening to event ``event_type``.""" if event_type not in self._listeners: return - for i, (cb, args, kwargs) in enumerate(self._listeners[event_type]): - if reflection.is_same_callback(cb, callback): + for i, listener in enumerate(self._listeners[event_type]): + if listener.is_equivalent(callback, details_filter=details_filter): self._listeners[event_type].pop(i) break