From 4fa911fef5cbd42e7d0e84bde7bc27e81985eae0 Mon Sep 17 00:00:00 2001 From: Graham Dumpleton Date: Sun, 11 Aug 2013 22:38:51 +0800 Subject: [PATCH] Proxy through query/update of annotations in C version of wrapper. --- src/_wrappers.c | 34 ++++++++++++++++++++++++++ tests/test_update_attributes.py | 43 +++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/src/_wrappers.c b/src/_wrappers.c index 11d5861..3298da7 100644 --- a/src/_wrappers.c +++ b/src/_wrappers.c @@ -301,6 +301,32 @@ static int WraptWrapperBase_set_doc(WraptWrapperBaseObject *self, /* ------------------------------------------------------------------------- */ +static PyObject *WraptWrapperBase_get_annotations( + WraptWrapperBaseObject *self) +{ + if (!self->wrapped) { + PyErr_SetString(PyExc_ValueError, "wrapper has not been initialised"); + return NULL; + } + + return PyObject_GetAttrString(self->wrapped, "__annotations__"); +} + +/* ------------------------------------------------------------------------- */ + +static int WraptWrapperBase_set_annotations(WraptWrapperBaseObject *self, + PyObject *value) +{ + if (!self->wrapped) { + PyErr_SetString(PyExc_ValueError, "wrapper has not been initialised"); + return -1; + } + + return PyObject_SetAttrString(self->wrapped, "__annotations__", value); +} + +/* ------------------------------------------------------------------------- */ + static PyObject *WraptWrapperBase_getattro( WraptWrapperBaseObject *self, PyObject *name) { @@ -398,6 +424,8 @@ static PyGetSetDef WraptWrapperBase_getset[] = { (setter)WraptWrapperBase_set_module, 0 }, { "__doc__", (getter)WraptWrapperBase_get_doc, (setter)WraptWrapperBase_set_doc, 0 }, + { "__annotations__", (getter)WraptWrapperBase_get_annotations, + (setter)WraptWrapperBase_set_annotations, 0 }, { NULL }, }; @@ -640,6 +668,8 @@ static PyGetSetDef WraptFunctionWrapper_getset[] = { (setter)WraptWrapperBase_set_module, 0 }, { "__doc__", (getter)WraptWrapperBase_get_doc, (setter)WraptWrapperBase_set_doc, 0 }, + { "__annotations__", (getter)WraptWrapperBase_get_annotations, + (setter)WraptWrapperBase_set_annotations, 0 }, { NULL }, }; @@ -850,6 +880,8 @@ static PyGetSetDef WraptBoundFunctionWrapper_getset[] = { (setter)WraptWrapperBase_set_module, 0 }, { "__doc__", (getter)WraptWrapperBase_get_doc, (setter)WraptWrapperBase_set_doc, 0 }, + { "__annotations__", (getter)WraptWrapperBase_get_annotations, + (setter)WraptWrapperBase_set_annotations, 0 }, { NULL }, }; @@ -1112,6 +1144,8 @@ static PyGetSetDef WraptBoundMethodWrapper_getset[] = { (setter)WraptWrapperBase_set_module, 0 }, { "__doc__", (getter)WraptWrapperBase_get_doc, (setter)WraptWrapperBase_set_doc, 0 }, + { "__annotations__", (getter)WraptWrapperBase_get_annotations, + (setter)WraptWrapperBase_set_annotations, 0 }, { NULL }, }; diff --git a/tests/test_update_attributes.py b/tests/test_update_attributes.py index 465503e..055d9af 100644 --- a/tests/test_update_attributes.py +++ b/tests/test_update_attributes.py @@ -129,5 +129,48 @@ class TestUpdateAttributes(unittest.TestCase): self.assertEqual(function.__doc__, 'override_doc') self.assertEqual(instance.__doc__, 'override_doc') + def test_update_annotations(self): + @passthru_decorator + def function(): + pass + + if six.PY3: + self.assertEqual(function.__annotations__, {}) + + else: + def run(*args): + function.__annotations__ + + self.assertRaises(AttributeError, run, ()) + + override_annotations = { 'override_annotations': '' } + function.__annotations__ = override_annotations + + self.assertEqual(function.__annotations__, override_annotations) + + def test_update_annotations_modified_on_original(self): + def function(): + pass + + def wrapper(wrapped, instance, args, kwargs): + return wrapped(*args, **kwargs) + + instance = wrapt.FunctionWrapper(function, wrapper) + + if six.PY3: + self.assertEqual(instance.__annotations__, {}) + + else: + def run(*args): + instance.__annotations__ + + self.assertRaises(AttributeError, run, ()) + + override_annotations = { 'override_annotations': '' } + instance.__annotations__ = override_annotations + + self.assertEqual(function.__annotations__, override_annotations) + self.assertEqual(instance.__annotations__, override_annotations) + if __name__ == '__main__': unittest.main()