Merge "Add a directed graph type (new types module)"

This commit is contained in:
Jenkins
2014-04-22 01:23:27 +00:00
committed by Gerrit Code Review
9 changed files with 176 additions and 163 deletions

View File

@@ -17,16 +17,8 @@ The following classes and modules are *recommended* for external usage:
.. autoclass:: taskflow.utils.eventlet_utils.GreenExecutor
:members:
.. autofunction:: taskflow.utils.graph_utils.pformat
.. autofunction:: taskflow.utils.graph_utils.export_graph_to_dot
.. autofunction:: taskflow.utils.persistence_utils.temporary_log_book
.. autofunction:: taskflow.utils.persistence_utils.temporary_flow_detail
.. autofunction:: taskflow.utils.persistence_utils.pformat
.. autofunction:: taskflow.utils.persistence_utils.pformat_flow_detail
.. autofunction:: taskflow.utils.persistence_utils.pformat_atom_detail

View File

@@ -16,12 +16,11 @@
import collections
import networkx as nx
from networkx.algorithms import traversal
from taskflow import exceptions as exc
from taskflow import flow
from taskflow.utils import graph_utils
from taskflow.types import graph as gr
class Flow(flow.Flow):
@@ -39,7 +38,8 @@ class Flow(flow.Flow):
def __init__(self, name, retry=None):
super(Flow, self).__init__(name, retry)
self._graph = nx.freeze(nx.DiGraph())
self._graph = gr.DiGraph()
self._graph.freeze()
def link(self, u, v):
"""Link existing node u as a runtime dependency of existing node v."""
@@ -57,7 +57,7 @@ class Flow(flow.Flow):
mutable_graph = False
# NOTE(harlowja): Add an edge to a temporary copy and only if that
# copy is valid then do we swap with the underlying graph.
attrs = graph_utils.get_edge_attrs(graph, u, v)
attrs = graph.get_edge_data(u, v)
if not attrs:
attrs = {}
if manual:
@@ -67,21 +67,22 @@ class Flow(flow.Flow):
attrs['reasons'] = set()
attrs['reasons'].add(reason)
if not mutable_graph:
graph = nx.DiGraph(graph)
graph = gr.DiGraph(graph)
graph.add_edge(u, v, **attrs)
return graph
def _swap(self, replacement_graph):
def _swap(self, graph):
"""Validates the replacement graph and then swaps the underlying graph
with a frozen version of the replacement graph (this maintains the
invariant that the underlying graph is immutable).
"""
if not nx.is_directed_acyclic_graph(replacement_graph):
if not graph.is_directed_acyclic():
raise exc.DependencyFailure("No path through the items in the"
" graph produces an ordering that"
" will allow for correct dependency"
" resolution")
self._graph = nx.freeze(replacement_graph)
self._graph = graph
self._graph.freeze()
def add(self, *items):
"""Adds a given task/tasks/flow/flows to this flow."""
@@ -109,7 +110,7 @@ class Flow(flow.Flow):
# NOTE(harlowja): Add items and edges to a temporary copy of the
# underlying graph and only if that is successful added to do we then
# swap with the underlying graph.
tmp_graph = nx.DiGraph(self._graph)
tmp_graph = gr.DiGraph(self._graph)
for item in items:
tmp_graph.add_node(item)
update_requirements(item)
@@ -237,5 +238,6 @@ class TargetedFlow(Flow):
nodes = [self._target]
nodes.extend(dst for _src, dst in
traversal.dfs_edges(self._graph.reverse(), self._target))
self._subgraph = nx.freeze(self._graph.subgraph(nodes))
self._subgraph = self._graph.subgraph(nodes)
self._subgraph.freeze()
return self._subgraph

View File

@@ -15,7 +15,6 @@
# under the License.
import contextlib
import networkx
import testtools
import threading
@@ -36,6 +35,7 @@ from taskflow import states
from taskflow import task
from taskflow import test
from taskflow.tests import utils
from taskflow.types import graph as gr
from taskflow.utils import eventlet_utils as eu
from taskflow.utils import misc
@@ -466,7 +466,7 @@ class EngineGraphFlowTest(utils.EngineTestBase):
engine = self._make_engine(flow)
engine.compile()
graph = engine.execution_graph
self.assertIsInstance(graph, networkx.DiGraph)
self.assertIsInstance(graph, gr.DiGraph)
def test_task_graph_property_for_one_task(self):
flow = utils.TaskNoRequiresNoReturns(name='task1')
@@ -474,7 +474,7 @@ class EngineGraphFlowTest(utils.EngineTestBase):
engine = self._make_engine(flow)
engine.compile()
graph = engine.execution_graph
self.assertIsInstance(graph, networkx.DiGraph)
self.assertIsInstance(graph, gr.DiGraph)
class EngineCheckingTaskTest(utils.EngineTestBase):

View File

@@ -16,8 +16,6 @@
import string
import networkx as nx
from taskflow import exceptions as exc
from taskflow.patterns import graph_flow as gf
from taskflow.patterns import linear_flow as lf
@@ -27,7 +25,6 @@ from taskflow import retry
from taskflow import test
from taskflow.tests import utils as t_utils
from taskflow.utils import flow_utils as f_utils
from taskflow.utils import graph_utils as g_utils
def _make_many(amount):
@@ -66,13 +63,13 @@ class FlattenTest(test.TestCase):
g = f_utils.flatten(flo)
self.assertEqual(4, len(g))
order = nx.topological_sort(g)
order = g.topological_sort()
self.assertEqual([a, b, c, d], order)
self.assertTrue(g.has_edge(c, d))
self.assertEqual(g.get_edge_data(c, d), {'invariant': True})
self.assertEqual([d], list(g_utils.get_no_successors(g)))
self.assertEqual([a], list(g_utils.get_no_predecessors(g)))
self.assertEqual([d], list(g.no_successors_iter()))
self.assertEqual([a], list(g.no_predecessors_iter()))
def test_invalid_flatten(self):
a, b, c = _make_many(3)
@@ -89,9 +86,9 @@ class FlattenTest(test.TestCase):
self.assertEqual(4, len(g))
self.assertEqual(0, g.number_of_edges())
self.assertEqual(set([a, b, c, d]),
set(g_utils.get_no_successors(g)))
set(g.no_successors_iter()))
self.assertEqual(set([a, b, c, d]),
set(g_utils.get_no_predecessors(g)))
set(g.no_predecessors_iter()))
def test_linear_nested_flatten(self):
a, b, c, d = _make_many(4)
@@ -206,8 +203,8 @@ class FlattenTest(test.TestCase):
(b, c, {'manual': True}),
(c, d, {'manual': True}),
])
self.assertItemsEqual([a], g_utils.get_no_predecessors(g))
self.assertItemsEqual([d], g_utils.get_no_successors(g))
self.assertItemsEqual([a], g.no_predecessors_iter())
self.assertItemsEqual([d], g.no_successors_iter())
def test_graph_flatten_dependencies(self):
a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[])
@@ -219,8 +216,8 @@ class FlattenTest(test.TestCase):
self.assertItemsEqual(g.edges(data=True), [
(a, b, {'reasons': set(['x'])})
])
self.assertItemsEqual([a], g_utils.get_no_predecessors(g))
self.assertItemsEqual([b], g_utils.get_no_successors(g))
self.assertItemsEqual([a], g.no_predecessors_iter())
self.assertItemsEqual([b], g.no_successors_iter())
def test_graph_flatten_nested_requires(self):
a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[])
@@ -237,8 +234,8 @@ class FlattenTest(test.TestCase):
(a, c, {'reasons': set(['x'])}),
(b, c, {'invariant': True})
])
self.assertItemsEqual([a, b], g_utils.get_no_predecessors(g))
self.assertItemsEqual([c], g_utils.get_no_successors(g))
self.assertItemsEqual([a, b], g.no_predecessors_iter())
self.assertItemsEqual([c], g.no_successors_iter())
def test_graph_flatten_nested_provides(self):
a = t_utils.ProvidesRequiresTask('a', provides=[], requires=['x'])
@@ -255,8 +252,8 @@ class FlattenTest(test.TestCase):
(b, c, {'invariant': True}),
(b, a, {'reasons': set(['x'])})
])
self.assertItemsEqual([b], g_utils.get_no_predecessors(g))
self.assertItemsEqual([a, c], g_utils.get_no_successors(g))
self.assertItemsEqual([b], g.no_predecessors_iter())
self.assertItemsEqual([a, c], g.no_successors_iter())
def test_flatten_checks_for_dups(self):
flo = gf.Flow("test").add(
@@ -304,8 +301,8 @@ class FlattenTest(test.TestCase):
(c1, c2, {'retry': True})
])
self.assertIs(c1, g.node[c2]['retry'])
self.assertItemsEqual([c1], g_utils.get_no_predecessors(g))
self.assertItemsEqual([c2], g_utils.get_no_successors(g))
self.assertItemsEqual([c1], g.no_predecessors_iter())
self.assertItemsEqual([c2], g.no_successors_iter())
def test_flatten_retry_in_linear_flow_with_tasks(self):
c = retry.AlwaysRevert("c")
@@ -318,8 +315,8 @@ class FlattenTest(test.TestCase):
(c, a, {'retry': True})
])
self.assertItemsEqual([c], g_utils.get_no_predecessors(g))
self.assertItemsEqual([b], g_utils.get_no_successors(g))
self.assertItemsEqual([c], g.no_predecessors_iter())
self.assertItemsEqual([b], g.no_successors_iter())
self.assertIs(c, g.node[a]['retry'])
self.assertIs(c, g.node[b]['retry'])
@@ -334,8 +331,8 @@ class FlattenTest(test.TestCase):
(c, b, {'retry': True})
])
self.assertItemsEqual([c], g_utils.get_no_predecessors(g))
self.assertItemsEqual([a, b], g_utils.get_no_successors(g))
self.assertItemsEqual([c], g.no_predecessors_iter())
self.assertItemsEqual([a, b], g.no_successors_iter())
self.assertIs(c, g.node[a]['retry'])
self.assertIs(c, g.node[b]['retry'])
@@ -352,8 +349,8 @@ class FlattenTest(test.TestCase):
(b, c, {'manual': True})
])
self.assertItemsEqual([r], g_utils.get_no_predecessors(g))
self.assertItemsEqual([a, c], g_utils.get_no_successors(g))
self.assertItemsEqual([r], g.no_predecessors_iter())
self.assertItemsEqual([a, c], g.no_successors_iter())
self.assertIs(r, g.node[a]['retry'])
self.assertIs(r, g.node[b]['retry'])
self.assertIs(r, g.node[c]['retry'])

View File

122
taskflow/types/graph.py Normal file
View File

@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import networkx as nx
import six
class DiGraph(nx.DiGraph):
"""A directed graph subclass with useful utility functions."""
def __init__(self, data=None, name=''):
super(DiGraph, self).__init__(name=name, data=data)
self.frozen = False
def freeze(self):
"""Freezes the graph so that no more mutations can occur."""
if not self.frozen:
nx.freeze(self)
return self
def get_edge_data(self, u, v, default=None):
"""Returns a *copy* of the attribute dictionary associated with edges
between (u, v).
NOTE(harlowja): this differs from the networkx get_edge_data() as that
function does not return a copy (but returns a reference to the actual
edge data).
"""
try:
return dict(self.adj[u][v])
except KeyError:
return default
def topological_sort(self):
"""Return a list of nodes in this graph in topological sort order."""
return nx.topological_sort(self)
def pformat(self):
"""Pretty formats your graph into a string representation that includes
details about your graph, including; name, type, frozeness, node count,
nodes, edge count, edges, graph density and graph cycles (if any).
"""
lines = []
lines.append("Name: %s" % self.name)
lines.append("Type: %s" % type(self).__name__)
lines.append("Frozen: %s" % nx.is_frozen(self))
lines.append("Nodes: %s" % self.number_of_nodes())
for n in self.nodes_iter():
lines.append(" - %s" % n)
lines.append("Edges: %s" % self.number_of_edges())
for (u, v, e_data) in self.edges_iter(data=True):
if e_data:
lines.append(" %s -> %s (%s)" % (u, v, e_data))
else:
lines.append(" %s -> %s" % (u, v))
lines.append("Density: %0.3f" % nx.density(self))
cycles = list(nx.cycles.recursive_simple_cycles(self))
lines.append("Cycles: %s" % len(cycles))
for cycle in cycles:
buf = six.StringIO()
buf.write("%s" % (cycle[0]))
for i in range(1, len(cycle)):
buf.write(" --> %s" % (cycle[i]))
buf.write(" --> %s" % (cycle[0]))
lines.append(" %s" % buf.getvalue())
return "\n".join(lines)
def export_to_dot(self):
"""Exports the graph to a dot format (requires pydot library)."""
return nx.to_pydot(self).to_string()
def is_directed_acyclic(self):
"""Returns if this graph is a DAG or not."""
return nx.is_directed_acyclic_graph(self)
def no_successors_iter(self):
"""Returns an iterator for all nodes with no successors."""
for n in self.nodes_iter():
if not len(self.successors(n)):
yield n
def no_predecessors_iter(self):
"""Returns an iterator for all nodes with no predecessors."""
for n in self.nodes_iter():
if not len(self.predecessors(n)):
yield n
def merge_graphs(graphs, allow_overlaps=False):
"""Merges a bunch of graphs into a single graph."""
if not graphs:
return None
graph = graphs[0]
for g in graphs[1:]:
# This should ensure that the nodes to be merged do not already exist
# in the graph that is to be merged into. This could be problematic if
# there are duplicates.
if not allow_overlaps:
# Attempt to induce a subgraph using the to be merged graphs nodes
# and see if any graph results.
overlaps = graph.subgraph(g.nodes_iter())
if len(overlaps):
raise ValueError("Can not merge graph %s into %s since there "
"are %s overlapping nodes" (g, graph,
len(overlaps)))
# Keep the target graphs name.
name = graph.name
graph = nx.algorithms.compose(graph, g)
graph.name = name
return graph

View File

@@ -16,13 +16,11 @@
import logging
import networkx as nx
from taskflow import exceptions
from taskflow import flow
from taskflow import retry
from taskflow import task
from taskflow.utils import graph_utils as gu
from taskflow.types import graph as gr
from taskflow.utils import misc
@@ -80,7 +78,7 @@ class Flattener(object):
graph.add_node(retry)
# All graph nodes that have no predecessors should depend on its retry
nodes_to = [n for n in gu.get_no_predecessors(graph) if n != retry]
nodes_to = [n for n in graph.no_predecessors_iter() if n != retry]
self._add_new_edges(graph, [retry], nodes_to, RETRY_EDGE_DATA)
# Add link to retry for each node of subgraph that hasn't
@@ -91,34 +89,37 @@ class Flattener(object):
def _flatten_task(self, task):
"""Flattens a individual task."""
graph = nx.DiGraph(name=task.name)
graph = gr.DiGraph(name=task.name)
graph.add_node(task)
return graph
def _flatten_flow(self, flow):
"""Flattens a graph flow."""
graph = nx.DiGraph(name=flow.name)
graph = gr.DiGraph(name=flow.name)
# Flatten all nodes into a single subgraph per node.
subgraph_map = {}
for item in flow:
subgraph = self._flatten(item)
subgraph_map[item] = subgraph
graph = gu.merge_graphs([graph, subgraph])
graph = gr.merge_graphs([graph, subgraph])
# Reconnect all node edges to their corresponding subgraphs.
for (u, v, attrs) in flow.iter_links():
u_g = subgraph_map[u]
v_g = subgraph_map[v]
if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')):
# Connect nodes with no predecessors in v to nodes with
# no successors in u (thus maintaining the edge dependency).
self._add_new_edges(graph,
gu.get_no_successors(subgraph_map[u]),
gu.get_no_predecessors(subgraph_map[v]),
u_g.no_successors_iter(),
v_g.no_predecessors_iter(),
edge_attrs=attrs)
else:
# This is dependency-only edge, connect corresponding
# providers and consumers.
for provider in subgraph_map[u]:
for consumer in subgraph_map[v]:
for provider in u_g:
for consumer in v_g:
reasons = provider.provides & consumer.requires
if reasons:
graph.add_edge(provider, consumer, reasons=reasons)
@@ -143,7 +144,7 @@ class Flattener(object):
# and not under all cases.
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug("Translated '%s' into a graph:", item)
for line in gu.pformat(graph).splitlines():
for line in graph.pformat().splitlines():
# Indent it so that it's slightly offset from the above line.
LOG.debug(" %s", line)
@@ -168,10 +169,9 @@ class Flattener(object):
self._pre_flatten()
graph = self._flatten(self._root)
self._post_flatten(graph)
self._graph = graph
if self._freeze:
self._graph = nx.freeze(graph)
else:
self._graph = graph
self._graph.freeze()
return self._graph

View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import networkx as nx
import six
def get_edge_attrs(graph, u, v):
"""Gets the dictionary of edge attributes between u->v (or none)."""
if not graph.has_edge(u, v):
return None
return dict(graph.adj[u][v])
def merge_graphs(graphs, allow_overlaps=False):
if not graphs:
return None
graph = graphs[0]
for g in graphs[1:]:
# This should ensure that the nodes to be merged do not already exist
# in the graph that is to be merged into. This could be problematic if
# there are duplicates.
if not allow_overlaps:
# Attempt to induce a subgraph using the to be merged graphs nodes
# and see if any graph results.
overlaps = graph.subgraph(g.nodes_iter())
if len(overlaps):
raise ValueError("Can not merge graph %s into %s since there "
"are %s overlapping nodes" (g, graph,
len(overlaps)))
# Keep the target graphs name.
name = graph.name
graph = nx.algorithms.compose(graph, g)
graph.name = name
return graph
def get_no_successors(graph):
"""Returns an iterator for all nodes with no successors."""
for n in graph.nodes_iter():
if not len(graph.successors(n)):
yield n
def get_no_predecessors(graph):
"""Returns an iterator for all nodes with no predecessors."""
for n in graph.nodes_iter():
if not len(graph.predecessors(n)):
yield n
def pformat(graph):
"""Pretty formats your graph into a string representation that includes
details about your graph, including; name, type, frozeness, node count,
nodes, edge count, edges, graph density and graph cycles (if any).
"""
lines = []
lines.append("Name: %s" % graph.name)
lines.append("Type: %s" % type(graph).__name__)
lines.append("Frozen: %s" % nx.is_frozen(graph))
lines.append("Nodes: %s" % graph.number_of_nodes())
for n in graph.nodes_iter():
lines.append(" - %s" % n)
lines.append("Edges: %s" % graph.number_of_edges())
for (u, v, e_data) in graph.edges_iter(data=True):
if e_data:
lines.append(" %s -> %s (%s)" % (u, v, e_data))
else:
lines.append(" %s -> %s" % (u, v))
lines.append("Density: %0.3f" % nx.density(graph))
cycles = list(nx.cycles.recursive_simple_cycles(graph))
lines.append("Cycles: %s" % len(cycles))
for cycle in cycles:
buf = six.StringIO()
buf.write(str(cycle[0]))
for i in range(1, len(cycle)):
buf.write(" --> %s" % (cycle[i]))
buf.write(" --> %s" % (cycle[0]))
lines.append(" %s" % buf.getvalue())
return "\n".join(lines)
def export_graph_to_dot(graph):
"""Exports the graph to a dot format (requires pydot library)."""
return nx.to_pydot(graph).to_string()

View File

@@ -11,10 +11,8 @@ import optparse
import subprocess
import tempfile
import networkx as nx
from taskflow import states
from taskflow.utils import graph_utils as gu
from taskflow.types import graph as gr
def mini_exec(cmd, ok_codes=(0,)):
@@ -31,7 +29,7 @@ def mini_exec(cmd, ok_codes=(0,)):
def make_svg(graph, output_filename, output_format):
# NOTE(harlowja): requires pydot!
gdot = gu.export_graph_to_dot(graph)
gdot = graph.export_to_dot()
if output_format == 'dot':
output = gdot
elif output_format in ('svg', 'svgz', 'png'):
@@ -62,7 +60,7 @@ def main():
if options.filename is None:
options.filename = 'states.%s' % options.format
g = nx.DiGraph(name="State transitions")
g = gr.DiGraph(name="State transitions")
if not options.tasks:
source = states._ALLOWED_FLOW_TRANSITIONS
else: