diff --git a/taskflow/decorators.py b/taskflow/decorators.py index 2739edc2..a41f1717 100644 --- a/taskflow/decorators.py +++ b/taskflow/decorators.py @@ -20,9 +20,11 @@ import collections import functools import inspect +from taskflow import utils + # These arguments are ones that we will skip when parsing for requirements # for a function to operate (when used as a task). -AUTO_ARGS = ('self', 'context',) +AUTO_ARGS = ('self', 'context', 'cls') def _take_arg(a): @@ -118,7 +120,11 @@ def requires(*args, **kwargs): f.requires = set() if kwargs.pop('auto_extract', True): - inspect_what = getattr(f, '__wrapped__', f) + inspect_what = getattr(f, '__wrapped__', None) + + if not inspect_what: + inspect_what = utils.get_wrapped_function(f) + f_args = inspect.getargspec(inspect_what).args f.requires.update([a for a in f_args if _take_arg(a)]) diff --git a/taskflow/utils.py b/taskflow/utils.py index dfec1f6e..c296c2f4 100644 --- a/taskflow/utils.py +++ b/taskflow/utils.py @@ -24,6 +24,28 @@ import time LOG = logging.getLogger(__name__) +def get_wrapped_function(function): + """Get the method at the bottom of a stack of decorators.""" + + if not hasattr(function, 'func_closure') or not function.func_closure: + return function + + def _get_wrapped_function(function): + if not hasattr(function, 'func_closure') or not function.func_closure: + return None + + for closure in function.func_closure: + func = closure.cell_contents + + deeper_func = _get_wrapped_function(func) + if deeper_func: + return deeper_func + elif hasattr(closure.cell_contents, '__call__'): + return closure.cell_contents + + return _get_wrapped_function(function) + + def join(itr, with_what=","): pieces = [str(i) for i in itr] return with_what.join(pieces)