Allow the storage unit to use the right scoping strategy

Instead of having the fetch arguments functions need to be
provided a scope walker to correctly find the right arguments,
which only the internals of the action engine know about
provide a default scope walker (that is the same one the
action engine internal uses) to the storage unit and have it be
the default strategy used so that users need not know how to
pass it in (which they should not care about).

This allows for users to fetch the same mapped arguments as the
internals of the engine will fetch.

Change-Id: I1beca532b2b7c7ad98b09265a0c4477658052d16
This commit is contained in:
Joshua Harlow
2015-03-11 16:41:15 -07:00
parent f0de22c18a
commit 5996c8f25e
7 changed files with 62 additions and 33 deletions

View File

@@ -32,10 +32,9 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
class Action(object): class Action(object):
"""An action that handles executing, state changes, ... of atoms.""" """An action that handles executing, state changes, ... of atoms."""
def __init__(self, storage, notifier, walker_factory): def __init__(self, storage, notifier):
self._storage = storage self._storage = storage
self._notifier = notifier self._notifier = notifier
self._walker_factory = walker_factory
@abc.abstractmethod @abc.abstractmethod
def handles(self, atom): def handles(self, atom):

View File

@@ -44,8 +44,8 @@ def _revert_retry(retry, arguments):
class RetryAction(base.Action): class RetryAction(base.Action):
"""An action that handles executing, state changes, ... of retry atoms.""" """An action that handles executing, state changes, ... of retry atoms."""
def __init__(self, storage, notifier, walker_factory): def __init__(self, storage, notifier):
super(RetryAction, self).__init__(storage, notifier, walker_factory) super(RetryAction, self).__init__(storage, notifier)
self._executor = futures.SynchronousExecutor() self._executor = futures.SynchronousExecutor()
@staticmethod @staticmethod
@@ -53,11 +53,9 @@ class RetryAction(base.Action):
return isinstance(atom, retry_atom.Retry) return isinstance(atom, retry_atom.Retry)
def _get_retry_args(self, retry, addons=None): def _get_retry_args(self, retry, addons=None):
scope_walker = self._walker_factory(retry)
arguments = self._storage.fetch_mapped_args( arguments = self._storage.fetch_mapped_args(
retry.rebind, retry.rebind,
atom_name=retry.name, atom_name=retry.name,
scope_walker=scope_walker,
optional_args=retry.optional optional_args=retry.optional
) )
history = self._storage.get_retry_history(retry.name) history = self._storage.get_retry_history(retry.name)

View File

@@ -28,8 +28,8 @@ LOG = logging.getLogger(__name__)
class TaskAction(base.Action): class TaskAction(base.Action):
"""An action that handles scheduling, state changes, ... of task atoms.""" """An action that handles scheduling, state changes, ... of task atoms."""
def __init__(self, storage, notifier, walker_factory, task_executor): def __init__(self, storage, notifier, task_executor):
super(TaskAction, self).__init__(storage, notifier, walker_factory) super(TaskAction, self).__init__(storage, notifier)
self._task_executor = task_executor self._task_executor = task_executor
@staticmethod @staticmethod
@@ -100,11 +100,9 @@ class TaskAction(base.Action):
def schedule_execution(self, task): def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0) self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task)
arguments = self._storage.fetch_mapped_args( arguments = self._storage.fetch_mapped_args(
task.rebind, task.rebind,
atom_name=task.name, atom_name=task.name,
scope_walker=scope_walker,
optional_args=task.optional optional_args=task.optional
) )
if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS): if task.notifier.can_be_registered(task_atom.EVENT_UPDATE_PROGRESS):
@@ -126,11 +124,9 @@ class TaskAction(base.Action):
def schedule_reversion(self, task): def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0) self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task)
arguments = self._storage.fetch_mapped_args( arguments = self._storage.fetch_mapped_args(
task.rebind, task.rebind,
atom_name=task.name, atom_name=task.name,
scope_walker=scope_walker,
optional_args=task.optional optional_args=task.optional
) )
task_uuid = self._storage.get_atom_uuid(task.name) task_uuid = self._storage.get_atom_uuid(task.name)

View File

@@ -29,6 +29,7 @@ from taskflow.engines.action_engine import runtime
from taskflow.engines import base from taskflow.engines import base
from taskflow import exceptions as exc from taskflow import exceptions as exc
from taskflow import states from taskflow import states
from taskflow import storage
from taskflow.types import failure from taskflow.types import failure
from taskflow.utils import lock_utils from taskflow.utils import lock_utils
from taskflow.utils import misc from taskflow.utils import misc
@@ -89,6 +90,29 @@ class ActionEngine(base.Engine):
else: else:
return None return None
@misc.cachedproperty
def storage(self):
"""The storage unit for this engine.
NOTE(harlowja): the atom argument lookup strategy will change for
this storage unit after
:py:func:`~taskflow.engines.base.Engine.compile` has
completed (since **only** after compilation is the actual structure
known). Before :py:func:`~taskflow.engines.base.Engine.compile`
has completed the atom argument lookup strategy lookup will be
restricted to injected arguments **only** (this will **not** reflect
the actual runtime lookup strategy, which typically will be, but is
not always different).
"""
def _scope_fetcher(atom_name):
if self._compiled:
return self._runtime.fetch_scopes_for(atom_name)
else:
return None
return storage.Storage(self._flow_detail,
backend=self._backend,
scope_fetcher=_scope_fetcher)
def run(self): def run(self):
with lock_utils.try_lock(self._lock) as was_locked: with lock_utils.try_lock(self._lock) as was_locked:
if not was_locked: if not was_locked:
@@ -192,9 +216,7 @@ class ActionEngine(base.Engine):
missing = set() missing = set()
fetch = self.storage.fetch_unsatisfied_args fetch = self.storage.fetch_unsatisfied_args
for node in self._compilation.execution_graph.nodes_iter(): for node in self._compilation.execution_graph.nodes_iter():
scope_walker = self._runtime.fetch_scopes_for(node)
missing.update(fetch(node.name, node.rebind, missing.update(fetch(node.name, node.rebind,
scope_walker=scope_walker,
optional_args=node.optional)) optional_args=node.optional))
if missing: if missing:
raise exc.MissingDependencies(self._flow, sorted(missing)) raise exc.MissingDependencies(self._flow, sorted(missing))

View File

@@ -66,24 +66,31 @@ class Runtime(object):
@misc.cachedproperty @misc.cachedproperty
def retry_action(self): def retry_action(self):
return ra.RetryAction(self._storage, self._atom_notifier, return ra.RetryAction(self._storage,
self.fetch_scopes_for) self._atom_notifier)
@misc.cachedproperty @misc.cachedproperty
def task_action(self): def task_action(self):
return ta.TaskAction(self._storage, return ta.TaskAction(self._storage,
self._atom_notifier, self.fetch_scopes_for, self._atom_notifier,
self._task_executor) self._task_executor)
def fetch_scopes_for(self, atom): def fetch_scopes_for(self, atom_name):
"""Fetches a tuple of the visible scopes for the given atom.""" """Fetches a tuple of the visible scopes for the given atom."""
try: try:
return self._scopes[atom] return self._scopes[atom_name]
except KeyError: except KeyError:
atom = None
for node in self.analyzer.iterate_all_nodes():
if node.name == atom_name:
atom = node
break
if atom is not None:
walker = sc.ScopeWalker(self.compilation, atom, walker = sc.ScopeWalker(self.compilation, atom,
names_only=True) names_only=True)
visible_to = tuple(walker) self._scopes[atom_name] = visible_to = tuple(walker)
self._scopes[atom] = visible_to else:
visible_to = tuple([])
return visible_to return visible_to
# Various helper methods used by the runtime components; not for public # Various helper methods used by the runtime components; not for public

View File

@@ -20,9 +20,7 @@ import abc
from debtcollector import moves from debtcollector import moves
import six import six
from taskflow import storage
from taskflow.types import notifier from taskflow.types import notifier
from taskflow.utils import misc
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
@@ -72,10 +70,9 @@ class Engine(object):
"""The options that were passed to this engine on construction.""" """The options that were passed to this engine on construction."""
return self._options return self._options
@misc.cachedproperty @abc.abstractproperty
def storage(self): def storage(self):
"""The storage unit for this flow.""" """The storage unit for this engine."""
return storage.Storage(self._flow_detail, backend=self._backend)
@abc.abstractmethod @abc.abstractmethod
def compile(self): def compile(self):

View File

@@ -131,7 +131,7 @@ class Storage(object):
with it must be avoided) that are *global* to the flow being executed. with it must be avoided) that are *global* to the flow being executed.
""" """
def __init__(self, flow_detail, backend=None): def __init__(self, flow_detail, backend=None, scope_fetcher=None):
self._result_mappings = {} self._result_mappings = {}
self._reverse_mapping = {} self._reverse_mapping = {}
self._backend = backend self._backend = backend
@@ -143,6 +143,9 @@ class Storage(object):
((task.BaseTask,), self._ensure_task), ((task.BaseTask,), self._ensure_task),
((retry.Retry,), self._ensure_retry), ((retry.Retry,), self._ensure_retry),
] ]
if scope_fetcher is None:
scope_fetcher = lambda atom_name: None
self._scope_fetcher = scope_fetcher
# NOTE(imelnikov): failure serialization looses information, # NOTE(imelnikov): failure serialization looses information,
# so we cache failures here, in atom name -> failure mapping. # so we cache failures here, in atom name -> failure mapping.
@@ -698,7 +701,7 @@ class Storage(object):
non_default_providers.append(p) non_default_providers.append(p)
return default_providers, non_default_providers return default_providers, non_default_providers
def _locate_providers(name): def _locate_providers(name, scope_walker=None):
"""Finds the accessible *potential* providers.""" """Finds the accessible *potential* providers."""
default_providers, non_default_providers = _fetch_providers(name) default_providers, non_default_providers = _fetch_providers(name)
providers = [] providers = []
@@ -728,6 +731,8 @@ class Storage(object):
return providers return providers
ad = self._atomdetail_by_name(atom_name) ad = self._atomdetail_by_name(atom_name)
if scope_walker is None:
scope_walker = self._scope_fetcher(atom_name)
if optional_args is None: if optional_args is None:
optional_args = [] optional_args = []
injected_sources = [ injected_sources = [
@@ -749,7 +754,8 @@ class Storage(object):
continue continue
if name in source: if name in source:
maybe_providers += 1 maybe_providers += 1
maybe_providers += len(_locate_providers(name)) providers = _locate_providers(name, scope_walker=scope_walker)
maybe_providers += len(providers)
if maybe_providers: if maybe_providers:
LOG.blather("Atom %s will have %s potential providers" LOG.blather("Atom %s will have %s potential providers"
" of %r <= %r", atom_name, maybe_providers, " of %r <= %r", atom_name, maybe_providers,
@@ -797,7 +803,8 @@ class Storage(object):
" by %s but was unable to get at that providers" " by %s but was unable to get at that providers"
" results" % (looking_for, provider), e) " results" % (looking_for, provider), e)
def _locate_providers(looking_for, possible_providers): def _locate_providers(looking_for, possible_providers,
scope_walker=None):
"""Finds the accessible providers.""" """Finds the accessible providers."""
default_providers = [] default_providers = []
for p in possible_providers: for p in possible_providers:
@@ -832,6 +839,8 @@ class Storage(object):
self._injected_args.get(atom_name, {}), self._injected_args.get(atom_name, {}),
ad.meta.get(META_INJECTED, {}), ad.meta.get(META_INJECTED, {}),
] ]
if scope_walker is None:
scope_walker = self._scope_fetcher(atom_name)
else: else:
injected_sources = [] injected_sources = []
if not args_mapping: if not args_mapping:
@@ -869,7 +878,8 @@ class Storage(object):
" produced output by any" " produced output by any"
" providers" % name) " providers" % name)
# Reduce the possible providers to one that are allowed. # Reduce the possible providers to one that are allowed.
providers = _locate_providers(name, possible_providers) providers = _locate_providers(name, possible_providers,
scope_walker=scope_walker)
if not providers: if not providers:
raise exceptions.NotFound( raise exceptions.NotFound(
"Mapped argument %r <= %r was not produced" "Mapped argument %r <= %r was not produced"