From f60ee1db6bdd743c2257bd0a36bc4f4ace3ad595 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Sun, 26 May 2013 12:10:04 -0700 Subject: [PATCH] Instead of apply use __call__ Instead of requiring a apply() function just use the built-in one provided by objects or functions implementing __call__. Also change how requires/provides may not be found if functors are just passed in (since functors implement __call__). --- taskflow/patterns/graph_flow.py | 9 +++++---- taskflow/patterns/linear_flow.py | 9 +++++---- taskflow/patterns/ordered_flow.py | 7 +++++-- taskflow/task.py | 2 +- taskflow/tests/utils.py | 2 +- taskflow/utils.py | 4 ++++ taskflow/wrappers.py | 2 +- 7 files changed, 22 insertions(+), 13 deletions(-) diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index df2cc23d9..5c560891a 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -24,6 +24,7 @@ from networkx.algorithms import dag from networkx.classes import digraph from taskflow import exceptions as exc +from taskflow import utils from taskflow.patterns import ordered_flow LOG = logging.getLogger(__name__) @@ -52,10 +53,10 @@ class Flow(ordered_flow.Flow): def _fetch_task_inputs(self, task): inputs = collections.defaultdict(list) - for n in task.requires: + for n in utils.safe_attr(task, 'requires', []): for (them, there_result) in self.results: if (not self._graph.has_edge(them, task) or - not n in them.provides): + not n in utils.safe_attr(them, 'provides', [])): continue if there_result and n in there_result: inputs[n].append(there_result[n]) @@ -89,9 +90,9 @@ class Flow(ordered_flow.Flow): provides_what = collections.defaultdict(list) requires_what = collections.defaultdict(list) for t in self._graph.nodes_iter(): - for r in t.requires: + for r in utils.safe_attr(t, 'requires', []): requires_what[r].append(t) - for p in t.provides: + for p in utils.safe_attr(t, 'provides', []): provides_what[p].append(t) def get_providers(node, want_what): diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 532dade71..59653a74e 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -17,6 +17,7 @@ # under the License. from taskflow import exceptions as exc +from taskflow import utils from taskflow.patterns import ordered_flow @@ -31,10 +32,10 @@ class Flow(ordered_flow.Flow): def _fetch_task_inputs(self, task): inputs = {} - for r in task.requires: + for r in utils.safe_attr(task, 'requires', []): # Find the last task that provided this. for (last_task, last_results) in reversed(self.results): - if r not in last_task.provides: + if r not in utils.safe_attr(last_task, 'provides', []): continue if last_results and r in last_results: inputs[r] = last_results[r] @@ -47,10 +48,10 @@ class Flow(ordered_flow.Flow): def _validate_provides(self, task): # Ensure that some previous task provides this input. missing_requires = [] - for r in task.requires: + for r in utils.safe_attr(task, 'requires', []): found_provider = False for prev_task in reversed(self._tasks): - if r in prev_task.provides: + if r in utils.safe_attr(prev_task, 'provides', []): found_provider = True break if not found_provider: diff --git a/taskflow/patterns/ordered_flow.py b/taskflow/patterns/ordered_flow.py index 5d20ec823..21acb7a64 100644 --- a/taskflow/patterns/ordered_flow.py +++ b/taskflow/patterns/ordered_flow.py @@ -17,6 +17,7 @@ # under the License. import abc +import collections import copy import functools import logging @@ -49,7 +50,9 @@ class RollbackTask(object): return str(self.task) def __call__(self, cause): - self.task.revert(self.context, self.result, cause) + if (hasattr(self.task, "revert") and + isinstance(self.task.revert, collections.Callable)): + self.task.revert(self.context, self.result, cause) class Flow(object): @@ -140,7 +143,7 @@ class Flow(object): if not inputs: inputs = {} inputs.update(kwargs) - result = task.apply(context, *args, **inputs) + result = task(context, *args, **inputs) # Keep a pristine copy of the result # so that if said result is altered by other further # states the one here will not be. This ensures that diff --git a/taskflow/task.py b/taskflow/task.py index 70fa6915a..51601101e 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -38,7 +38,7 @@ class Task(object): return "Task: %s" % (self.name) @abc.abstractmethod - def apply(self, context, *args, **kwargs): + def __call__(self, context, *args, **kwargs): """Activate a given task which will perform some operation and return. This method can be used to apply some given context and given set diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index ffe424dce..00e3e1986 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -40,7 +40,7 @@ class ProvidesRequiresTask(task.Task): self.provides = provides self.requires = requires - def apply(self, context, *args, **kwargs): + def __call__(self, context, *args, **kwargs): outs = { KWARGS_KEY: dict(kwargs), ARGS_KEY: list(args), diff --git a/taskflow/utils.py b/taskflow/utils.py index 730bc817b..dbb5add59 100644 --- a/taskflow/utils.py +++ b/taskflow/utils.py @@ -24,6 +24,10 @@ import time LOG = logging.getLogger(__name__) +def safe_attr(obj, name, default=None): + return getattr(obj, name, default) + + def await(check_functor, timeout=None): if timeout is not None: end_time = time.time() + max(0, timeout) diff --git a/taskflow/wrappers.py b/taskflow/wrappers.py index c97e9e987..714bedac1 100644 --- a/taskflow/wrappers.py +++ b/taskflow/wrappers.py @@ -45,7 +45,7 @@ class FunctorTask(task.Task): continue self.requires.add(arg_name) - def apply(self, context, *args, **kwargs): + def __call__(self, context, *args, **kwargs): return self._apply_functor(context, *args, **kwargs) def revert(self, context, result, cause):