diff --git a/src/_wrappers.c b/src/_wrappers.c index 6f94cb8..70c2223 100644 --- a/src/_wrappers.c +++ b/src/_wrappers.c @@ -26,6 +26,7 @@ typedef struct { PyObject *wrapper; PyObject *adapter; PyObject *bound_type; + PyObject *parent; } WraptFunctionWrapperObject; PyTypeObject WraptFunctionWrapperBase_Type; @@ -1421,6 +1422,7 @@ static PyObject *WraptFunctionWrapperBase_new(PyTypeObject *type, self->wrapper = NULL; self->adapter = NULL; self->bound_type = NULL; + self->parent = NULL; return (PyObject *)self; } @@ -1429,7 +1431,7 @@ static PyObject *WraptFunctionWrapperBase_new(PyTypeObject *type, static int WraptFunctionWrapperBase_raw_init(WraptFunctionWrapperObject *self, PyObject *wrapped, PyObject *instance, PyObject *wrapper, - PyObject *adapter, PyObject *bound_type) + PyObject *adapter, PyObject *bound_type, PyObject *parent) { int result = 0; @@ -1452,6 +1454,10 @@ static int WraptFunctionWrapperBase_raw_init(WraptFunctionWrapperObject *self, Py_INCREF(bound_type); Py_XDECREF(self->bound_type); self->bound_type = bound_type; + + Py_INCREF(parent); + Py_XDECREF(self->parent); + self->parent = parent; } return result; @@ -1467,17 +1473,19 @@ static int WraptFunctionWrapperBase_init(WraptFunctionWrapperObject *self, PyObject *wrapper = NULL; PyObject *adapter = Py_None; PyObject *bound_type = Py_None; + PyObject *parent = Py_None; static char *kwlist[] = { "wrapped", "instance", "wrapper", - "adapter", "bound_type", NULL }; + "adapter", "bound_type", "parent", NULL }; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO:FunctionWrapperBase", - kwlist, &wrapped, &instance, &wrapper, &adapter, &bound_type)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OOO:FunctionWrapperBase", + kwlist, &wrapped, &instance, &wrapper, &adapter, &bound_type, + &parent)) { return -1; } return WraptFunctionWrapperBase_raw_init(self, wrapped, instance, wrapper, - adapter, bound_type); + adapter, bound_type, parent); } /* ------------------------------------------------------------------------- */ @@ -1491,6 +1499,7 @@ static int WraptFunctionWrapperBase_traverse(WraptFunctionWrapperObject *self, Py_VISIT(self->wrapper); Py_VISIT(self->adapter); Py_VISIT(self->bound_type); + Py_VISIT(self->parent); return 0; } @@ -1505,6 +1514,7 @@ static int WraptFunctionWrapperBase_clear(WraptFunctionWrapperObject *self) Py_CLEAR(self->wrapper); Py_CLEAR(self->adapter); Py_CLEAR(self->bound_type); + Py_CLEAR(self->parent); return 0; } @@ -1566,8 +1576,8 @@ static PyObject *WraptFunctionWrapperBase_descr_get( type = Py_None; if (descriptor) { - result = PyObject_CallFunction(self->bound_type, "(OOO)", - descriptor, obj, self->wrapper); + result = PyObject_CallFunction(self->bound_type, "(OOOOOO)", + descriptor, obj, self->wrapper, self->adapter, Py_None, self); } Py_XDECREF(descriptor); @@ -1695,6 +1705,20 @@ static PyObject *WraptFunctionWrapperBase_get_self_bound_type( return self->bound_type; } +/* ------------------------------------------------------------------------- */ + +static PyObject *WraptFunctionWrapperBase_get_self_parent( + WraptFunctionWrapperObject *self, void *closure) +{ + if (!self->parent) { + Py_INCREF(Py_None); + return Py_None; + } + + Py_INCREF(self->parent); + return self->parent; +} + /* ------------------------------------------------------------------------- */; static PyGetSetDef WraptFunctionWrapperBase_getset[] = { @@ -1724,6 +1748,8 @@ static PyGetSetDef WraptFunctionWrapperBase_getset[] = { NULL, 0 }, { "_self_bound_type", (getter)WraptFunctionWrapperBase_get_self_bound_type, NULL, 0 }, + { "_self_parent", (getter)WraptFunctionWrapperBase_get_self_parent, + NULL, 0 }, { NULL }, }; @@ -2042,7 +2068,7 @@ static int WraptFunctionWrapper_init(WraptFunctionWrapperObject *self, bound_type = (PyObject *)&WraptBoundMethodWrapper_Type; result = WraptFunctionWrapperBase_raw_init(self, wrapped, Py_None, - wrapper, adapter, bound_type); + wrapper, adapter, bound_type, Py_None); return result; } diff --git a/src/wrappers.py b/src/wrappers.py index e48ed9b..65c17aa 100644 --- a/src/wrappers.py +++ b/src/wrappers.py @@ -373,10 +373,10 @@ class ObjectProxy(six.with_metaclass(_ObjectProxyMetaType)): class _FunctionWrapperBase(ObjectProxy): __slots__ = ('_self_instance', '_self_wrapper', '_self_adapter', - '_self_bound_type') + '_self_bound_type', '_self_parent') def __init__(self, wrapped, instance, wrapper, adapter=None, - bound_type=None): + bound_type=None, parent=None): super(_FunctionWrapperBase, self).__init__(wrapped) @@ -384,6 +384,7 @@ class _FunctionWrapperBase(ObjectProxy): object.__setattr__(self, '_self_wrapper', wrapper) object.__setattr__(self, '_self_adapter', adapter) object.__setattr__(self, '_self_bound_type', bound_type) + object.__setattr__(self, '_self_parent', parent) def __get__(self, instance, owner): # If we have already been bound to an instance of something, we @@ -396,7 +397,7 @@ class _FunctionWrapperBase(ObjectProxy): descriptor = self.__wrapped__.__get__(instance, owner) return self._self_bound_type(descriptor, instance, self._self_wrapper, - self._self_adapter) + self._self_adapter, None, self) def __call__(self, *args, **kwargs): # This is generally invoked when the wrapped function is being diff --git a/tests/test_function_wrapper.py b/tests/test_function_wrapper.py index ad60806..cdda742 100644 --- a/tests/test_function_wrapper.py +++ b/tests/test_function_wrapper.py @@ -153,5 +153,70 @@ class TestAttributeAccess(unittest.TestCase): self.assertEqual(instance.function2._self_instance, instance) self.assertEqual(instance.function2._self_wrapper, decorator1) +class TestParentReference(unittest.TestCase): + + def test_function_decorator(self): + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + @_decorator + def function(): + pass + + self.assertEqual(function._self_parent, None) + + def test_class_decorator(self): + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + @_decorator + class Class: + pass + + self.assertEqual(Class._self_parent, None) + + def test_instancemethod(self): + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + class Class: + @_decorator + def function_im(self): + pass + + c = Class() + + self.assertNotEqual(c.function_im._self_parent, None) + self.assertNotEqual(Class.function_im._self_parent, None) + + def test_classmethod(self): + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + class Class: + @_decorator + @classmethod + def function_cm(cls): + pass + + self.assertNotEqual(Class.function_cm._self_parent, None) + + def test_staticmethod_inner(self): + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + class Class: + @_decorator + @staticmethod + def function_sm_inner(): + pass + + self.assertNotEqual(Class.function_sm_inner._self_parent, None) + if __name__ == '__main__': unittest.main()