From b7e2a25bcf9c52cab8f17d1e9703b9e15058893c Mon Sep 17 00:00:00 2001 From: Graham Dumpleton Date: Fri, 23 Aug 2013 21:17:22 +1000 Subject: [PATCH] Allow *args and **kwargs in the additional parameters for the wrapper for the decorator. --- src/decorators.py | 29 +++++++------ tests/test_decorators.py | 89 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 13 deletions(-) diff --git a/src/decorators.py b/src/decorators.py index 4fd3813..fe056d3 100644 --- a/src/decorators.py +++ b/src/decorators.py @@ -60,7 +60,8 @@ def decorator(wrapper=None, target=None): wrapper_defaults = (wrapper_argspec.defaults and wrapper_arglist[-len(wrapper_argspec.defaults):] or []) - if len(wrapper_arglist) > len(WRAPPER_ARGLIST): + if (len(wrapper_arglist) > len(WRAPPER_ARGLIST) or + wrapper_argspec.varargs or wrapper_argspec.keywords): # For the case where the user decorator is able to accept # parameters, return a partial wrapper to collect the # parameters. @@ -79,10 +80,11 @@ def decorator(wrapper=None, target=None): expected_names = wrapper_arglist[len(WRAPPER_ARGLIST):] if len(decorator_args) > len(expected_names): - raise UnexpectedParameters('Expected at most %r ' - 'positional parameters for decorator %r, ' - 'but received %r.' % (len(expected_names), - wrapper.__name__, len(decorator_args))) + if not wrapper_argspec.varargs: + raise UnexpectedParameters('Expected at most %r ' + 'positional parameters for decorator %r, ' + 'but received %r.' % (len(expected_names), + wrapper.__name__, len(decorator_args))) unexpected_params = [] for name in decorator_kwargs: @@ -90,21 +92,21 @@ def decorator(wrapper=None, target=None): unexpected_params.append(name) if unexpected_params: - raise UnexpectedParameters('Unexpected parameters ' - '%r supplied for decorator %r.' % ( - unexpected_params, wrapper.__name__)) + if not wrapper_argspec.keywords: + raise UnexpectedParameters('Unexpected parameters ' + '%r supplied for decorator %r.' % ( + unexpected_params, wrapper.__name__)) - for i, arg in enumerate(decorator_args): + received_names = set(wrapper_defaults) + for i in range(min(len(decorator_args), len(expected_names))): if expected_names[i] in decorator_kwargs: raise UnexpectedParameters('Positional parameter ' '%r also supplied as keyword parameter ' 'to decorator %r.' % (expected_names[i], wrapper.__name__)) - decorator_kwargs[expected_names[i]] = arg + received_names.add(expected_names[i]) - received_names = set(wrapper_defaults) received_names.update(decorator_kwargs.keys()) - for name in expected_names: if name not in received_names: raise MissingParameter('Expected value for ' @@ -116,7 +118,8 @@ def decorator(wrapper=None, target=None): def _wrapper(func): result = FunctionWrapper(wrapped=func, - wrapper=wrapper, kwargs=decorator_kwargs) + wrapper=wrapper, args=decorator_args, + kwargs=decorator_kwargs) if target: _update_adapter(result, target) return result diff --git a/tests/test_decorators.py b/tests/test_decorators.py index b10a773..040099d 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -192,5 +192,94 @@ class TestDecorator(unittest.TestCase): self.assertRaises(wrapt.exceptions.UnexpectedParameters, run, ()) + def test_varargs_parameters(self): + _args = (1, 2) + _kwargs = { 'one': 1, 'two': 2 } + + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs, *wrapper_args): + self.assertEqual(wrapper_args, tuple(reversed(_args))) + return wrapped(*args, **kwargs) + + @_decorator(*reversed(_args)) + def _function(*args, **kwargs): + return args, kwargs + + result = _function(*_args, **_kwargs) + + self.assertEqual(result, (_args, _kwargs)) + + def test_args_plus_varargs_parameters(self): + _args = (1, 2) + _kwargs = { 'one': 1, 'two': 2 } + + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs, p1, *wrapper_args): + self.assertEqual(p1, 2) + self.assertEqual(wrapper_args, tuple(reversed(_args))[1:]) + return wrapped(*args, **kwargs) + + @_decorator(*reversed(_args)) + def _function(*args, **kwargs): + return args, kwargs + + result = _function(*_args, **_kwargs) + + self.assertEqual(result, (_args, _kwargs)) + + def test_keyword_parameters(self): + _args = (1, 2) + _kwargs = { 'one': 1, 'two': 2 } + + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs, **wrapper_kwargs): + self.assertEqual(wrapper_kwargs, _kwargs) + return wrapped(*args, **kwargs) + + @_decorator(**_kwargs) + def _function(*args, **kwargs): + return args, kwargs + + result = _function(*_args, **_kwargs) + + self.assertEqual(result, (_args, _kwargs)) + + def test_args_plus_keyword_parameters(self): + _args = (1, 2) + _kwargs = { 'one': 1, 'two': 2 } + + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs, one, **wrapper_kwargs): + self.assertEqual(one, 1) + self.assertEqual(wrapper_kwargs, {'two': 2}) + return wrapped(*args, **kwargs) + + @_decorator(**_kwargs) + def _function(*args, **kwargs): + return args, kwargs + + result = _function(*_args, **_kwargs) + + self.assertEqual(result, (_args, _kwargs)) + + def test_varargs_plus_keyword_parameters(self): + _args = (1, 2) + _kwargs = { 'one': 1, 'two': 2 } + + @wrapt.decorator + def _decorator(wrapped, instance, args, kwargs, + *wrapper_args, **wrapper_kwargs): + self.assertEqual(wrapper_args, tuple(reversed(_args))) + self.assertEqual(wrapper_kwargs, _kwargs) + return wrapped(*args, **kwargs) + + @_decorator(*reversed(_args), **_kwargs) + def _function(*args, **kwargs): + return args, kwargs + + result = _function(*_args, **_kwargs) + + self.assertEqual(result, (_args, _kwargs)) + if __name__ == '__main__': unittest.main()