Revamp the symbol lookup mechanism

To complement the future changes in patterns we also want
to allow the execution of patterns to be affected in a similar
manner so that symbol lookup is no longer as confined as it was.

This change adds in the following:

- Symbol lookup by walking through an atoms contained scope/s.
- Better error messaging when symbols are not found.
- Adjusted & new tests (existing ones work).
- Better logging of the symbol lookup mechanism (helpful
  during debugging, although it is very verbose...)

Part of blueprint taskflow-improved-scoping

Change-Id: Id921a4abd9bf2b7b5c5a762337f8e90e8f1fe194
This commit is contained in:
Joshua Harlow
2014-07-21 18:25:49 -07:00
parent e68d72f66e
commit fa077c953f
14 changed files with 735 additions and 228 deletions

View File

@@ -15,53 +15,34 @@
# under the License.
import logging
import threading
from taskflow import exceptions as exc
from taskflow import flow
from taskflow import retry
from taskflow import task
from taskflow.types import graph as gr
from taskflow.types import tree as tr
from taskflow.utils import lock_utils
from taskflow.utils import misc
LOG = logging.getLogger(__name__)
class Compilation(object):
"""The result of a compilers compile() is this *immutable* object.
"""The result of a compilers compile() is this *immutable* object."""
For now it is just a execution graph but in the future it will grow to
include more methods & properties that help the various runtime units
execute in a more optimal & featureful manner.
"""
def __init__(self, execution_graph):
def __init__(self, execution_graph, hierarchy):
self._execution_graph = execution_graph
self._hierarchy = hierarchy
@property
def execution_graph(self):
return self._execution_graph
class PatternCompiler(object):
"""Compiles patterns & atoms into a compilation unit.
NOTE(harlowja): during this pattern translation process any nested flows
will be converted into there equivalent subgraphs. This currently implies
that contained atoms in those nested flows, post-translation will no longer
be associated with there previously containing flow but instead will lose
this identity and what will remain is the logical constraints that there
contained flow mandated. In the future this may be changed so that this
association is not lost via the compilation process (since it can be
useful to retain this relationship).
"""
def compile(self, root):
graph = _Flattener(root).flatten()
if graph.number_of_nodes() == 0:
# Try to get a name attribute, otherwise just use the object
# string representation directly if that attribute does not exist.
name = getattr(root, 'name', root)
raise exc.Empty("Root container '%s' (%s) is empty."
% (name, type(root)))
return Compilation(graph)
@property
def hierarchy(self):
return self._hierarchy
_RETRY_EDGE_DATA = {
@@ -69,14 +50,15 @@ _RETRY_EDGE_DATA = {
}
class _Flattener(object):
"""Flattens a root item (task/flow) into a execution graph."""
class PatternCompiler(object):
"""Compiles a pattern (or task) into a compilation unit."""
def __init__(self, root, freeze=True):
self._root = root
self._graph = None
self._history = set()
self._freeze = bool(freeze)
self._freeze = freeze
self._lock = threading.Lock()
self._compilation = None
def _add_new_edges(self, graph, nodes_from, nodes_to, edge_attrs):
"""Adds new edges from nodes to other nodes in the specified graph.
@@ -93,72 +75,74 @@ class _Flattener(object):
# if it's later modified that the same copy isn't modified.
graph.add_edge(u, v, attr_dict=edge_attrs.copy())
def _flatten(self, item):
functor = self._find_flattener(item)
if not functor:
raise TypeError("Unknown type requested to flatten: %s (%s)"
% (item, type(item)))
def _flatten(self, item, parent):
functor = self._find_flattener(item, parent)
self._pre_item_flatten(item)
graph = functor(item)
self._post_item_flatten(item, graph)
return graph
graph, node = functor(item, parent)
self._post_item_flatten(item, graph, node)
return graph, node
def _find_flattener(self, item):
def _find_flattener(self, item, parent):
"""Locates the flattening function to use to flatten the given item."""
if isinstance(item, flow.Flow):
return self._flatten_flow
elif isinstance(item, task.BaseTask):
return self._flatten_task
elif isinstance(item, retry.Retry):
if len(self._history) == 1:
raise TypeError("Retry controller: %s (%s) must only be used"
if parent is None:
raise TypeError("Retry controller '%s' (%s) must only be used"
" as a flow constructor parameter and not as a"
" root component" % (item, type(item)))
else:
# TODO(harlowja): we should raise this type error earlier
# instead of later since we should do this same check on add()
# calls, this makes the error more visible (instead of waiting
# until compile time).
raise TypeError("Retry controller: %s (%s) must only be used"
raise TypeError("Retry controller '%s' (%s) must only be used"
" as a flow constructor parameter and not as a"
" flow added component" % (item, type(item)))
else:
return None
raise TypeError("Unknown item '%s' (%s) requested to flatten"
% (item, type(item)))
def _connect_retry(self, retry, graph):
graph.add_node(retry)
# All graph nodes that have no predecessors should depend on its retry
nodes_to = [n for n in graph.no_predecessors_iter() if n != retry]
# All nodes that have no predecessors should depend on this retry.
nodes_to = [n for n in graph.no_predecessors_iter() if n is not retry]
self._add_new_edges(graph, [retry], nodes_to, _RETRY_EDGE_DATA)
# Add link to retry for each node of subgraph that hasn't
# a parent retry
# Add association for each node of graph that has no existing retry.
for n in graph.nodes_iter():
if n != retry and 'retry' not in graph.node[n]:
if n is not retry and 'retry' not in graph.node[n]:
graph.node[n]['retry'] = retry
def _flatten_task(self, task):
def _flatten_task(self, task, parent):
"""Flattens a individual task."""
graph = gr.DiGraph(name=task.name)
graph.add_node(task)
return graph
node = tr.Node(task)
if parent is not None:
parent.add(node)
return graph, node
def _flatten_flow(self, flow):
"""Flattens a graph flow."""
def _flatten_flow(self, flow, parent):
"""Flattens a flow."""
graph = gr.DiGraph(name=flow.name)
node = tr.Node(flow)
if parent is not None:
parent.add(node)
if flow.retry is not None:
node.add(tr.Node(flow.retry))
# Flatten all nodes into a single subgraph per node.
subgraph_map = {}
# Flatten all nodes into a single subgraph per item (and track origin
# item to its newly expanded graph).
subgraphs = {}
for item in flow:
subgraph = self._flatten(item)
subgraph_map[item] = subgraph
subgraph = self._flatten(item, node)[0]
subgraphs[item] = subgraph
graph = gr.merge_graphs([graph, subgraph])
# Reconnect all node edges to their corresponding subgraphs.
# Reconnect all items edges to their corresponding subgraphs.
for (u, v, attrs) in flow.iter_links():
u_g = subgraph_map[u]
v_g = subgraph_map[v]
u_g = subgraphs[u]
v_g = subgraphs[v]
if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')):
# Connect nodes with no predecessors in v to nodes with
# no successors in u (thus maintaining the edge dependency).
@@ -177,48 +161,57 @@ class _Flattener(object):
if flow.retry is not None:
self._connect_retry(flow.retry, graph)
return graph
return graph, node
def _pre_item_flatten(self, item):
"""Called before a item is flattened; any pre-flattening actions."""
if id(item) in self._history:
raise ValueError("Already flattened item: %s (%s), recursive"
" flattening not supported" % (item, id(item)))
self._history.add(id(item))
if item in self._history:
raise ValueError("Already flattened item '%s' (%s), recursive"
" flattening is not supported" % (item,
type(item)))
self._history.add(item)
def _post_item_flatten(self, item, graph):
"""Called before a item is flattened; any post-flattening actions."""
def _post_item_flatten(self, item, graph, node):
"""Called after a item is flattened; doing post-flattening actions."""
def _pre_flatten(self):
"""Called before the flattening of the item starts."""
"""Called before the flattening of the root starts."""
self._history.clear()
def _post_flatten(self, graph):
"""Called after the flattening of the item finishes successfully."""
def _post_flatten(self, graph, node):
"""Called after the flattening of the root finishes successfully."""
dup_names = misc.get_duplicate_keys(graph.nodes_iter(),
key=lambda node: node.name)
if dup_names:
dup_names = ', '.join(sorted(dup_names))
raise exc.Duplicate("Atoms with duplicate names "
"found: %s" % (dup_names))
raise exc.Duplicate(
"Atoms with duplicate names found: %s" % (sorted(dup_names)))
if graph.number_of_nodes() == 0:
raise exc.Empty("Root container '%s' (%s) is empty"
% (self._root, type(self._root)))
self._history.clear()
# NOTE(harlowja): this one can be expensive to calculate (especially
# the cycle detection), so only do it if we know debugging is enabled
# and not under all cases.
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug("Translated '%s' into a graph:", self._root)
LOG.debug("Translated '%s'", self._root)
LOG.debug("Graph:")
for line in graph.pformat().splitlines():
# Indent it so that it's slightly offset from the above line.
LOG.debug(" %s", line)
LOG.debug(" %s", line)
LOG.debug("Hierarchy:")
for line in node.pformat().splitlines():
# Indent it so that it's slightly offset from the above line.
LOG.debug(" %s", line)
def flatten(self):
"""Flattens a item (a task or flow) into a single execution graph."""
if self._graph is not None:
return self._graph
self._pre_flatten()
graph = self._flatten(self._root)
self._post_flatten(graph)
self._graph = graph
if self._freeze:
self._graph.freeze()
return self._graph
@lock_utils.locked
def compile(self):
"""Compiles the contained item into a compiled equivalent."""
if self._compilation is None:
self._pre_flatten()
graph, node = self._flatten(self._root, None)
self._post_flatten(graph, node)
if self._freeze:
graph.freeze()
node.freeze()
self._compilation = Compilation(graph, node)
return self._compilation

View File

@@ -210,13 +210,13 @@ class ActionEngine(base.EngineBase):
@misc.cachedproperty
def _compiler(self):
return self._compiler_factory()
return self._compiler_factory(self._flow)
@lock_utils.locked
def compile(self):
if self._compiled:
return
self._compilation = self._compiler.compile(self._flow)
self._compilation = self._compiler.compile()
self._runtime = runtime.Runtime(self._compilation,
self.storage,
self.task_notifier,

View File

@@ -27,13 +27,16 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
class RetryAction(object):
def __init__(self, storage, notifier):
def __init__(self, storage, notifier, walker_factory):
self._storage = storage
self._notifier = notifier
self._walker_factory = walker_factory
def _get_retry_args(self, retry):
scope_walker = self._walker_factory(retry)
kwargs = self._storage.fetch_mapped_args(retry.rebind,
atom_name=retry.name)
atom_name=retry.name,
scope_walker=scope_walker)
kwargs['history'] = self._storage.get_retry_history(retry.name)
return kwargs

View File

@@ -18,6 +18,7 @@ from taskflow.engines.action_engine import analyzer as ca
from taskflow.engines.action_engine import executor as ex
from taskflow.engines.action_engine import retry_action as ra
from taskflow.engines.action_engine import runner as ru
from taskflow.engines.action_engine import scopes as sc
from taskflow.engines.action_engine import task_action as ta
from taskflow import exceptions as excp
from taskflow import retry as retry_atom
@@ -66,12 +67,18 @@ class Runtime(object):
@misc.cachedproperty
def retry_action(self):
return ra.RetryAction(self.storage, self._task_notifier)
return ra.RetryAction(self._storage, self._task_notifier,
lambda atom: sc.ScopeWalker(self.compilation,
atom,
names_only=True))
@misc.cachedproperty
def task_action(self):
return ta.TaskAction(self.storage, self._task_executor,
self._task_notifier)
return ta.TaskAction(self._storage, self._task_executor,
self._task_notifier,
lambda atom: sc.ScopeWalker(self.compilation,
atom,
names_only=True))
def reset_nodes(self, nodes, state=st.PENDING, intention=st.EXECUTE):
for node in nodes:
@@ -81,7 +88,7 @@ class Runtime(object):
elif isinstance(node, retry_atom.Retry):
self.retry_action.change_state(node, state)
else:
raise TypeError("Unknown how to reset node %s, %s"
raise TypeError("Unknown how to reset atom '%s' (%s)"
% (node, type(node)))
if intention:
self.storage.set_atom_intention(node.name, intention)
@@ -209,7 +216,7 @@ class Scheduler(object):
elif isinstance(node, retry_atom.Retry):
return self._schedule_retry(node)
else:
raise TypeError("Unknown how to schedule node %s, %s"
raise TypeError("Unknown how to schedule atom '%s' (%s)"
% (node, type(node)))
def _schedule_retry(self, retry):

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from taskflow import atom as atom_type
from taskflow import flow as flow_type
LOG = logging.getLogger(__name__)
def _extract_atoms(node, idx=-1):
# Always go left to right, since right to left is the pattern order
# and we want to go backwards and not forwards through that ordering...
if idx == -1:
children_iter = node.reverse_iter()
else:
children_iter = reversed(node[0:idx])
atoms = []
for child in children_iter:
if isinstance(child.item, flow_type.Flow):
atoms.extend(_extract_atoms(child))
elif isinstance(child.item, atom_type.Atom):
atoms.append(child.item)
else:
raise TypeError(
"Unknown extraction item '%s' (%s)" % (child.item,
type(child.item)))
return atoms
class ScopeWalker(object):
"""Walks through the scopes of a atom using a engines compilation.
This will walk the visible scopes that are accessible for the given
atom, which can be used by some external entity in some meaningful way,
for example to find dependent values...
"""
def __init__(self, compilation, atom, names_only=False):
self._node = compilation.hierarchy.find(atom)
if self._node is None:
raise ValueError("Unable to find atom '%s' in compilation"
" hierarchy" % atom)
self._atom = atom
self._graph = compilation.execution_graph
self._names_only = names_only
def __iter__(self):
"""Iterates over the visible scopes.
How this works is the following:
We find all the possible predecessors of the given atom, this is useful
since we know they occurred before this atom but it doesn't tell us
the corresponding scope *level* that each predecessor was created in,
so we need to find this information.
For that information we consult the location of the atom ``Y`` in the
node hierarchy. We lookup in a reverse order the parent ``X`` of ``Y``
and traverse backwards from the index in the parent where ``Y``
occurred, all children in ``X`` that we encounter in this backwards
search (if a child is a flow itself, its atom contents will be
expanded) will be assumed to be at the same scope. This is then a
*potential* single scope, to make an *actual* scope we remove the items
from the *potential* scope that are not predecessors of ``Y`` to form
the *actual* scope.
Then for additional scopes we continue up the tree, by finding the
parent of ``X`` (lets call it ``Z``) and perform the same operation,
going through the children in a reverse manner from the index in
parent ``Z`` where ``X`` was located. This forms another *potential*
scope which we provide back as an *actual* scope after reducing the
potential set by the predecessors of ``Y``. We then repeat this process
until we no longer have any parent nodes (aka have reached the top of
the tree) or we run out of predecessors.
"""
predecessors = set(self._graph.bfs_predecessors_iter(self._atom))
last = self._node
for parent in self._node.path_iter(include_self=False):
if not predecessors:
break
last_idx = parent.index(last.item)
visible = []
for a in _extract_atoms(parent, idx=last_idx):
if a in predecessors:
predecessors.remove(a)
if not self._names_only:
visible.append(a)
else:
visible.append(a.name)
if LOG.isEnabledFor(logging.DEBUG):
if not self._names_only:
visible_names = [a.name for a in visible]
else:
visible_names = visible
# TODO(harlowja): we should likely use a created TRACE level
# for this kind of *very* verbose information; otherwise the
# cinder and other folks are going to complain that there
# debug logs are full of not so useful information (it is
# useful to taskflow debugging...).
LOG.debug("Scope visible to '%s' (limited by parent '%s' index"
" < %s) is: %s", self._atom, parent.item.name,
last_idx, visible_names)
yield visible
last = parent

View File

@@ -26,10 +26,11 @@ SAVE_RESULT_STATES = (states.SUCCESS, states.FAILURE)
class TaskAction(object):
def __init__(self, storage, task_executor, notifier):
def __init__(self, storage, task_executor, notifier, walker_factory):
self._storage = storage
self._task_executor = task_executor
self._notifier = notifier
self._walker_factory = walker_factory
def _is_identity_transition(self, state, task, progress):
if state in SAVE_RESULT_STATES:
@@ -81,8 +82,10 @@ class TaskAction(object):
def schedule_execution(self, task):
self.change_state(task, states.RUNNING, progress=0.0)
scope_walker = self._walker_factory(task)
kwargs = self._storage.fetch_mapped_args(task.rebind,
atom_name=task.name)
atom_name=task.name,
scope_walker=scope_walker)
task_uuid = self._storage.get_atom_uuid(task.name)
return self._task_executor.execute_task(task, task_uuid, kwargs,
self._on_update_progress)
@@ -96,8 +99,10 @@ class TaskAction(object):
def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0)
scope_walker = self._walker_factory(task)
kwargs = self._storage.fetch_mapped_args(task.rebind,
atom_name=task.name)
atom_name=task.name,
scope_walker=scope_walker)
task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name)
failures = self._storage.get_failures()

View File

@@ -161,7 +161,7 @@ class Flow(flow.Flow):
return self._get_subgraph().number_of_nodes()
def __iter__(self):
for n in self._get_subgraph().nodes_iter():
for n in self._get_subgraph().topological_sort():
yield n
def iter_links(self):

View File

@@ -31,6 +31,78 @@ from taskflow.utils import reflection
LOG = logging.getLogger(__name__)
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
# TODO(harlowja): do this better (via a singleton or something else...)
_TRANSIENT_PROVIDER = object()
# NOTE(harlowja): Perhaps the container is a dictionary-like object and that
# key does not exist (key error), or the container is a tuple/list and a
# non-numeric key is being requested (index error), or there was no container
# and an attempt to index into none/other unsubscriptable type is being
# requested (type error).
#
# Overall this (along with the item_from* functions) try to handle the vast
# majority of wrong indexing operations on the wrong/invalid types so that we
# can fail extraction during lookup or emit warning on result reception...
_EXTRACTION_EXCEPTIONS = (IndexError, KeyError, ValueError, TypeError)
class _Provider(object):
"""A named symbol provider that produces a output at the given index."""
def __init__(self, name, index):
self.name = name
self.index = index
def __repr__(self):
# TODO(harlowja): clean this up...
if self.name is _TRANSIENT_PROVIDER:
base = "<TransientProvider"
else:
base = "<Provider '%s'" % (self.name)
if self.index is None:
base += ">"
else:
base += " @ index %r>" % (self.index)
return base
def __hash__(self):
return hash((self.name, self.index))
def __eq__(self, other):
return (self.name, self.index) == (other.name, other.index)
def _item_from(container, index):
"""Attempts to fetch a index/key from a given container."""
if index is None:
return container
return container[index]
def _item_from_single(provider, container, looking_for):
"""Returns item from a *single* provider."""
try:
return _item_from(container, provider.index)
except _EXTRACTION_EXCEPTIONS:
raise exceptions.NotFound(
"Unable to find result %r, expected to be able to find it"
" created by %s but was unable to perform successful"
" extraction" % (looking_for, provider))
def _item_from_first_of(providers, looking_for):
"""Returns item from the *first* successful container extraction."""
for (provider, container) in providers:
try:
return (provider, _item_from(container, provider.index))
except _EXTRACTION_EXCEPTIONS:
pass
providers = [p[0] for p in providers]
raise exceptions.NotFound(
"Unable to find result %r, expected to be able to find it"
" created by one of %s but was unable to perform successful"
" extraction" % (looking_for, providers))
@six.add_metaclass(abc.ABCMeta)
class Storage(object):
@@ -313,7 +385,7 @@ class Storage(object):
except KeyError:
return None
def _check_all_results_provided(self, atom_name, data):
def _check_all_results_provided(self, atom_name, container):
"""Warn if an atom did not provide some of its expected results.
This may happen if atom returns shorter tuple or list or dict
@@ -325,8 +397,8 @@ class Storage(object):
return
for name, index in six.iteritems(result_mapping):
try:
misc.item_from(data, index, name=name)
except exceptions.NotFound:
_item_from(container, index)
except _EXTRACTION_EXCEPTIONS:
LOG.warning("Atom %s did not supply result "
"with index %r (name %s)", atom_name, index, name)
@@ -464,94 +536,180 @@ class Storage(object):
def save_transient():
self._transients.update(pairs)
# NOTE(harlowja): none is not a valid atom name, so that means
# we can use it internally to reference all of our transient
# variables.
return (None, six.iterkeys(self._transients))
return (_TRANSIENT_PROVIDER, six.iterkeys(self._transients))
with self._lock.write_lock():
if transient:
(atom_name, names) = save_transient()
provider_name, names = save_transient()
else:
(atom_name, names) = save_persistent()
self._set_result_mapping(atom_name,
provider_name, names = save_persistent()
self._set_result_mapping(provider_name,
dict((name, name) for name in names))
def _set_result_mapping(self, atom_name, mapping):
"""Sets the result mapping for an atom.
def _set_result_mapping(self, provider_name, mapping):
"""Sets the result mapping for a given producer.
The result saved with given name would be accessible by names
defined in mapping. Mapping is a dict name => index. If index
is None, the whole result will have this name; else, only
part of it, result[index].
"""
if not mapping:
return
self._result_mappings[atom_name] = mapping
for name, index in six.iteritems(mapping):
entries = self._reverse_mapping.setdefault(name, [])
provider_mapping = self._result_mappings.setdefault(provider_name, {})
if mapping:
provider_mapping.update(mapping)
# Ensure the reverse mapping/index is updated (for faster lookups).
for name, index in six.iteritems(provider_mapping):
entries = self._reverse_mapping.setdefault(name, [])
provider = _Provider(provider_name, index)
if provider not in entries:
entries.append(provider)
# NOTE(imelnikov): We support setting same result mapping for
# the same atom twice (e.g when we are injecting 'a' and then
# injecting 'a' again), so we should not log warning below in
# that case and we should have only one item for each pair
# (atom_name, index) in entries. It should be put to the end of
# entries list because order matters on fetching.
try:
entries.remove((atom_name, index))
except ValueError:
pass
entries.append((atom_name, index))
if len(entries) > 1:
LOG.warning("Multiple provider mappings being created for %r",
name)
def fetch(self, name):
"""Fetch a named atoms result."""
def fetch(self, name, many_handler=None):
"""Fetch a named result."""
# By default we just return the first of many (unless provided
# a different callback that can translate many results into something
# more meaningful).
if many_handler is None:
many_handler = lambda values: values[0]
with self._lock.read_lock():
try:
indexes = self._reverse_mapping[name]
providers = self._reverse_mapping[name]
except KeyError:
raise exceptions.NotFound("Name %r is not mapped" % name)
# Return the first one that is found.
for (atom_name, index) in reversed(indexes):
if not atom_name:
results = self._transients
raise exceptions.NotFound("Name %r is not mapped as a"
" produced output by any"
" providers" % name)
values = []
for provider in providers:
if provider.name is _TRANSIENT_PROVIDER:
values.append(_item_from_single(provider,
self._transients, name))
else:
results = self._get(atom_name, only_last=True)
try:
return misc.item_from(results, index, name)
except exceptions.NotFound:
pass
raise exceptions.NotFound("Unable to find result %r" % name)
try:
container = self._get(provider.name, only_last=True)
except exceptions.NotFound:
pass
else:
values.append(_item_from_single(provider,
container, name))
if not values:
raise exceptions.NotFound("Unable to find result %r,"
" searched %s" % (name, providers))
else:
return many_handler(values)
def fetch_all(self):
"""Fetch all named atom results known so far.
"""Fetch all named results known so far.
Should be used for debugging and testing purposes mostly.
NOTE(harlowja): should be used for debugging and testing purposes.
"""
def many_handler(values):
if len(values) > 1:
return values
return values[0]
with self._lock.read_lock():
results = {}
for name in self._reverse_mapping:
for name in six.iterkeys(self._reverse_mapping):
try:
results[name] = self.fetch(name)
results[name] = self.fetch(name, many_handler=many_handler)
except exceptions.NotFound:
pass
return results
def fetch_mapped_args(self, args_mapping, atom_name=None):
"""Fetch arguments for an atom using an atoms arguments mapping."""
def fetch_mapped_args(self, args_mapping,
atom_name=None, scope_walker=None):
"""Fetch arguments for an atom using an atoms argument mapping."""
def _get_results(looking_for, provider):
"""Gets the results saved for a given provider."""
try:
return self._get(provider.name, only_last=True)
except exceptions.NotFound as e:
raise exceptions.NotFound(
"Expected to be able to find output %r produced"
" by %s but was unable to get at that providers"
" results" % (looking_for, provider), e)
def _locate_providers(looking_for, possible_providers):
"""Finds the accessible providers."""
default_providers = []
for p in possible_providers:
if p.name is _TRANSIENT_PROVIDER:
default_providers.append((p, self._transients))
if p.name == self.injector_name:
default_providers.append((p, _get_results(looking_for, p)))
if default_providers:
return default_providers
if scope_walker is not None:
scope_iter = iter(scope_walker)
else:
scope_iter = iter([])
for atom_names in scope_iter:
if not atom_names:
continue
providers = []
for p in possible_providers:
if p.name in atom_names:
providers.append((p, _get_results(looking_for, p)))
if providers:
return providers
return []
with self._lock.read_lock():
injected_args = {}
if atom_name and atom_name not in self._atom_name_to_uuid:
raise exceptions.NotFound("Unknown atom name: %s" % atom_name)
if not args_mapping:
return {}
# The order of lookup is the following:
#
# 1. Injected atom specific arguments.
# 2. Transient injected arguments.
# 3. Non-transient injected arguments.
# 4. First scope visited group that produces the named result.
# a). The first of that group that actually provided the name
# result is selected (if group size is greater than one).
#
# Otherwise: blowup! (this will also happen if reading or
# extracting an expected result fails, since it is better to fail
# on lookup then provide invalid data from the wrong provider)
if atom_name:
injected_args = self._injected_args.get(atom_name, {})
else:
injected_args = {}
mapped_args = {}
for key, name in six.iteritems(args_mapping):
for (bound_name, name) in six.iteritems(args_mapping):
# TODO(harlowja): This logging information may be to verbose
# even for DEBUG mode, let's see if we can maybe in the future
# add a TRACE mode or something else if people complain...
if LOG.isEnabledFor(logging.DEBUG):
if atom_name:
LOG.debug("Looking for %r <= %r for atom named: %s",
bound_name, name, atom_name)
else:
LOG.debug("Looking for %r <= %r", bound_name, name)
if name in injected_args:
mapped_args[key] = injected_args[name]
value = injected_args[name]
mapped_args[bound_name] = value
LOG.debug("Matched %r <= %r to %r (from injected values)",
bound_name, name, value)
else:
mapped_args[key] = self.fetch(name)
try:
possible_providers = self._reverse_mapping[name]
except KeyError:
raise exceptions.NotFound("Name %r is not mapped as a"
" produced output by any"
" providers" % name)
# Reduce the possible providers to one that are allowed.
providers = _locate_providers(name, possible_providers)
if not providers:
raise exceptions.NotFound(
"Mapped argument %r <= %r was not produced"
" by any accessible provider (%s possible"
" providers were scanned)"
% (bound_name, name, len(possible_providers)))
provider, value = _item_from_first_of(providers, name)
mapped_args[bound_name] = value
LOG.debug("Matched %r <= %r to %r (from %s)",
bound_name, name, value, provider)
return mapped_args
def set_flow_state(self, state):

View File

@@ -27,21 +27,25 @@ from taskflow.tests import utils as test_utils
class PatternCompileTest(test.TestCase):
def test_task(self):
task = test_utils.DummyTask(name='a')
compilation = compiler.PatternCompiler().compile(task)
compilation = compiler.PatternCompiler(task).compile()
g = compilation.execution_graph
self.assertEqual(list(g.nodes()), [task])
self.assertEqual(list(g.edges()), [])
def test_retry(self):
r = retry.AlwaysRevert('r1')
msg_regex = "^Retry controller: .* must only be used .*"
msg_regex = "^Retry controller .* must only be used .*"
self.assertRaisesRegexp(TypeError, msg_regex,
compiler.PatternCompiler().compile, r)
compiler.PatternCompiler(r).compile)
def test_wrong_object(self):
msg_regex = '^Unknown type requested to flatten'
msg_regex = '^Unknown item .* requested to flatten'
self.assertRaisesRegexp(TypeError, msg_regex,
compiler.PatternCompiler().compile, 42)
compiler.PatternCompiler(42).compile)
def test_empty(self):
flo = lf.Flow("test")
self.assertRaises(exc.Empty, compiler.PatternCompiler(flo).compile)
def test_linear(self):
a, b, c, d = test_utils.make_many(4)
@@ -51,7 +55,7 @@ class PatternCompileTest(test.TestCase):
sflo.add(d)
flo.add(sflo)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
@@ -69,13 +73,13 @@ class PatternCompileTest(test.TestCase):
flo.add(a, b, c)
flo.add(flo)
self.assertRaises(ValueError,
compiler.PatternCompiler().compile, flo)
compiler.PatternCompiler(flo).compile)
def test_unordered(self):
a, b, c, d = test_utils.make_many(4)
flo = uf.Flow("test")
flo.add(a, b, c, d)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
self.assertEqual(0, g.number_of_edges())
@@ -92,7 +96,7 @@ class PatternCompileTest(test.TestCase):
flo2.add(c, d)
flo.add(flo2)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
@@ -116,7 +120,7 @@ class PatternCompileTest(test.TestCase):
flo2.add(c, d)
flo.add(flo2)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
for n in [a, b]:
@@ -138,7 +142,7 @@ class PatternCompileTest(test.TestCase):
uf.Flow('ut').add(b, c),
d)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
self.assertItemsEqual(g.edges(), [
@@ -153,7 +157,7 @@ class PatternCompileTest(test.TestCase):
flo = gf.Flow("test")
flo.add(a, b, c, d)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
self.assertEqual(0, g.number_of_edges())
@@ -167,7 +171,7 @@ class PatternCompileTest(test.TestCase):
flo2.add(e, f, g)
flo.add(flo2)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
graph = compilation.execution_graph
self.assertEqual(7, len(graph))
self.assertItemsEqual(graph.edges(data=True), [
@@ -184,7 +188,7 @@ class PatternCompileTest(test.TestCase):
flo2.add(e, f, g)
flo.add(flo2)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(7, len(g))
self.assertEqual(0, g.number_of_edges())
@@ -197,7 +201,7 @@ class PatternCompileTest(test.TestCase):
flo.link(b, c)
flo.link(c, d)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
self.assertItemsEqual(g.edges(data=True), [
@@ -213,7 +217,7 @@ class PatternCompileTest(test.TestCase):
b = test_utils.ProvidesRequiresTask('b', provides=[], requires=['x'])
flo = gf.Flow("test").add(a, b)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(2, len(g))
self.assertItemsEqual(g.edges(data=True), [
@@ -231,7 +235,7 @@ class PatternCompileTest(test.TestCase):
lf.Flow("test2").add(b, c)
)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(3, len(g))
self.assertItemsEqual(g.edges(data=True), [
@@ -250,7 +254,7 @@ class PatternCompileTest(test.TestCase):
lf.Flow("test2").add(b, c)
)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(3, len(g))
self.assertItemsEqual(g.edges(data=True), [
@@ -267,7 +271,7 @@ class PatternCompileTest(test.TestCase):
)
self.assertRaisesRegexp(exc.Duplicate,
'^Atoms with duplicate names',
compiler.PatternCompiler().compile, flo)
compiler.PatternCompiler(flo).compile)
def test_checks_for_dups_globally(self):
flo = gf.Flow("test").add(
@@ -275,25 +279,25 @@ class PatternCompileTest(test.TestCase):
gf.Flow("int2").add(test_utils.DummyTask(name="a")))
self.assertRaisesRegexp(exc.Duplicate,
'^Atoms with duplicate names',
compiler.PatternCompiler().compile, flo)
compiler.PatternCompiler(flo).compile)
def test_retry_in_linear_flow(self):
flo = lf.Flow("test", retry.AlwaysRevert("c"))
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(1, len(g))
self.assertEqual(0, g.number_of_edges())
def test_retry_in_unordered_flow(self):
flo = uf.Flow("test", retry.AlwaysRevert("c"))
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(1, len(g))
self.assertEqual(0, g.number_of_edges())
def test_retry_in_graph_flow(self):
flo = gf.Flow("test", retry.AlwaysRevert("c"))
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(1, len(g))
self.assertEqual(0, g.number_of_edges())
@@ -302,7 +306,7 @@ class PatternCompileTest(test.TestCase):
c1 = retry.AlwaysRevert("c1")
c2 = retry.AlwaysRevert("c2")
flo = lf.Flow("test", c1).add(lf.Flow("test2", c2))
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(2, len(g))
@@ -317,7 +321,7 @@ class PatternCompileTest(test.TestCase):
c = retry.AlwaysRevert("c")
a, b = test_utils.make_many(2)
flo = lf.Flow("test", c).add(a, b)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(3, len(g))
@@ -335,7 +339,7 @@ class PatternCompileTest(test.TestCase):
c = retry.AlwaysRevert("c")
a, b = test_utils.make_many(2)
flo = uf.Flow("test", c).add(a, b)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(3, len(g))
@@ -353,7 +357,7 @@ class PatternCompileTest(test.TestCase):
r = retry.AlwaysRevert("cp")
a, b, c = test_utils.make_many(3)
flo = gf.Flow("test", r).add(a, b, c).link(b, c)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(4, len(g))
@@ -377,7 +381,7 @@ class PatternCompileTest(test.TestCase):
a,
lf.Flow("test", c2).add(b, c),
d)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(6, len(g))
@@ -402,7 +406,7 @@ class PatternCompileTest(test.TestCase):
a,
lf.Flow("test").add(b, c),
d)
compilation = compiler.PatternCompiler().compile(flo)
compilation = compiler.PatternCompiler(flo).compile()
g = compilation.execution_graph
self.assertEqual(5, len(g))

View File

@@ -33,7 +33,7 @@ from taskflow.utils import persistence_utils as pu
class _RunnerTestMixin(object):
def _make_runtime(self, flow, initial_state=None):
compilation = compiler.PatternCompiler().compile(flow)
compilation = compiler.PatternCompiler(flow).compile()
flow_detail = pu.create_flow_detail(flow)
store = storage.SingleThreadedStorage(flow_detail)
# This ensures the tasks exist in storage...

View File

@@ -0,0 +1,248 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from taskflow.engines.action_engine import compiler
from taskflow.engines.action_engine import scopes as sc
from taskflow.patterns import graph_flow as gf
from taskflow.patterns import linear_flow as lf
from taskflow.patterns import unordered_flow as uf
from taskflow import test
from taskflow.tests import utils as test_utils
def _get_scopes(compilation, atom, names_only=True):
walker = sc.ScopeWalker(compilation, atom, names_only=names_only)
return list(iter(walker))
class LinearScopingTest(test.TestCase):
def test_unknown(self):
r = lf.Flow("root")
r_1 = test_utils.TaskOneReturn("root.1")
r.add(r_1)
r_2 = test_utils.TaskOneReturn("root.2")
c = compiler.PatternCompiler(r).compile()
self.assertRaises(ValueError, _get_scopes, c, r_2)
def test_empty(self):
r = lf.Flow("root")
r_1 = test_utils.TaskOneReturn("root.1")
r.add(r_1)
c = compiler.PatternCompiler(r).compile()
self.assertIn(r_1, c.execution_graph)
self.assertIsNotNone(c.hierarchy.find(r_1))
walker = sc.ScopeWalker(c, r_1)
scopes = list(walker)
self.assertEqual([], scopes)
def test_single_prior_linear(self):
r = lf.Flow("root")
r_1 = test_utils.TaskOneReturn("root.1")
r_2 = test_utils.TaskOneReturn("root.2")
r.add(r_1, r_2)
c = compiler.PatternCompiler(r).compile()
for a in r:
self.assertIn(a, c.execution_graph)
self.assertIsNotNone(c.hierarchy.find(a))
self.assertEqual([], _get_scopes(c, r_1))
self.assertEqual([['root.1']], _get_scopes(c, r_2))
def test_nested_prior_linear(self):
r = lf.Flow("root")
r.add(test_utils.TaskOneReturn("root.1"),
test_utils.TaskOneReturn("root.2"))
sub_r = lf.Flow("subroot")
sub_r_1 = test_utils.TaskOneReturn("subroot.1")
sub_r.add(sub_r_1)
r.add(sub_r)
c = compiler.PatternCompiler(r).compile()
self.assertEqual([[], ['root.2', 'root.1']], _get_scopes(c, sub_r_1))
def test_nested_prior_linear_begin_middle_end(self):
r = lf.Flow("root")
begin_r = test_utils.TaskOneReturn("root.1")
r.add(begin_r, test_utils.TaskOneReturn("root.2"))
middle_r = test_utils.TaskOneReturn("root.3")
r.add(middle_r)
sub_r = lf.Flow("subroot")
sub_r.add(test_utils.TaskOneReturn("subroot.1"),
test_utils.TaskOneReturn("subroot.2"))
r.add(sub_r)
end_r = test_utils.TaskOneReturn("root.4")
r.add(end_r)
c = compiler.PatternCompiler(r).compile()
self.assertEqual([], _get_scopes(c, begin_r))
self.assertEqual([['root.2', 'root.1']], _get_scopes(c, middle_r))
self.assertEqual([['subroot.2', 'subroot.1', 'root.3', 'root.2',
'root.1']], _get_scopes(c, end_r))
class GraphScopingTest(test.TestCase):
def test_dependent(self):
r = gf.Flow("root")
customer = test_utils.ProvidesRequiresTask("customer",
provides=['dog'],
requires=[])
washer = test_utils.ProvidesRequiresTask("washer",
requires=['dog'],
provides=['wash'])
dryer = test_utils.ProvidesRequiresTask("dryer",
requires=['dog', 'wash'],
provides=['dry_dog'])
shaved = test_utils.ProvidesRequiresTask("shaver",
requires=['dry_dog'],
provides=['shaved_dog'])
happy_customer = test_utils.ProvidesRequiresTask(
"happy_customer", requires=['shaved_dog'], provides=['happiness'])
r.add(customer, washer, dryer, shaved, happy_customer)
c = compiler.PatternCompiler(r).compile()
self.assertEqual([], _get_scopes(c, customer))
self.assertEqual([['washer', 'customer']], _get_scopes(c, dryer))
self.assertEqual([['shaver', 'dryer', 'washer', 'customer']],
_get_scopes(c, happy_customer))
def test_no_visible(self):
r = gf.Flow("root")
atoms = []
for i in range(0, 10):
atoms.append(test_utils.TaskOneReturn("root.%s" % i))
r.add(*atoms)
c = compiler.PatternCompiler(r).compile()
for a in atoms:
self.assertEqual([], _get_scopes(c, a))
def test_nested(self):
r = gf.Flow("root")
r_1 = test_utils.TaskOneReturn("root.1")
r_2 = test_utils.TaskOneReturn("root.2")
r.add(r_1, r_2)
r.link(r_1, r_2)
subroot = gf.Flow("subroot")
subroot_r_1 = test_utils.TaskOneReturn("subroot.1")
subroot_r_2 = test_utils.TaskOneReturn("subroot.2")
subroot.add(subroot_r_1, subroot_r_2)
subroot.link(subroot_r_1, subroot_r_2)
r.add(subroot)
r_3 = test_utils.TaskOneReturn("root.3")
r.add(r_3)
r.link(r_2, r_3)
c = compiler.PatternCompiler(r).compile()
self.assertEqual([], _get_scopes(c, r_1))
self.assertEqual([['root.1']], _get_scopes(c, r_2))
self.assertEqual([['root.2', 'root.1']], _get_scopes(c, r_3))
self.assertEqual([], _get_scopes(c, subroot_r_1))
self.assertEqual([['subroot.1']], _get_scopes(c, subroot_r_2))
class UnorderedScopingTest(test.TestCase):
def test_no_visible(self):
r = uf.Flow("root")
atoms = []
for i in range(0, 10):
atoms.append(test_utils.TaskOneReturn("root.%s" % i))
r.add(*atoms)
c = compiler.PatternCompiler(r).compile()
for a in atoms:
self.assertEqual([], _get_scopes(c, a))
class MixedPatternScopingTest(test.TestCase):
def test_graph_linear_scope(self):
r = gf.Flow("root")
r_1 = test_utils.TaskOneReturn("root.1")
r_2 = test_utils.TaskOneReturn("root.2")
r.add(r_1, r_2)
r.link(r_1, r_2)
s = lf.Flow("subroot")
s_1 = test_utils.TaskOneReturn("subroot.1")
s_2 = test_utils.TaskOneReturn("subroot.2")
s.add(s_1, s_2)
r.add(s)
t = gf.Flow("subroot2")
t_1 = test_utils.TaskOneReturn("subroot2.1")
t_2 = test_utils.TaskOneReturn("subroot2.2")
t.add(t_1, t_2)
t.link(t_1, t_2)
r.add(t)
r.link(s, t)
c = compiler.PatternCompiler(r).compile()
self.assertEqual([], _get_scopes(c, r_1))
self.assertEqual([['root.1']], _get_scopes(c, r_2))
self.assertEqual([], _get_scopes(c, s_1))
self.assertEqual([['subroot.1']], _get_scopes(c, s_2))
self.assertEqual([[], ['subroot.2', 'subroot.1']],
_get_scopes(c, t_1))
self.assertEqual([["subroot2.1"], ['subroot.2', 'subroot.1']],
_get_scopes(c, t_2))
def test_linear_unordered_scope(self):
r = lf.Flow("root")
r_1 = test_utils.TaskOneReturn("root.1")
r_2 = test_utils.TaskOneReturn("root.2")
r.add(r_1, r_2)
u = uf.Flow("subroot")
atoms = []
for i in range(0, 5):
atoms.append(test_utils.TaskOneReturn("subroot.%s" % i))
u.add(*atoms)
r.add(u)
r_3 = test_utils.TaskOneReturn("root.3")
r.add(r_3)
c = compiler.PatternCompiler(r).compile()
self.assertEqual([], _get_scopes(c, r_1))
self.assertEqual([['root.1']], _get_scopes(c, r_2))
for a in atoms:
self.assertEqual([[], ['root.2', 'root.1']], _get_scopes(c, a))
scope = _get_scopes(c, r_3)
self.assertEqual(1, len(scope))
first_root = 0
for i, n in enumerate(scope[0]):
if n.startswith('root.'):
first_root = i
break
first_subroot = 0
for i, n in enumerate(scope[0]):
if n.startswith('subroot.'):
first_subroot = i
break
self.assertGreater(first_subroot, first_root)
self.assertEqual(scope[0][-2:], ['root.2', 'root.1'])

View File

@@ -454,23 +454,6 @@ class StorageTestMixin(object):
self.assertRaisesRegexp(exceptions.NotFound,
'^Unable to find result', s.fetch, 'b')
@mock.patch.object(storage.LOG, 'warning')
def test_multiple_providers_are_checked(self, mocked_warning):
s = self._get_storage()
s.ensure_task('my task', result_mapping={'result': 'key'})
self.assertEqual(mocked_warning.mock_calls, [])
s.ensure_task('my other task', result_mapping={'result': 'key'})
mocked_warning.assert_called_once_with(
mock.ANY, 'result')
@mock.patch.object(storage.LOG, 'warning')
def test_multiple_providers_with_inject_are_checked(self, mocked_warning):
s = self._get_storage()
s.inject({'result': 'DONE'})
self.assertEqual(mocked_warning.mock_calls, [])
s.ensure_task('my other task', result_mapping={'result': 'key'})
mocked_warning.assert_called_once_with(mock.ANY, 'result')
def test_ensure_retry(self):
s = self._get_storage()
s.ensure_retry('my retry')

View File

@@ -166,6 +166,11 @@ class Node(object):
for c in self._children:
yield c
def reverse_iter(self):
"""Iterates over the direct children of this node (left->right)."""
for c in reversed(self._children):
yield c
def index(self, item):
"""Finds the child index of a given item, searchs in added order."""
index_at = None

View File

@@ -257,24 +257,6 @@ def sequence_minus(seq1, seq2):
return result
def item_from(container, index, name=None):
"""Attempts to fetch a index/key from a given container."""
if index is None:
return container
try:
return container[index]
except (IndexError, KeyError, ValueError, TypeError):
# NOTE(harlowja): Perhaps the container is a dictionary-like object
# and that key does not exist (key error), or the container is a
# tuple/list and a non-numeric key is being requested (index error),
# or there was no container and an attempt to index into none/other
# unsubscriptable type is being requested (type error).
if name is None:
name = index
raise exc.NotFound("Unable to find %r in container %s"
% (name, container))
def get_duplicate_keys(iterable, key=None):
if key is not None:
iterable = six.moves.map(key, iterable)
@@ -399,8 +381,8 @@ def ensure_tree(path):
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST:
except OSError as e:
if e.errno == errno.EEXIST:
if not os.path.isdir(path):
raise
else: