Add a directed graph type (new types module)

Most of the utility graph functions we have can
be connected to a directed graph class that itself
derives (and adds on to) the networkx base class.

Doing this allows for functionality that isn't exposed
in networkx to be exposed in our subclass (which is a
useful pattern to have).

It also makes it possible (if ever needed) to replace
the networkx usage in taskflow with something else if
this ever becomes a major request.

Change-Id: I0a825d5637236d7b5dbdbda0d426adb0183d5ba3
This commit is contained in:
Joshua Harlow 2014-04-12 21:35:16 -07:00
parent 963330242f
commit 5ca61f956e
9 changed files with 176 additions and 163 deletions

View File

@ -17,16 +17,8 @@ The following classes and modules are *recommended* for external usage:
.. autoclass:: taskflow.utils.eventlet_utils.GreenExecutor
:members:
.. autofunction:: taskflow.utils.graph_utils.pformat
.. autofunction:: taskflow.utils.graph_utils.export_graph_to_dot
.. autofunction:: taskflow.utils.persistence_utils.temporary_log_book
.. autofunction:: taskflow.utils.persistence_utils.temporary_flow_detail
.. autofunction:: taskflow.utils.persistence_utils.pformat
.. autofunction:: taskflow.utils.persistence_utils.pformat_flow_detail
.. autofunction:: taskflow.utils.persistence_utils.pformat_atom_detail

View File

@ -16,12 +16,11 @@
import collections
import networkx as nx
from networkx.algorithms import traversal
from taskflow import exceptions as exc
from taskflow import flow
from taskflow.utils import graph_utils
from taskflow.types import graph as gr
class Flow(flow.Flow):
@ -39,7 +38,8 @@ class Flow(flow.Flow):
def __init__(self, name, retry=None):
super(Flow, self).__init__(name, retry)
self._graph = nx.freeze(nx.DiGraph())
self._graph = gr.DiGraph()
self._graph.freeze()
def link(self, u, v):
"""Link existing node u as a runtime dependency of existing node v."""
@ -57,7 +57,7 @@ class Flow(flow.Flow):
mutable_graph = False
# NOTE(harlowja): Add an edge to a temporary copy and only if that
# copy is valid then do we swap with the underlying graph.
attrs = graph_utils.get_edge_attrs(graph, u, v)
attrs = graph.get_edge_data(u, v)
if not attrs:
attrs = {}
if manual:
@ -67,21 +67,22 @@ class Flow(flow.Flow):
attrs['reasons'] = set()
attrs['reasons'].add(reason)
if not mutable_graph:
graph = nx.DiGraph(graph)
graph = gr.DiGraph(graph)
graph.add_edge(u, v, **attrs)
return graph
def _swap(self, replacement_graph):
def _swap(self, graph):
"""Validates the replacement graph and then swaps the underlying graph
with a frozen version of the replacement graph (this maintains the
invariant that the underlying graph is immutable).
"""
if not nx.is_directed_acyclic_graph(replacement_graph):
if not graph.is_directed_acyclic():
raise exc.DependencyFailure("No path through the items in the"
" graph produces an ordering that"
" will allow for correct dependency"
" resolution")
self._graph = nx.freeze(replacement_graph)
self._graph = graph
self._graph.freeze()
def add(self, *items):
"""Adds a given task/tasks/flow/flows to this flow."""
@ -109,7 +110,7 @@ class Flow(flow.Flow):
# NOTE(harlowja): Add items and edges to a temporary copy of the
# underlying graph and only if that is successful added to do we then
# swap with the underlying graph.
tmp_graph = nx.DiGraph(self._graph)
tmp_graph = gr.DiGraph(self._graph)
for item in items:
tmp_graph.add_node(item)
update_requirements(item)
@ -237,5 +238,6 @@ class TargetedFlow(Flow):
nodes = [self._target]
nodes.extend(dst for _src, dst in
traversal.dfs_edges(self._graph.reverse(), self._target))
self._subgraph = nx.freeze(self._graph.subgraph(nodes))
self._subgraph = self._graph.subgraph(nodes)
self._subgraph.freeze()
return self._subgraph

View File

@ -15,7 +15,6 @@
# under the License.
import contextlib
import networkx
import testtools
import threading
@ -36,6 +35,7 @@ from taskflow import states
from taskflow import task
from taskflow import test
from taskflow.tests import utils
from taskflow.types import graph as gr
from taskflow.utils import eventlet_utils as eu
from taskflow.utils import misc
@ -466,7 +466,7 @@ class EngineGraphFlowTest(utils.EngineTestBase):
engine = self._make_engine(flow)
engine.compile()
graph = engine.execution_graph
self.assertIsInstance(graph, networkx.DiGraph)
self.assertIsInstance(graph, gr.DiGraph)
def test_task_graph_property_for_one_task(self):
flow = utils.TaskNoRequiresNoReturns(name='task1')
@ -474,7 +474,7 @@ class EngineGraphFlowTest(utils.EngineTestBase):
engine = self._make_engine(flow)
engine.compile()
graph = engine.execution_graph
self.assertIsInstance(graph, networkx.DiGraph)
self.assertIsInstance(graph, gr.DiGraph)
class EngineCheckingTaskTest(utils.EngineTestBase):

View File

@ -16,8 +16,6 @@
import string
import networkx as nx
from taskflow import exceptions as exc
from taskflow.patterns import graph_flow as gf
from taskflow.patterns import linear_flow as lf
@ -27,7 +25,6 @@ from taskflow import retry
from taskflow import test
from taskflow.tests import utils as t_utils
from taskflow.utils import flow_utils as f_utils
from taskflow.utils import graph_utils as g_utils
def _make_many(amount):
@ -66,13 +63,13 @@ class FlattenTest(test.TestCase):
g = f_utils.flatten(flo)
self.assertEqual(4, len(g))
order = nx.topological_sort(g)
order = g.topological_sort()
self.assertEqual([a, b, c, d], order)
self.assertTrue(g.has_edge(c, d))
self.assertEqual(g.get_edge_data(c, d), {'invariant': True})
self.assertEqual([d], list(g_utils.get_no_successors(g)))
self.assertEqual([a], list(g_utils.get_no_predecessors(g)))
self.assertEqual([d], list(g.no_successors_iter()))
self.assertEqual([a], list(g.no_predecessors_iter()))
def test_invalid_flatten(self):
a, b, c = _make_many(3)
@ -89,9 +86,9 @@ class FlattenTest(test.TestCase):
self.assertEqual(4, len(g))
self.assertEqual(0, g.number_of_edges())
self.assertEqual(set([a, b, c, d]),
set(g_utils.get_no_successors(g)))
set(g.no_successors_iter()))
self.assertEqual(set([a, b, c, d]),
set(g_utils.get_no_predecessors(g)))
set(g.no_predecessors_iter()))
def test_linear_nested_flatten(self):
a, b, c, d = _make_many(4)
@ -206,8 +203,8 @@ class FlattenTest(test.TestCase):
(b, c, {'manual': True}),
(c, d, {'manual': True}),
])
self.assertItemsEqual([a], g_utils.get_no_predecessors(g))
self.assertItemsEqual([d], g_utils.get_no_successors(g))
self.assertItemsEqual([a], g.no_predecessors_iter())
self.assertItemsEqual([d], g.no_successors_iter())
def test_graph_flatten_dependencies(self):
a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[])
@ -219,8 +216,8 @@ class FlattenTest(test.TestCase):
self.assertItemsEqual(g.edges(data=True), [
(a, b, {'reasons': set(['x'])})
])
self.assertItemsEqual([a], g_utils.get_no_predecessors(g))
self.assertItemsEqual([b], g_utils.get_no_successors(g))
self.assertItemsEqual([a], g.no_predecessors_iter())
self.assertItemsEqual([b], g.no_successors_iter())
def test_graph_flatten_nested_requires(self):
a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[])
@ -237,8 +234,8 @@ class FlattenTest(test.TestCase):
(a, c, {'reasons': set(['x'])}),
(b, c, {'invariant': True})
])
self.assertItemsEqual([a, b], g_utils.get_no_predecessors(g))
self.assertItemsEqual([c], g_utils.get_no_successors(g))
self.assertItemsEqual([a, b], g.no_predecessors_iter())
self.assertItemsEqual([c], g.no_successors_iter())
def test_graph_flatten_nested_provides(self):
a = t_utils.ProvidesRequiresTask('a', provides=[], requires=['x'])
@ -255,8 +252,8 @@ class FlattenTest(test.TestCase):
(b, c, {'invariant': True}),
(b, a, {'reasons': set(['x'])})
])
self.assertItemsEqual([b], g_utils.get_no_predecessors(g))
self.assertItemsEqual([a, c], g_utils.get_no_successors(g))
self.assertItemsEqual([b], g.no_predecessors_iter())
self.assertItemsEqual([a, c], g.no_successors_iter())
def test_flatten_checks_for_dups(self):
flo = gf.Flow("test").add(
@ -304,8 +301,8 @@ class FlattenTest(test.TestCase):
(c1, c2, {'retry': True})
])
self.assertIs(c1, g.node[c2]['retry'])
self.assertItemsEqual([c1], g_utils.get_no_predecessors(g))
self.assertItemsEqual([c2], g_utils.get_no_successors(g))
self.assertItemsEqual([c1], g.no_predecessors_iter())
self.assertItemsEqual([c2], g.no_successors_iter())
def test_flatten_retry_in_linear_flow_with_tasks(self):
c = retry.AlwaysRevert("c")
@ -318,8 +315,8 @@ class FlattenTest(test.TestCase):
(c, a, {'retry': True})
])
self.assertItemsEqual([c], g_utils.get_no_predecessors(g))
self.assertItemsEqual([b], g_utils.get_no_successors(g))
self.assertItemsEqual([c], g.no_predecessors_iter())
self.assertItemsEqual([b], g.no_successors_iter())
self.assertIs(c, g.node[a]['retry'])
self.assertIs(c, g.node[b]['retry'])
@ -334,8 +331,8 @@ class FlattenTest(test.TestCase):
(c, b, {'retry': True})
])
self.assertItemsEqual([c], g_utils.get_no_predecessors(g))
self.assertItemsEqual([a, b], g_utils.get_no_successors(g))
self.assertItemsEqual([c], g.no_predecessors_iter())
self.assertItemsEqual([a, b], g.no_successors_iter())
self.assertIs(c, g.node[a]['retry'])
self.assertIs(c, g.node[b]['retry'])
@ -352,8 +349,8 @@ class FlattenTest(test.TestCase):
(b, c, {'manual': True})
])
self.assertItemsEqual([r], g_utils.get_no_predecessors(g))
self.assertItemsEqual([a, c], g_utils.get_no_successors(g))
self.assertItemsEqual([r], g.no_predecessors_iter())
self.assertItemsEqual([a, c], g.no_successors_iter())
self.assertIs(r, g.node[a]['retry'])
self.assertIs(r, g.node[b]['retry'])
self.assertIs(r, g.node[c]['retry'])

View File

122
taskflow/types/graph.py Normal file
View File

@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import networkx as nx
import six
class DiGraph(nx.DiGraph):
"""A directed graph subclass with useful utility functions."""
def __init__(self, data=None, name=''):
super(DiGraph, self).__init__(name=name, data=data)
self.frozen = False
def freeze(self):
"""Freezes the graph so that no more mutations can occur."""
if not self.frozen:
nx.freeze(self)
return self
def get_edge_data(self, u, v, default=None):
"""Returns a *copy* of the attribute dictionary associated with edges
between (u, v).
NOTE(harlowja): this differs from the networkx get_edge_data() as that
function does not return a copy (but returns a reference to the actual
edge data).
"""
try:
return dict(self.adj[u][v])
except KeyError:
return default
def topological_sort(self):
"""Return a list of nodes in this graph in topological sort order."""
return nx.topological_sort(self)
def pformat(self):
"""Pretty formats your graph into a string representation that includes
details about your graph, including; name, type, frozeness, node count,
nodes, edge count, edges, graph density and graph cycles (if any).
"""
lines = []
lines.append("Name: %s" % self.name)
lines.append("Type: %s" % type(self).__name__)
lines.append("Frozen: %s" % nx.is_frozen(self))
lines.append("Nodes: %s" % self.number_of_nodes())
for n in self.nodes_iter():
lines.append(" - %s" % n)
lines.append("Edges: %s" % self.number_of_edges())
for (u, v, e_data) in self.edges_iter(data=True):
if e_data:
lines.append(" %s -> %s (%s)" % (u, v, e_data))
else:
lines.append(" %s -> %s" % (u, v))
lines.append("Density: %0.3f" % nx.density(self))
cycles = list(nx.cycles.recursive_simple_cycles(self))
lines.append("Cycles: %s" % len(cycles))
for cycle in cycles:
buf = six.StringIO()
buf.write("%s" % (cycle[0]))
for i in range(1, len(cycle)):
buf.write(" --> %s" % (cycle[i]))
buf.write(" --> %s" % (cycle[0]))
lines.append(" %s" % buf.getvalue())
return "\n".join(lines)
def export_to_dot(self):
"""Exports the graph to a dot format (requires pydot library)."""
return nx.to_pydot(self).to_string()
def is_directed_acyclic(self):
"""Returns if this graph is a DAG or not."""
return nx.is_directed_acyclic_graph(self)
def no_successors_iter(self):
"""Returns an iterator for all nodes with no successors."""
for n in self.nodes_iter():
if not len(self.successors(n)):
yield n
def no_predecessors_iter(self):
"""Returns an iterator for all nodes with no predecessors."""
for n in self.nodes_iter():
if not len(self.predecessors(n)):
yield n
def merge_graphs(graphs, allow_overlaps=False):
"""Merges a bunch of graphs into a single graph."""
if not graphs:
return None
graph = graphs[0]
for g in graphs[1:]:
# This should ensure that the nodes to be merged do not already exist
# in the graph that is to be merged into. This could be problematic if
# there are duplicates.
if not allow_overlaps:
# Attempt to induce a subgraph using the to be merged graphs nodes
# and see if any graph results.
overlaps = graph.subgraph(g.nodes_iter())
if len(overlaps):
raise ValueError("Can not merge graph %s into %s since there "
"are %s overlapping nodes" (g, graph,
len(overlaps)))
# Keep the target graphs name.
name = graph.name
graph = nx.algorithms.compose(graph, g)
graph.name = name
return graph

View File

@ -16,13 +16,11 @@
import logging
import networkx as nx
from taskflow import exceptions
from taskflow import flow
from taskflow import retry
from taskflow import task
from taskflow.utils import graph_utils as gu
from taskflow.types import graph as gr
from taskflow.utils import misc
@ -80,7 +78,7 @@ class Flattener(object):
graph.add_node(retry)
# All graph nodes that have no predecessors should depend on its retry
nodes_to = [n for n in gu.get_no_predecessors(graph) if n != retry]
nodes_to = [n for n in graph.no_predecessors_iter() if n != retry]
self._add_new_edges(graph, [retry], nodes_to, RETRY_EDGE_DATA)
# Add link to retry for each node of subgraph that hasn't
@ -91,34 +89,37 @@ class Flattener(object):
def _flatten_task(self, task):
"""Flattens a individual task."""
graph = nx.DiGraph(name=task.name)
graph = gr.DiGraph(name=task.name)
graph.add_node(task)
return graph
def _flatten_flow(self, flow):
"""Flattens a graph flow."""
graph = nx.DiGraph(name=flow.name)
graph = gr.DiGraph(name=flow.name)
# Flatten all nodes into a single subgraph per node.
subgraph_map = {}
for item in flow:
subgraph = self._flatten(item)
subgraph_map[item] = subgraph
graph = gu.merge_graphs([graph, subgraph])
graph = gr.merge_graphs([graph, subgraph])
# Reconnect all node edges to their corresponding subgraphs.
for (u, v, attrs) in flow.iter_links():
u_g = subgraph_map[u]
v_g = subgraph_map[v]
if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')):
# Connect nodes with no predecessors in v to nodes with
# no successors in u (thus maintaining the edge dependency).
self._add_new_edges(graph,
gu.get_no_successors(subgraph_map[u]),
gu.get_no_predecessors(subgraph_map[v]),
u_g.no_successors_iter(),
v_g.no_predecessors_iter(),
edge_attrs=attrs)
else:
# This is dependency-only edge, connect corresponding
# providers and consumers.
for provider in subgraph_map[u]:
for consumer in subgraph_map[v]:
for provider in u_g:
for consumer in v_g:
reasons = provider.provides & consumer.requires
if reasons:
graph.add_edge(provider, consumer, reasons=reasons)
@ -143,7 +144,7 @@ class Flattener(object):
# and not under all cases.
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug("Translated '%s' into a graph:", item)
for line in gu.pformat(graph).splitlines():
for line in graph.pformat().splitlines():
# Indent it so that it's slightly offset from the above line.
LOG.debug(" %s", line)
@ -168,10 +169,9 @@ class Flattener(object):
self._pre_flatten()
graph = self._flatten(self._root)
self._post_flatten(graph)
if self._freeze:
self._graph = nx.freeze(graph)
else:
self._graph = graph
if self._freeze:
self._graph.freeze()
return self._graph

View File

@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import networkx as nx
import six
def get_edge_attrs(graph, u, v):
"""Gets the dictionary of edge attributes between u->v (or none)."""
if not graph.has_edge(u, v):
return None
return dict(graph.adj[u][v])
def merge_graphs(graphs, allow_overlaps=False):
if not graphs:
return None
graph = graphs[0]
for g in graphs[1:]:
# This should ensure that the nodes to be merged do not already exist
# in the graph that is to be merged into. This could be problematic if
# there are duplicates.
if not allow_overlaps:
# Attempt to induce a subgraph using the to be merged graphs nodes
# and see if any graph results.
overlaps = graph.subgraph(g.nodes_iter())
if len(overlaps):
raise ValueError("Can not merge graph %s into %s since there "
"are %s overlapping nodes" (g, graph,
len(overlaps)))
# Keep the target graphs name.
name = graph.name
graph = nx.algorithms.compose(graph, g)
graph.name = name
return graph
def get_no_successors(graph):
"""Returns an iterator for all nodes with no successors."""
for n in graph.nodes_iter():
if not len(graph.successors(n)):
yield n
def get_no_predecessors(graph):
"""Returns an iterator for all nodes with no predecessors."""
for n in graph.nodes_iter():
if not len(graph.predecessors(n)):
yield n
def pformat(graph):
"""Pretty formats your graph into a string representation that includes
details about your graph, including; name, type, frozeness, node count,
nodes, edge count, edges, graph density and graph cycles (if any).
"""
lines = []
lines.append("Name: %s" % graph.name)
lines.append("Type: %s" % type(graph).__name__)
lines.append("Frozen: %s" % nx.is_frozen(graph))
lines.append("Nodes: %s" % graph.number_of_nodes())
for n in graph.nodes_iter():
lines.append(" - %s" % n)
lines.append("Edges: %s" % graph.number_of_edges())
for (u, v, e_data) in graph.edges_iter(data=True):
if e_data:
lines.append(" %s -> %s (%s)" % (u, v, e_data))
else:
lines.append(" %s -> %s" % (u, v))
lines.append("Density: %0.3f" % nx.density(graph))
cycles = list(nx.cycles.recursive_simple_cycles(graph))
lines.append("Cycles: %s" % len(cycles))
for cycle in cycles:
buf = six.StringIO()
buf.write(str(cycle[0]))
for i in range(1, len(cycle)):
buf.write(" --> %s" % (cycle[i]))
buf.write(" --> %s" % (cycle[0]))
lines.append(" %s" % buf.getvalue())
return "\n".join(lines)
def export_graph_to_dot(graph):
"""Exports the graph to a dot format (requires pydot library)."""
return nx.to_pydot(graph).to_string()

View File

@ -11,10 +11,8 @@ import optparse
import subprocess
import tempfile
import networkx as nx
from taskflow import states
from taskflow.utils import graph_utils as gu
from taskflow.types import graph as gr
def mini_exec(cmd, ok_codes=(0,)):
@ -31,7 +29,7 @@ def mini_exec(cmd, ok_codes=(0,)):
def make_svg(graph, output_filename, output_format):
# NOTE(harlowja): requires pydot!
gdot = gu.export_graph_to_dot(graph)
gdot = graph.export_to_dot()
if output_format == 'dot':
output = gdot
elif output_format in ('svg', 'svgz', 'png'):
@ -62,7 +60,7 @@ def main():
if options.filename is None:
options.filename = 'states.%s' % options.format
g = nx.DiGraph(name="State transitions")
g = gr.DiGraph(name="State transitions")
if not options.tasks:
source = states._ALLOWED_FLOW_TRANSITIONS
else: