diff --git a/taskflow/engines/action_engine/compiler.py b/taskflow/engines/action_engine/compiler.py index 32cb58c8..41fe5f96 100644 --- a/taskflow/engines/action_engine/compiler.py +++ b/taskflow/engines/action_engine/compiler.py @@ -15,53 +15,34 @@ # under the License. import logging +import threading from taskflow import exceptions as exc from taskflow import flow from taskflow import retry from taskflow import task from taskflow.types import graph as gr +from taskflow.types import tree as tr +from taskflow.utils import lock_utils from taskflow.utils import misc LOG = logging.getLogger(__name__) class Compilation(object): - """The result of a compilers compile() is this *immutable* object. + """The result of a compilers compile() is this *immutable* object.""" - For now it is just a execution graph but in the future it will grow to - include more methods & properties that help the various runtime units - execute in a more optimal & featureful manner. - """ - def __init__(self, execution_graph): + def __init__(self, execution_graph, hierarchy): self._execution_graph = execution_graph + self._hierarchy = hierarchy @property def execution_graph(self): return self._execution_graph - -class PatternCompiler(object): - """Compiles patterns & atoms into a compilation unit. - - NOTE(harlowja): during this pattern translation process any nested flows - will be converted into there equivalent subgraphs. This currently implies - that contained atoms in those nested flows, post-translation will no longer - be associated with there previously containing flow but instead will lose - this identity and what will remain is the logical constraints that there - contained flow mandated. In the future this may be changed so that this - association is not lost via the compilation process (since it can be - useful to retain this relationship). - """ - def compile(self, root): - graph = _Flattener(root).flatten() - if graph.number_of_nodes() == 0: - # Try to get a name attribute, otherwise just use the object - # string representation directly if that attribute does not exist. - name = getattr(root, 'name', root) - raise exc.Empty("Root container '%s' (%s) is empty." - % (name, type(root))) - return Compilation(graph) + @property + def hierarchy(self): + return self._hierarchy _RETRY_EDGE_DATA = { @@ -69,14 +50,15 @@ _RETRY_EDGE_DATA = { } -class _Flattener(object): - """Flattens a root item (task/flow) into a execution graph.""" +class PatternCompiler(object): + """Compiles a pattern (or task) into a compilation unit.""" def __init__(self, root, freeze=True): self._root = root - self._graph = None self._history = set() - self._freeze = bool(freeze) + self._freeze = freeze + self._lock = threading.Lock() + self._compilation = None def _add_new_edges(self, graph, nodes_from, nodes_to, edge_attrs): """Adds new edges from nodes to other nodes in the specified graph. @@ -93,72 +75,74 @@ class _Flattener(object): # if it's later modified that the same copy isn't modified. graph.add_edge(u, v, attr_dict=edge_attrs.copy()) - def _flatten(self, item): - functor = self._find_flattener(item) - if not functor: - raise TypeError("Unknown type requested to flatten: %s (%s)" - % (item, type(item))) + def _flatten(self, item, parent): + functor = self._find_flattener(item, parent) self._pre_item_flatten(item) - graph = functor(item) - self._post_item_flatten(item, graph) - return graph + graph, node = functor(item, parent) + self._post_item_flatten(item, graph, node) + return graph, node - def _find_flattener(self, item): + def _find_flattener(self, item, parent): """Locates the flattening function to use to flatten the given item.""" if isinstance(item, flow.Flow): return self._flatten_flow elif isinstance(item, task.BaseTask): return self._flatten_task elif isinstance(item, retry.Retry): - if len(self._history) == 1: - raise TypeError("Retry controller: %s (%s) must only be used" + if parent is None: + raise TypeError("Retry controller '%s' (%s) must only be used" " as a flow constructor parameter and not as a" " root component" % (item, type(item))) else: - # TODO(harlowja): we should raise this type error earlier - # instead of later since we should do this same check on add() - # calls, this makes the error more visible (instead of waiting - # until compile time). - raise TypeError("Retry controller: %s (%s) must only be used" + raise TypeError("Retry controller '%s' (%s) must only be used" " as a flow constructor parameter and not as a" " flow added component" % (item, type(item))) else: - return None + raise TypeError("Unknown item '%s' (%s) requested to flatten" + % (item, type(item))) def _connect_retry(self, retry, graph): graph.add_node(retry) - # All graph nodes that have no predecessors should depend on its retry - nodes_to = [n for n in graph.no_predecessors_iter() if n != retry] + # All nodes that have no predecessors should depend on this retry. + nodes_to = [n for n in graph.no_predecessors_iter() if n is not retry] self._add_new_edges(graph, [retry], nodes_to, _RETRY_EDGE_DATA) - # Add link to retry for each node of subgraph that hasn't - # a parent retry + # Add association for each node of graph that has no existing retry. for n in graph.nodes_iter(): - if n != retry and 'retry' not in graph.node[n]: + if n is not retry and 'retry' not in graph.node[n]: graph.node[n]['retry'] = retry - def _flatten_task(self, task): + def _flatten_task(self, task, parent): """Flattens a individual task.""" graph = gr.DiGraph(name=task.name) graph.add_node(task) - return graph + node = tr.Node(task) + if parent is not None: + parent.add(node) + return graph, node - def _flatten_flow(self, flow): - """Flattens a graph flow.""" + def _flatten_flow(self, flow, parent): + """Flattens a flow.""" graph = gr.DiGraph(name=flow.name) + node = tr.Node(flow) + if parent is not None: + parent.add(node) + if flow.retry is not None: + node.add(tr.Node(flow.retry)) - # Flatten all nodes into a single subgraph per node. - subgraph_map = {} + # Flatten all nodes into a single subgraph per item (and track origin + # item to its newly expanded graph). + subgraphs = {} for item in flow: - subgraph = self._flatten(item) - subgraph_map[item] = subgraph + subgraph = self._flatten(item, node)[0] + subgraphs[item] = subgraph graph = gr.merge_graphs([graph, subgraph]) - # Reconnect all node edges to their corresponding subgraphs. + # Reconnect all items edges to their corresponding subgraphs. for (u, v, attrs) in flow.iter_links(): - u_g = subgraph_map[u] - v_g = subgraph_map[v] + u_g = subgraphs[u] + v_g = subgraphs[v] if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')): # Connect nodes with no predecessors in v to nodes with # no successors in u (thus maintaining the edge dependency). @@ -177,48 +161,57 @@ class _Flattener(object): if flow.retry is not None: self._connect_retry(flow.retry, graph) - return graph + return graph, node def _pre_item_flatten(self, item): """Called before a item is flattened; any pre-flattening actions.""" - if id(item) in self._history: - raise ValueError("Already flattened item: %s (%s), recursive" - " flattening not supported" % (item, id(item))) - self._history.add(id(item)) + if item in self._history: + raise ValueError("Already flattened item '%s' (%s), recursive" + " flattening is not supported" % (item, + type(item))) + self._history.add(item) - def _post_item_flatten(self, item, graph): - """Called before a item is flattened; any post-flattening actions.""" + def _post_item_flatten(self, item, graph, node): + """Called after a item is flattened; doing post-flattening actions.""" def _pre_flatten(self): - """Called before the flattening of the item starts.""" + """Called before the flattening of the root starts.""" self._history.clear() - def _post_flatten(self, graph): - """Called after the flattening of the item finishes successfully.""" + def _post_flatten(self, graph, node): + """Called after the flattening of the root finishes successfully.""" dup_names = misc.get_duplicate_keys(graph.nodes_iter(), key=lambda node: node.name) if dup_names: - dup_names = ', '.join(sorted(dup_names)) - raise exc.Duplicate("Atoms with duplicate names " - "found: %s" % (dup_names)) + raise exc.Duplicate( + "Atoms with duplicate names found: %s" % (sorted(dup_names))) + if graph.number_of_nodes() == 0: + raise exc.Empty("Root container '%s' (%s) is empty" + % (self._root, type(self._root))) self._history.clear() # NOTE(harlowja): this one can be expensive to calculate (especially # the cycle detection), so only do it if we know debugging is enabled # and not under all cases. if LOG.isEnabledFor(logging.DEBUG): - LOG.debug("Translated '%s' into a graph:", self._root) + LOG.debug("Translated '%s'", self._root) + LOG.debug("Graph:") for line in graph.pformat().splitlines(): # Indent it so that it's slightly offset from the above line. - LOG.debug(" %s", line) + LOG.debug(" %s", line) + LOG.debug("Hierarchy:") + for line in node.pformat().splitlines(): + # Indent it so that it's slightly offset from the above line. + LOG.debug(" %s", line) - def flatten(self): - """Flattens a item (a task or flow) into a single execution graph.""" - if self._graph is not None: - return self._graph - self._pre_flatten() - graph = self._flatten(self._root) - self._post_flatten(graph) - self._graph = graph - if self._freeze: - self._graph.freeze() - return self._graph + @lock_utils.locked + def compile(self): + """Compiles the contained item into a compiled equivalent.""" + if self._compilation is None: + self._pre_flatten() + graph, node = self._flatten(self._root, None) + self._post_flatten(graph, node) + if self._freeze: + graph.freeze() + node.freeze() + self._compilation = Compilation(graph, node) + return self._compilation diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 9bf62429..3aaa3226 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -211,13 +211,13 @@ class ActionEngine(base.EngineBase): @misc.cachedproperty def _compiler(self): - return self._compiler_factory() + return self._compiler_factory(self._flow) @lock_utils.locked def compile(self): if self._compiled: return - self._compilation = self._compiler.compile(self._flow) + self._compilation = self._compiler.compile() self._runtime = runtime.Runtime(self._compilation, self.storage, self.task_notifier, diff --git a/taskflow/engines/action_engine/retry_action.py b/taskflow/engines/action_engine/retry_action.py index afdfb456..e4df5afa 100644 --- a/taskflow/engines/action_engine/retry_action.py +++ b/taskflow/engines/action_engine/retry_action.py @@ -27,13 +27,16 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE) class RetryAction(object): - def __init__(self, storage, notifier): + def __init__(self, storage, notifier, walker_factory): self._storage = storage self._notifier = notifier + self._walker_factory = walker_factory def _get_retry_args(self, retry): + scope_walker = self._walker_factory(retry) kwargs = self._storage.fetch_mapped_args(retry.rebind, - atom_name=retry.name) + atom_name=retry.name, + scope_walker=scope_walker) kwargs['history'] = self._storage.get_retry_history(retry.name) return kwargs diff --git a/taskflow/engines/action_engine/runtime.py b/taskflow/engines/action_engine/runtime.py index 90913b99..c0c58367 100644 --- a/taskflow/engines/action_engine/runtime.py +++ b/taskflow/engines/action_engine/runtime.py @@ -18,6 +18,7 @@ from taskflow.engines.action_engine import analyzer as ca from taskflow.engines.action_engine import executor as ex from taskflow.engines.action_engine import retry_action as ra from taskflow.engines.action_engine import runner as ru +from taskflow.engines.action_engine import scopes as sc from taskflow.engines.action_engine import task_action as ta from taskflow import exceptions as excp from taskflow import retry as retry_atom @@ -66,12 +67,18 @@ class Runtime(object): @misc.cachedproperty def retry_action(self): - return ra.RetryAction(self.storage, self._task_notifier) + return ra.RetryAction(self._storage, self._task_notifier, + lambda atom: sc.ScopeWalker(self.compilation, + atom, + names_only=True)) @misc.cachedproperty def task_action(self): - return ta.TaskAction(self.storage, self._task_executor, - self._task_notifier) + return ta.TaskAction(self._storage, self._task_executor, + self._task_notifier, + lambda atom: sc.ScopeWalker(self.compilation, + atom, + names_only=True)) def reset_nodes(self, nodes, state=st.PENDING, intention=st.EXECUTE): for node in nodes: @@ -81,7 +88,7 @@ class Runtime(object): elif isinstance(node, retry_atom.Retry): self.retry_action.change_state(node, state) else: - raise TypeError("Unknown how to reset node %s, %s" + raise TypeError("Unknown how to reset atom '%s' (%s)" % (node, type(node))) if intention: self.storage.set_atom_intention(node.name, intention) @@ -209,7 +216,7 @@ class Scheduler(object): elif isinstance(node, retry_atom.Retry): return self._schedule_retry(node) else: - raise TypeError("Unknown how to schedule node %s, %s" + raise TypeError("Unknown how to schedule atom '%s' (%s)" % (node, type(node))) def _schedule_retry(self, retry): diff --git a/taskflow/engines/action_engine/scopes.py b/taskflow/engines/action_engine/scopes.py new file mode 100644 index 00000000..f1dd49d1 --- /dev/null +++ b/taskflow/engines/action_engine/scopes.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging + +from taskflow import atom as atom_type +from taskflow import flow as flow_type + +LOG = logging.getLogger(__name__) + + +def _extract_atoms(node, idx=-1): + # Always go left to right, since right to left is the pattern order + # and we want to go backwards and not forwards through that ordering... + if idx == -1: + children_iter = node.reverse_iter() + else: + children_iter = reversed(node[0:idx]) + atoms = [] + for child in children_iter: + if isinstance(child.item, flow_type.Flow): + atoms.extend(_extract_atoms(child)) + elif isinstance(child.item, atom_type.Atom): + atoms.append(child.item) + else: + raise TypeError( + "Unknown extraction item '%s' (%s)" % (child.item, + type(child.item))) + return atoms + + +class ScopeWalker(object): + """Walks through the scopes of a atom using a engines compilation. + + This will walk the visible scopes that are accessible for the given + atom, which can be used by some external entity in some meaningful way, + for example to find dependent values... + """ + + def __init__(self, compilation, atom, names_only=False): + self._node = compilation.hierarchy.find(atom) + if self._node is None: + raise ValueError("Unable to find atom '%s' in compilation" + " hierarchy" % atom) + self._atom = atom + self._graph = compilation.execution_graph + self._names_only = names_only + + def __iter__(self): + """Iterates over the visible scopes. + + How this works is the following: + + We find all the possible predecessors of the given atom, this is useful + since we know they occurred before this atom but it doesn't tell us + the corresponding scope *level* that each predecessor was created in, + so we need to find this information. + + For that information we consult the location of the atom ``Y`` in the + node hierarchy. We lookup in a reverse order the parent ``X`` of ``Y`` + and traverse backwards from the index in the parent where ``Y`` + occurred, all children in ``X`` that we encounter in this backwards + search (if a child is a flow itself, its atom contents will be + expanded) will be assumed to be at the same scope. This is then a + *potential* single scope, to make an *actual* scope we remove the items + from the *potential* scope that are not predecessors of ``Y`` to form + the *actual* scope. + + Then for additional scopes we continue up the tree, by finding the + parent of ``X`` (lets call it ``Z``) and perform the same operation, + going through the children in a reverse manner from the index in + parent ``Z`` where ``X`` was located. This forms another *potential* + scope which we provide back as an *actual* scope after reducing the + potential set by the predecessors of ``Y``. We then repeat this process + until we no longer have any parent nodes (aka have reached the top of + the tree) or we run out of predecessors. + """ + predecessors = set(self._graph.bfs_predecessors_iter(self._atom)) + last = self._node + for parent in self._node.path_iter(include_self=False): + if not predecessors: + break + last_idx = parent.index(last.item) + visible = [] + for a in _extract_atoms(parent, idx=last_idx): + if a in predecessors: + predecessors.remove(a) + if not self._names_only: + visible.append(a) + else: + visible.append(a.name) + if LOG.isEnabledFor(logging.DEBUG): + if not self._names_only: + visible_names = [a.name for a in visible] + else: + visible_names = visible + # TODO(harlowja): we should likely use a created TRACE level + # for this kind of *very* verbose information; otherwise the + # cinder and other folks are going to complain that there + # debug logs are full of not so useful information (it is + # useful to taskflow debugging...). + LOG.debug("Scope visible to '%s' (limited by parent '%s' index" + " < %s) is: %s", self._atom, parent.item.name, + last_idx, visible_names) + yield visible + last = parent diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index a07ded79..3503df7c 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -26,10 +26,11 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE) class TaskAction(object): - def __init__(self, storage, task_executor, notifier): + def __init__(self, storage, task_executor, notifier, walker_factory): self._storage = storage self._task_executor = task_executor self._notifier = notifier + self._walker_factory = walker_factory def _is_identity_transition(self, state, task, progress): if state in SAVE_RESULT_STATES: @@ -81,8 +82,10 @@ class TaskAction(object): def schedule_execution(self, task): self.change_state(task, states.RUNNING, progress=0.0) + scope_walker = self._walker_factory(task) kwargs = self._storage.fetch_mapped_args(task.rebind, - atom_name=task.name) + atom_name=task.name, + scope_walker=scope_walker) task_uuid = self._storage.get_atom_uuid(task.name) return self._task_executor.execute_task(task, task_uuid, kwargs, self._on_update_progress) @@ -96,8 +99,10 @@ class TaskAction(object): def schedule_reversion(self, task): self.change_state(task, states.REVERTING, progress=0.0) + scope_walker = self._walker_factory(task) kwargs = self._storage.fetch_mapped_args(task.rebind, - atom_name=task.name) + atom_name=task.name, + scope_walker=scope_walker) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 9717b211..658e051c 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -161,7 +161,7 @@ class Flow(flow.Flow): return self._get_subgraph().number_of_nodes() def __iter__(self): - for n in self._get_subgraph().nodes_iter(): + for n in self._get_subgraph().topological_sort(): yield n def iter_links(self): diff --git a/taskflow/storage.py b/taskflow/storage.py index 31a8868f..bcc2b157 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -31,6 +31,78 @@ from taskflow.utils import reflection LOG = logging.getLogger(__name__) STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE) +# TODO(harlowja): do this better (via a singleton or something else...) +_TRANSIENT_PROVIDER = object() + +# NOTE(harlowja): Perhaps the container is a dictionary-like object and that +# key does not exist (key error), or the container is a tuple/list and a +# non-numeric key is being requested (index error), or there was no container +# and an attempt to index into none/other unsubscriptable type is being +# requested (type error). +# +# Overall this (along with the item_from* functions) try to handle the vast +# majority of wrong indexing operations on the wrong/invalid types so that we +# can fail extraction during lookup or emit warning on result reception... +_EXTRACTION_EXCEPTIONS = (IndexError, KeyError, ValueError, TypeError) + + +class _Provider(object): + """A named symbol provider that produces a output at the given index.""" + + def __init__(self, name, index): + self.name = name + self.index = index + + def __repr__(self): + # TODO(harlowja): clean this up... + if self.name is _TRANSIENT_PROVIDER: + base = " index. If index is None, the whole result will have this name; else, only part of it, result[index]. """ - if not mapping: - return - self._result_mappings[atom_name] = mapping - for name, index in six.iteritems(mapping): - entries = self._reverse_mapping.setdefault(name, []) + provider_mapping = self._result_mappings.setdefault(provider_name, {}) + if mapping: + provider_mapping.update(mapping) + # Ensure the reverse mapping/index is updated (for faster lookups). + for name, index in six.iteritems(provider_mapping): + entries = self._reverse_mapping.setdefault(name, []) + provider = _Provider(provider_name, index) + if provider not in entries: + entries.append(provider) - # NOTE(imelnikov): We support setting same result mapping for - # the same atom twice (e.g when we are injecting 'a' and then - # injecting 'a' again), so we should not log warning below in - # that case and we should have only one item for each pair - # (atom_name, index) in entries. It should be put to the end of - # entries list because order matters on fetching. - try: - entries.remove((atom_name, index)) - except ValueError: - pass - - entries.append((atom_name, index)) - if len(entries) > 1: - LOG.warning("Multiple provider mappings being created for %r", - name) - - def fetch(self, name): - """Fetch a named atoms result.""" + def fetch(self, name, many_handler=None): + """Fetch a named result.""" + # By default we just return the first of many (unless provided + # a different callback that can translate many results into something + # more meaningful). + if many_handler is None: + many_handler = lambda values: values[0] with self._lock.read_lock(): try: - indexes = self._reverse_mapping[name] + providers = self._reverse_mapping[name] except KeyError: - raise exceptions.NotFound("Name %r is not mapped" % name) - # Return the first one that is found. - for (atom_name, index) in reversed(indexes): - if not atom_name: - results = self._transients + raise exceptions.NotFound("Name %r is not mapped as a" + " produced output by any" + " providers" % name) + values = [] + for provider in providers: + if provider.name is _TRANSIENT_PROVIDER: + values.append(_item_from_single(provider, + self._transients, name)) else: - results = self._get(atom_name, only_last=True) - try: - return misc.item_from(results, index, name) - except exceptions.NotFound: - pass - raise exceptions.NotFound("Unable to find result %r" % name) + try: + container = self._get(provider.name, only_last=True) + except exceptions.NotFound: + pass + else: + values.append(_item_from_single(provider, + container, name)) + if not values: + raise exceptions.NotFound("Unable to find result %r," + " searched %s" % (name, providers)) + else: + return many_handler(values) def fetch_all(self): - """Fetch all named atom results known so far. + """Fetch all named results known so far. - Should be used for debugging and testing purposes mostly. + NOTE(harlowja): should be used for debugging and testing purposes. """ + def many_handler(values): + if len(values) > 1: + return values + return values[0] with self._lock.read_lock(): results = {} - for name in self._reverse_mapping: + for name in six.iterkeys(self._reverse_mapping): try: - results[name] = self.fetch(name) + results[name] = self.fetch(name, many_handler=many_handler) except exceptions.NotFound: pass return results - def fetch_mapped_args(self, args_mapping, atom_name=None): - """Fetch arguments for an atom using an atoms arguments mapping.""" + def fetch_mapped_args(self, args_mapping, + atom_name=None, scope_walker=None): + """Fetch arguments for an atom using an atoms argument mapping.""" + + def _get_results(looking_for, provider): + """Gets the results saved for a given provider.""" + try: + return self._get(provider.name, only_last=True) + except exceptions.NotFound as e: + raise exceptions.NotFound( + "Expected to be able to find output %r produced" + " by %s but was unable to get at that providers" + " results" % (looking_for, provider), e) + + def _locate_providers(looking_for, possible_providers): + """Finds the accessible providers.""" + default_providers = [] + for p in possible_providers: + if p.name is _TRANSIENT_PROVIDER: + default_providers.append((p, self._transients)) + if p.name == self.injector_name: + default_providers.append((p, _get_results(looking_for, p))) + if default_providers: + return default_providers + if scope_walker is not None: + scope_iter = iter(scope_walker) + else: + scope_iter = iter([]) + for atom_names in scope_iter: + if not atom_names: + continue + providers = [] + for p in possible_providers: + if p.name in atom_names: + providers.append((p, _get_results(looking_for, p))) + if providers: + return providers + return [] + with self._lock.read_lock(): - injected_args = {} + if atom_name and atom_name not in self._atom_name_to_uuid: + raise exceptions.NotFound("Unknown atom name: %s" % atom_name) + if not args_mapping: + return {} + # The order of lookup is the following: + # + # 1. Injected atom specific arguments. + # 2. Transient injected arguments. + # 3. Non-transient injected arguments. + # 4. First scope visited group that produces the named result. + # a). The first of that group that actually provided the name + # result is selected (if group size is greater than one). + # + # Otherwise: blowup! (this will also happen if reading or + # extracting an expected result fails, since it is better to fail + # on lookup then provide invalid data from the wrong provider) if atom_name: injected_args = self._injected_args.get(atom_name, {}) + else: + injected_args = {} mapped_args = {} - for key, name in six.iteritems(args_mapping): + for (bound_name, name) in six.iteritems(args_mapping): + # TODO(harlowja): This logging information may be to verbose + # even for DEBUG mode, let's see if we can maybe in the future + # add a TRACE mode or something else if people complain... + if LOG.isEnabledFor(logging.DEBUG): + if atom_name: + LOG.debug("Looking for %r <= %r for atom named: %s", + bound_name, name, atom_name) + else: + LOG.debug("Looking for %r <= %r", bound_name, name) if name in injected_args: - mapped_args[key] = injected_args[name] + value = injected_args[name] + mapped_args[bound_name] = value + LOG.debug("Matched %r <= %r to %r (from injected values)", + bound_name, name, value) else: - mapped_args[key] = self.fetch(name) + try: + possible_providers = self._reverse_mapping[name] + except KeyError: + raise exceptions.NotFound("Name %r is not mapped as a" + " produced output by any" + " providers" % name) + # Reduce the possible providers to one that are allowed. + providers = _locate_providers(name, possible_providers) + if not providers: + raise exceptions.NotFound( + "Mapped argument %r <= %r was not produced" + " by any accessible provider (%s possible" + " providers were scanned)" + % (bound_name, name, len(possible_providers))) + provider, value = _item_from_first_of(providers, name) + mapped_args[bound_name] = value + LOG.debug("Matched %r <= %r to %r (from %s)", + bound_name, name, value, provider) return mapped_args def set_flow_state(self, state): diff --git a/taskflow/tests/unit/action_engine/test_compile.py b/taskflow/tests/unit/action_engine/test_compile.py index 7207468e..63b3c0b0 100644 --- a/taskflow/tests/unit/action_engine/test_compile.py +++ b/taskflow/tests/unit/action_engine/test_compile.py @@ -27,21 +27,25 @@ from taskflow.tests import utils as test_utils class PatternCompileTest(test.TestCase): def test_task(self): task = test_utils.DummyTask(name='a') - compilation = compiler.PatternCompiler().compile(task) + compilation = compiler.PatternCompiler(task).compile() g = compilation.execution_graph self.assertEqual(list(g.nodes()), [task]) self.assertEqual(list(g.edges()), []) def test_retry(self): r = retry.AlwaysRevert('r1') - msg_regex = "^Retry controller: .* must only be used .*" + msg_regex = "^Retry controller .* must only be used .*" self.assertRaisesRegexp(TypeError, msg_regex, - compiler.PatternCompiler().compile, r) + compiler.PatternCompiler(r).compile) def test_wrong_object(self): - msg_regex = '^Unknown type requested to flatten' + msg_regex = '^Unknown item .* requested to flatten' self.assertRaisesRegexp(TypeError, msg_regex, - compiler.PatternCompiler().compile, 42) + compiler.PatternCompiler(42).compile) + + def test_empty(self): + flo = lf.Flow("test") + self.assertRaises(exc.Empty, compiler.PatternCompiler(flo).compile) def test_linear(self): a, b, c, d = test_utils.make_many(4) @@ -51,7 +55,7 @@ class PatternCompileTest(test.TestCase): sflo.add(d) flo.add(sflo) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) @@ -69,13 +73,13 @@ class PatternCompileTest(test.TestCase): flo.add(a, b, c) flo.add(flo) self.assertRaises(ValueError, - compiler.PatternCompiler().compile, flo) + compiler.PatternCompiler(flo).compile) def test_unordered(self): a, b, c, d = test_utils.make_many(4) flo = uf.Flow("test") flo.add(a, b, c, d) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) self.assertEqual(0, g.number_of_edges()) @@ -92,7 +96,7 @@ class PatternCompileTest(test.TestCase): flo2.add(c, d) flo.add(flo2) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) @@ -116,7 +120,7 @@ class PatternCompileTest(test.TestCase): flo2.add(c, d) flo.add(flo2) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) for n in [a, b]: @@ -138,7 +142,7 @@ class PatternCompileTest(test.TestCase): uf.Flow('ut').add(b, c), d) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) self.assertItemsEqual(g.edges(), [ @@ -153,7 +157,7 @@ class PatternCompileTest(test.TestCase): flo = gf.Flow("test") flo.add(a, b, c, d) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) self.assertEqual(0, g.number_of_edges()) @@ -167,7 +171,7 @@ class PatternCompileTest(test.TestCase): flo2.add(e, f, g) flo.add(flo2) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() graph = compilation.execution_graph self.assertEqual(7, len(graph)) self.assertItemsEqual(graph.edges(data=True), [ @@ -184,7 +188,7 @@ class PatternCompileTest(test.TestCase): flo2.add(e, f, g) flo.add(flo2) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(7, len(g)) self.assertEqual(0, g.number_of_edges()) @@ -197,7 +201,7 @@ class PatternCompileTest(test.TestCase): flo.link(b, c) flo.link(c, d) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) self.assertItemsEqual(g.edges(data=True), [ @@ -213,7 +217,7 @@ class PatternCompileTest(test.TestCase): b = test_utils.ProvidesRequiresTask('b', provides=[], requires=['x']) flo = gf.Flow("test").add(a, b) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(2, len(g)) self.assertItemsEqual(g.edges(data=True), [ @@ -231,7 +235,7 @@ class PatternCompileTest(test.TestCase): lf.Flow("test2").add(b, c) ) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(3, len(g)) self.assertItemsEqual(g.edges(data=True), [ @@ -250,7 +254,7 @@ class PatternCompileTest(test.TestCase): lf.Flow("test2").add(b, c) ) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(3, len(g)) self.assertItemsEqual(g.edges(data=True), [ @@ -267,7 +271,7 @@ class PatternCompileTest(test.TestCase): ) self.assertRaisesRegexp(exc.Duplicate, '^Atoms with duplicate names', - compiler.PatternCompiler().compile, flo) + compiler.PatternCompiler(flo).compile) def test_checks_for_dups_globally(self): flo = gf.Flow("test").add( @@ -275,25 +279,25 @@ class PatternCompileTest(test.TestCase): gf.Flow("int2").add(test_utils.DummyTask(name="a"))) self.assertRaisesRegexp(exc.Duplicate, '^Atoms with duplicate names', - compiler.PatternCompiler().compile, flo) + compiler.PatternCompiler(flo).compile) def test_retry_in_linear_flow(self): flo = lf.Flow("test", retry.AlwaysRevert("c")) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(1, len(g)) self.assertEqual(0, g.number_of_edges()) def test_retry_in_unordered_flow(self): flo = uf.Flow("test", retry.AlwaysRevert("c")) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(1, len(g)) self.assertEqual(0, g.number_of_edges()) def test_retry_in_graph_flow(self): flo = gf.Flow("test", retry.AlwaysRevert("c")) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(1, len(g)) self.assertEqual(0, g.number_of_edges()) @@ -302,7 +306,7 @@ class PatternCompileTest(test.TestCase): c1 = retry.AlwaysRevert("c1") c2 = retry.AlwaysRevert("c2") flo = lf.Flow("test", c1).add(lf.Flow("test2", c2)) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(2, len(g)) @@ -317,7 +321,7 @@ class PatternCompileTest(test.TestCase): c = retry.AlwaysRevert("c") a, b = test_utils.make_many(2) flo = lf.Flow("test", c).add(a, b) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(3, len(g)) @@ -335,7 +339,7 @@ class PatternCompileTest(test.TestCase): c = retry.AlwaysRevert("c") a, b = test_utils.make_many(2) flo = uf.Flow("test", c).add(a, b) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(3, len(g)) @@ -353,7 +357,7 @@ class PatternCompileTest(test.TestCase): r = retry.AlwaysRevert("cp") a, b, c = test_utils.make_many(3) flo = gf.Flow("test", r).add(a, b, c).link(b, c) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(4, len(g)) @@ -377,7 +381,7 @@ class PatternCompileTest(test.TestCase): a, lf.Flow("test", c2).add(b, c), d) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(6, len(g)) @@ -402,7 +406,7 @@ class PatternCompileTest(test.TestCase): a, lf.Flow("test").add(b, c), d) - compilation = compiler.PatternCompiler().compile(flo) + compilation = compiler.PatternCompiler(flo).compile() g = compilation.execution_graph self.assertEqual(5, len(g)) diff --git a/taskflow/tests/unit/action_engine/test_runner.py b/taskflow/tests/unit/action_engine/test_runner.py index 2e18f6b6..82440fc5 100644 --- a/taskflow/tests/unit/action_engine/test_runner.py +++ b/taskflow/tests/unit/action_engine/test_runner.py @@ -33,7 +33,7 @@ from taskflow.utils import persistence_utils as pu class _RunnerTestMixin(object): def _make_runtime(self, flow, initial_state=None): - compilation = compiler.PatternCompiler().compile(flow) + compilation = compiler.PatternCompiler(flow).compile() flow_detail = pu.create_flow_detail(flow) store = storage.SingleThreadedStorage(flow_detail) # This ensures the tasks exist in storage... diff --git a/taskflow/tests/unit/test_action_engine_scoping.py b/taskflow/tests/unit/test_action_engine_scoping.py new file mode 100644 index 00000000..e2de763f --- /dev/null +++ b/taskflow/tests/unit/test_action_engine_scoping.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from taskflow.engines.action_engine import compiler +from taskflow.engines.action_engine import scopes as sc +from taskflow.patterns import graph_flow as gf +from taskflow.patterns import linear_flow as lf +from taskflow.patterns import unordered_flow as uf +from taskflow import test +from taskflow.tests import utils as test_utils + + +def _get_scopes(compilation, atom, names_only=True): + walker = sc.ScopeWalker(compilation, atom, names_only=names_only) + return list(iter(walker)) + + +class LinearScopingTest(test.TestCase): + def test_unknown(self): + r = lf.Flow("root") + r_1 = test_utils.TaskOneReturn("root.1") + r.add(r_1) + + r_2 = test_utils.TaskOneReturn("root.2") + c = compiler.PatternCompiler(r).compile() + self.assertRaises(ValueError, _get_scopes, c, r_2) + + def test_empty(self): + r = lf.Flow("root") + r_1 = test_utils.TaskOneReturn("root.1") + r.add(r_1) + + c = compiler.PatternCompiler(r).compile() + self.assertIn(r_1, c.execution_graph) + self.assertIsNotNone(c.hierarchy.find(r_1)) + + walker = sc.ScopeWalker(c, r_1) + scopes = list(walker) + self.assertEqual([], scopes) + + def test_single_prior_linear(self): + r = lf.Flow("root") + r_1 = test_utils.TaskOneReturn("root.1") + r_2 = test_utils.TaskOneReturn("root.2") + r.add(r_1, r_2) + + c = compiler.PatternCompiler(r).compile() + for a in r: + self.assertIn(a, c.execution_graph) + self.assertIsNotNone(c.hierarchy.find(a)) + + self.assertEqual([], _get_scopes(c, r_1)) + self.assertEqual([['root.1']], _get_scopes(c, r_2)) + + def test_nested_prior_linear(self): + r = lf.Flow("root") + r.add(test_utils.TaskOneReturn("root.1"), + test_utils.TaskOneReturn("root.2")) + sub_r = lf.Flow("subroot") + sub_r_1 = test_utils.TaskOneReturn("subroot.1") + sub_r.add(sub_r_1) + r.add(sub_r) + + c = compiler.PatternCompiler(r).compile() + self.assertEqual([[], ['root.2', 'root.1']], _get_scopes(c, sub_r_1)) + + def test_nested_prior_linear_begin_middle_end(self): + r = lf.Flow("root") + begin_r = test_utils.TaskOneReturn("root.1") + r.add(begin_r, test_utils.TaskOneReturn("root.2")) + middle_r = test_utils.TaskOneReturn("root.3") + r.add(middle_r) + sub_r = lf.Flow("subroot") + sub_r.add(test_utils.TaskOneReturn("subroot.1"), + test_utils.TaskOneReturn("subroot.2")) + r.add(sub_r) + end_r = test_utils.TaskOneReturn("root.4") + r.add(end_r) + + c = compiler.PatternCompiler(r).compile() + + self.assertEqual([], _get_scopes(c, begin_r)) + self.assertEqual([['root.2', 'root.1']], _get_scopes(c, middle_r)) + self.assertEqual([['subroot.2', 'subroot.1', 'root.3', 'root.2', + 'root.1']], _get_scopes(c, end_r)) + + +class GraphScopingTest(test.TestCase): + def test_dependent(self): + r = gf.Flow("root") + + customer = test_utils.ProvidesRequiresTask("customer", + provides=['dog'], + requires=[]) + washer = test_utils.ProvidesRequiresTask("washer", + requires=['dog'], + provides=['wash']) + dryer = test_utils.ProvidesRequiresTask("dryer", + requires=['dog', 'wash'], + provides=['dry_dog']) + shaved = test_utils.ProvidesRequiresTask("shaver", + requires=['dry_dog'], + provides=['shaved_dog']) + happy_customer = test_utils.ProvidesRequiresTask( + "happy_customer", requires=['shaved_dog'], provides=['happiness']) + + r.add(customer, washer, dryer, shaved, happy_customer) + + c = compiler.PatternCompiler(r).compile() + + self.assertEqual([], _get_scopes(c, customer)) + self.assertEqual([['washer', 'customer']], _get_scopes(c, dryer)) + self.assertEqual([['shaver', 'dryer', 'washer', 'customer']], + _get_scopes(c, happy_customer)) + + def test_no_visible(self): + r = gf.Flow("root") + atoms = [] + for i in range(0, 10): + atoms.append(test_utils.TaskOneReturn("root.%s" % i)) + r.add(*atoms) + + c = compiler.PatternCompiler(r).compile() + for a in atoms: + self.assertEqual([], _get_scopes(c, a)) + + def test_nested(self): + r = gf.Flow("root") + + r_1 = test_utils.TaskOneReturn("root.1") + r_2 = test_utils.TaskOneReturn("root.2") + r.add(r_1, r_2) + r.link(r_1, r_2) + + subroot = gf.Flow("subroot") + subroot_r_1 = test_utils.TaskOneReturn("subroot.1") + subroot_r_2 = test_utils.TaskOneReturn("subroot.2") + subroot.add(subroot_r_1, subroot_r_2) + subroot.link(subroot_r_1, subroot_r_2) + + r.add(subroot) + r_3 = test_utils.TaskOneReturn("root.3") + r.add(r_3) + r.link(r_2, r_3) + + c = compiler.PatternCompiler(r).compile() + self.assertEqual([], _get_scopes(c, r_1)) + self.assertEqual([['root.1']], _get_scopes(c, r_2)) + self.assertEqual([['root.2', 'root.1']], _get_scopes(c, r_3)) + + self.assertEqual([], _get_scopes(c, subroot_r_1)) + self.assertEqual([['subroot.1']], _get_scopes(c, subroot_r_2)) + + +class UnorderedScopingTest(test.TestCase): + def test_no_visible(self): + r = uf.Flow("root") + atoms = [] + for i in range(0, 10): + atoms.append(test_utils.TaskOneReturn("root.%s" % i)) + r.add(*atoms) + c = compiler.PatternCompiler(r).compile() + for a in atoms: + self.assertEqual([], _get_scopes(c, a)) + + +class MixedPatternScopingTest(test.TestCase): + def test_graph_linear_scope(self): + r = gf.Flow("root") + r_1 = test_utils.TaskOneReturn("root.1") + r_2 = test_utils.TaskOneReturn("root.2") + r.add(r_1, r_2) + r.link(r_1, r_2) + + s = lf.Flow("subroot") + s_1 = test_utils.TaskOneReturn("subroot.1") + s_2 = test_utils.TaskOneReturn("subroot.2") + s.add(s_1, s_2) + r.add(s) + + t = gf.Flow("subroot2") + t_1 = test_utils.TaskOneReturn("subroot2.1") + t_2 = test_utils.TaskOneReturn("subroot2.2") + t.add(t_1, t_2) + t.link(t_1, t_2) + r.add(t) + r.link(s, t) + + c = compiler.PatternCompiler(r).compile() + self.assertEqual([], _get_scopes(c, r_1)) + self.assertEqual([['root.1']], _get_scopes(c, r_2)) + self.assertEqual([], _get_scopes(c, s_1)) + self.assertEqual([['subroot.1']], _get_scopes(c, s_2)) + self.assertEqual([[], ['subroot.2', 'subroot.1']], + _get_scopes(c, t_1)) + self.assertEqual([["subroot2.1"], ['subroot.2', 'subroot.1']], + _get_scopes(c, t_2)) + + def test_linear_unordered_scope(self): + r = lf.Flow("root") + r_1 = test_utils.TaskOneReturn("root.1") + r_2 = test_utils.TaskOneReturn("root.2") + r.add(r_1, r_2) + + u = uf.Flow("subroot") + atoms = [] + for i in range(0, 5): + atoms.append(test_utils.TaskOneReturn("subroot.%s" % i)) + u.add(*atoms) + r.add(u) + + r_3 = test_utils.TaskOneReturn("root.3") + r.add(r_3) + + c = compiler.PatternCompiler(r).compile() + + self.assertEqual([], _get_scopes(c, r_1)) + self.assertEqual([['root.1']], _get_scopes(c, r_2)) + for a in atoms: + self.assertEqual([[], ['root.2', 'root.1']], _get_scopes(c, a)) + + scope = _get_scopes(c, r_3) + self.assertEqual(1, len(scope)) + first_root = 0 + for i, n in enumerate(scope[0]): + if n.startswith('root.'): + first_root = i + break + first_subroot = 0 + for i, n in enumerate(scope[0]): + if n.startswith('subroot.'): + first_subroot = i + break + self.assertGreater(first_subroot, first_root) + self.assertEqual(scope[0][-2:], ['root.2', 'root.1']) diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index 94d73012..3c74bebe 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -453,23 +453,6 @@ class StorageTestMixin(object): self.assertRaisesRegexp(exceptions.NotFound, '^Unable to find result', s.fetch, 'b') - @mock.patch.object(storage.LOG, 'warning') - def test_multiple_providers_are_checked(self, mocked_warning): - s = self._get_storage() - s.ensure_task('my task', result_mapping={'result': 'key'}) - self.assertEqual(mocked_warning.mock_calls, []) - s.ensure_task('my other task', result_mapping={'result': 'key'}) - mocked_warning.assert_called_once_with( - mock.ANY, 'result') - - @mock.patch.object(storage.LOG, 'warning') - def test_multiple_providers_with_inject_are_checked(self, mocked_warning): - s = self._get_storage() - s.inject({'result': 'DONE'}) - self.assertEqual(mocked_warning.mock_calls, []) - s.ensure_task('my other task', result_mapping={'result': 'key'}) - mocked_warning.assert_called_once_with(mock.ANY, 'result') - def test_ensure_retry(self): s = self._get_storage() s.ensure_retry('my retry') diff --git a/taskflow/types/tree.py b/taskflow/types/tree.py index 759e0fda..1c520966 100644 --- a/taskflow/types/tree.py +++ b/taskflow/types/tree.py @@ -167,6 +167,11 @@ class Node(object): for c in self._children: yield c + def reverse_iter(self): + """Iterates over the direct children of this node (left->right).""" + for c in reversed(self._children): + yield c + def index(self, item): """Finds the child index of a given item, searches in added order.""" index_at = None diff --git a/taskflow/utils/misc.py b/taskflow/utils/misc.py index d05e68db..fd3fc5e7 100644 --- a/taskflow/utils/misc.py +++ b/taskflow/utils/misc.py @@ -268,24 +268,6 @@ def sequence_minus(seq1, seq2): return result -def item_from(container, index, name=None): - """Attempts to fetch a index/key from a given container.""" - if index is None: - return container - try: - return container[index] - except (IndexError, KeyError, ValueError, TypeError): - # NOTE(harlowja): Perhaps the container is a dictionary-like object - # and that key does not exist (key error), or the container is a - # tuple/list and a non-numeric key is being requested (index error), - # or there was no container and an attempt to index into none/other - # unsubscriptable type is being requested (type error). - if name is None: - name = index - raise exc.NotFound("Unable to find %r in container %s" - % (name, container)) - - def get_duplicate_keys(iterable, key=None): if key is not None: iterable = compat_map(key, iterable) @@ -410,8 +392,8 @@ def ensure_tree(path): """ try: os.makedirs(path) - except OSError as exc: - if exc.errno == errno.EEXIST: + except OSError as e: + if e.errno == errno.EEXIST: if not os.path.isdir(path): raise else: