Fixes how instances methods are not deregistered

It appears that instance methods do not get removed
due to not passing the 'is' check, so switch to using
a new reflection utility method which can handle this
equality check in a way that actually works.

Also adds tests to make sure this does not occur again.

Closes-Bug: 1257550

Change-Id: Iab47dd62cb61de0d93d0fe8d90e59772beebeaeb
This commit is contained in:
Joshua Harlow
2013-12-03 17:50:44 -08:00
parent 4400723399
commit 9e272c6e2f
3 changed files with 158 additions and 2 deletions

View File

@@ -16,8 +16,11 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
import functools
import sys import sys
from taskflow import states
from taskflow import test from taskflow import test
from taskflow.tests import utils as test_utils from taskflow.tests import utils as test_utils
from taskflow.utils import lock_utils from taskflow.utils import lock_utils
@@ -61,6 +64,46 @@ class ClassWithInit(object):
pass pass
class CallbackEqualityTest(test.TestCase):
def test_different_simple_callbacks(self):
def a():
pass
def b():
pass
self.assertFalse(reflection.is_same_callback(a, b))
def test_static_instance_callbacks(self):
class A(object):
@staticmethod
def b(a, b, c):
pass
a = A()
b = A()
self.assertTrue(reflection.is_same_callback(a.b, b.b))
def test_different_instance_callbacks(self):
class A(object):
def b(self):
pass
def __eq__(self, other):
return True
b = A()
c = A()
self.assertFalse(reflection.is_same_callback(b.b, c.b))
self.assertTrue(reflection.is_same_callback(b.b, c.b, strict=False))
class GetCallableNameTest(test.TestCase): class GetCallableNameTest(test.TestCase):
def test_mere_function(self): def test_mere_function(self):
@@ -99,6 +142,88 @@ class GetCallableNameTest(test.TestCase):
'__call__'))) '__call__')))
class NotifierTest(test.TestCase):
def test_notify_called(self):
call_collector = []
def call_me(state, details):
call_collector.append((state, details))
notifier = misc.TransitionNotifier()
notifier.register(misc.TransitionNotifier.ANY, call_me)
notifier.notify(states.SUCCESS, {})
notifier.notify(states.SUCCESS, {})
self.assertEqual(2, len(call_collector))
self.assertEqual(1, len(notifier))
def test_notify_register_deregister(self):
def call_me(state, details):
pass
class A(object):
def call_me_too(self, state, details):
pass
notifier = misc.TransitionNotifier()
notifier.register(misc.TransitionNotifier.ANY, call_me)
a = A()
notifier.register(misc.TransitionNotifier.ANY, a.call_me_too)
self.assertEqual(2, len(notifier))
notifier.deregister(misc.TransitionNotifier.ANY, call_me)
notifier.deregister(misc.TransitionNotifier.ANY, a.call_me_too)
self.assertEqual(0, len(notifier))
def test_notify_reset(self):
def call_me(state, details):
pass
notifier = misc.TransitionNotifier()
notifier.register(misc.TransitionNotifier.ANY, call_me)
self.assertEqual(1, len(notifier))
notifier.reset()
self.assertEqual(0, len(notifier))
def test_bad_notify(self):
def call_me(state, details):
pass
notifier = misc.TransitionNotifier()
self.assertRaises(KeyError, notifier.register,
misc.TransitionNotifier.ANY, call_me,
kwargs={'details': 5})
def test_selective_notify(self):
call_counts = collections.defaultdict(list)
def call_me_on(registered_state, state, details):
call_counts[registered_state].append((state, details))
notifier = misc.TransitionNotifier()
notifier.register(states.SUCCESS,
functools.partial(call_me_on, states.SUCCESS))
notifier.register(misc.TransitionNotifier.ANY,
functools.partial(call_me_on,
misc.TransitionNotifier.ANY))
self.assertEqual(2, len(notifier))
notifier.notify(states.SUCCESS, {})
self.assertEqual(1, len(call_counts[misc.TransitionNotifier.ANY]))
self.assertEqual(1, len(call_counts[states.SUCCESS]))
notifier.notify(states.FAILURE, {})
self.assertEqual(2, len(call_counts[misc.TransitionNotifier.ANY]))
self.assertEqual(1, len(call_counts[states.SUCCESS]))
self.assertEqual(2, len(call_counts))
class GetCallableArgsTest(test.TestCase): class GetCallableArgsTest(test.TestCase):
def test_mere_function(self): def test_mere_function(self):

View File

@@ -214,8 +214,16 @@ class TransitionNotifier(object):
def __init__(self): def __init__(self):
self._listeners = collections.defaultdict(list) self._listeners = collections.defaultdict(list)
def __len__(self):
"""Returns how many callbacks are registered"""
count = 0
for (_s, callbacks) in six.iteritems(self._listeners):
count += len(callbacks)
return count
def reset(self): def reset(self):
self._listeners = collections.defaultdict(list) self._listeners.clear()
def notify(self, state, details): def notify(self, state, details):
listeners = list(self._listeners.get(self.ANY, [])) listeners = list(self._listeners.get(self.ANY, []))
@@ -255,7 +263,7 @@ class TransitionNotifier(object):
if state not in self._listeners: if state not in self._listeners:
return return
for i, (cb, args, kwargs) in enumerate(self._listeners[state]): for i, (cb, args, kwargs) in enumerate(self._listeners[state]):
if cb is callback: if reflection.is_same_callback(cb, callback):
self._listeners[state].pop(i) self._listeners[state].pop(i)
break break

View File

@@ -93,6 +93,29 @@ def get_method_self(method):
return None return None
def is_same_callback(callback1, callback2, strict=True):
"""Returns if the two callbacks are the same."""
if callback1 is callback2:
# This happens when plain methods are given (or static/non-bound
# methods).
return True
if callback1 == callback2:
if not strict:
return True
# If two bound method are equal if functions themselves are equal
# and objects they are applied to are equal. This means that a bound
# method could be the same bound method on another object if the
# objects have __eq__ methods that return true (when in fact it is a
# different bound method). Python u so crazy!
try:
self1 = six.get_method_self(callback1)
self2 = six.get_method_self(callback2)
return self1 is self2
except AttributeError:
pass
return False
def is_bound_method(method): def is_bound_method(method):
"""Returns if the method given is a bound to a object or not.""" """Returns if the method given is a bound to a object or not."""
return bool(get_method_self(method)) return bool(get_method_self(method))