Allow *args and **kwargs in the additional parameters for the wrapper for the decorator.

This commit is contained in:
Graham Dumpleton
2013-08-23 21:17:22 +10:00
parent 9f062fbb5a
commit b7e2a25bcf
2 changed files with 105 additions and 13 deletions

View File

@@ -60,7 +60,8 @@ def decorator(wrapper=None, target=None):
wrapper_defaults = (wrapper_argspec.defaults and
wrapper_arglist[-len(wrapper_argspec.defaults):] or [])
if len(wrapper_arglist) > len(WRAPPER_ARGLIST):
if (len(wrapper_arglist) > len(WRAPPER_ARGLIST) or
wrapper_argspec.varargs or wrapper_argspec.keywords):
# For the case where the user decorator is able to accept
# parameters, return a partial wrapper to collect the
# parameters.
@@ -79,10 +80,11 @@ def decorator(wrapper=None, target=None):
expected_names = wrapper_arglist[len(WRAPPER_ARGLIST):]
if len(decorator_args) > len(expected_names):
raise UnexpectedParameters('Expected at most %r '
'positional parameters for decorator %r, '
'but received %r.' % (len(expected_names),
wrapper.__name__, len(decorator_args)))
if not wrapper_argspec.varargs:
raise UnexpectedParameters('Expected at most %r '
'positional parameters for decorator %r, '
'but received %r.' % (len(expected_names),
wrapper.__name__, len(decorator_args)))
unexpected_params = []
for name in decorator_kwargs:
@@ -90,21 +92,21 @@ def decorator(wrapper=None, target=None):
unexpected_params.append(name)
if unexpected_params:
raise UnexpectedParameters('Unexpected parameters '
'%r supplied for decorator %r.' % (
unexpected_params, wrapper.__name__))
if not wrapper_argspec.keywords:
raise UnexpectedParameters('Unexpected parameters '
'%r supplied for decorator %r.' % (
unexpected_params, wrapper.__name__))
for i, arg in enumerate(decorator_args):
received_names = set(wrapper_defaults)
for i in range(min(len(decorator_args), len(expected_names))):
if expected_names[i] in decorator_kwargs:
raise UnexpectedParameters('Positional parameter '
'%r also supplied as keyword parameter '
'to decorator %r.' % (expected_names[i],
wrapper.__name__))
decorator_kwargs[expected_names[i]] = arg
received_names.add(expected_names[i])
received_names = set(wrapper_defaults)
received_names.update(decorator_kwargs.keys())
for name in expected_names:
if name not in received_names:
raise MissingParameter('Expected value for '
@@ -116,7 +118,8 @@ def decorator(wrapper=None, target=None):
def _wrapper(func):
result = FunctionWrapper(wrapped=func,
wrapper=wrapper, kwargs=decorator_kwargs)
wrapper=wrapper, args=decorator_args,
kwargs=decorator_kwargs)
if target:
_update_adapter(result, target)
return result

View File

@@ -192,5 +192,94 @@ class TestDecorator(unittest.TestCase):
self.assertRaises(wrapt.exceptions.UnexpectedParameters, run, ())
def test_varargs_parameters(self):
_args = (1, 2)
_kwargs = { 'one': 1, 'two': 2 }
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs, *wrapper_args):
self.assertEqual(wrapper_args, tuple(reversed(_args)))
return wrapped(*args, **kwargs)
@_decorator(*reversed(_args))
def _function(*args, **kwargs):
return args, kwargs
result = _function(*_args, **_kwargs)
self.assertEqual(result, (_args, _kwargs))
def test_args_plus_varargs_parameters(self):
_args = (1, 2)
_kwargs = { 'one': 1, 'two': 2 }
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs, p1, *wrapper_args):
self.assertEqual(p1, 2)
self.assertEqual(wrapper_args, tuple(reversed(_args))[1:])
return wrapped(*args, **kwargs)
@_decorator(*reversed(_args))
def _function(*args, **kwargs):
return args, kwargs
result = _function(*_args, **_kwargs)
self.assertEqual(result, (_args, _kwargs))
def test_keyword_parameters(self):
_args = (1, 2)
_kwargs = { 'one': 1, 'two': 2 }
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs, **wrapper_kwargs):
self.assertEqual(wrapper_kwargs, _kwargs)
return wrapped(*args, **kwargs)
@_decorator(**_kwargs)
def _function(*args, **kwargs):
return args, kwargs
result = _function(*_args, **_kwargs)
self.assertEqual(result, (_args, _kwargs))
def test_args_plus_keyword_parameters(self):
_args = (1, 2)
_kwargs = { 'one': 1, 'two': 2 }
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs, one, **wrapper_kwargs):
self.assertEqual(one, 1)
self.assertEqual(wrapper_kwargs, {'two': 2})
return wrapped(*args, **kwargs)
@_decorator(**_kwargs)
def _function(*args, **kwargs):
return args, kwargs
result = _function(*_args, **_kwargs)
self.assertEqual(result, (_args, _kwargs))
def test_varargs_plus_keyword_parameters(self):
_args = (1, 2)
_kwargs = { 'one': 1, 'two': 2 }
@wrapt.decorator
def _decorator(wrapped, instance, args, kwargs,
*wrapper_args, **wrapper_kwargs):
self.assertEqual(wrapper_args, tuple(reversed(_args)))
self.assertEqual(wrapper_kwargs, _kwargs)
return wrapped(*args, **kwargs)
@_decorator(*reversed(_args), **_kwargs)
def _function(*args, **kwargs):
return args, kwargs
result = _function(*_args, **_kwargs)
self.assertEqual(result, (_args, _kwargs))
if __name__ == '__main__':
unittest.main()