diff --git a/src/_wrappers.c b/src/_wrappers.c index 00766d2..8752b07 100644 --- a/src/_wrappers.c +++ b/src/_wrappers.c @@ -1575,18 +1575,31 @@ static int WraptFunctionWrapperBase_init(WraptFunctionWrapperObject *self, PyObject *instance = NULL; PyObject *wrapper = NULL; PyObject *enabled = Py_None; - PyObject *binding = Py_None; + PyObject *binding = NULL; PyObject *parent = Py_None; + static PyObject *function_str = NULL; + static char *kwlist[] = { "wrapped", "instance", "wrapper", "enabled", "binding", "parent", NULL }; + if (!function_str) { +#if PY_MAJOR_VERSION >= 3 + function_str = PyUnicode_InternFromString("function"); +#else + function_str = PyString_InternFromString("function"); +#endif + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OOO:FunctionWrapperBase", kwlist, &wrapped, &instance, &wrapper, &enabled, &binding, &parent)) { return -1; } + if (!binding) + binding = function_str; + return WraptFunctionWrapperBase_raw_init(self, wrapped, instance, wrapper, enabled, binding, parent); } @@ -1684,6 +1697,7 @@ static PyObject *WraptFunctionWrapperBase_descr_get( PyObject *result = NULL; static PyObject *bound_type_str = NULL; + static PyObject *function_str = NULL; if (!bound_type_str) { #if PY_MAJOR_VERSION >= 3 @@ -1695,46 +1709,97 @@ static PyObject *WraptFunctionWrapperBase_descr_get( #endif } - /* - * If we have already been bound to an instance of something, we do - * not do it again and return ourselves again. This appears to - * mirror what Python itself does. Determine this by looking to see - * if we have a parent as it is the least expensive. - */ - - if (self->parent != Py_None) { - Py_INCREF(self); - return (PyObject *)self; + if (!function_str) { +#if PY_MAJOR_VERSION >= 3 + function_str = PyUnicode_InternFromString("function"); +#else + function_str = PyString_InternFromString("function"); +#endif } - descriptor = (Py_TYPE(self->object_proxy.wrapped)->tp_descr_get)( - self->object_proxy.wrapped, obj, type); + if (self->parent == Py_None) { + if (!obj) + obj = Py_None; + if (!type) + type = Py_TYPE(obj); - /* No point looking up bound type if not a derived class. */ + descriptor = (Py_TYPE(self->object_proxy.wrapped)->tp_descr_get)( + self->object_proxy.wrapped, obj, type); - if (Py_TYPE(self) != &WraptFunctionWrapper_Type) { - bound_type = PyObject_GenericGetAttr((PyObject *)self, bound_type_str); + if (Py_TYPE(self) != &WraptFunctionWrapper_Type) { + bound_type = PyObject_GenericGetAttr((PyObject *)self, + bound_type_str); - if (!bound_type) - PyErr_Clear(); + if (!bound_type) + PyErr_Clear(); + } + + if (descriptor) { + result = PyObject_CallFunctionObjArgs(bound_type ? bound_type : + (PyObject *)&WraptBoundFunctionWrapper_Type, descriptor, + obj, self->wrapper, self->enabled, self->binding, + self, NULL); + } + + Py_XDECREF(bound_type); + Py_XDECREF(descriptor); + + return result; } - if (!obj) - obj = Py_None; - if (!type) - type = Py_None; + if (self->instance == Py_None && (self->binding == function_str || + PyObject_RichCompareBool(self->binding, function_str, + Py_EQ) == 1)) { - if (descriptor) { - result = PyObject_CallFunctionObjArgs(bound_type ? bound_type : - (PyObject *)&WraptBoundFunctionWrapper_Type, descriptor, - obj, self->wrapper, self->enabled, self->binding, - self, NULL); + PyObject *wrapped = NULL; + + static PyObject *wrapped_str = NULL; + + if (!wrapped_str) { +#if PY_MAJOR_VERSION >= 3 + wrapped_str = PyUnicode_InternFromString("__wrapped__"); +#else + wrapped_str = PyString_InternFromString("__wrapped__"); +#endif + } + + wrapped = PyObject_GetAttr(self->parent, wrapped_str); + + if (!wrapped) + return NULL; + + if (!obj) + obj = Py_None; + if (!type) + type = Py_TYPE(obj); + + descriptor = (Py_TYPE(wrapped)->tp_descr_get)(wrapped, obj, type); + + Py_DECREF(wrapped); + + if (Py_TYPE(self->parent) != &WraptFunctionWrapper_Type) { + bound_type = PyObject_GenericGetAttr((PyObject *)self->parent, + bound_type_str); + + if (!bound_type) + PyErr_Clear(); + } + + if (descriptor) { + result = PyObject_CallFunctionObjArgs(bound_type ? bound_type : + (PyObject *)&WraptBoundFunctionWrapper_Type, descriptor, + obj, self->wrapper, self->enabled, self->binding, + self->parent, NULL); + } + + Py_XDECREF(bound_type); + Py_XDECREF(descriptor); + + return result; } - Py_XDECREF(bound_type); - Py_XDECREF(descriptor); - - return result; + Py_INCREF(self); + return (PyObject *)self; } /* ------------------------------------------------------------------------- */ @@ -1890,7 +1955,7 @@ static PyObject *WraptBoundFunctionWrapper_call( PyObject *result = NULL; - static PyObject *instancemethod_str = NULL; + static PyObject *function_str = NULL; if (self->enabled != Py_None) { if (PyCallable_Check(self->enabled)) { @@ -1913,11 +1978,11 @@ static PyObject *WraptBoundFunctionWrapper_call( } } - if (!instancemethod_str) { + if (!function_str) { #if PY_MAJOR_VERSION >= 3 - instancemethod_str = PyUnicode_InternFromString("instancemethod"); + function_str = PyUnicode_InternFromString("function"); #else - instancemethod_str = PyString_InternFromString("instancemethod"); + function_str = PyString_InternFromString("function"); #endif } @@ -1931,8 +1996,8 @@ static PyObject *WraptBoundFunctionWrapper_call( * wrapping an instance method vs a static method or class method. */ - if (self->binding == instancemethod_str || PyObject_RichCompareBool( - self->binding, instancemethod_str, Py_EQ) == 1) { + if (self->binding == function_str || PyObject_RichCompareBool( + self->binding, function_str, Py_EQ) == 1) { if (self->instance == Py_None) { /* @@ -2108,7 +2173,7 @@ static int WraptFunctionWrapper_init(WraptFunctionWrapperObject *self, static PyObject *classmethod_str = NULL; static PyObject *staticmethod_str = NULL; - static PyObject *instancemethod_str = NULL; + static PyObject *function_str = NULL; int result = 0; @@ -2135,11 +2200,11 @@ static int WraptFunctionWrapper_init(WraptFunctionWrapperObject *self, #endif } - if (!instancemethod_str) { + if (!function_str) { #if PY_MAJOR_VERSION >= 3 - instancemethod_str = PyUnicode_InternFromString("instancemethod"); + function_str = PyUnicode_InternFromString("function"); #else - instancemethod_str = PyString_InternFromString("instancemethod"); + function_str = PyString_InternFromString("function"); #endif } @@ -2148,7 +2213,7 @@ static int WraptFunctionWrapper_init(WraptFunctionWrapperObject *self, else if (PyObject_IsInstance(wrapped, (PyObject *)&PyStaticMethod_Type)) binding = staticmethod_str; else - binding = instancemethod_str; + binding = function_str; result = WraptFunctionWrapperBase_raw_init(self, wrapped, Py_None, wrapper, enabled, binding, Py_None); diff --git a/src/wrappers.py b/src/wrappers.py index 9ad0250..fb06153 100644 --- a/src/wrappers.py +++ b/src/wrappers.py @@ -384,7 +384,7 @@ class _FunctionWrapperBase(ObjectProxy): '_self_binding', '_self_parent') def __init__(self, wrapped, instance, wrapper, enabled=None, - binding=None, parent=None): + binding='function', parent=None): super(_FunctionWrapperBase, self).__init__(wrapped) @@ -395,19 +395,38 @@ class _FunctionWrapperBase(ObjectProxy): object.__setattr__(self, '_self_parent', parent) def __get__(self, instance, owner): + # If we are called in an unbound wrapper, then perform the binding. + # Note that we do this even if instance is None and accessing an + # unbound instance method from a class. This is because we need to + # be able to later detect that specific case as we will need to + # extract the instance from the first argument of those passed in. + # For the binding against an instance of None case, we also need to + # allow rebinding below. + + if self._self_parent is None: + descriptor = self.__wrapped__.__get__(instance, owner) + + return self.__bound_function_wrapper__(descriptor, instance, + self._self_wrapper, self._self_enabled, + self._self_binding, self) + # If we have already been bound to an instance of something, we - # do not do it again and return ourselves again. This appears to - # mirror what Python itself does. Determine this by looking to - # see if we have a parent as it is the least expensive. + # would usually return ourselves again. This mirrors what Python + # does. The exception is where we were originally bound to an + # instance of None and we were an instance method. In that case + # we rebind against the original wrapped function from the parent + # again. - if self._self_parent is not None: - return self + if self._self_instance is None and self._self_binding == 'function': + descriptor = self._self_parent.__wrapped__.__get__( + instance, owner) - descriptor = self.__wrapped__.__get__(instance, owner) + return self._self_parent.__bound_function_wrapper__( + descriptor, instance, self._self_wrapper, + self._self_enabled, self._self_binding, + self._self_parent) - return self.__bound_function_wrapper__(descriptor, instance, - self._self_wrapper, self._self_enabled, - self._self_binding, self) + return self def __call__(self, *args, **kwargs): # If enabled has been specified, then evaluate it at this point @@ -452,7 +471,7 @@ class BoundFunctionWrapper(_FunctionWrapperBase): # likely wrapping an instance method vs a static method or class # method. - if self._self_binding == 'instancemethod': + if self._self_binding == 'function': if self._self_instance is None: # This situation can occur where someone is calling the # instancemethod via the class type and passing the instance @@ -522,7 +541,7 @@ class FunctionWrapper(_FunctionWrapperBase): elif isinstance(wrapped, staticmethod): binding = 'staticmethod' else: - binding = 'instancemethod' + binding = 'function' super(FunctionWrapper, self).__init__(wrapped, None, wrapper, enabled, binding) diff --git a/tests/test_function_wrapper.py b/tests/test_function_wrapper.py index d9b3d39..50d8e7a 100644 --- a/tests/test_function_wrapper.py +++ b/tests/test_function_wrapper.py @@ -90,7 +90,7 @@ class TestAttributeAccess(unittest.TestCase): self.assertEqual(function2.__wrapped__, function1) self.assertEqual(function2._self_wrapper, decorator1) - self.assertEqual(function2._self_binding, 'instancemethod') + self.assertEqual(function2._self_binding, 'function') def test_instancemethod_attributes(self): def decorator1(wrapped, instance, args, kwargs): @@ -104,7 +104,7 @@ class TestAttributeAccess(unittest.TestCase): self.assertEqual(function2.__wrapped__, function1) self.assertEqual(function2._self_wrapper, decorator1) - self.assertEqual(function2._self_binding, 'instancemethod') + self.assertEqual(function2._self_binding, 'function') instance = Class() @@ -384,15 +384,54 @@ class TestFunctionBinding(unittest.TestCase): _bound_wrapper_1 = _wrapper.__get__(instance, type(instance)) + self.assertTrue(_bound_wrapper_1._self_parent is _wrapper) + self.assertTrue(isinstance(_bound_wrapper_1, wrapt.BoundFunctionWrapper)) + self.assertEqual(_bound_wrapper_1._self_instance, instance) _bound_wrapper_2 = _bound_wrapper_1.__get__(instance, type(instance)) + self.assertTrue(_bound_wrapper_2._self_parent is _wrapper) + self.assertTrue(isinstance(_bound_wrapper_2, wrapt.BoundFunctionWrapper)) + self.assertEqual(_bound_wrapper_2._self_instance, + _bound_wrapper_1._self_instance) self.assertTrue(_bound_wrapper_1 is _bound_wrapper_2) + def test_re_bind_after_none(self): + + def function(): + pass + + def wrapper(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + _wrapper = wrapt.FunctionWrapper(function, wrapper) + + self.assertTrue(isinstance(_wrapper, wrapt.FunctionWrapper)) + + instance = object() + + _bound_wrapper_1 = _wrapper.__get__(None, type(instance)) + + self.assertTrue(_bound_wrapper_1._self_parent is _wrapper) + + self.assertTrue(isinstance(_bound_wrapper_1, + wrapt.BoundFunctionWrapper)) + self.assertEqual(_bound_wrapper_1._self_instance, None) + + _bound_wrapper_2 = _bound_wrapper_1.__get__(instance, type(instance)) + + self.assertTrue(_bound_wrapper_2._self_parent is _wrapper) + + self.assertTrue(isinstance(_bound_wrapper_2, + wrapt.BoundFunctionWrapper)) + self.assertEqual(_bound_wrapper_2._self_instance, instance) + + self.assertTrue(_bound_wrapper_1 is not _bound_wrapper_2) + if __name__ == '__main__': unittest.main()