diff --git a/positional/__init__.py b/positional/__init__.py index 7052513..9ba668b 100644 --- a/positional/__init__.py +++ b/positional/__init__.py @@ -10,11 +10,11 @@ # License for the specific language governing permissions and limitations # under the License. -import functools import inspect import warnings import pbr.version +import wrapt __version__ = pbr.version.VersionInfo('positional').version_string() @@ -76,14 +76,21 @@ class positional(object): plural = '' if self._max_positional_args == 1 else 's' - @functools.wraps(func) - def inner(*args, **kwargs): - if len(args) > self._max_positional_args: + @wrapt.decorator + def inner(wrapped, instance, args, kwargs): + + # If called on an instance, adjust args len for the 'self' + # parameter. + args_len = len(args) + if instance: + args_len += 1 + + if args_len > self._max_positional_args: message = ('%(name)s takes at most %(max)d positional ' 'argument%(plural)s (%(given)d given)' % - {'name': func.__name__, + {'name': wrapped.__name__, 'max': self._max_positional_args, - 'given': len(args), + 'given': args_len, 'plural': plural}) if self._enforcement == self.EXCEPT: @@ -91,6 +98,6 @@ class positional(object): elif self._enforcement == self.WARN: warnings.warn(message, DeprecationWarning, stacklevel=2) - return func(*args, **kwargs) + return wrapped(*args, **kwargs) - return inner + return inner(func) diff --git a/positional/tests/test_positional.py b/positional/tests/test_positional.py index 1609bbb..a3b7668 100644 --- a/positional/tests/test_positional.py +++ b/positional/tests/test_positional.py @@ -10,6 +10,7 @@ # License for the specific language governing permissions and limitations # under the License. +import inspect import warnings import testtools @@ -81,3 +82,15 @@ class TestPositional(testtools.TestCase): def test_normal_method(self): self.assertEqual((self, 1, 2), self.normal_method(1, b=2)) self.assertRaises(TypeError, self.normal_method, 1, 2) + + def test_argspec_preserved(self): + + @positional() + def f_wrapped(my_arg=False): + return my_arg + + def f_not_wrapped(my_arg=False): + return my_arg + + self.assertEqual(inspect.getargspec(f_not_wrapped), + inspect.getargspec(f_wrapped)) diff --git a/requirements.txt b/requirements.txt index 006c00d..0a3bb8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ # of appearance. pbr>=1.6 +wrapt