From 039aecad50275b163c63aa37b83a659a67e6bdf1 Mon Sep 17 00:00:00 2001 From: Graham Dumpleton Date: Sun, 11 Aug 2013 22:20:15 +0800 Subject: [PATCH] Propagate __name__ and __qualname__ updates as well as saving them against the wrapper. --- src/_wrappers.c | 15 ++++ src/wrappers.py | 11 ++- tests/test_update_attributes.py | 133 ++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 4 deletions(-) create mode 100644 tests/test_update_attributes.py diff --git a/src/_wrappers.c b/src/_wrappers.c index 64ad541..11d5861 100644 --- a/src/_wrappers.c +++ b/src/_wrappers.c @@ -327,6 +327,8 @@ static int WraptWrapperBase_setattro( WraptWrapperBaseObject *self, PyObject *name, PyObject *value) { PyObject *self_prefix = NULL; + PyObject *attr_name = NULL; + PyObject *attr_qualname = NULL; PyObject *match = NULL; @@ -337,8 +339,12 @@ static int WraptWrapperBase_setattro( #if PY_MAJOR_VERSION >= 3 self_prefix = PyUnicode_FromString("_self_"); + attr_name = PyUnicode_FromString("__name__"); + attr_qualname = PyUnicode_FromString("__qualname__"); #else self_prefix = PyString_FromString("_self_"); + attr_name = PyString_FromString("__name__"); + attr_qualname = PyString_FromString("__qualname__"); #endif match = PyEval_CallMethod(name, "startswith", "(O)", self_prefix); @@ -353,6 +359,15 @@ static int WraptWrapperBase_setattro( Py_XDECREF(match); + if (PyObject_RichCompareBool(name, attr_name, Py_EQ) == 1 || + PyObject_RichCompareBool(name, attr_qualname, Py_EQ) == 1) { + + if (PyObject_GenericSetAttr((PyObject *)self, name, value) == -1) + return -1; + + return PyObject_SetAttr(self->wrapped, name, value); + } + return PyObject_SetAttr(self->wrapped, name, value); } diff --git a/src/wrappers.py b/src/wrappers.py index 2b6d622..f45d99a 100644 --- a/src/wrappers.py +++ b/src/wrappers.py @@ -61,9 +61,9 @@ class _WrapperBase(six.with_metaclass(_WrapperBaseMetaType)): try: if target is None: - object.__setattr__(self, '__qualname__', wrapped.__qualname__) + self.__qualname__ = wrapped.__qualname__ else: - object.__setattr__(self, '__qualname__', target.__qualname__) + self.__qualname__ = target.__qualname__ except AttributeError: pass @@ -74,15 +74,18 @@ class _WrapperBase(six.with_metaclass(_WrapperBaseMetaType)): try: if target is None: - object.__setattr__(self, '__name__', wrapped.__name__) + self. __name__ = wrapped.__name__ else: - object.__setattr__(self, '__name__', target.__name__) + self.__name__ = target.__name__ except AttributeError: pass def __setattr__(self, name, value): if name.startswith('_self_'): object.__setattr__(self, name, value) + elif name in ('__name__', '__qualname__'): + object.__setattr__(self, name, value) + setattr(self._self_wrapped, name, value) else: setattr(self._self_wrapped, name, value) diff --git a/tests/test_update_attributes.py b/tests/test_update_attributes.py new file mode 100644 index 0000000..465503e --- /dev/null +++ b/tests/test_update_attributes.py @@ -0,0 +1,133 @@ +from __future__ import print_function + +import unittest + +import wrapt + +from wrapt import six + +@wrapt.decorator +def passthru_decorator(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + +class TestUpdateAttributes(unittest.TestCase): + + def test_update_name(self): + @passthru_decorator + def function(): + pass + + self.assertEqual(function.__name__, 'function') + + function.__name__ = 'override_name' + + self.assertEqual(function.__name__, 'override_name') + + def test_update_name_modified_on_original(self): + def function(): + pass + + def wrapper(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + instance = wrapt.FunctionWrapper(function, wrapper) + + self.assertEqual(instance.__name__, 'function') + + instance.__name__ = 'override_name' + + self.assertEqual(function.__name__, 'override_name') + self.assertEqual(instance.__name__, 'override_name') + + def test_update_qualname(self): + + @passthru_decorator + def function(): + pass + + if six.PY3: + method = self.test_update_qualname + self.assertEqual(function.__qualname__, + (method.__qualname__ + '..function')) + + function.__qualname__ = 'override_qualname' + + self.assertEqual(function.__qualname__, 'override_qualname') + + def test_update_qualname_modified_on_original(self): + def function(): + pass + + def wrapper(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + instance = wrapt.FunctionWrapper(function, wrapper) + + if six.PY3: + method = self.test_update_qualname_modified_on_original + self.assertEqual(instance.__qualname__, + (method.__qualname__ + '..function')) + + instance.__qualname__ = 'override_qualname' + + self.assertEqual(function.__qualname__, 'override_qualname') + self.assertEqual(instance.__qualname__, 'override_qualname') + + def test_update_module(self): + @passthru_decorator + def function(): + pass + + self.assertEqual(function.__module__, __name__) + + function.__module__ = 'override_module' + + self.assertEqual(function.__module__, 'override_module') + + def test_update_module_modified_on_original(self): + def function(): + pass + + def wrapper(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + instance = wrapt.FunctionWrapper(function, wrapper) + + self.assertEqual(instance.__module__, __name__) + + instance.__module__ = 'override_module' + + self.assertEqual(function.__module__, 'override_module') + self.assertEqual(instance.__module__, 'override_module') + + def test_update_doc(self): + @passthru_decorator + def function(): + """documentation""" + pass + + self.assertEqual(function.__doc__, "documentation") + + function.__doc__ = 'override_doc' + + self.assertEqual(function.__doc__, 'override_doc') + + def test_update_doc_modified_on_original(self): + def function(): + """documentation""" + pass + + def wrapper(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + instance = wrapt.FunctionWrapper(function, wrapper) + + self.assertEqual(instance.__doc__, "documentation") + + instance.__doc__ = 'override_doc' + + self.assertEqual(function.__doc__, 'override_doc') + self.assertEqual(instance.__doc__, 'override_doc') + +if __name__ == '__main__': + unittest.main()