Merge "Allow instance methods to be wrapped and unwrapped correctly."

This commit is contained in:
Jenkins
2013-06-27 19:46:04 +00:00
committed by Gerrit Code Review
7 changed files with 120 additions and 79 deletions

View File

@@ -19,14 +19,61 @@
import collections
import functools
import inspect
from taskflow import utils
import types
# These arguments are ones that we will skip when parsing for requirements
# for a function to operate (when used as a task).
AUTO_ARGS = ('self', 'context', 'cls')
def is_decorated(functor):
if not isinstance(functor, (types.MethodType, types.FunctionType)):
return False
return getattr(extract(functor), '__task__', False)
def extract(functor):
# Extract the underlying functor if its a method since we can not set
# attributes on instance methods, this is supposedly fixed in python 3
# and later.
#
# TODO(harlowja): add link to this fix.
assert isinstance(functor, (types.MethodType, types.FunctionType))
if isinstance(functor, types.MethodType):
return functor.__func__
else:
return functor
def _mark_as_task(functor):
setattr(functor, '__task__', True)
def _get_wrapped(function):
"""Get the method at the bottom of a stack of decorators."""
if hasattr(function, '__wrapped__'):
return getattr(function, '__wrapped__')
if not hasattr(function, 'func_closure') or not function.func_closure:
return function
def _get_wrapped_function(function):
if not hasattr(function, 'func_closure') or not function.func_closure:
return None
for closure in function.func_closure:
func = closure.cell_contents
deeper_func = _get_wrapped_function(func)
if deeper_func:
return deeper_func
elif hasattr(closure.cell_contents, '__call__'):
return closure.cell_contents
return _get_wrapped_function(function)
def _take_arg(a):
if a in AUTO_ARGS:
return False
@@ -54,30 +101,35 @@ def task(*args, **kwargs):
that function are set so that the function can be used as a task."""
def decorator(f):
w_f = extract(f)
def noop(*args, **kwargs):
pass
f.revert = kwargs.pop('revert_with', noop)
# Mark as being a task
_mark_as_task(w_f)
# By default don't revert this.
w_f.revert = kwargs.pop('revert_with', noop)
# Associate a name of this task that is the module + function name.
w_f.name = "%s.%s" % (f.__module__, f.__name__)
# Sets the version of the task.
version = kwargs.pop('version', (1, 0))
f = versionize(*version)(f)
f = _versionize(*version)(f)
# Attach any requirements this function needs for running.
requires_what = kwargs.pop('requires', [])
f = requires(*requires_what, **kwargs)(f)
f = _requires(*requires_what, **kwargs)(f)
# Attach any optional requirements this function needs for running.
optional_what = kwargs.pop('optional', [])
f = optional(*optional_what, **kwargs)(f)
f = _optional(*optional_what, **kwargs)(f)
# Attach any items this function provides as output
provides_what = kwargs.pop('provides', [])
f = provides(*provides_what, **kwargs)(f)
# Associate a name of this task that is the module + function name.
f.name = "%s.%s" % (f.__module__, f.__name__)
f = _provides(*provides_what, **kwargs)(f)
@wraps(f)
def wrapper(*args, **kwargs):
@@ -96,7 +148,7 @@ def task(*args, **kwargs):
return decorator
def versionize(major, minor=None):
def _versionize(major, minor=None):
"""A decorator that marks the wrapped function with a major & minor version
number."""
@@ -104,7 +156,8 @@ def versionize(major, minor=None):
minor = 0
def decorator(f):
f.__version__ = (major, minor)
w_f = extract(f)
w_f.version = (major, minor)
@wraps(f)
def wrapper(*args, **kwargs):
@@ -115,15 +168,17 @@ def versionize(major, minor=None):
return decorator
def optional(*args, **kwargs):
def _optional(*args, **kwargs):
"""Attaches a set of items that the decorated function would like as input
to the functions underlying dictionary."""
def decorator(f):
if not hasattr(f, 'optional'):
f.optional = set()
w_f = extract(f)
f.optional.update([a for a in args if _take_arg(a)])
if not hasattr(w_f, 'optional'):
w_f.optional = set()
w_f.optional.update([a for a in args if _take_arg(a)])
@wraps(f)
def wrapper(*args, **kwargs):
@@ -142,24 +197,22 @@ def optional(*args, **kwargs):
return decorator
def requires(*args, **kwargs):
def _requires(*args, **kwargs):
"""Attaches a set of items that the decorated function requires as input
to the functions underlying dictionary."""
def decorator(f):
if not hasattr(f, 'requires'):
f.requires = set()
w_f = extract(f)
if not hasattr(w_f, 'requires'):
w_f.requires = set()
if kwargs.pop('auto_extract', True):
inspect_what = getattr(f, '__wrapped__', None)
if not inspect_what:
inspect_what = utils.get_wrapped_function(f)
inspect_what = _get_wrapped(f)
f_args = inspect.getargspec(inspect_what).args
f.requires.update([a for a in f_args if _take_arg(a)])
w_f.requires.update([a for a in f_args if _take_arg(a)])
f.requires.update([a for a in args if _take_arg(a)])
w_f.requires.update([a for a in args if _take_arg(a)])
@wraps(f)
def wrapper(*args, **kwargs):
@@ -178,15 +231,17 @@ def requires(*args, **kwargs):
return decorator
def provides(*args, **kwargs):
def _provides(*args, **kwargs):
"""Attaches a set of items that the decorated function provides as output
to the functions underlying dictionary."""
def decorator(f):
if not hasattr(f, 'provides'):
f.provides = set()
w_f = extract(f)
f.provides.update([a for a in args if _take_arg(a)])
if not hasattr(f, 'provides'):
w_f.provides = set()
w_f.provides.update([a for a in args if _take_arg(a)])
@wraps(f)
def wrapper(*args, **kwargs):

View File

@@ -32,14 +32,10 @@ LOG = logging.getLogger(__name__)
def _get_task_version(task):
"""Gets a tasks *string* version, whether it is a task object/function."""
task_version = ''
if isinstance(task, types.FunctionType):
task_version = getattr(task, '__version__', '')
if not task_version and hasattr(task, 'version'):
task_version = task.version
task_version = utils.get_attr(task, 'version')
if isinstance(task_version, (list, tuple)):
task_version = utils.join(task_version, with_what=".")
if not isinstance(task_version, basestring):
if task_version is not None and not isinstance(task_version, basestring):
task_version = str(task_version)
return task_version
@@ -47,14 +43,13 @@ def _get_task_version(task):
def _get_task_name(task):
"""Gets a tasks *string* name, whether it is a task object/function."""
task_name = ""
if isinstance(task, types.FunctionType):
if isinstance(task, (types.MethodType, types.FunctionType)):
# If its a function look for the attributes that should have been
# set using the task() decorator provided in the decorators file. If
# those have not been set, then we should at least have enough basic
# information (not a version) to form a useful task name.
if hasattr(task, 'name'):
task_name = str(task.name)
else:
task_name = utils.get_attr(task, 'name')
if not task_name:
name_pieces = [a for a in utils.get_many_attr(task,
'__module__',
'__name__')

View File

@@ -25,6 +25,7 @@ from networkx import exception as g_exc
from taskflow import exceptions as exc
from taskflow.patterns import ordered_flow
from taskflow import utils
LOG = logging.getLogger(__name__)
@@ -46,6 +47,7 @@ class Flow(ordered_flow.Flow):
#
# Only insert the node to start, connect all the edges
# together later after all nodes have been added.
assert isinstance(task, collections.Callable)
self._graph.add_node(task)
self._connected = False
@@ -54,7 +56,8 @@ class Flow(ordered_flow.Flow):
def extract_inputs(place_where, would_like, is_optional=False):
for n in would_like:
for (them, there_result) in self.results:
if not n in set(getattr(them, 'provides', [])):
they_provide = utils.get_attr(them, 'provides', [])
if n not in set(they_provide):
continue
if ((not is_optional and
not self._graph.has_edge(them, task))):
@@ -68,8 +71,8 @@ class Flow(ordered_flow.Flow):
elif not is_optional:
place_where[n].append(None)
required_inputs = set(getattr(task, 'requires', []))
optional_inputs = set(getattr(task, 'optional', []))
required_inputs = set(utils.get_attr(task, 'requires', []))
optional_inputs = set(utils.get_attr(task, 'optional', []))
optional_inputs = optional_inputs - required_inputs
task_inputs = collections.defaultdict(list)
@@ -103,9 +106,9 @@ class Flow(ordered_flow.Flow):
provides_what = collections.defaultdict(list)
requires_what = collections.defaultdict(list)
for t in self._graph.nodes_iter():
for r in getattr(t, 'requires', []):
for r in utils.get_attr(t, 'requires', []):
requires_what[r].append(t)
for p in getattr(t, 'provides', []):
for p in utils.get_attr(t, 'provides', []):
provides_what[p].append(t)
def get_providers(node, want_what):

View File

@@ -16,8 +16,11 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
from taskflow import exceptions as exc
from taskflow.patterns import ordered_flow
from taskflow import utils
class Flow(ordered_flow.Flow):
@@ -30,14 +33,14 @@ class Flow(ordered_flow.Flow):
self._tasks = []
def _fetch_task_inputs(self, task):
would_like = set(getattr(task, 'requires', []))
would_like.update(getattr(task, 'optional', []))
would_like = set(utils.get_attr(task, 'requires', []))
would_like.update(utils.get_attr(task, 'optional', []))
inputs = {}
for n in would_like:
# Find the last task that provided this.
for (last_task, last_results) in reversed(self.results):
if n not in getattr(last_task, 'provides', []):
if n not in utils.get_attr(last_task, 'provides', []):
continue
if last_results and n in last_results:
inputs[n] = last_results[n]
@@ -50,10 +53,10 @@ class Flow(ordered_flow.Flow):
def _validate_provides(self, task):
# Ensure that some previous task provides this input.
missing_requires = []
for r in getattr(task, 'requires', []):
for r in utils.get_attr(task, 'requires', []):
found_provider = False
for prev_task in reversed(self._tasks):
if r in getattr(prev_task, 'provides', []):
if r in utils.get_attr(prev_task, 'provides', []):
found_provider = True
break
if not found_provider:
@@ -66,6 +69,7 @@ class Flow(ordered_flow.Flow):
raise exc.InvalidStateException(msg)
def add(self, task):
assert isinstance(task, collections.Callable)
self._validate_provides(task)
self._tasks.append(task)

View File

@@ -60,7 +60,7 @@ class LinearFlowTest(unittest2.TestCase):
def test_functor_flow(self):
wf = lw.Flow("the-test-action")
@decorators.provides('a', 'b', 'c')
@decorators.task(provides=['a', 'b', 'c'])
def do_apply1(context):
context['1'] = True
return {
@@ -69,7 +69,7 @@ class LinearFlowTest(unittest2.TestCase):
'c': 3,
}
@decorators.requires('c', 'a', auto_extract=False)
@decorators.task(requires=['c', 'a'], auto_extract=False)
def do_apply2(context, **kwargs):
self.assertTrue('c' in kwargs)
self.assertEquals(1, kwargs['a'])

View File

@@ -52,8 +52,7 @@ class MemoryBackendTest(unittest2.TestCase):
while not poison.isSet():
my_jobs = []
job_board.await(0.05)
job_search_from = None
for j in job_board.posted_after(job_search_from):
for j in job_board.posted_after():
if j.owner is not None:
continue
try:
@@ -61,10 +60,6 @@ class MemoryBackendTest(unittest2.TestCase):
my_jobs.append(j)
except exc.UnclaimableJobException:
pass
if not my_jobs:
# No jobs were claimed, lets not search the past again
# then, since *likely* those jobs will remain claimed...
job_search_from = datetime.datetime.utcnow()
if my_jobs and poison.isSet():
# Oh crap, we need to unclaim and repost the jobs.
for j in my_jobs:

View File

@@ -21,29 +21,18 @@ import logging
import threading
import time
from taskflow import decorators
LOG = logging.getLogger(__name__)
def get_wrapped_function(function):
"""Get the method at the bottom of a stack of decorators."""
if not hasattr(function, 'func_closure') or not function.func_closure:
return function
def _get_wrapped_function(function):
if not hasattr(function, 'func_closure') or not function.func_closure:
return None
for closure in function.func_closure:
func = closure.cell_contents
deeper_func = _get_wrapped_function(func)
if deeper_func:
return deeper_func
elif hasattr(closure.cell_contents, '__call__'):
return closure.cell_contents
return _get_wrapped_function(function)
def get_attr(task, field, default=None):
if decorators.is_decorated(task):
# If its a decorated functor then the attributes will be either
# in the underlying function of the instancemethod or the function
# itself.
task = decorators.extract(task)
return getattr(task, field, default)
def join(itr, with_what=","):
@@ -54,7 +43,7 @@ def join(itr, with_what=","):
def get_many_attr(obj, *attrs):
many = []
for a in attrs:
many.append(getattr(obj, a, None))
many.append(get_attr(obj, a, None))
return many