diff --git a/taskflow/task.py b/taskflow/task.py index 3c8f3b021..2a27b97af 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -17,14 +17,9 @@ # under the License. import abc -import inspect from taskflow import utils -# These arguments are ones that we will skip when parsing for requirements -# for a function to operate (when used as a task). -AUTO_ARGS = ('self', 'context', 'cls') - class Task(object): """An abstraction that defines a potential piece of work that can be @@ -73,16 +68,6 @@ class Task(object): """ -def _filter_arg(arg): - if arg in AUTO_ARGS: - return False - # In certain decorator cases it seems like we get the function to be - # decorated as an argument, we don't want to take that as a real argument. - if not isinstance(arg, basestring): - return False - return True - - class FunctorTask(Task): """Adaptor to make task from a callable @@ -114,8 +99,8 @@ class FunctorTask(Task): self.provides.update(kwargs.pop('provides', ())) self.requires.update(kwargs.pop('requires', ())) if kwargs.pop('auto_extract', True): - f_args = inspect.getargspec(execute_with).args - self.requires.update([arg for arg in f_args if _filter_arg(arg)]) + f_args = utils.get_required_callable_args(execute_with) + self.requires.update(a for a in f_args if a != 'context') if kwargs: raise TypeError('__init__() got an unexpected keyword argument %r' % kwargs.keys[0]) diff --git a/taskflow/tests/unit/test_decorators.py b/taskflow/tests/unit/test_decorators.py index 466890204..193a65581 100644 --- a/taskflow/tests/unit/test_decorators.py +++ b/taskflow/tests/unit/test_decorators.py @@ -26,15 +26,15 @@ class WrapableObjectsTest(test.TestCase): def test_simple_function(self): values = [] - def revert_one(self, *args, **kwargs): + def revert_one(*args, **kwargs): values.append('revert one') @decorators.task(revert_with=revert_one) - def run_one(self, *args, **kwargs): + def run_one(*args, **kwargs): values.append('one') @decorators.task - def run_fail(self, *args, **kwargs): + def run_fail(*args, **kwargs): values.append('fail') raise RuntimeError('Woot!') diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index c2ed575f7..eb3d76010 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -18,6 +18,7 @@ import functools +from taskflow import decorators from taskflow import test from taskflow import utils @@ -58,29 +59,35 @@ class UtilTest(test.TestCase): self.assertEquals(10, len(context)) -def mere_function(): +def mere_function(a, b): + pass + + +def function_with_defaults(a, b, optional=None): pass class Class(object): - def __init__(self): - pass - - def method(self): + def method(self, c, d): pass @staticmethod - def static_method(): + def static_method(e, f): pass @classmethod - def class_method(): + def class_method(cls, g, h): pass class CallableClass(object): - def __call__(self): + def __call__(self, i, j): + pass + + +class ClassWithInit(object): + def __init__(self, k, l): pass @@ -120,3 +127,45 @@ class GetCallableNameTest(test.TestCase): name = utils.get_callable_name(CallableClass().__call__) self.assertEquals(name, '.'.join((__name__, 'CallableClass', '__call__'))) + + +class GetRequiredCallableArgsTest(test.TestCase): + + def test_mere_function(self): + self.assertEquals(['a', 'b'], + utils.get_required_callable_args(mere_function)) + + def test_function_with_defaults(self): + self.assertEquals(['a', 'b'], + utils.get_required_callable_args( + function_with_defaults)) + + def test_method(self): + self.assertEquals(['self', 'c', 'd'], + utils.get_required_callable_args(Class.method)) + + def test_instance_method(self): + self.assertEquals(['c', 'd'], + utils.get_required_callable_args(Class().method)) + + def test_class_method(self): + self.assertEquals(['g', 'h'], + utils.get_required_callable_args( + Class.class_method)) + + def test_class_constructor(self): + self.assertEquals(['k', 'l'], + utils.get_required_callable_args( + ClassWithInit)) + + def test_class_with_call(self): + self.assertEquals(['i', 'j'], + utils.get_required_callable_args( + CallableClass())) + + def test_decorators_work(self): + @decorators.locked + def locked_fun(x, y): + pass + self.assertEquals(['x', 'y'], + utils.get_required_callable_args(locked_fun)) diff --git a/taskflow/utils.py b/taskflow/utils.py index c2033bdb1..df5a89516 100644 --- a/taskflow/utils.py +++ b/taskflow/utils.py @@ -20,6 +20,7 @@ import collections import contextlib import copy +import inspect import logging import re import sys @@ -75,6 +76,32 @@ def get_callable_name(function): return '.'.join(parts) +def is_bound_method(method): + return getattr(method, 'im_self', None) is not None + + +def get_required_callable_args(function): + """Get names of argument required by callable""" + + if isinstance(function, type): + bound = True + function = function.__init__ + elif isinstance(function, (types.FunctionType, types.MethodType)): + bound = is_bound_method(function) + function = getattr(function, '__wrapped__', function) + else: + function = function.__call__ + bound = is_bound_method(function) + + argspec = inspect.getargspec(function) + f_args = argspec.args + if argspec.defaults: + f_args = f_args[:-len(argspec.defaults)] + if bound: + f_args = f_args[1:] + return f_args + + def get_task_version(task): """Gets a tasks *string* version, whether it is a task object/function.""" task_version = getattr(task, 'version')