Add get_required_callable_args utility function

Move code that gets callable args from functor task to utils.py and
enhance it to support more use cases.

We do not rely on naming conventions for 'self' and 'cls' parameters any
more; we look at the collable if it is bound method instead.

Change-Id: Ie1a9a7cc727b5fbc2780aba28b1d0253e5bc0ea4
This commit is contained in:
Ivan A. Melnikov
2013-08-27 17:09:09 +04:00
committed by Joshua Harlow
parent fbd06b552e
commit d4609ef3d0
4 changed files with 89 additions and 28 deletions

View File

@@ -17,14 +17,9 @@
# under the License. # under the License.
import abc import abc
import inspect
from taskflow import utils from taskflow import utils
# These arguments are ones that we will skip when parsing for requirements
# for a function to operate (when used as a task).
AUTO_ARGS = ('self', 'context', 'cls')
class Task(object): class Task(object):
"""An abstraction that defines a potential piece of work that can be """An abstraction that defines a potential piece of work that can be
@@ -73,16 +68,6 @@ class Task(object):
""" """
def _filter_arg(arg):
if arg in AUTO_ARGS:
return False
# In certain decorator cases it seems like we get the function to be
# decorated as an argument, we don't want to take that as a real argument.
if not isinstance(arg, basestring):
return False
return True
class FunctorTask(Task): class FunctorTask(Task):
"""Adaptor to make task from a callable """Adaptor to make task from a callable
@@ -114,8 +99,8 @@ class FunctorTask(Task):
self.provides.update(kwargs.pop('provides', ())) self.provides.update(kwargs.pop('provides', ()))
self.requires.update(kwargs.pop('requires', ())) self.requires.update(kwargs.pop('requires', ()))
if kwargs.pop('auto_extract', True): if kwargs.pop('auto_extract', True):
f_args = inspect.getargspec(execute_with).args f_args = utils.get_required_callable_args(execute_with)
self.requires.update([arg for arg in f_args if _filter_arg(arg)]) self.requires.update(a for a in f_args if a != 'context')
if kwargs: if kwargs:
raise TypeError('__init__() got an unexpected keyword argument %r' raise TypeError('__init__() got an unexpected keyword argument %r'
% kwargs.keys[0]) % kwargs.keys[0])

View File

@@ -26,15 +26,15 @@ class WrapableObjectsTest(test.TestCase):
def test_simple_function(self): def test_simple_function(self):
values = [] values = []
def revert_one(self, *args, **kwargs): def revert_one(*args, **kwargs):
values.append('revert one') values.append('revert one')
@decorators.task(revert_with=revert_one) @decorators.task(revert_with=revert_one)
def run_one(self, *args, **kwargs): def run_one(*args, **kwargs):
values.append('one') values.append('one')
@decorators.task @decorators.task
def run_fail(self, *args, **kwargs): def run_fail(*args, **kwargs):
values.append('fail') values.append('fail')
raise RuntimeError('Woot!') raise RuntimeError('Woot!')

View File

@@ -18,6 +18,7 @@
import functools import functools
from taskflow import decorators
from taskflow import test from taskflow import test
from taskflow import utils from taskflow import utils
@@ -58,29 +59,35 @@ class UtilTest(test.TestCase):
self.assertEquals(10, len(context)) self.assertEquals(10, len(context))
def mere_function(): def mere_function(a, b):
pass
def function_with_defaults(a, b, optional=None):
pass pass
class Class(object): class Class(object):
def __init__(self): def method(self, c, d):
pass
def method(self):
pass pass
@staticmethod @staticmethod
def static_method(): def static_method(e, f):
pass pass
@classmethod @classmethod
def class_method(): def class_method(cls, g, h):
pass pass
class CallableClass(object): class CallableClass(object):
def __call__(self): def __call__(self, i, j):
pass
class ClassWithInit(object):
def __init__(self, k, l):
pass pass
@@ -120,3 +127,45 @@ class GetCallableNameTest(test.TestCase):
name = utils.get_callable_name(CallableClass().__call__) name = utils.get_callable_name(CallableClass().__call__)
self.assertEquals(name, '.'.join((__name__, 'CallableClass', self.assertEquals(name, '.'.join((__name__, 'CallableClass',
'__call__'))) '__call__')))
class GetRequiredCallableArgsTest(test.TestCase):
def test_mere_function(self):
self.assertEquals(['a', 'b'],
utils.get_required_callable_args(mere_function))
def test_function_with_defaults(self):
self.assertEquals(['a', 'b'],
utils.get_required_callable_args(
function_with_defaults))
def test_method(self):
self.assertEquals(['self', 'c', 'd'],
utils.get_required_callable_args(Class.method))
def test_instance_method(self):
self.assertEquals(['c', 'd'],
utils.get_required_callable_args(Class().method))
def test_class_method(self):
self.assertEquals(['g', 'h'],
utils.get_required_callable_args(
Class.class_method))
def test_class_constructor(self):
self.assertEquals(['k', 'l'],
utils.get_required_callable_args(
ClassWithInit))
def test_class_with_call(self):
self.assertEquals(['i', 'j'],
utils.get_required_callable_args(
CallableClass()))
def test_decorators_work(self):
@decorators.locked
def locked_fun(x, y):
pass
self.assertEquals(['x', 'y'],
utils.get_required_callable_args(locked_fun))

View File

@@ -20,6 +20,7 @@
import collections import collections
import contextlib import contextlib
import copy import copy
import inspect
import logging import logging
import re import re
import sys import sys
@@ -75,6 +76,32 @@ def get_callable_name(function):
return '.'.join(parts) return '.'.join(parts)
def is_bound_method(method):
return getattr(method, 'im_self', None) is not None
def get_required_callable_args(function):
"""Get names of argument required by callable"""
if isinstance(function, type):
bound = True
function = function.__init__
elif isinstance(function, (types.FunctionType, types.MethodType)):
bound = is_bound_method(function)
function = getattr(function, '__wrapped__', function)
else:
function = function.__call__
bound = is_bound_method(function)
argspec = inspect.getargspec(function)
f_args = argspec.args
if argspec.defaults:
f_args = f_args[:-len(argspec.defaults)]
if bound:
f_args = f_args[1:]
return f_args
def get_task_version(task): def get_task_version(task):
"""Gets a tasks *string* version, whether it is a task object/function.""" """Gets a tasks *string* version, whether it is a task object/function."""
task_version = getattr(task, 'version') task_version = getattr(task, 'version')