Revamp the symbol lookup mechanism
To complement the future changes in patterns we also want to allow the execution of patterns to be affected in a similar manner so that symbol lookup is no longer as confined as it was. This change adds in the following: - Symbol lookup by walking through an atoms contained scope/s. - Better error messaging when symbols are not found. - Adjusted & new tests (existing ones work). - Better logging of the symbol lookup mechanism (helpful during debugging, although it is very verbose...) Part of blueprint taskflow-improved-scoping Change-Id: Id921a4abd9bf2b7b5c5a762337f8e90e8f1fe194
This commit is contained in:
@@ -15,53 +15,34 @@
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import flow
|
||||
from taskflow import retry
|
||||
from taskflow import task
|
||||
from taskflow.types import graph as gr
|
||||
from taskflow.types import tree as tr
|
||||
from taskflow.utils import lock_utils
|
||||
from taskflow.utils import misc
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Compilation(object):
|
||||
"""The result of a compilers compile() is this *immutable* object.
|
||||
"""The result of a compilers compile() is this *immutable* object."""
|
||||
|
||||
For now it is just a execution graph but in the future it will grow to
|
||||
include more methods & properties that help the various runtime units
|
||||
execute in a more optimal & featureful manner.
|
||||
"""
|
||||
def __init__(self, execution_graph):
|
||||
def __init__(self, execution_graph, hierarchy):
|
||||
self._execution_graph = execution_graph
|
||||
self._hierarchy = hierarchy
|
||||
|
||||
@property
|
||||
def execution_graph(self):
|
||||
return self._execution_graph
|
||||
|
||||
|
||||
class PatternCompiler(object):
|
||||
"""Compiles patterns & atoms into a compilation unit.
|
||||
|
||||
NOTE(harlowja): during this pattern translation process any nested flows
|
||||
will be converted into there equivalent subgraphs. This currently implies
|
||||
that contained atoms in those nested flows, post-translation will no longer
|
||||
be associated with there previously containing flow but instead will lose
|
||||
this identity and what will remain is the logical constraints that there
|
||||
contained flow mandated. In the future this may be changed so that this
|
||||
association is not lost via the compilation process (since it can be
|
||||
useful to retain this relationship).
|
||||
"""
|
||||
def compile(self, root):
|
||||
graph = _Flattener(root).flatten()
|
||||
if graph.number_of_nodes() == 0:
|
||||
# Try to get a name attribute, otherwise just use the object
|
||||
# string representation directly if that attribute does not exist.
|
||||
name = getattr(root, 'name', root)
|
||||
raise exc.Empty("Root container '%s' (%s) is empty."
|
||||
% (name, type(root)))
|
||||
return Compilation(graph)
|
||||
@property
|
||||
def hierarchy(self):
|
||||
return self._hierarchy
|
||||
|
||||
|
||||
_RETRY_EDGE_DATA = {
|
||||
@@ -69,14 +50,15 @@ _RETRY_EDGE_DATA = {
|
||||
}
|
||||
|
||||
|
||||
class _Flattener(object):
|
||||
"""Flattens a root item (task/flow) into a execution graph."""
|
||||
class PatternCompiler(object):
|
||||
"""Compiles a pattern (or task) into a compilation unit."""
|
||||
|
||||
def __init__(self, root, freeze=True):
|
||||
self._root = root
|
||||
self._graph = None
|
||||
self._history = set()
|
||||
self._freeze = bool(freeze)
|
||||
self._freeze = freeze
|
||||
self._lock = threading.Lock()
|
||||
self._compilation = None
|
||||
|
||||
def _add_new_edges(self, graph, nodes_from, nodes_to, edge_attrs):
|
||||
"""Adds new edges from nodes to other nodes in the specified graph.
|
||||
@@ -93,72 +75,74 @@ class _Flattener(object):
|
||||
# if it's later modified that the same copy isn't modified.
|
||||
graph.add_edge(u, v, attr_dict=edge_attrs.copy())
|
||||
|
||||
def _flatten(self, item):
|
||||
functor = self._find_flattener(item)
|
||||
if not functor:
|
||||
raise TypeError("Unknown type requested to flatten: %s (%s)"
|
||||
% (item, type(item)))
|
||||
def _flatten(self, item, parent):
|
||||
functor = self._find_flattener(item, parent)
|
||||
self._pre_item_flatten(item)
|
||||
graph = functor(item)
|
||||
self._post_item_flatten(item, graph)
|
||||
return graph
|
||||
graph, node = functor(item, parent)
|
||||
self._post_item_flatten(item, graph, node)
|
||||
return graph, node
|
||||
|
||||
def _find_flattener(self, item):
|
||||
def _find_flattener(self, item, parent):
|
||||
"""Locates the flattening function to use to flatten the given item."""
|
||||
if isinstance(item, flow.Flow):
|
||||
return self._flatten_flow
|
||||
elif isinstance(item, task.BaseTask):
|
||||
return self._flatten_task
|
||||
elif isinstance(item, retry.Retry):
|
||||
if len(self._history) == 1:
|
||||
raise TypeError("Retry controller: %s (%s) must only be used"
|
||||
if parent is None:
|
||||
raise TypeError("Retry controller '%s' (%s) must only be used"
|
||||
" as a flow constructor parameter and not as a"
|
||||
" root component" % (item, type(item)))
|
||||
else:
|
||||
# TODO(harlowja): we should raise this type error earlier
|
||||
# instead of later since we should do this same check on add()
|
||||
# calls, this makes the error more visible (instead of waiting
|
||||
# until compile time).
|
||||
raise TypeError("Retry controller: %s (%s) must only be used"
|
||||
raise TypeError("Retry controller '%s' (%s) must only be used"
|
||||
" as a flow constructor parameter and not as a"
|
||||
" flow added component" % (item, type(item)))
|
||||
else:
|
||||
return None
|
||||
raise TypeError("Unknown item '%s' (%s) requested to flatten"
|
||||
% (item, type(item)))
|
||||
|
||||
def _connect_retry(self, retry, graph):
|
||||
graph.add_node(retry)
|
||||
|
||||
# All graph nodes that have no predecessors should depend on its retry
|
||||
nodes_to = [n for n in graph.no_predecessors_iter() if n != retry]
|
||||
# All nodes that have no predecessors should depend on this retry.
|
||||
nodes_to = [n for n in graph.no_predecessors_iter() if n is not retry]
|
||||
self._add_new_edges(graph, [retry], nodes_to, _RETRY_EDGE_DATA)
|
||||
|
||||
# Add link to retry for each node of subgraph that hasn't
|
||||
# a parent retry
|
||||
# Add association for each node of graph that has no existing retry.
|
||||
for n in graph.nodes_iter():
|
||||
if n != retry and 'retry' not in graph.node[n]:
|
||||
if n is not retry and 'retry' not in graph.node[n]:
|
||||
graph.node[n]['retry'] = retry
|
||||
|
||||
def _flatten_task(self, task):
|
||||
def _flatten_task(self, task, parent):
|
||||
"""Flattens a individual task."""
|
||||
graph = gr.DiGraph(name=task.name)
|
||||
graph.add_node(task)
|
||||
return graph
|
||||
node = tr.Node(task)
|
||||
if parent is not None:
|
||||
parent.add(node)
|
||||
return graph, node
|
||||
|
||||
def _flatten_flow(self, flow):
|
||||
"""Flattens a graph flow."""
|
||||
def _flatten_flow(self, flow, parent):
|
||||
"""Flattens a flow."""
|
||||
graph = gr.DiGraph(name=flow.name)
|
||||
node = tr.Node(flow)
|
||||
if parent is not None:
|
||||
parent.add(node)
|
||||
if flow.retry is not None:
|
||||
node.add(tr.Node(flow.retry))
|
||||
|
||||
# Flatten all nodes into a single subgraph per node.
|
||||
subgraph_map = {}
|
||||
# Flatten all nodes into a single subgraph per item (and track origin
|
||||
# item to its newly expanded graph).
|
||||
subgraphs = {}
|
||||
for item in flow:
|
||||
subgraph = self._flatten(item)
|
||||
subgraph_map[item] = subgraph
|
||||
subgraph = self._flatten(item, node)[0]
|
||||
subgraphs[item] = subgraph
|
||||
graph = gr.merge_graphs([graph, subgraph])
|
||||
|
||||
# Reconnect all node edges to their corresponding subgraphs.
|
||||
# Reconnect all items edges to their corresponding subgraphs.
|
||||
for (u, v, attrs) in flow.iter_links():
|
||||
u_g = subgraph_map[u]
|
||||
v_g = subgraph_map[v]
|
||||
u_g = subgraphs[u]
|
||||
v_g = subgraphs[v]
|
||||
if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')):
|
||||
# Connect nodes with no predecessors in v to nodes with
|
||||
# no successors in u (thus maintaining the edge dependency).
|
||||
@@ -177,48 +161,57 @@ class _Flattener(object):
|
||||
|
||||
if flow.retry is not None:
|
||||
self._connect_retry(flow.retry, graph)
|
||||
return graph
|
||||
return graph, node
|
||||
|
||||
def _pre_item_flatten(self, item):
|
||||
"""Called before a item is flattened; any pre-flattening actions."""
|
||||
if id(item) in self._history:
|
||||
raise ValueError("Already flattened item: %s (%s), recursive"
|
||||
" flattening not supported" % (item, id(item)))
|
||||
self._history.add(id(item))
|
||||
if item in self._history:
|
||||
raise ValueError("Already flattened item '%s' (%s), recursive"
|
||||
" flattening is not supported" % (item,
|
||||
type(item)))
|
||||
self._history.add(item)
|
||||
|
||||
def _post_item_flatten(self, item, graph):
|
||||
"""Called before a item is flattened; any post-flattening actions."""
|
||||
def _post_item_flatten(self, item, graph, node):
|
||||
"""Called after a item is flattened; doing post-flattening actions."""
|
||||
|
||||
def _pre_flatten(self):
|
||||
"""Called before the flattening of the item starts."""
|
||||
"""Called before the flattening of the root starts."""
|
||||
self._history.clear()
|
||||
|
||||
def _post_flatten(self, graph):
|
||||
"""Called after the flattening of the item finishes successfully."""
|
||||
def _post_flatten(self, graph, node):
|
||||
"""Called after the flattening of the root finishes successfully."""
|
||||
dup_names = misc.get_duplicate_keys(graph.nodes_iter(),
|
||||
key=lambda node: node.name)
|
||||
if dup_names:
|
||||
dup_names = ', '.join(sorted(dup_names))
|
||||
raise exc.Duplicate("Atoms with duplicate names "
|
||||
"found: %s" % (dup_names))
|
||||
raise exc.Duplicate(
|
||||
"Atoms with duplicate names found: %s" % (sorted(dup_names)))
|
||||
if graph.number_of_nodes() == 0:
|
||||
raise exc.Empty("Root container '%s' (%s) is empty"
|
||||
% (self._root, type(self._root)))
|
||||
self._history.clear()
|
||||
# NOTE(harlowja): this one can be expensive to calculate (especially
|
||||
# the cycle detection), so only do it if we know debugging is enabled
|
||||
# and not under all cases.
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
LOG.debug("Translated '%s' into a graph:", self._root)
|
||||
LOG.debug("Translated '%s'", self._root)
|
||||
LOG.debug("Graph:")
|
||||
for line in graph.pformat().splitlines():
|
||||
# Indent it so that it's slightly offset from the above line.
|
||||
LOG.debug(" %s", line)
|
||||
LOG.debug(" %s", line)
|
||||
LOG.debug("Hierarchy:")
|
||||
for line in node.pformat().splitlines():
|
||||
# Indent it so that it's slightly offset from the above line.
|
||||
LOG.debug(" %s", line)
|
||||
|
||||
def flatten(self):
|
||||
"""Flattens a item (a task or flow) into a single execution graph."""
|
||||
if self._graph is not None:
|
||||
return self._graph
|
||||
self._pre_flatten()
|
||||
graph = self._flatten(self._root)
|
||||
self._post_flatten(graph)
|
||||
self._graph = graph
|
||||
if self._freeze:
|
||||
self._graph.freeze()
|
||||
return self._graph
|
||||
@lock_utils.locked
|
||||
def compile(self):
|
||||
"""Compiles the contained item into a compiled equivalent."""
|
||||
if self._compilation is None:
|
||||
self._pre_flatten()
|
||||
graph, node = self._flatten(self._root, None)
|
||||
self._post_flatten(graph, node)
|
||||
if self._freeze:
|
||||
graph.freeze()
|
||||
node.freeze()
|
||||
self._compilation = Compilation(graph, node)
|
||||
return self._compilation
|
||||
|
||||
@@ -210,13 +210,13 @@ class ActionEngine(base.EngineBase):
|
||||
|
||||
@misc.cachedproperty
|
||||
def _compiler(self):
|
||||
return self._compiler_factory()
|
||||
return self._compiler_factory(self._flow)
|
||||
|
||||
@lock_utils.locked
|
||||
def compile(self):
|
||||
if self._compiled:
|
||||
return
|
||||
self._compilation = self._compiler.compile(self._flow)
|
||||
self._compilation = self._compiler.compile()
|
||||
self._runtime = runtime.Runtime(self._compilation,
|
||||
self.storage,
|
||||
self.task_notifier,
|
||||
|
||||
@@ -27,13 +27,16 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
|
||||
|
||||
|
||||
class RetryAction(object):
|
||||
def __init__(self, storage, notifier):
|
||||
def __init__(self, storage, notifier, walker_factory):
|
||||
self._storage = storage
|
||||
self._notifier = notifier
|
||||
self._walker_factory = walker_factory
|
||||
|
||||
def _get_retry_args(self, retry):
|
||||
scope_walker = self._walker_factory(retry)
|
||||
kwargs = self._storage.fetch_mapped_args(retry.rebind,
|
||||
atom_name=retry.name)
|
||||
atom_name=retry.name,
|
||||
scope_walker=scope_walker)
|
||||
kwargs['history'] = self._storage.get_retry_history(retry.name)
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from taskflow.engines.action_engine import analyzer as ca
|
||||
from taskflow.engines.action_engine import executor as ex
|
||||
from taskflow.engines.action_engine import retry_action as ra
|
||||
from taskflow.engines.action_engine import runner as ru
|
||||
from taskflow.engines.action_engine import scopes as sc
|
||||
from taskflow.engines.action_engine import task_action as ta
|
||||
from taskflow import exceptions as excp
|
||||
from taskflow import retry as retry_atom
|
||||
@@ -66,12 +67,18 @@ class Runtime(object):
|
||||
|
||||
@misc.cachedproperty
|
||||
def retry_action(self):
|
||||
return ra.RetryAction(self.storage, self._task_notifier)
|
||||
return ra.RetryAction(self._storage, self._task_notifier,
|
||||
lambda atom: sc.ScopeWalker(self.compilation,
|
||||
atom,
|
||||
names_only=True))
|
||||
|
||||
@misc.cachedproperty
|
||||
def task_action(self):
|
||||
return ta.TaskAction(self.storage, self._task_executor,
|
||||
self._task_notifier)
|
||||
return ta.TaskAction(self._storage, self._task_executor,
|
||||
self._task_notifier,
|
||||
lambda atom: sc.ScopeWalker(self.compilation,
|
||||
atom,
|
||||
names_only=True))
|
||||
|
||||
def reset_nodes(self, nodes, state=st.PENDING, intention=st.EXECUTE):
|
||||
for node in nodes:
|
||||
@@ -81,7 +88,7 @@ class Runtime(object):
|
||||
elif isinstance(node, retry_atom.Retry):
|
||||
self.retry_action.change_state(node, state)
|
||||
else:
|
||||
raise TypeError("Unknown how to reset node %s, %s"
|
||||
raise TypeError("Unknown how to reset atom '%s' (%s)"
|
||||
% (node, type(node)))
|
||||
if intention:
|
||||
self.storage.set_atom_intention(node.name, intention)
|
||||
@@ -209,7 +216,7 @@ class Scheduler(object):
|
||||
elif isinstance(node, retry_atom.Retry):
|
||||
return self._schedule_retry(node)
|
||||
else:
|
||||
raise TypeError("Unknown how to schedule node %s, %s"
|
||||
raise TypeError("Unknown how to schedule atom '%s' (%s)"
|
||||
% (node, type(node)))
|
||||
|
||||
def _schedule_retry(self, retry):
|
||||
|
||||
119
taskflow/engines/action_engine/scopes.py
Normal file
119
taskflow/engines/action_engine/scopes.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from taskflow import atom as atom_type
|
||||
from taskflow import flow as flow_type
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_atoms(node, idx=-1):
|
||||
# Always go left to right, since right to left is the pattern order
|
||||
# and we want to go backwards and not forwards through that ordering...
|
||||
if idx == -1:
|
||||
children_iter = node.reverse_iter()
|
||||
else:
|
||||
children_iter = reversed(node[0:idx])
|
||||
atoms = []
|
||||
for child in children_iter:
|
||||
if isinstance(child.item, flow_type.Flow):
|
||||
atoms.extend(_extract_atoms(child))
|
||||
elif isinstance(child.item, atom_type.Atom):
|
||||
atoms.append(child.item)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unknown extraction item '%s' (%s)" % (child.item,
|
||||
type(child.item)))
|
||||
return atoms
|
||||
|
||||
|
||||
class ScopeWalker(object):
|
||||
"""Walks through the scopes of a atom using a engines compilation.
|
||||
|
||||
This will walk the visible scopes that are accessible for the given
|
||||
atom, which can be used by some external entity in some meaningful way,
|
||||
for example to find dependent values...
|
||||
"""
|
||||
|
||||
def __init__(self, compilation, atom, names_only=False):
|
||||
self._node = compilation.hierarchy.find(atom)
|
||||
if self._node is None:
|
||||
raise ValueError("Unable to find atom '%s' in compilation"
|
||||
" hierarchy" % atom)
|
||||
self._atom = atom
|
||||
self._graph = compilation.execution_graph
|
||||
self._names_only = names_only
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates over the visible scopes.
|
||||
|
||||
How this works is the following:
|
||||
|
||||
We find all the possible predecessors of the given atom, this is useful
|
||||
since we know they occurred before this atom but it doesn't tell us
|
||||
the corresponding scope *level* that each predecessor was created in,
|
||||
so we need to find this information.
|
||||
|
||||
For that information we consult the location of the atom ``Y`` in the
|
||||
node hierarchy. We lookup in a reverse order the parent ``X`` of ``Y``
|
||||
and traverse backwards from the index in the parent where ``Y``
|
||||
occurred, all children in ``X`` that we encounter in this backwards
|
||||
search (if a child is a flow itself, its atom contents will be
|
||||
expanded) will be assumed to be at the same scope. This is then a
|
||||
*potential* single scope, to make an *actual* scope we remove the items
|
||||
from the *potential* scope that are not predecessors of ``Y`` to form
|
||||
the *actual* scope.
|
||||
|
||||
Then for additional scopes we continue up the tree, by finding the
|
||||
parent of ``X`` (lets call it ``Z``) and perform the same operation,
|
||||
going through the children in a reverse manner from the index in
|
||||
parent ``Z`` where ``X`` was located. This forms another *potential*
|
||||
scope which we provide back as an *actual* scope after reducing the
|
||||
potential set by the predecessors of ``Y``. We then repeat this process
|
||||
until we no longer have any parent nodes (aka have reached the top of
|
||||
the tree) or we run out of predecessors.
|
||||
"""
|
||||
predecessors = set(self._graph.bfs_predecessors_iter(self._atom))
|
||||
last = self._node
|
||||
for parent in self._node.path_iter(include_self=False):
|
||||
if not predecessors:
|
||||
break
|
||||
last_idx = parent.index(last.item)
|
||||
visible = []
|
||||
for a in _extract_atoms(parent, idx=last_idx):
|
||||
if a in predecessors:
|
||||
predecessors.remove(a)
|
||||
if not self._names_only:
|
||||
visible.append(a)
|
||||
else:
|
||||
visible.append(a.name)
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
if not self._names_only:
|
||||
visible_names = [a.name for a in visible]
|
||||
else:
|
||||
visible_names = visible
|
||||
# TODO(harlowja): we should likely use a created TRACE level
|
||||
# for this kind of *very* verbose information; otherwise the
|
||||
# cinder and other folks are going to complain that there
|
||||
# debug logs are full of not so useful information (it is
|
||||
# useful to taskflow debugging...).
|
||||
LOG.debug("Scope visible to '%s' (limited by parent '%s' index"
|
||||
" < %s) is: %s", self._atom, parent.item.name,
|
||||
last_idx, visible_names)
|
||||
yield visible
|
||||
last = parent
|
||||
@@ -26,10 +26,11 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
|
||||
|
||||
class TaskAction(object):
|
||||
|
||||
def __init__(self, storage, task_executor, notifier):
|
||||
def __init__(self, storage, task_executor, notifier, walker_factory):
|
||||
self._storage = storage
|
||||
self._task_executor = task_executor
|
||||
self._notifier = notifier
|
||||
self._walker_factory = walker_factory
|
||||
|
||||
def _is_identity_transition(self, state, task, progress):
|
||||
if state in SAVE_RESULT_STATES:
|
||||
@@ -81,8 +82,10 @@ class TaskAction(object):
|
||||
|
||||
def schedule_execution(self, task):
|
||||
self.change_state(task, states.RUNNING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name)
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
return self._task_executor.execute_task(task, task_uuid, kwargs,
|
||||
self._on_update_progress)
|
||||
@@ -96,8 +99,10 @@ class TaskAction(object):
|
||||
|
||||
def schedule_reversion(self, task):
|
||||
self.change_state(task, states.REVERTING, progress=0.0)
|
||||
scope_walker = self._walker_factory(task)
|
||||
kwargs = self._storage.fetch_mapped_args(task.rebind,
|
||||
atom_name=task.name)
|
||||
atom_name=task.name,
|
||||
scope_walker=scope_walker)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
failures = self._storage.get_failures()
|
||||
|
||||
@@ -161,7 +161,7 @@ class Flow(flow.Flow):
|
||||
return self._get_subgraph().number_of_nodes()
|
||||
|
||||
def __iter__(self):
|
||||
for n in self._get_subgraph().nodes_iter():
|
||||
for n in self._get_subgraph().topological_sort():
|
||||
yield n
|
||||
|
||||
def iter_links(self):
|
||||
|
||||
@@ -31,6 +31,78 @@ from taskflow.utils import reflection
|
||||
LOG = logging.getLogger(__name__)
|
||||
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
|
||||
|
||||
# TODO(harlowja): do this better (via a singleton or something else...)
|
||||
_TRANSIENT_PROVIDER = object()
|
||||
|
||||
# NOTE(harlowja): Perhaps the container is a dictionary-like object and that
|
||||
# key does not exist (key error), or the container is a tuple/list and a
|
||||
# non-numeric key is being requested (index error), or there was no container
|
||||
# and an attempt to index into none/other unsubscriptable type is being
|
||||
# requested (type error).
|
||||
#
|
||||
# Overall this (along with the item_from* functions) try to handle the vast
|
||||
# majority of wrong indexing operations on the wrong/invalid types so that we
|
||||
# can fail extraction during lookup or emit warning on result reception...
|
||||
_EXTRACTION_EXCEPTIONS = (IndexError, KeyError, ValueError, TypeError)
|
||||
|
||||
|
||||
class _Provider(object):
|
||||
"""A named symbol provider that produces a output at the given index."""
|
||||
|
||||
def __init__(self, name, index):
|
||||
self.name = name
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
# TODO(harlowja): clean this up...
|
||||
if self.name is _TRANSIENT_PROVIDER:
|
||||
base = "<TransientProvider"
|
||||
else:
|
||||
base = "<Provider '%s'" % (self.name)
|
||||
if self.index is None:
|
||||
base += ">"
|
||||
else:
|
||||
base += " @ index %r>" % (self.index)
|
||||
return base
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name, self.index))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.name, self.index) == (other.name, other.index)
|
||||
|
||||
|
||||
def _item_from(container, index):
|
||||
"""Attempts to fetch a index/key from a given container."""
|
||||
if index is None:
|
||||
return container
|
||||
return container[index]
|
||||
|
||||
|
||||
def _item_from_single(provider, container, looking_for):
|
||||
"""Returns item from a *single* provider."""
|
||||
try:
|
||||
return _item_from(container, provider.index)
|
||||
except _EXTRACTION_EXCEPTIONS:
|
||||
raise exceptions.NotFound(
|
||||
"Unable to find result %r, expected to be able to find it"
|
||||
" created by %s but was unable to perform successful"
|
||||
" extraction" % (looking_for, provider))
|
||||
|
||||
|
||||
def _item_from_first_of(providers, looking_for):
|
||||
"""Returns item from the *first* successful container extraction."""
|
||||
for (provider, container) in providers:
|
||||
try:
|
||||
return (provider, _item_from(container, provider.index))
|
||||
except _EXTRACTION_EXCEPTIONS:
|
||||
pass
|
||||
providers = [p[0] for p in providers]
|
||||
raise exceptions.NotFound(
|
||||
"Unable to find result %r, expected to be able to find it"
|
||||
" created by one of %s but was unable to perform successful"
|
||||
" extraction" % (looking_for, providers))
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Storage(object):
|
||||
@@ -313,7 +385,7 @@ class Storage(object):
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _check_all_results_provided(self, atom_name, data):
|
||||
def _check_all_results_provided(self, atom_name, container):
|
||||
"""Warn if an atom did not provide some of its expected results.
|
||||
|
||||
This may happen if atom returns shorter tuple or list or dict
|
||||
@@ -325,8 +397,8 @@ class Storage(object):
|
||||
return
|
||||
for name, index in six.iteritems(result_mapping):
|
||||
try:
|
||||
misc.item_from(data, index, name=name)
|
||||
except exceptions.NotFound:
|
||||
_item_from(container, index)
|
||||
except _EXTRACTION_EXCEPTIONS:
|
||||
LOG.warning("Atom %s did not supply result "
|
||||
"with index %r (name %s)", atom_name, index, name)
|
||||
|
||||
@@ -464,94 +536,180 @@ class Storage(object):
|
||||
|
||||
def save_transient():
|
||||
self._transients.update(pairs)
|
||||
# NOTE(harlowja): none is not a valid atom name, so that means
|
||||
# we can use it internally to reference all of our transient
|
||||
# variables.
|
||||
return (None, six.iterkeys(self._transients))
|
||||
return (_TRANSIENT_PROVIDER, six.iterkeys(self._transients))
|
||||
|
||||
with self._lock.write_lock():
|
||||
if transient:
|
||||
(atom_name, names) = save_transient()
|
||||
provider_name, names = save_transient()
|
||||
else:
|
||||
(atom_name, names) = save_persistent()
|
||||
self._set_result_mapping(atom_name,
|
||||
provider_name, names = save_persistent()
|
||||
self._set_result_mapping(provider_name,
|
||||
dict((name, name) for name in names))
|
||||
|
||||
def _set_result_mapping(self, atom_name, mapping):
|
||||
"""Sets the result mapping for an atom.
|
||||
def _set_result_mapping(self, provider_name, mapping):
|
||||
"""Sets the result mapping for a given producer.
|
||||
|
||||
The result saved with given name would be accessible by names
|
||||
defined in mapping. Mapping is a dict name => index. If index
|
||||
is None, the whole result will have this name; else, only
|
||||
part of it, result[index].
|
||||
"""
|
||||
if not mapping:
|
||||
return
|
||||
self._result_mappings[atom_name] = mapping
|
||||
for name, index in six.iteritems(mapping):
|
||||
entries = self._reverse_mapping.setdefault(name, [])
|
||||
provider_mapping = self._result_mappings.setdefault(provider_name, {})
|
||||
if mapping:
|
||||
provider_mapping.update(mapping)
|
||||
# Ensure the reverse mapping/index is updated (for faster lookups).
|
||||
for name, index in six.iteritems(provider_mapping):
|
||||
entries = self._reverse_mapping.setdefault(name, [])
|
||||
provider = _Provider(provider_name, index)
|
||||
if provider not in entries:
|
||||
entries.append(provider)
|
||||
|
||||
# NOTE(imelnikov): We support setting same result mapping for
|
||||
# the same atom twice (e.g when we are injecting 'a' and then
|
||||
# injecting 'a' again), so we should not log warning below in
|
||||
# that case and we should have only one item for each pair
|
||||
# (atom_name, index) in entries. It should be put to the end of
|
||||
# entries list because order matters on fetching.
|
||||
try:
|
||||
entries.remove((atom_name, index))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
entries.append((atom_name, index))
|
||||
if len(entries) > 1:
|
||||
LOG.warning("Multiple provider mappings being created for %r",
|
||||
name)
|
||||
|
||||
def fetch(self, name):
|
||||
"""Fetch a named atoms result."""
|
||||
def fetch(self, name, many_handler=None):
|
||||
"""Fetch a named result."""
|
||||
# By default we just return the first of many (unless provided
|
||||
# a different callback that can translate many results into something
|
||||
# more meaningful).
|
||||
if many_handler is None:
|
||||
many_handler = lambda values: values[0]
|
||||
with self._lock.read_lock():
|
||||
try:
|
||||
indexes = self._reverse_mapping[name]
|
||||
providers = self._reverse_mapping[name]
|
||||
except KeyError:
|
||||
raise exceptions.NotFound("Name %r is not mapped" % name)
|
||||
# Return the first one that is found.
|
||||
for (atom_name, index) in reversed(indexes):
|
||||
if not atom_name:
|
||||
results = self._transients
|
||||
raise exceptions.NotFound("Name %r is not mapped as a"
|
||||
" produced output by any"
|
||||
" providers" % name)
|
||||
values = []
|
||||
for provider in providers:
|
||||
if provider.name is _TRANSIENT_PROVIDER:
|
||||
values.append(_item_from_single(provider,
|
||||
self._transients, name))
|
||||
else:
|
||||
results = self._get(atom_name, only_last=True)
|
||||
try:
|
||||
return misc.item_from(results, index, name)
|
||||
except exceptions.NotFound:
|
||||
pass
|
||||
raise exceptions.NotFound("Unable to find result %r" % name)
|
||||
try:
|
||||
container = self._get(provider.name, only_last=True)
|
||||
except exceptions.NotFound:
|
||||
pass
|
||||
else:
|
||||
values.append(_item_from_single(provider,
|
||||
container, name))
|
||||
if not values:
|
||||
raise exceptions.NotFound("Unable to find result %r,"
|
||||
" searched %s" % (name, providers))
|
||||
else:
|
||||
return many_handler(values)
|
||||
|
||||
def fetch_all(self):
|
||||
"""Fetch all named atom results known so far.
|
||||
"""Fetch all named results known so far.
|
||||
|
||||
Should be used for debugging and testing purposes mostly.
|
||||
NOTE(harlowja): should be used for debugging and testing purposes.
|
||||
"""
|
||||
def many_handler(values):
|
||||
if len(values) > 1:
|
||||
return values
|
||||
return values[0]
|
||||
with self._lock.read_lock():
|
||||
results = {}
|
||||
for name in self._reverse_mapping:
|
||||
for name in six.iterkeys(self._reverse_mapping):
|
||||
try:
|
||||
results[name] = self.fetch(name)
|
||||
results[name] = self.fetch(name, many_handler=many_handler)
|
||||
except exceptions.NotFound:
|
||||
pass
|
||||
return results
|
||||
|
||||
def fetch_mapped_args(self, args_mapping, atom_name=None):
|
||||
"""Fetch arguments for an atom using an atoms arguments mapping."""
|
||||
def fetch_mapped_args(self, args_mapping,
|
||||
atom_name=None, scope_walker=None):
|
||||
"""Fetch arguments for an atom using an atoms argument mapping."""
|
||||
|
||||
def _get_results(looking_for, provider):
|
||||
"""Gets the results saved for a given provider."""
|
||||
try:
|
||||
return self._get(provider.name, only_last=True)
|
||||
except exceptions.NotFound as e:
|
||||
raise exceptions.NotFound(
|
||||
"Expected to be able to find output %r produced"
|
||||
" by %s but was unable to get at that providers"
|
||||
" results" % (looking_for, provider), e)
|
||||
|
||||
def _locate_providers(looking_for, possible_providers):
|
||||
"""Finds the accessible providers."""
|
||||
default_providers = []
|
||||
for p in possible_providers:
|
||||
if p.name is _TRANSIENT_PROVIDER:
|
||||
default_providers.append((p, self._transients))
|
||||
if p.name == self.injector_name:
|
||||
default_providers.append((p, _get_results(looking_for, p)))
|
||||
if default_providers:
|
||||
return default_providers
|
||||
if scope_walker is not None:
|
||||
scope_iter = iter(scope_walker)
|
||||
else:
|
||||
scope_iter = iter([])
|
||||
for atom_names in scope_iter:
|
||||
if not atom_names:
|
||||
continue
|
||||
providers = []
|
||||
for p in possible_providers:
|
||||
if p.name in atom_names:
|
||||
providers.append((p, _get_results(looking_for, p)))
|
||||
if providers:
|
||||
return providers
|
||||
return []
|
||||
|
||||
with self._lock.read_lock():
|
||||
injected_args = {}
|
||||
if atom_name and atom_name not in self._atom_name_to_uuid:
|
||||
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
|
||||
if not args_mapping:
|
||||
return {}
|
||||
# The order of lookup is the following:
|
||||
#
|
||||
# 1. Injected atom specific arguments.
|
||||
# 2. Transient injected arguments.
|
||||
# 3. Non-transient injected arguments.
|
||||
# 4. First scope visited group that produces the named result.
|
||||
# a). The first of that group that actually provided the name
|
||||
# result is selected (if group size is greater than one).
|
||||
#
|
||||
# Otherwise: blowup! (this will also happen if reading or
|
||||
# extracting an expected result fails, since it is better to fail
|
||||
# on lookup then provide invalid data from the wrong provider)
|
||||
if atom_name:
|
||||
injected_args = self._injected_args.get(atom_name, {})
|
||||
else:
|
||||
injected_args = {}
|
||||
mapped_args = {}
|
||||
for key, name in six.iteritems(args_mapping):
|
||||
for (bound_name, name) in six.iteritems(args_mapping):
|
||||
# TODO(harlowja): This logging information may be to verbose
|
||||
# even for DEBUG mode, let's see if we can maybe in the future
|
||||
# add a TRACE mode or something else if people complain...
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
if atom_name:
|
||||
LOG.debug("Looking for %r <= %r for atom named: %s",
|
||||
bound_name, name, atom_name)
|
||||
else:
|
||||
LOG.debug("Looking for %r <= %r", bound_name, name)
|
||||
if name in injected_args:
|
||||
mapped_args[key] = injected_args[name]
|
||||
value = injected_args[name]
|
||||
mapped_args[bound_name] = value
|
||||
LOG.debug("Matched %r <= %r to %r (from injected values)",
|
||||
bound_name, name, value)
|
||||
else:
|
||||
mapped_args[key] = self.fetch(name)
|
||||
try:
|
||||
possible_providers = self._reverse_mapping[name]
|
||||
except KeyError:
|
||||
raise exceptions.NotFound("Name %r is not mapped as a"
|
||||
" produced output by any"
|
||||
" providers" % name)
|
||||
# Reduce the possible providers to one that are allowed.
|
||||
providers = _locate_providers(name, possible_providers)
|
||||
if not providers:
|
||||
raise exceptions.NotFound(
|
||||
"Mapped argument %r <= %r was not produced"
|
||||
" by any accessible provider (%s possible"
|
||||
" providers were scanned)"
|
||||
% (bound_name, name, len(possible_providers)))
|
||||
provider, value = _item_from_first_of(providers, name)
|
||||
mapped_args[bound_name] = value
|
||||
LOG.debug("Matched %r <= %r to %r (from %s)",
|
||||
bound_name, name, value, provider)
|
||||
return mapped_args
|
||||
|
||||
def set_flow_state(self, state):
|
||||
|
||||
@@ -27,21 +27,25 @@ from taskflow.tests import utils as test_utils
|
||||
class PatternCompileTest(test.TestCase):
|
||||
def test_task(self):
|
||||
task = test_utils.DummyTask(name='a')
|
||||
compilation = compiler.PatternCompiler().compile(task)
|
||||
compilation = compiler.PatternCompiler(task).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(list(g.nodes()), [task])
|
||||
self.assertEqual(list(g.edges()), [])
|
||||
|
||||
def test_retry(self):
|
||||
r = retry.AlwaysRevert('r1')
|
||||
msg_regex = "^Retry controller: .* must only be used .*"
|
||||
msg_regex = "^Retry controller .* must only be used .*"
|
||||
self.assertRaisesRegexp(TypeError, msg_regex,
|
||||
compiler.PatternCompiler().compile, r)
|
||||
compiler.PatternCompiler(r).compile)
|
||||
|
||||
def test_wrong_object(self):
|
||||
msg_regex = '^Unknown type requested to flatten'
|
||||
msg_regex = '^Unknown item .* requested to flatten'
|
||||
self.assertRaisesRegexp(TypeError, msg_regex,
|
||||
compiler.PatternCompiler().compile, 42)
|
||||
compiler.PatternCompiler(42).compile)
|
||||
|
||||
def test_empty(self):
|
||||
flo = lf.Flow("test")
|
||||
self.assertRaises(exc.Empty, compiler.PatternCompiler(flo).compile)
|
||||
|
||||
def test_linear(self):
|
||||
a, b, c, d = test_utils.make_many(4)
|
||||
@@ -51,7 +55,7 @@ class PatternCompileTest(test.TestCase):
|
||||
sflo.add(d)
|
||||
flo.add(sflo)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
|
||||
@@ -69,13 +73,13 @@ class PatternCompileTest(test.TestCase):
|
||||
flo.add(a, b, c)
|
||||
flo.add(flo)
|
||||
self.assertRaises(ValueError,
|
||||
compiler.PatternCompiler().compile, flo)
|
||||
compiler.PatternCompiler(flo).compile)
|
||||
|
||||
def test_unordered(self):
|
||||
a, b, c, d = test_utils.make_many(4)
|
||||
flo = uf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
self.assertEqual(0, g.number_of_edges())
|
||||
@@ -92,7 +96,7 @@ class PatternCompileTest(test.TestCase):
|
||||
flo2.add(c, d)
|
||||
flo.add(flo2)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
|
||||
@@ -116,7 +120,7 @@ class PatternCompileTest(test.TestCase):
|
||||
flo2.add(c, d)
|
||||
flo.add(flo2)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
for n in [a, b]:
|
||||
@@ -138,7 +142,7 @@ class PatternCompileTest(test.TestCase):
|
||||
uf.Flow('ut').add(b, c),
|
||||
d)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
self.assertItemsEqual(g.edges(), [
|
||||
@@ -153,7 +157,7 @@ class PatternCompileTest(test.TestCase):
|
||||
flo = gf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
self.assertEqual(0, g.number_of_edges())
|
||||
@@ -167,7 +171,7 @@ class PatternCompileTest(test.TestCase):
|
||||
flo2.add(e, f, g)
|
||||
flo.add(flo2)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
graph = compilation.execution_graph
|
||||
self.assertEqual(7, len(graph))
|
||||
self.assertItemsEqual(graph.edges(data=True), [
|
||||
@@ -184,7 +188,7 @@ class PatternCompileTest(test.TestCase):
|
||||
flo2.add(e, f, g)
|
||||
flo.add(flo2)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(7, len(g))
|
||||
self.assertEqual(0, g.number_of_edges())
|
||||
@@ -197,7 +201,7 @@ class PatternCompileTest(test.TestCase):
|
||||
flo.link(b, c)
|
||||
flo.link(c, d)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
self.assertItemsEqual(g.edges(data=True), [
|
||||
@@ -213,7 +217,7 @@ class PatternCompileTest(test.TestCase):
|
||||
b = test_utils.ProvidesRequiresTask('b', provides=[], requires=['x'])
|
||||
flo = gf.Flow("test").add(a, b)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(2, len(g))
|
||||
self.assertItemsEqual(g.edges(data=True), [
|
||||
@@ -231,7 +235,7 @@ class PatternCompileTest(test.TestCase):
|
||||
lf.Flow("test2").add(b, c)
|
||||
)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(3, len(g))
|
||||
self.assertItemsEqual(g.edges(data=True), [
|
||||
@@ -250,7 +254,7 @@ class PatternCompileTest(test.TestCase):
|
||||
lf.Flow("test2").add(b, c)
|
||||
)
|
||||
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(3, len(g))
|
||||
self.assertItemsEqual(g.edges(data=True), [
|
||||
@@ -267,7 +271,7 @@ class PatternCompileTest(test.TestCase):
|
||||
)
|
||||
self.assertRaisesRegexp(exc.Duplicate,
|
||||
'^Atoms with duplicate names',
|
||||
compiler.PatternCompiler().compile, flo)
|
||||
compiler.PatternCompiler(flo).compile)
|
||||
|
||||
def test_checks_for_dups_globally(self):
|
||||
flo = gf.Flow("test").add(
|
||||
@@ -275,25 +279,25 @@ class PatternCompileTest(test.TestCase):
|
||||
gf.Flow("int2").add(test_utils.DummyTask(name="a")))
|
||||
self.assertRaisesRegexp(exc.Duplicate,
|
||||
'^Atoms with duplicate names',
|
||||
compiler.PatternCompiler().compile, flo)
|
||||
compiler.PatternCompiler(flo).compile)
|
||||
|
||||
def test_retry_in_linear_flow(self):
|
||||
flo = lf.Flow("test", retry.AlwaysRevert("c"))
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(1, len(g))
|
||||
self.assertEqual(0, g.number_of_edges())
|
||||
|
||||
def test_retry_in_unordered_flow(self):
|
||||
flo = uf.Flow("test", retry.AlwaysRevert("c"))
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(1, len(g))
|
||||
self.assertEqual(0, g.number_of_edges())
|
||||
|
||||
def test_retry_in_graph_flow(self):
|
||||
flo = gf.Flow("test", retry.AlwaysRevert("c"))
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(1, len(g))
|
||||
self.assertEqual(0, g.number_of_edges())
|
||||
@@ -302,7 +306,7 @@ class PatternCompileTest(test.TestCase):
|
||||
c1 = retry.AlwaysRevert("c1")
|
||||
c2 = retry.AlwaysRevert("c2")
|
||||
flo = lf.Flow("test", c1).add(lf.Flow("test2", c2))
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
|
||||
self.assertEqual(2, len(g))
|
||||
@@ -317,7 +321,7 @@ class PatternCompileTest(test.TestCase):
|
||||
c = retry.AlwaysRevert("c")
|
||||
a, b = test_utils.make_many(2)
|
||||
flo = lf.Flow("test", c).add(a, b)
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
|
||||
self.assertEqual(3, len(g))
|
||||
@@ -335,7 +339,7 @@ class PatternCompileTest(test.TestCase):
|
||||
c = retry.AlwaysRevert("c")
|
||||
a, b = test_utils.make_many(2)
|
||||
flo = uf.Flow("test", c).add(a, b)
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
|
||||
self.assertEqual(3, len(g))
|
||||
@@ -353,7 +357,7 @@ class PatternCompileTest(test.TestCase):
|
||||
r = retry.AlwaysRevert("cp")
|
||||
a, b, c = test_utils.make_many(3)
|
||||
flo = gf.Flow("test", r).add(a, b, c).link(b, c)
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
self.assertEqual(4, len(g))
|
||||
|
||||
@@ -377,7 +381,7 @@ class PatternCompileTest(test.TestCase):
|
||||
a,
|
||||
lf.Flow("test", c2).add(b, c),
|
||||
d)
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
|
||||
self.assertEqual(6, len(g))
|
||||
@@ -402,7 +406,7 @@ class PatternCompileTest(test.TestCase):
|
||||
a,
|
||||
lf.Flow("test").add(b, c),
|
||||
d)
|
||||
compilation = compiler.PatternCompiler().compile(flo)
|
||||
compilation = compiler.PatternCompiler(flo).compile()
|
||||
g = compilation.execution_graph
|
||||
|
||||
self.assertEqual(5, len(g))
|
||||
|
||||
@@ -33,7 +33,7 @@ from taskflow.utils import persistence_utils as pu
|
||||
|
||||
class _RunnerTestMixin(object):
|
||||
def _make_runtime(self, flow, initial_state=None):
|
||||
compilation = compiler.PatternCompiler().compile(flow)
|
||||
compilation = compiler.PatternCompiler(flow).compile()
|
||||
flow_detail = pu.create_flow_detail(flow)
|
||||
store = storage.SingleThreadedStorage(flow_detail)
|
||||
# This ensures the tasks exist in storage...
|
||||
|
||||
248
taskflow/tests/unit/test_action_engine_scoping.py
Normal file
248
taskflow/tests/unit/test_action_engine_scoping.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.action_engine import compiler
|
||||
from taskflow.engines.action_engine import scopes as sc
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.patterns import unordered_flow as uf
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils as test_utils
|
||||
|
||||
|
||||
def _get_scopes(compilation, atom, names_only=True):
|
||||
walker = sc.ScopeWalker(compilation, atom, names_only=names_only)
|
||||
return list(iter(walker))
|
||||
|
||||
|
||||
class LinearScopingTest(test.TestCase):
|
||||
def test_unknown(self):
|
||||
r = lf.Flow("root")
|
||||
r_1 = test_utils.TaskOneReturn("root.1")
|
||||
r.add(r_1)
|
||||
|
||||
r_2 = test_utils.TaskOneReturn("root.2")
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
self.assertRaises(ValueError, _get_scopes, c, r_2)
|
||||
|
||||
def test_empty(self):
|
||||
r = lf.Flow("root")
|
||||
r_1 = test_utils.TaskOneReturn("root.1")
|
||||
r.add(r_1)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
self.assertIn(r_1, c.execution_graph)
|
||||
self.assertIsNotNone(c.hierarchy.find(r_1))
|
||||
|
||||
walker = sc.ScopeWalker(c, r_1)
|
||||
scopes = list(walker)
|
||||
self.assertEqual([], scopes)
|
||||
|
||||
def test_single_prior_linear(self):
|
||||
r = lf.Flow("root")
|
||||
r_1 = test_utils.TaskOneReturn("root.1")
|
||||
r_2 = test_utils.TaskOneReturn("root.2")
|
||||
r.add(r_1, r_2)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
for a in r:
|
||||
self.assertIn(a, c.execution_graph)
|
||||
self.assertIsNotNone(c.hierarchy.find(a))
|
||||
|
||||
self.assertEqual([], _get_scopes(c, r_1))
|
||||
self.assertEqual([['root.1']], _get_scopes(c, r_2))
|
||||
|
||||
def test_nested_prior_linear(self):
|
||||
r = lf.Flow("root")
|
||||
r.add(test_utils.TaskOneReturn("root.1"),
|
||||
test_utils.TaskOneReturn("root.2"))
|
||||
sub_r = lf.Flow("subroot")
|
||||
sub_r_1 = test_utils.TaskOneReturn("subroot.1")
|
||||
sub_r.add(sub_r_1)
|
||||
r.add(sub_r)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
self.assertEqual([[], ['root.2', 'root.1']], _get_scopes(c, sub_r_1))
|
||||
|
||||
def test_nested_prior_linear_begin_middle_end(self):
|
||||
r = lf.Flow("root")
|
||||
begin_r = test_utils.TaskOneReturn("root.1")
|
||||
r.add(begin_r, test_utils.TaskOneReturn("root.2"))
|
||||
middle_r = test_utils.TaskOneReturn("root.3")
|
||||
r.add(middle_r)
|
||||
sub_r = lf.Flow("subroot")
|
||||
sub_r.add(test_utils.TaskOneReturn("subroot.1"),
|
||||
test_utils.TaskOneReturn("subroot.2"))
|
||||
r.add(sub_r)
|
||||
end_r = test_utils.TaskOneReturn("root.4")
|
||||
r.add(end_r)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
|
||||
self.assertEqual([], _get_scopes(c, begin_r))
|
||||
self.assertEqual([['root.2', 'root.1']], _get_scopes(c, middle_r))
|
||||
self.assertEqual([['subroot.2', 'subroot.1', 'root.3', 'root.2',
|
||||
'root.1']], _get_scopes(c, end_r))
|
||||
|
||||
|
||||
class GraphScopingTest(test.TestCase):
|
||||
def test_dependent(self):
|
||||
r = gf.Flow("root")
|
||||
|
||||
customer = test_utils.ProvidesRequiresTask("customer",
|
||||
provides=['dog'],
|
||||
requires=[])
|
||||
washer = test_utils.ProvidesRequiresTask("washer",
|
||||
requires=['dog'],
|
||||
provides=['wash'])
|
||||
dryer = test_utils.ProvidesRequiresTask("dryer",
|
||||
requires=['dog', 'wash'],
|
||||
provides=['dry_dog'])
|
||||
shaved = test_utils.ProvidesRequiresTask("shaver",
|
||||
requires=['dry_dog'],
|
||||
provides=['shaved_dog'])
|
||||
happy_customer = test_utils.ProvidesRequiresTask(
|
||||
"happy_customer", requires=['shaved_dog'], provides=['happiness'])
|
||||
|
||||
r.add(customer, washer, dryer, shaved, happy_customer)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
|
||||
self.assertEqual([], _get_scopes(c, customer))
|
||||
self.assertEqual([['washer', 'customer']], _get_scopes(c, dryer))
|
||||
self.assertEqual([['shaver', 'dryer', 'washer', 'customer']],
|
||||
_get_scopes(c, happy_customer))
|
||||
|
||||
def test_no_visible(self):
|
||||
r = gf.Flow("root")
|
||||
atoms = []
|
||||
for i in range(0, 10):
|
||||
atoms.append(test_utils.TaskOneReturn("root.%s" % i))
|
||||
r.add(*atoms)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
for a in atoms:
|
||||
self.assertEqual([], _get_scopes(c, a))
|
||||
|
||||
def test_nested(self):
|
||||
r = gf.Flow("root")
|
||||
|
||||
r_1 = test_utils.TaskOneReturn("root.1")
|
||||
r_2 = test_utils.TaskOneReturn("root.2")
|
||||
r.add(r_1, r_2)
|
||||
r.link(r_1, r_2)
|
||||
|
||||
subroot = gf.Flow("subroot")
|
||||
subroot_r_1 = test_utils.TaskOneReturn("subroot.1")
|
||||
subroot_r_2 = test_utils.TaskOneReturn("subroot.2")
|
||||
subroot.add(subroot_r_1, subroot_r_2)
|
||||
subroot.link(subroot_r_1, subroot_r_2)
|
||||
|
||||
r.add(subroot)
|
||||
r_3 = test_utils.TaskOneReturn("root.3")
|
||||
r.add(r_3)
|
||||
r.link(r_2, r_3)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
self.assertEqual([], _get_scopes(c, r_1))
|
||||
self.assertEqual([['root.1']], _get_scopes(c, r_2))
|
||||
self.assertEqual([['root.2', 'root.1']], _get_scopes(c, r_3))
|
||||
|
||||
self.assertEqual([], _get_scopes(c, subroot_r_1))
|
||||
self.assertEqual([['subroot.1']], _get_scopes(c, subroot_r_2))
|
||||
|
||||
|
||||
class UnorderedScopingTest(test.TestCase):
|
||||
def test_no_visible(self):
|
||||
r = uf.Flow("root")
|
||||
atoms = []
|
||||
for i in range(0, 10):
|
||||
atoms.append(test_utils.TaskOneReturn("root.%s" % i))
|
||||
r.add(*atoms)
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
for a in atoms:
|
||||
self.assertEqual([], _get_scopes(c, a))
|
||||
|
||||
|
||||
class MixedPatternScopingTest(test.TestCase):
|
||||
def test_graph_linear_scope(self):
|
||||
r = gf.Flow("root")
|
||||
r_1 = test_utils.TaskOneReturn("root.1")
|
||||
r_2 = test_utils.TaskOneReturn("root.2")
|
||||
r.add(r_1, r_2)
|
||||
r.link(r_1, r_2)
|
||||
|
||||
s = lf.Flow("subroot")
|
||||
s_1 = test_utils.TaskOneReturn("subroot.1")
|
||||
s_2 = test_utils.TaskOneReturn("subroot.2")
|
||||
s.add(s_1, s_2)
|
||||
r.add(s)
|
||||
|
||||
t = gf.Flow("subroot2")
|
||||
t_1 = test_utils.TaskOneReturn("subroot2.1")
|
||||
t_2 = test_utils.TaskOneReturn("subroot2.2")
|
||||
t.add(t_1, t_2)
|
||||
t.link(t_1, t_2)
|
||||
r.add(t)
|
||||
r.link(s, t)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
self.assertEqual([], _get_scopes(c, r_1))
|
||||
self.assertEqual([['root.1']], _get_scopes(c, r_2))
|
||||
self.assertEqual([], _get_scopes(c, s_1))
|
||||
self.assertEqual([['subroot.1']], _get_scopes(c, s_2))
|
||||
self.assertEqual([[], ['subroot.2', 'subroot.1']],
|
||||
_get_scopes(c, t_1))
|
||||
self.assertEqual([["subroot2.1"], ['subroot.2', 'subroot.1']],
|
||||
_get_scopes(c, t_2))
|
||||
|
||||
def test_linear_unordered_scope(self):
|
||||
r = lf.Flow("root")
|
||||
r_1 = test_utils.TaskOneReturn("root.1")
|
||||
r_2 = test_utils.TaskOneReturn("root.2")
|
||||
r.add(r_1, r_2)
|
||||
|
||||
u = uf.Flow("subroot")
|
||||
atoms = []
|
||||
for i in range(0, 5):
|
||||
atoms.append(test_utils.TaskOneReturn("subroot.%s" % i))
|
||||
u.add(*atoms)
|
||||
r.add(u)
|
||||
|
||||
r_3 = test_utils.TaskOneReturn("root.3")
|
||||
r.add(r_3)
|
||||
|
||||
c = compiler.PatternCompiler(r).compile()
|
||||
|
||||
self.assertEqual([], _get_scopes(c, r_1))
|
||||
self.assertEqual([['root.1']], _get_scopes(c, r_2))
|
||||
for a in atoms:
|
||||
self.assertEqual([[], ['root.2', 'root.1']], _get_scopes(c, a))
|
||||
|
||||
scope = _get_scopes(c, r_3)
|
||||
self.assertEqual(1, len(scope))
|
||||
first_root = 0
|
||||
for i, n in enumerate(scope[0]):
|
||||
if n.startswith('root.'):
|
||||
first_root = i
|
||||
break
|
||||
first_subroot = 0
|
||||
for i, n in enumerate(scope[0]):
|
||||
if n.startswith('subroot.'):
|
||||
first_subroot = i
|
||||
break
|
||||
self.assertGreater(first_subroot, first_root)
|
||||
self.assertEqual(scope[0][-2:], ['root.2', 'root.1'])
|
||||
@@ -454,23 +454,6 @@ class StorageTestMixin(object):
|
||||
self.assertRaisesRegexp(exceptions.NotFound,
|
||||
'^Unable to find result', s.fetch, 'b')
|
||||
|
||||
@mock.patch.object(storage.LOG, 'warning')
|
||||
def test_multiple_providers_are_checked(self, mocked_warning):
|
||||
s = self._get_storage()
|
||||
s.ensure_task('my task', result_mapping={'result': 'key'})
|
||||
self.assertEqual(mocked_warning.mock_calls, [])
|
||||
s.ensure_task('my other task', result_mapping={'result': 'key'})
|
||||
mocked_warning.assert_called_once_with(
|
||||
mock.ANY, 'result')
|
||||
|
||||
@mock.patch.object(storage.LOG, 'warning')
|
||||
def test_multiple_providers_with_inject_are_checked(self, mocked_warning):
|
||||
s = self._get_storage()
|
||||
s.inject({'result': 'DONE'})
|
||||
self.assertEqual(mocked_warning.mock_calls, [])
|
||||
s.ensure_task('my other task', result_mapping={'result': 'key'})
|
||||
mocked_warning.assert_called_once_with(mock.ANY, 'result')
|
||||
|
||||
def test_ensure_retry(self):
|
||||
s = self._get_storage()
|
||||
s.ensure_retry('my retry')
|
||||
|
||||
@@ -166,6 +166,11 @@ class Node(object):
|
||||
for c in self._children:
|
||||
yield c
|
||||
|
||||
def reverse_iter(self):
|
||||
"""Iterates over the direct children of this node (left->right)."""
|
||||
for c in reversed(self._children):
|
||||
yield c
|
||||
|
||||
def index(self, item):
|
||||
"""Finds the child index of a given item, searchs in added order."""
|
||||
index_at = None
|
||||
|
||||
@@ -257,24 +257,6 @@ def sequence_minus(seq1, seq2):
|
||||
return result
|
||||
|
||||
|
||||
def item_from(container, index, name=None):
|
||||
"""Attempts to fetch a index/key from a given container."""
|
||||
if index is None:
|
||||
return container
|
||||
try:
|
||||
return container[index]
|
||||
except (IndexError, KeyError, ValueError, TypeError):
|
||||
# NOTE(harlowja): Perhaps the container is a dictionary-like object
|
||||
# and that key does not exist (key error), or the container is a
|
||||
# tuple/list and a non-numeric key is being requested (index error),
|
||||
# or there was no container and an attempt to index into none/other
|
||||
# unsubscriptable type is being requested (type error).
|
||||
if name is None:
|
||||
name = index
|
||||
raise exc.NotFound("Unable to find %r in container %s"
|
||||
% (name, container))
|
||||
|
||||
|
||||
def get_duplicate_keys(iterable, key=None):
|
||||
if key is not None:
|
||||
iterable = six.moves.map(key, iterable)
|
||||
@@ -399,8 +381,8 @@ def ensure_tree(path):
|
||||
"""
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EEXIST:
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
if not os.path.isdir(path):
|
||||
raise
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user