diff --git a/taskflow/engines/action_engine/actions/base.py b/taskflow/engines/action_engine/actions/base.py index 5595268ab..869ef228e 100644 --- a/taskflow/engines/action_engine/actions/base.py +++ b/taskflow/engines/action_engine/actions/base.py @@ -32,10 +32,9 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE) class Action(object): """An action that handles executing, state changes, ... of atoms.""" - def __init__(self, storage, notifier, walker_factory): + def __init__(self, storage, notifier): self._storage = storage self._notifier = notifier - self._walker_factory = walker_factory @abc.abstractmethod def handles(self, atom): diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py index 05496d96c..f69d5a5b8 100644 --- a/taskflow/engines/action_engine/actions/retry.py +++ b/taskflow/engines/action_engine/actions/retry.py @@ -44,8 +44,8 @@ def _revert_retry(retry, arguments): class RetryAction(base.Action): """An action that handles executing, state changes, ... of retry atoms.""" - def __init__(self, storage, notifier, walker_factory): - super(RetryAction, self).__init__(storage, notifier, walker_factory) + def __init__(self, storage, notifier): + super(RetryAction, self).__init__(storage, notifier) self._executor = futures.SynchronousExecutor() @staticmethod @@ -53,11 +53,9 @@ class RetryAction(base.Action): return isinstance(atom, retry_atom.Retry) def _get_retry_args(self, retry, addons=None): - scope_walker = self._walker_factory(retry) arguments = self._storage.fetch_mapped_args( retry.rebind, atom_name=retry.name, - scope_walker=scope_walker, optional_args=retry.optional ) history = self._storage.get_retry_history(retry.name) diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py index 8c64931ab..2a11bf8df 100644 --- a/taskflow/engines/action_engine/actions/task.py +++ b/taskflow/engines/action_engine/actions/task.py @@ -28,8 +28,8 @@ LOG = logging.getLogger(__name__) class TaskAction(base.Action): """An action that handles scheduling, state changes, ... of task atoms.""" - def __init__(self, storage, notifier, walker_factory, task_executor): - super(TaskAction, self).__init__(storage, notifier, walker_factory) + def __init__(self, storage, notifier, task_executor): + super(TaskAction, self).__init__(storage, notifier) self._task_executor = task_executor @staticmethod @@ -100,11 +100,9 @@ class TaskAction(base.Action): def schedule_execution(self, task): self.change_state(task, states.RUNNING, progress=0.0) - scope_walker = self._walker_factory(task) arguments = self._storage.fetch_mapped_args( task.rebind, atom_name=task.name, - scope_walker=scope_walker, optional_args=task.optional ) if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): @@ -126,11 +124,9 @@ class TaskAction(base.Action): def schedule_reversion(self, task): self.change_state(task, states.REVERTING, progress=0.0) - scope_walker = self._walker_factory(task) arguments = self._storage.fetch_mapped_args( task.rebind, atom_name=task.name, - scope_walker=scope_walker, optional_args=task.optional ) task_uuid = self._storage.get_atom_uuid(task.name) diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 9fa0f5a59..68bd3458b 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -29,6 +29,7 @@ from taskflow.engines.action_engine import runtime from taskflow.engines import base from taskflow import exceptions as exc from taskflow import states +from taskflow import storage from taskflow.types import failure from taskflow.utils import lock_utils from taskflow.utils import misc @@ -89,6 +90,29 @@ class ActionEngine(base.Engine): else: return None + @misc.cachedproperty + def storage(self): + """The storage unit for this engine. + + NOTE(harlowja): the atom argument lookup strategy will change for + this storage unit after + :py:func:`~taskflow.engines.base.Engine.compile` has + completed (since **only** after compilation is the actual structure + known). Before :py:func:`~taskflow.engines.base.Engine.compile` + has completed the atom argument lookup strategy lookup will be + restricted to injected arguments **only** (this will **not** reflect + the actual runtime lookup strategy, which typically will be, but is + not always different). + """ + def _scope_fetcher(atom_name): + if self._compiled: + return self._runtime.fetch_scopes_for(atom_name) + else: + return None + return storage.Storage(self._flow_detail, + backend=self._backend, + scope_fetcher=_scope_fetcher) + def run(self): with lock_utils.try_lock(self._lock) as was_locked: if not was_locked: @@ -192,9 +216,7 @@ class ActionEngine(base.Engine): missing = set() fetch = self.storage.fetch_unsatisfied_args for node in self._compilation.execution_graph.nodes_iter(): - scope_walker = self._runtime.fetch_scopes_for(node) missing.update(fetch(node.name, node.rebind, - scope_walker=scope_walker, optional_args=node.optional)) if missing: raise exc.MissingDependencies(self._flow, sorted(missing)) diff --git a/taskflow/engines/action_engine/runtime.py b/taskflow/engines/action_engine/runtime.py index d71ff65c7..d8df4705c 100644 --- a/taskflow/engines/action_engine/runtime.py +++ b/taskflow/engines/action_engine/runtime.py @@ -66,24 +66,31 @@ class Runtime(object): @misc.cachedproperty def retry_action(self): - return ra.RetryAction(self._storage, self._atom_notifier, - self.fetch_scopes_for) + return ra.RetryAction(self._storage, + self._atom_notifier) @misc.cachedproperty def task_action(self): return ta.TaskAction(self._storage, - self._atom_notifier, self.fetch_scopes_for, + self._atom_notifier, self._task_executor) - def fetch_scopes_for(self, atom): + def fetch_scopes_for(self, atom_name): """Fetches a tuple of the visible scopes for the given atom.""" try: - return self._scopes[atom] + return self._scopes[atom_name] except KeyError: - walker = sc.ScopeWalker(self.compilation, atom, - names_only=True) - visible_to = tuple(walker) - self._scopes[atom] = visible_to + atom = None + for node in self.analyzer.iterate_all_nodes(): + if node.name == atom_name: + atom = node + break + if atom is not None: + walker = sc.ScopeWalker(self.compilation, atom, + names_only=True) + self._scopes[atom_name] = visible_to = tuple(walker) + else: + visible_to = tuple([]) return visible_to # Various helper methods used by the runtime components; not for public diff --git a/taskflow/engines/base.py b/taskflow/engines/base.py index 632f626f3..824d90875 100644 --- a/taskflow/engines/base.py +++ b/taskflow/engines/base.py @@ -20,9 +20,7 @@ import abc from debtcollector import moves import six -from taskflow import storage from taskflow.types import notifier -from taskflow.utils import misc @six.add_metaclass(abc.ABCMeta) @@ -72,10 +70,9 @@ class Engine(object): """The options that were passed to this engine on construction.""" return self._options - @misc.cachedproperty + @abc.abstractproperty def storage(self): - """The storage unit for this flow.""" - return storage.Storage(self._flow_detail, backend=self._backend) + """The storage unit for this engine.""" @abc.abstractmethod def compile(self): diff --git a/taskflow/storage.py b/taskflow/storage.py index 3cf496aaf..1df6fd6d3 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -131,7 +131,7 @@ class Storage(object): with it must be avoided) that are *global* to the flow being executed. """ - def __init__(self, flow_detail, backend=None): + def __init__(self, flow_detail, backend=None, scope_fetcher=None): self._result_mappings = {} self._reverse_mapping = {} self._backend = backend @@ -143,6 +143,9 @@ class Storage(object): ((task.BaseTask,), self._ensure_task), ((retry.Retry,), self._ensure_retry), ] + if scope_fetcher is None: + scope_fetcher = lambda atom_name: None + self._scope_fetcher = scope_fetcher # NOTE(imelnikov): failure serialization looses information, # so we cache failures here, in atom name -> failure mapping. @@ -698,7 +701,7 @@ class Storage(object): non_default_providers.append(p) return default_providers, non_default_providers - def _locate_providers(name): + def _locate_providers(name, scope_walker=None): """Finds the accessible *potential* providers.""" default_providers, non_default_providers = _fetch_providers(name) providers = [] @@ -728,6 +731,8 @@ class Storage(object): return providers ad = self._atomdetail_by_name(atom_name) + if scope_walker is None: + scope_walker = self._scope_fetcher(atom_name) if optional_args is None: optional_args = [] injected_sources = [ @@ -749,7 +754,8 @@ class Storage(object): continue if name in source: maybe_providers += 1 - maybe_providers += len(_locate_providers(name)) + providers = _locate_providers(name, scope_walker=scope_walker) + maybe_providers += len(providers) if maybe_providers: LOG.blather("Atom %s will have %s potential providers" " of %r <= %r", atom_name, maybe_providers, @@ -797,7 +803,8 @@ class Storage(object): " by %s but was unable to get at that providers" " results" % (looking_for, provider), e) - def _locate_providers(looking_for, possible_providers): + def _locate_providers(looking_for, possible_providers, + scope_walker=None): """Finds the accessible providers.""" default_providers = [] for p in possible_providers: @@ -832,6 +839,8 @@ class Storage(object): self._injected_args.get(atom_name, {}), ad.meta.get(META_INJECTED, {}), ] + if scope_walker is None: + scope_walker = self._scope_fetcher(atom_name) else: injected_sources = [] if not args_mapping: @@ -869,7 +878,8 @@ class Storage(object): " produced output by any" " providers" % name) # Reduce the possible providers to one that are allowed. - providers = _locate_providers(name, possible_providers) + providers = _locate_providers(name, possible_providers, + scope_walker=scope_walker) if not providers: raise exceptions.NotFound( "Mapped argument %r <= %r was not produced"