Add parent back reference so that bound wrapper knows from which wrapper is was created.

This commit is contained in:
Graham Dumpleton
2013-09-23 21:04:23 +10:00
parent fc71e94995
commit 5377af8a3a
3 changed files with 103 additions and 11 deletions

View File

@@ -26,6 +26,7 @@ typedef struct {
PyObject *wrapper;
PyObject *adapter;
PyObject *bound_type;
PyObject *parent;
} WraptFunctionWrapperObject;
PyTypeObject WraptFunctionWrapperBase_Type;
@@ -1421,6 +1422,7 @@ static PyObject *WraptFunctionWrapperBase_new(PyTypeObject *type,
self->wrapper = NULL;
self->adapter = NULL;
self->bound_type = NULL;
self->parent = NULL;
return (PyObject *)self;
}
@@ -1429,7 +1431,7 @@ static PyObject *WraptFunctionWrapperBase_new(PyTypeObject *type,
static int WraptFunctionWrapperBase_raw_init(WraptFunctionWrapperObject *self,
PyObject *wrapped, PyObject *instance, PyObject *wrapper,
PyObject *adapter, PyObject *bound_type)
PyObject *adapter, PyObject *bound_type, PyObject *parent)
{
int result = 0;
@@ -1452,6 +1454,10 @@ static int WraptFunctionWrapperBase_raw_init(WraptFunctionWrapperObject *self,
Py_INCREF(bound_type);
Py_XDECREF(self->bound_type);
self->bound_type = bound_type;
Py_INCREF(parent);
Py_XDECREF(self->parent);
self->parent = parent;
}
return result;
@@ -1467,17 +1473,19 @@ static int WraptFunctionWrapperBase_init(WraptFunctionWrapperObject *self,
PyObject *wrapper = NULL;
PyObject *adapter = Py_None;
PyObject *bound_type = Py_None;
PyObject *parent = Py_None;
static char *kwlist[] = { "wrapped", "instance", "wrapper",
"adapter", "bound_type", NULL };
"adapter", "bound_type", "parent", NULL };
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO:FunctionWrapperBase",
kwlist, &wrapped, &instance, &wrapper, &adapter, &bound_type)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OOO:FunctionWrapperBase",
kwlist, &wrapped, &instance, &wrapper, &adapter, &bound_type,
&parent)) {
return -1;
}
return WraptFunctionWrapperBase_raw_init(self, wrapped, instance, wrapper,
adapter, bound_type);
adapter, bound_type, parent);
}
/* ------------------------------------------------------------------------- */
@@ -1491,6 +1499,7 @@ static int WraptFunctionWrapperBase_traverse(WraptFunctionWrapperObject *self,
Py_VISIT(self->wrapper);
Py_VISIT(self->adapter);
Py_VISIT(self->bound_type);
Py_VISIT(self->parent);
return 0;
}
@@ -1505,6 +1514,7 @@ static int WraptFunctionWrapperBase_clear(WraptFunctionWrapperObject *self)
Py_CLEAR(self->wrapper);
Py_CLEAR(self->adapter);
Py_CLEAR(self->bound_type);
Py_CLEAR(self->parent);
return 0;
}
@@ -1566,8 +1576,8 @@ static PyObject *WraptFunctionWrapperBase_descr_get(
type = Py_None;
if (descriptor) {
result = PyObject_CallFunction(self->bound_type, "(OOO)",
descriptor, obj, self->wrapper);
result = PyObject_CallFunction(self->bound_type, "(OOOOOO)",
descriptor, obj, self->wrapper, self->adapter, Py_None, self);
}
Py_XDECREF(descriptor);
@@ -1695,6 +1705,20 @@ static PyObject *WraptFunctionWrapperBase_get_self_bound_type(
return self->bound_type;
}
/* ------------------------------------------------------------------------- */
static PyObject *WraptFunctionWrapperBase_get_self_parent(
WraptFunctionWrapperObject *self, void *closure)
{
if (!self->parent) {
Py_INCREF(Py_None);
return Py_None;
}
Py_INCREF(self->parent);
return self->parent;
}
/* ------------------------------------------------------------------------- */;
static PyGetSetDef WraptFunctionWrapperBase_getset[] = {
@@ -1724,6 +1748,8 @@ static PyGetSetDef WraptFunctionWrapperBase_getset[] = {
NULL, 0 },
{ "_self_bound_type", (getter)WraptFunctionWrapperBase_get_self_bound_type,
NULL, 0 },
{ "_self_parent", (getter)WraptFunctionWrapperBase_get_self_parent,
NULL, 0 },
{ NULL },
};
@@ -2042,7 +2068,7 @@ static int WraptFunctionWrapper_init(WraptFunctionWrapperObject *self,
bound_type = (PyObject *)&WraptBoundMethodWrapper_Type;
result = WraptFunctionWrapperBase_raw_init(self, wrapped, Py_None,
wrapper, adapter, bound_type);
wrapper, adapter, bound_type, Py_None);
return result;
}

View File

@@ -373,10 +373,10 @@ class ObjectProxy(six.with_metaclass(_ObjectProxyMetaType)):
class _FunctionWrapperBase(ObjectProxy):
__slots__ = ('_self_instance', '_self_wrapper', '_self_adapter',
'_self_bound_type')
'_self_bound_type', '_self_parent')
def __init__(self, wrapped, instance, wrapper, adapter=None,
bound_type=None):
bound_type=None, parent=None):
super(_FunctionWrapperBase, self).__init__(wrapped)
@@ -384,6 +384,7 @@ class _FunctionWrapperBase(ObjectProxy):
object.__setattr__(self, '_self_wrapper', wrapper)
object.__setattr__(self, '_self_adapter', adapter)
object.__setattr__(self, '_self_bound_type', bound_type)
object.__setattr__(self, '_self_parent', parent)
def __get__(self, instance, owner):
# If we have already been bound to an instance of something, we
@@ -396,7 +397,7 @@ class _FunctionWrapperBase(ObjectProxy):
descriptor = self.__wrapped__.__get__(instance, owner)
return self._self_bound_type(descriptor, instance, self._self_wrapper,
self._self_adapter)
self._self_adapter, None, self)
def __call__(self, *args, **kwargs):
# This is generally invoked when the wrapped function is being

View File

@@ -153,5 +153,70 @@ class TestAttributeAccess(unittest.TestCase):
self.assertEqual(instance.function2._self_instance, instance)
self.assertEqual(instance.function2._self_wrapper, decorator1)
class TestParentReference(unittest.TestCase):
def test_function_decorator(self):
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)
@_decorator
def function():
pass
self.assertEqual(function._self_parent, None)
def test_class_decorator(self):
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)
@_decorator
class Class:
pass
self.assertEqual(Class._self_parent, None)
def test_instancemethod(self):
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)
class Class:
@_decorator
def function_im(self):
pass
c = Class()
self.assertNotEqual(c.function_im._self_parent, None)
self.assertNotEqual(Class.function_im._self_parent, None)
def test_classmethod(self):
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)
class Class:
@_decorator
@classmethod
def function_cm(cls):
pass
self.assertNotEqual(Class.function_cm._self_parent, None)
def test_staticmethod_inner(self):
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs):
return wrapped(*args, **kwargs)
class Class:
@_decorator
@staticmethod
def function_sm_inner():
pass
self.assertNotEqual(Class.function_sm_inner._self_parent, None)
if __name__ == '__main__':
unittest.main()