Instead of apply use __call__
Instead of requiring a apply() function just use the built-in one provided by objects or functions implementing __call__. Also change how requires/provides may not be found if functors are just passed in (since functors implement __call__).
This commit is contained in:
@@ -24,6 +24,7 @@ from networkx.algorithms import dag
|
||||
from networkx.classes import digraph
|
||||
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import utils
|
||||
from taskflow.patterns import ordered_flow
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -52,10 +53,10 @@ class Flow(ordered_flow.Flow):
|
||||
def _fetch_task_inputs(self, task):
|
||||
inputs = collections.defaultdict(list)
|
||||
|
||||
for n in task.requires:
|
||||
for n in utils.safe_attr(task, 'requires', []):
|
||||
for (them, there_result) in self.results:
|
||||
if (not self._graph.has_edge(them, task) or
|
||||
not n in them.provides):
|
||||
not n in utils.safe_attr(them, 'provides', [])):
|
||||
continue
|
||||
if there_result and n in there_result:
|
||||
inputs[n].append(there_result[n])
|
||||
@@ -89,9 +90,9 @@ class Flow(ordered_flow.Flow):
|
||||
provides_what = collections.defaultdict(list)
|
||||
requires_what = collections.defaultdict(list)
|
||||
for t in self._graph.nodes_iter():
|
||||
for r in t.requires:
|
||||
for r in utils.safe_attr(t, 'requires', []):
|
||||
requires_what[r].append(t)
|
||||
for p in t.provides:
|
||||
for p in utils.safe_attr(t, 'provides', []):
|
||||
provides_what[p].append(t)
|
||||
|
||||
def get_providers(node, want_what):
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# under the License.
|
||||
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import utils
|
||||
from taskflow.patterns import ordered_flow
|
||||
|
||||
|
||||
@@ -31,10 +32,10 @@ class Flow(ordered_flow.Flow):
|
||||
|
||||
def _fetch_task_inputs(self, task):
|
||||
inputs = {}
|
||||
for r in task.requires:
|
||||
for r in utils.safe_attr(task, 'requires', []):
|
||||
# Find the last task that provided this.
|
||||
for (last_task, last_results) in reversed(self.results):
|
||||
if r not in last_task.provides:
|
||||
if r not in utils.safe_attr(last_task, 'provides', []):
|
||||
continue
|
||||
if last_results and r in last_results:
|
||||
inputs[r] = last_results[r]
|
||||
@@ -47,10 +48,10 @@ class Flow(ordered_flow.Flow):
|
||||
def _validate_provides(self, task):
|
||||
# Ensure that some previous task provides this input.
|
||||
missing_requires = []
|
||||
for r in task.requires:
|
||||
for r in utils.safe_attr(task, 'requires', []):
|
||||
found_provider = False
|
||||
for prev_task in reversed(self._tasks):
|
||||
if r in prev_task.provides:
|
||||
if r in utils.safe_attr(prev_task, 'provides', []):
|
||||
found_provider = True
|
||||
break
|
||||
if not found_provider:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
@@ -49,7 +50,9 @@ class RollbackTask(object):
|
||||
return str(self.task)
|
||||
|
||||
def __call__(self, cause):
|
||||
self.task.revert(self.context, self.result, cause)
|
||||
if (hasattr(self.task, "revert") and
|
||||
isinstance(self.task.revert, collections.Callable)):
|
||||
self.task.revert(self.context, self.result, cause)
|
||||
|
||||
|
||||
class Flow(object):
|
||||
@@ -140,7 +143,7 @@ class Flow(object):
|
||||
if not inputs:
|
||||
inputs = {}
|
||||
inputs.update(kwargs)
|
||||
result = task.apply(context, *args, **inputs)
|
||||
result = task(context, *args, **inputs)
|
||||
# Keep a pristine copy of the result
|
||||
# so that if said result is altered by other further
|
||||
# states the one here will not be. This ensures that
|
||||
|
||||
@@ -38,7 +38,7 @@ class Task(object):
|
||||
return "Task: %s" % (self.name)
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply(self, context, *args, **kwargs):
|
||||
def __call__(self, context, *args, **kwargs):
|
||||
"""Activate a given task which will perform some operation and return.
|
||||
|
||||
This method can be used to apply some given context and given set
|
||||
|
||||
@@ -40,7 +40,7 @@ class ProvidesRequiresTask(task.Task):
|
||||
self.provides = provides
|
||||
self.requires = requires
|
||||
|
||||
def apply(self, context, *args, **kwargs):
|
||||
def __call__(self, context, *args, **kwargs):
|
||||
outs = {
|
||||
KWARGS_KEY: dict(kwargs),
|
||||
ARGS_KEY: list(args),
|
||||
|
||||
@@ -24,6 +24,10 @@ import time
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_attr(obj, name, default=None):
|
||||
return getattr(obj, name, default)
|
||||
|
||||
|
||||
def await(check_functor, timeout=None):
|
||||
if timeout is not None:
|
||||
end_time = time.time() + max(0, timeout)
|
||||
|
||||
@@ -45,7 +45,7 @@ class FunctorTask(task.Task):
|
||||
continue
|
||||
self.requires.add(arg_name)
|
||||
|
||||
def apply(self, context, *args, **kwargs):
|
||||
def __call__(self, context, *args, **kwargs):
|
||||
return self._apply_functor(context, *args, **kwargs)
|
||||
|
||||
def revert(self, context, result, cause):
|
||||
|
||||
Reference in New Issue
Block a user