From 040da134af654b920348330b1f1443b9d6c1a83a Mon Sep 17 00:00:00 2001 From: Graham Dumpleton Date: Mon, 7 Oct 2013 16:47:20 +1100 Subject: [PATCH] When enabled option is callable, it must be checked on the actual call for a bound method, not when binding occurs. --- src/_wrappers.c | 45 +++++++++++++++++----------------- src/wrappers.py | 29 +++++++++++----------- tests/test_function_wrapper.py | 2 +- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/_wrappers.c b/src/_wrappers.c index 778b8ec..00766d2 100644 --- a/src/_wrappers.c +++ b/src/_wrappers.c @@ -1710,28 +1710,6 @@ static PyObject *WraptFunctionWrapperBase_descr_get( descriptor = (Py_TYPE(self->object_proxy.wrapped)->tp_descr_get)( self->object_proxy.wrapped, obj, type); - if (self->enabled != Py_None) { - if (PyCallable_Check(self->enabled)) { - PyObject *object = NULL; - - object = PyObject_CallFunctionObjArgs(self->enabled, NULL); - - if (!object) { - Py_DECREF(descriptor); - return NULL; - } - - if (PyObject_Not(object)) { - Py_DECREF(object); - return descriptor; - } - - Py_DECREF(object); - } - else if (PyObject_Not(self->enabled)) - return descriptor; - } - /* No point looking up bound type if not a derived class. */ if (Py_TYPE(self) != &WraptFunctionWrapper_Type) { @@ -1749,7 +1727,7 @@ static PyObject *WraptFunctionWrapperBase_descr_get( if (descriptor) { result = PyObject_CallFunctionObjArgs(bound_type ? bound_type : (PyObject *)&WraptBoundFunctionWrapper_Type, descriptor, - obj, self->wrapper, Py_None, self->binding, + obj, self->wrapper, self->enabled, self->binding, self, NULL); } @@ -1914,6 +1892,27 @@ static PyObject *WraptBoundFunctionWrapper_call( static PyObject *instancemethod_str = NULL; + if (self->enabled != Py_None) { + if (PyCallable_Check(self->enabled)) { + PyObject *object = NULL; + + object = PyObject_CallFunctionObjArgs(self->enabled, NULL); + + if (!object) + return NULL; + + if (PyObject_Not(object)) { + Py_DECREF(object); + return PyObject_Call(self->object_proxy.wrapped, args, kwds); + } + + Py_DECREF(object); + } + else if (PyObject_Not(self->enabled)) { + return PyObject_Call(self->object_proxy.wrapped, args, kwds); + } + } + if (!instancemethod_str) { #if PY_MAJOR_VERSION >= 3 instancemethod_str = PyUnicode_InternFromString("instancemethod"); diff --git a/src/wrappers.py b/src/wrappers.py index e54283c..9ad0250 100644 --- a/src/wrappers.py +++ b/src/wrappers.py @@ -405,21 +405,9 @@ class _FunctionWrapperBase(ObjectProxy): descriptor = self.__wrapped__.__get__(instance, owner) - # If enabled has been specified, then evaluate it at this point - # and if the wrapper is not to be executed, then simply return - # the bound function rather than a bound wrapper for the bound - # function. When evaluating enabled, if it is callable we call - # it, otherwise we evaluate it as a boolean. - - if self._self_enabled is not None: - if callable(self._self_enabled): - if not self._self_enabled(): - return descriptor - elif not self._self_enabled: - return descriptor - return self.__bound_function_wrapper__(descriptor, instance, - self._self_wrapper, None, self._self_binding, self) + self._self_wrapper, self._self_enabled, + self._self_binding, self) def __call__(self, *args, **kwargs): # If enabled has been specified, then evaluate it at this point @@ -447,6 +435,19 @@ class _FunctionWrapperBase(ObjectProxy): class BoundFunctionWrapper(_FunctionWrapperBase): def __call__(self, *args, **kwargs): + # If enabled has been specified, then evaluate it at this point + # and if the wrapper is not to be executed, then simply return + # the bound function rather than a bound wrapper for the bound + # function. When evaluating enabled, if it is callable we call + # it, otherwise we evaluate it as a boolean. + + if self._self_enabled is not None: + if callable(self._self_enabled): + if not self._self_enabled(): + return self.__wrapped__(*args, **kwargs) + elif not self._self_enabled: + return self.__wrapped__(*args, **kwargs) + # We need to do things different depending on whether we are # likely wrapping an instance method vs a static method or class # method. diff --git a/tests/test_function_wrapper.py b/tests/test_function_wrapper.py index d000f5c..d9b3d39 100644 --- a/tests/test_function_wrapper.py +++ b/tests/test_function_wrapper.py @@ -333,7 +333,7 @@ class TestGuardArgument(unittest.TestCase): result = [] value = False - self.assertFalse(isinstance(c.function, wrapt.BoundFunctionWrapper)) + self.assertTrue(isinstance(c.function, wrapt.BoundFunctionWrapper)) c.function()