Instead of apply use __call__

Instead of requiring a apply() function just use the
built-in one provided by objects or functions implementing
__call__. Also change how requires/provides may not be found
if functors are just passed in (since functors implement
__call__).
This commit is contained in:
Joshua Harlow
2013-05-26 12:10:04 -07:00
parent de1af74125
commit f60ee1db6b
7 changed files with 22 additions and 13 deletions

View File

@@ -24,6 +24,7 @@ from networkx.algorithms import dag
from networkx.classes import digraph
from taskflow import exceptions as exc
from taskflow import utils
from taskflow.patterns import ordered_flow
LOG = logging.getLogger(__name__)
@@ -52,10 +53,10 @@ class Flow(ordered_flow.Flow):
def _fetch_task_inputs(self, task):
inputs = collections.defaultdict(list)
for n in task.requires:
for n in utils.safe_attr(task, 'requires', []):
for (them, there_result) in self.results:
if (not self._graph.has_edge(them, task) or
not n in them.provides):
not n in utils.safe_attr(them, 'provides', [])):
continue
if there_result and n in there_result:
inputs[n].append(there_result[n])
@@ -89,9 +90,9 @@ class Flow(ordered_flow.Flow):
provides_what = collections.defaultdict(list)
requires_what = collections.defaultdict(list)
for t in self._graph.nodes_iter():
for r in t.requires:
for r in utils.safe_attr(t, 'requires', []):
requires_what[r].append(t)
for p in t.provides:
for p in utils.safe_attr(t, 'provides', []):
provides_what[p].append(t)
def get_providers(node, want_what):

View File

@@ -17,6 +17,7 @@
# under the License.
from taskflow import exceptions as exc
from taskflow import utils
from taskflow.patterns import ordered_flow
@@ -31,10 +32,10 @@ class Flow(ordered_flow.Flow):
def _fetch_task_inputs(self, task):
inputs = {}
for r in task.requires:
for r in utils.safe_attr(task, 'requires', []):
# Find the last task that provided this.
for (last_task, last_results) in reversed(self.results):
if r not in last_task.provides:
if r not in utils.safe_attr(last_task, 'provides', []):
continue
if last_results and r in last_results:
inputs[r] = last_results[r]
@@ -47,10 +48,10 @@ class Flow(ordered_flow.Flow):
def _validate_provides(self, task):
# Ensure that some previous task provides this input.
missing_requires = []
for r in task.requires:
for r in utils.safe_attr(task, 'requires', []):
found_provider = False
for prev_task in reversed(self._tasks):
if r in prev_task.provides:
if r in utils.safe_attr(prev_task, 'provides', []):
found_provider = True
break
if not found_provider:

View File

@@ -17,6 +17,7 @@
# under the License.
import abc
import collections
import copy
import functools
import logging
@@ -49,7 +50,9 @@ class RollbackTask(object):
return str(self.task)
def __call__(self, cause):
self.task.revert(self.context, self.result, cause)
if (hasattr(self.task, "revert") and
isinstance(self.task.revert, collections.Callable)):
self.task.revert(self.context, self.result, cause)
class Flow(object):
@@ -140,7 +143,7 @@ class Flow(object):
if not inputs:
inputs = {}
inputs.update(kwargs)
result = task.apply(context, *args, **inputs)
result = task(context, *args, **inputs)
# Keep a pristine copy of the result
# so that if said result is altered by other further
# states the one here will not be. This ensures that

View File

@@ -38,7 +38,7 @@ class Task(object):
return "Task: %s" % (self.name)
@abc.abstractmethod
def apply(self, context, *args, **kwargs):
def __call__(self, context, *args, **kwargs):
"""Activate a given task which will perform some operation and return.
This method can be used to apply some given context and given set

View File

@@ -40,7 +40,7 @@ class ProvidesRequiresTask(task.Task):
self.provides = provides
self.requires = requires
def apply(self, context, *args, **kwargs):
def __call__(self, context, *args, **kwargs):
outs = {
KWARGS_KEY: dict(kwargs),
ARGS_KEY: list(args),

View File

@@ -24,6 +24,10 @@ import time
LOG = logging.getLogger(__name__)
def safe_attr(obj, name, default=None):
return getattr(obj, name, default)
def await(check_functor, timeout=None):
if timeout is not None:
end_time = time.time() + max(0, timeout)

View File

@@ -45,7 +45,7 @@ class FunctorTask(task.Task):
continue
self.requires.add(arg_name)
def apply(self, context, *args, **kwargs):
def __call__(self, context, *args, **kwargs):
return self._apply_functor(context, *args, **kwargs)
def revert(self, context, result, cause):