diff --git a/src/wrappers.py b/src/wrappers.py index d1e35c8..2ead039 100644 --- a/src/wrappers.py +++ b/src/wrappers.py @@ -513,18 +513,18 @@ try: except ImportError: pass -def _weak_function_proxy_callback(ref, self, callback): - if self._self_expired: +def _weak_function_proxy_callback(ref, proxy, callback): + if proxy._self_expired: return - self._self_expired = True + proxy._self_expired = True # This could raise an exception. We let it propagate back and let # the weakref.proxy() deal with it, at which point it generally # prints out a short error message direct to stderr and keeps going. if callback is not None: - callback(self) + callback(proxy) class WeakFunctionProxy(ObjectProxy): @@ -542,12 +542,13 @@ class WeakFunctionProxy(ObjectProxy): # the callback here so as not to cause any odd reference cycles. _callback = callback and functools.partial( - _weak_function_proxy_callback, self=self, callback=callback) + _weak_function_proxy_callback, proxy=self, + callback=callback) self._self_expired = False try: - self._self_instance = weakref.proxy(wrapped.__self__, _callback) + self._self_instance = weakref.ref(wrapped.__self__, _callback) super(WeakFunctionProxy, self).__init__( weakref.proxy(wrapped.__func__, _callback)) @@ -563,7 +564,7 @@ class WeakFunctionProxy(ObjectProxy): # function as that will trigger the reference error prior to # calling if the reference had expired. - instance = self._self_instance and self._self_instance + instance = self._self_instance and self._self_instance() function = self._self_wrapped and self._self_wrapped # If the wrapped function was originally a bound function, for diff --git a/tests/test_weak_function_proxy.py b/tests/test_weak_function_proxy.py new file mode 100644 index 0000000..9535b45 --- /dev/null +++ b/tests/test_weak_function_proxy.py @@ -0,0 +1,175 @@ +from __future__ import print_function + +import unittest +import gc + +import wrapt + +class TestWeakFunctionProxy(unittest.TestCase): + + def test_isinstance(self): + def function(a, b): + return a, b + + proxy = wrapt.WeakFunctionProxy(function) + + self.assertTrue(isinstance(proxy, type(function))) + + def test_no_callback(self): + def function(a, b): + return a, b + + proxy = wrapt.WeakFunctionProxy(function) + + self.assertEqual(proxy(1, 2), (1, 2)) + + function = None + gc.collect() + + def test_call_expired(self): + def function(a, b): + return a, b + + proxy = wrapt.WeakFunctionProxy(function) + + self.assertEqual(proxy(1, 2), (1, 2)) + + function = None + gc.collect() + + def run(*args): + proxy() + + self.assertRaises(ReferenceError, run, ()) + + def test_function(self): + def function(a, b): + return a, b + + result = [] + + def callback(proxy): + result.append(id(proxy)) + + proxy = wrapt.WeakFunctionProxy(function, callback) + + self.assertEqual(proxy(1, 2), (1, 2)) + + function = None + gc.collect() + + self.assertEqual(len(result), 1) + self.assertEqual(id(proxy), result[0]) + + def test_instancemethod_delete_instance(self): + class Class(object): + def function(self, a, b): + return a, b + + result = [] + + def callback(proxy): + result.append(id(proxy)) + + c = Class() + + proxy = wrapt.WeakFunctionProxy(c.function, callback) + + self.assertEqual(proxy(1, 2), (1, 2)) + + c = None + gc.collect() + + self.assertEqual(len(result), 1) + self.assertEqual(id(proxy), result[0]) + + def test_instancemethod_delete_function(self): + class Class(object): + def function(self, a, b): + return a, b + + result = [] + + def callback(proxy): + result.append(id(proxy)) + + c = Class() + + proxy = wrapt.WeakFunctionProxy(c.function, callback) + + self.assertEqual(proxy(1, 2), (1, 2)) + + del Class.function + gc.collect() + + self.assertEqual(len(result), 1) + self.assertEqual(id(proxy), result[0]) + + def test_instancemethod_delete_function_and_instance(self): + class Class(object): + def function(self, a, b): + return a, b + + result = [] + + def callback(proxy): + result.append(id(proxy)) + + c = Class() + + proxy = wrapt.WeakFunctionProxy(c.function, callback) + + self.assertEqual(proxy(1, 2), (1, 2)) + + c = None + del Class.function + gc.collect() + + self.assertEqual(len(result), 1) + self.assertEqual(id(proxy), result[0]) + + def test_classmethod(self): + class Class(object): + @classmethod + def function(cls, a, b): + self.assertEqual(cls, Class) + return a, b + + result = [] + + def callback(proxy): + result.append(id(proxy)) + + proxy = wrapt.WeakFunctionProxy(Class.function, callback) + + self.assertEqual(proxy(1, 2), (1, 2)) + + Class = None + gc.collect() + + self.assertEqual(len(result), 1) + self.assertEqual(id(proxy), result[0]) + + def test_staticmethod(self): + class Class(object): + @staticmethod + def function(a, b): + return a, b + + result = [] + + def callback(proxy): + result.append(id(proxy)) + + proxy = wrapt.WeakFunctionProxy(Class.function, callback) + + self.assertEqual(proxy(1, 2), (1, 2)) + + Class = None + gc.collect() + + self.assertEqual(len(result), 1) + self.assertEqual(id(proxy), result[0]) + +if __name__ == '__main__': + unittest.main()