Add get_required_callable_args utility function
Move code that gets callable args from functor task to utils.py and enhance it to support more use cases. We do not rely on naming conventions for 'self' and 'cls' parameters any more; we look at the collable if it is bound method instead. Change-Id: Ie1a9a7cc727b5fbc2780aba28b1d0253e5bc0ea4
This commit is contained in:
committed by
Joshua Harlow
parent
fbd06b552e
commit
d4609ef3d0
@@ -17,14 +17,9 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import inspect
|
||||
|
||||
from taskflow import utils
|
||||
|
||||
# These arguments are ones that we will skip when parsing for requirements
|
||||
# for a function to operate (when used as a task).
|
||||
AUTO_ARGS = ('self', 'context', 'cls')
|
||||
|
||||
|
||||
class Task(object):
|
||||
"""An abstraction that defines a potential piece of work that can be
|
||||
@@ -73,16 +68,6 @@ class Task(object):
|
||||
"""
|
||||
|
||||
|
||||
def _filter_arg(arg):
|
||||
if arg in AUTO_ARGS:
|
||||
return False
|
||||
# In certain decorator cases it seems like we get the function to be
|
||||
# decorated as an argument, we don't want to take that as a real argument.
|
||||
if not isinstance(arg, basestring):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class FunctorTask(Task):
|
||||
"""Adaptor to make task from a callable
|
||||
|
||||
@@ -114,8 +99,8 @@ class FunctorTask(Task):
|
||||
self.provides.update(kwargs.pop('provides', ()))
|
||||
self.requires.update(kwargs.pop('requires', ()))
|
||||
if kwargs.pop('auto_extract', True):
|
||||
f_args = inspect.getargspec(execute_with).args
|
||||
self.requires.update([arg for arg in f_args if _filter_arg(arg)])
|
||||
f_args = utils.get_required_callable_args(execute_with)
|
||||
self.requires.update(a for a in f_args if a != 'context')
|
||||
if kwargs:
|
||||
raise TypeError('__init__() got an unexpected keyword argument %r'
|
||||
% kwargs.keys[0])
|
||||
|
||||
@@ -26,15 +26,15 @@ class WrapableObjectsTest(test.TestCase):
|
||||
def test_simple_function(self):
|
||||
values = []
|
||||
|
||||
def revert_one(self, *args, **kwargs):
|
||||
def revert_one(*args, **kwargs):
|
||||
values.append('revert one')
|
||||
|
||||
@decorators.task(revert_with=revert_one)
|
||||
def run_one(self, *args, **kwargs):
|
||||
def run_one(*args, **kwargs):
|
||||
values.append('one')
|
||||
|
||||
@decorators.task
|
||||
def run_fail(self, *args, **kwargs):
|
||||
def run_fail(*args, **kwargs):
|
||||
values.append('fail')
|
||||
raise RuntimeError('Woot!')
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
import functools
|
||||
|
||||
from taskflow import decorators
|
||||
from taskflow import test
|
||||
from taskflow import utils
|
||||
|
||||
@@ -58,29 +59,35 @@ class UtilTest(test.TestCase):
|
||||
self.assertEquals(10, len(context))
|
||||
|
||||
|
||||
def mere_function():
|
||||
def mere_function(a, b):
|
||||
pass
|
||||
|
||||
|
||||
def function_with_defaults(a, b, optional=None):
|
||||
pass
|
||||
|
||||
|
||||
class Class(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def method(self):
|
||||
def method(self, c, d):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def static_method():
|
||||
def static_method(e, f):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def class_method():
|
||||
def class_method(cls, g, h):
|
||||
pass
|
||||
|
||||
|
||||
class CallableClass(object):
|
||||
def __call__(self):
|
||||
def __call__(self, i, j):
|
||||
pass
|
||||
|
||||
|
||||
class ClassWithInit(object):
|
||||
def __init__(self, k, l):
|
||||
pass
|
||||
|
||||
|
||||
@@ -120,3 +127,45 @@ class GetCallableNameTest(test.TestCase):
|
||||
name = utils.get_callable_name(CallableClass().__call__)
|
||||
self.assertEquals(name, '.'.join((__name__, 'CallableClass',
|
||||
'__call__')))
|
||||
|
||||
|
||||
class GetRequiredCallableArgsTest(test.TestCase):
|
||||
|
||||
def test_mere_function(self):
|
||||
self.assertEquals(['a', 'b'],
|
||||
utils.get_required_callable_args(mere_function))
|
||||
|
||||
def test_function_with_defaults(self):
|
||||
self.assertEquals(['a', 'b'],
|
||||
utils.get_required_callable_args(
|
||||
function_with_defaults))
|
||||
|
||||
def test_method(self):
|
||||
self.assertEquals(['self', 'c', 'd'],
|
||||
utils.get_required_callable_args(Class.method))
|
||||
|
||||
def test_instance_method(self):
|
||||
self.assertEquals(['c', 'd'],
|
||||
utils.get_required_callable_args(Class().method))
|
||||
|
||||
def test_class_method(self):
|
||||
self.assertEquals(['g', 'h'],
|
||||
utils.get_required_callable_args(
|
||||
Class.class_method))
|
||||
|
||||
def test_class_constructor(self):
|
||||
self.assertEquals(['k', 'l'],
|
||||
utils.get_required_callable_args(
|
||||
ClassWithInit))
|
||||
|
||||
def test_class_with_call(self):
|
||||
self.assertEquals(['i', 'j'],
|
||||
utils.get_required_callable_args(
|
||||
CallableClass()))
|
||||
|
||||
def test_decorators_work(self):
|
||||
@decorators.locked
|
||||
def locked_fun(x, y):
|
||||
pass
|
||||
self.assertEquals(['x', 'y'],
|
||||
utils.get_required_callable_args(locked_fun))
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
@@ -75,6 +76,32 @@ def get_callable_name(function):
|
||||
return '.'.join(parts)
|
||||
|
||||
|
||||
def is_bound_method(method):
|
||||
return getattr(method, 'im_self', None) is not None
|
||||
|
||||
|
||||
def get_required_callable_args(function):
|
||||
"""Get names of argument required by callable"""
|
||||
|
||||
if isinstance(function, type):
|
||||
bound = True
|
||||
function = function.__init__
|
||||
elif isinstance(function, (types.FunctionType, types.MethodType)):
|
||||
bound = is_bound_method(function)
|
||||
function = getattr(function, '__wrapped__', function)
|
||||
else:
|
||||
function = function.__call__
|
||||
bound = is_bound_method(function)
|
||||
|
||||
argspec = inspect.getargspec(function)
|
||||
f_args = argspec.args
|
||||
if argspec.defaults:
|
||||
f_args = f_args[:-len(argspec.defaults)]
|
||||
if bound:
|
||||
f_args = f_args[1:]
|
||||
return f_args
|
||||
|
||||
|
||||
def get_task_version(task):
|
||||
"""Gets a tasks *string* version, whether it is a task object/function."""
|
||||
task_version = getattr(task, 'version')
|
||||
|
||||
Reference in New Issue
Block a user