From 5ca61f956e588d8a088d78bc9a58b2904acfb416 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Sat, 12 Apr 2014 21:35:16 -0700 Subject: [PATCH] Add a directed graph type (new types module) Most of the utility graph functions we have can be connected to a directed graph class that itself derives (and adds on to) the networkx base class. Doing this allows for functionality that isn't exposed in networkx to be exposed in our subclass (which is a useful pattern to have). It also makes it possible (if ever needed) to replace the networkx usage in taskflow with something else if this ever becomes a major request. Change-Id: I0a825d5637236d7b5dbdbda0d426adb0183d5ba3 --- doc/source/utils.rst | 8 -- taskflow/patterns/graph_flow.py | 22 ++-- taskflow/tests/unit/test_action_engine.py | 6 +- taskflow/tests/unit/test_flattening.py | 45 ++++---- taskflow/types/__init__.py | 0 taskflow/types/graph.py | 122 ++++++++++++++++++++++ taskflow/utils/flow_utils.py | 30 +++--- taskflow/utils/graph_utils.py | 98 ----------------- tools/state_graph.py | 8 +- 9 files changed, 176 insertions(+), 163 deletions(-) create mode 100644 taskflow/types/__init__.py create mode 100644 taskflow/types/graph.py delete mode 100644 taskflow/utils/graph_utils.py diff --git a/doc/source/utils.rst b/doc/source/utils.rst index b0bd38156..b847e07f8 100644 --- a/doc/source/utils.rst +++ b/doc/source/utils.rst @@ -17,16 +17,8 @@ The following classes and modules are *recommended* for external usage: .. autoclass:: taskflow.utils.eventlet_utils.GreenExecutor :members: -.. autofunction:: taskflow.utils.graph_utils.pformat - -.. autofunction:: taskflow.utils.graph_utils.export_graph_to_dot - .. autofunction:: taskflow.utils.persistence_utils.temporary_log_book .. autofunction:: taskflow.utils.persistence_utils.temporary_flow_detail .. autofunction:: taskflow.utils.persistence_utils.pformat - -.. autofunction:: taskflow.utils.persistence_utils.pformat_flow_detail - -.. autofunction:: taskflow.utils.persistence_utils.pformat_atom_detail diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index ccbcf413e..68691996d 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -16,12 +16,11 @@ import collections -import networkx as nx from networkx.algorithms import traversal from taskflow import exceptions as exc from taskflow import flow -from taskflow.utils import graph_utils +from taskflow.types import graph as gr class Flow(flow.Flow): @@ -39,7 +38,8 @@ class Flow(flow.Flow): def __init__(self, name, retry=None): super(Flow, self).__init__(name, retry) - self._graph = nx.freeze(nx.DiGraph()) + self._graph = gr.DiGraph() + self._graph.freeze() def link(self, u, v): """Link existing node u as a runtime dependency of existing node v.""" @@ -57,7 +57,7 @@ class Flow(flow.Flow): mutable_graph = False # NOTE(harlowja): Add an edge to a temporary copy and only if that # copy is valid then do we swap with the underlying graph. - attrs = graph_utils.get_edge_attrs(graph, u, v) + attrs = graph.get_edge_data(u, v) if not attrs: attrs = {} if manual: @@ -67,21 +67,22 @@ class Flow(flow.Flow): attrs['reasons'] = set() attrs['reasons'].add(reason) if not mutable_graph: - graph = nx.DiGraph(graph) + graph = gr.DiGraph(graph) graph.add_edge(u, v, **attrs) return graph - def _swap(self, replacement_graph): + def _swap(self, graph): """Validates the replacement graph and then swaps the underlying graph with a frozen version of the replacement graph (this maintains the invariant that the underlying graph is immutable). """ - if not nx.is_directed_acyclic_graph(replacement_graph): + if not graph.is_directed_acyclic(): raise exc.DependencyFailure("No path through the items in the" " graph produces an ordering that" " will allow for correct dependency" " resolution") - self._graph = nx.freeze(replacement_graph) + self._graph = graph + self._graph.freeze() def add(self, *items): """Adds a given task/tasks/flow/flows to this flow.""" @@ -109,7 +110,7 @@ class Flow(flow.Flow): # NOTE(harlowja): Add items and edges to a temporary copy of the # underlying graph and only if that is successful added to do we then # swap with the underlying graph. - tmp_graph = nx.DiGraph(self._graph) + tmp_graph = gr.DiGraph(self._graph) for item in items: tmp_graph.add_node(item) update_requirements(item) @@ -237,5 +238,6 @@ class TargetedFlow(Flow): nodes = [self._target] nodes.extend(dst for _src, dst in traversal.dfs_edges(self._graph.reverse(), self._target)) - self._subgraph = nx.freeze(self._graph.subgraph(nodes)) + self._subgraph = self._graph.subgraph(nodes) + self._subgraph.freeze() return self._subgraph diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index d2401f436..d711a1c2f 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -15,7 +15,6 @@ # under the License. import contextlib -import networkx import testtools import threading @@ -36,6 +35,7 @@ from taskflow import states from taskflow import task from taskflow import test from taskflow.tests import utils +from taskflow.types import graph as gr from taskflow.utils import eventlet_utils as eu from taskflow.utils import misc @@ -466,7 +466,7 @@ class EngineGraphFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) engine.compile() graph = engine.execution_graph - self.assertIsInstance(graph, networkx.DiGraph) + self.assertIsInstance(graph, gr.DiGraph) def test_task_graph_property_for_one_task(self): flow = utils.TaskNoRequiresNoReturns(name='task1') @@ -474,7 +474,7 @@ class EngineGraphFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) engine.compile() graph = engine.execution_graph - self.assertIsInstance(graph, networkx.DiGraph) + self.assertIsInstance(graph, gr.DiGraph) class EngineCheckingTaskTest(utils.EngineTestBase): diff --git a/taskflow/tests/unit/test_flattening.py b/taskflow/tests/unit/test_flattening.py index 9d56c1111..600a000a8 100644 --- a/taskflow/tests/unit/test_flattening.py +++ b/taskflow/tests/unit/test_flattening.py @@ -16,8 +16,6 @@ import string -import networkx as nx - from taskflow import exceptions as exc from taskflow.patterns import graph_flow as gf from taskflow.patterns import linear_flow as lf @@ -27,7 +25,6 @@ from taskflow import retry from taskflow import test from taskflow.tests import utils as t_utils from taskflow.utils import flow_utils as f_utils -from taskflow.utils import graph_utils as g_utils def _make_many(amount): @@ -66,13 +63,13 @@ class FlattenTest(test.TestCase): g = f_utils.flatten(flo) self.assertEqual(4, len(g)) - order = nx.topological_sort(g) + order = g.topological_sort() self.assertEqual([a, b, c, d], order) self.assertTrue(g.has_edge(c, d)) self.assertEqual(g.get_edge_data(c, d), {'invariant': True}) - self.assertEqual([d], list(g_utils.get_no_successors(g))) - self.assertEqual([a], list(g_utils.get_no_predecessors(g))) + self.assertEqual([d], list(g.no_successors_iter())) + self.assertEqual([a], list(g.no_predecessors_iter())) def test_invalid_flatten(self): a, b, c = _make_many(3) @@ -89,9 +86,9 @@ class FlattenTest(test.TestCase): self.assertEqual(4, len(g)) self.assertEqual(0, g.number_of_edges()) self.assertEqual(set([a, b, c, d]), - set(g_utils.get_no_successors(g))) + set(g.no_successors_iter())) self.assertEqual(set([a, b, c, d]), - set(g_utils.get_no_predecessors(g))) + set(g.no_predecessors_iter())) def test_linear_nested_flatten(self): a, b, c, d = _make_many(4) @@ -206,8 +203,8 @@ class FlattenTest(test.TestCase): (b, c, {'manual': True}), (c, d, {'manual': True}), ]) - self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([d], g_utils.get_no_successors(g)) + self.assertItemsEqual([a], g.no_predecessors_iter()) + self.assertItemsEqual([d], g.no_successors_iter()) def test_graph_flatten_dependencies(self): a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[]) @@ -219,8 +216,8 @@ class FlattenTest(test.TestCase): self.assertItemsEqual(g.edges(data=True), [ (a, b, {'reasons': set(['x'])}) ]) - self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([b], g_utils.get_no_successors(g)) + self.assertItemsEqual([a], g.no_predecessors_iter()) + self.assertItemsEqual([b], g.no_successors_iter()) def test_graph_flatten_nested_requires(self): a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[]) @@ -237,8 +234,8 @@ class FlattenTest(test.TestCase): (a, c, {'reasons': set(['x'])}), (b, c, {'invariant': True}) ]) - self.assertItemsEqual([a, b], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([c], g_utils.get_no_successors(g)) + self.assertItemsEqual([a, b], g.no_predecessors_iter()) + self.assertItemsEqual([c], g.no_successors_iter()) def test_graph_flatten_nested_provides(self): a = t_utils.ProvidesRequiresTask('a', provides=[], requires=['x']) @@ -255,8 +252,8 @@ class FlattenTest(test.TestCase): (b, c, {'invariant': True}), (b, a, {'reasons': set(['x'])}) ]) - self.assertItemsEqual([b], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a, c], g_utils.get_no_successors(g)) + self.assertItemsEqual([b], g.no_predecessors_iter()) + self.assertItemsEqual([a, c], g.no_successors_iter()) def test_flatten_checks_for_dups(self): flo = gf.Flow("test").add( @@ -304,8 +301,8 @@ class FlattenTest(test.TestCase): (c1, c2, {'retry': True}) ]) self.assertIs(c1, g.node[c2]['retry']) - self.assertItemsEqual([c1], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([c2], g_utils.get_no_successors(g)) + self.assertItemsEqual([c1], g.no_predecessors_iter()) + self.assertItemsEqual([c2], g.no_successors_iter()) def test_flatten_retry_in_linear_flow_with_tasks(self): c = retry.AlwaysRevert("c") @@ -318,8 +315,8 @@ class FlattenTest(test.TestCase): (c, a, {'retry': True}) ]) - self.assertItemsEqual([c], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([b], g_utils.get_no_successors(g)) + self.assertItemsEqual([c], g.no_predecessors_iter()) + self.assertItemsEqual([b], g.no_successors_iter()) self.assertIs(c, g.node[a]['retry']) self.assertIs(c, g.node[b]['retry']) @@ -334,8 +331,8 @@ class FlattenTest(test.TestCase): (c, b, {'retry': True}) ]) - self.assertItemsEqual([c], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a, b], g_utils.get_no_successors(g)) + self.assertItemsEqual([c], g.no_predecessors_iter()) + self.assertItemsEqual([a, b], g.no_successors_iter()) self.assertIs(c, g.node[a]['retry']) self.assertIs(c, g.node[b]['retry']) @@ -352,8 +349,8 @@ class FlattenTest(test.TestCase): (b, c, {'manual': True}) ]) - self.assertItemsEqual([r], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a, c], g_utils.get_no_successors(g)) + self.assertItemsEqual([r], g.no_predecessors_iter()) + self.assertItemsEqual([a, c], g.no_successors_iter()) self.assertIs(r, g.node[a]['retry']) self.assertIs(r, g.node[b]['retry']) self.assertIs(r, g.node[c]['retry']) diff --git a/taskflow/types/__init__.py b/taskflow/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/taskflow/types/graph.py b/taskflow/types/graph.py new file mode 100644 index 000000000..f67591274 --- /dev/null +++ b/taskflow/types/graph.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import networkx as nx +import six + + +class DiGraph(nx.DiGraph): + """A directed graph subclass with useful utility functions.""" + def __init__(self, data=None, name=''): + super(DiGraph, self).__init__(name=name, data=data) + self.frozen = False + + def freeze(self): + """Freezes the graph so that no more mutations can occur.""" + if not self.frozen: + nx.freeze(self) + return self + + def get_edge_data(self, u, v, default=None): + """Returns a *copy* of the attribute dictionary associated with edges + between (u, v). + + NOTE(harlowja): this differs from the networkx get_edge_data() as that + function does not return a copy (but returns a reference to the actual + edge data). + """ + try: + return dict(self.adj[u][v]) + except KeyError: + return default + + def topological_sort(self): + """Return a list of nodes in this graph in topological sort order.""" + return nx.topological_sort(self) + + def pformat(self): + """Pretty formats your graph into a string representation that includes + details about your graph, including; name, type, frozeness, node count, + nodes, edge count, edges, graph density and graph cycles (if any). + """ + lines = [] + lines.append("Name: %s" % self.name) + lines.append("Type: %s" % type(self).__name__) + lines.append("Frozen: %s" % nx.is_frozen(self)) + lines.append("Nodes: %s" % self.number_of_nodes()) + for n in self.nodes_iter(): + lines.append(" - %s" % n) + lines.append("Edges: %s" % self.number_of_edges()) + for (u, v, e_data) in self.edges_iter(data=True): + if e_data: + lines.append(" %s -> %s (%s)" % (u, v, e_data)) + else: + lines.append(" %s -> %s" % (u, v)) + lines.append("Density: %0.3f" % nx.density(self)) + cycles = list(nx.cycles.recursive_simple_cycles(self)) + lines.append("Cycles: %s" % len(cycles)) + for cycle in cycles: + buf = six.StringIO() + buf.write("%s" % (cycle[0])) + for i in range(1, len(cycle)): + buf.write(" --> %s" % (cycle[i])) + buf.write(" --> %s" % (cycle[0])) + lines.append(" %s" % buf.getvalue()) + return "\n".join(lines) + + def export_to_dot(self): + """Exports the graph to a dot format (requires pydot library).""" + return nx.to_pydot(self).to_string() + + def is_directed_acyclic(self): + """Returns if this graph is a DAG or not.""" + return nx.is_directed_acyclic_graph(self) + + def no_successors_iter(self): + """Returns an iterator for all nodes with no successors.""" + for n in self.nodes_iter(): + if not len(self.successors(n)): + yield n + + def no_predecessors_iter(self): + """Returns an iterator for all nodes with no predecessors.""" + for n in self.nodes_iter(): + if not len(self.predecessors(n)): + yield n + + +def merge_graphs(graphs, allow_overlaps=False): + """Merges a bunch of graphs into a single graph.""" + if not graphs: + return None + graph = graphs[0] + for g in graphs[1:]: + # This should ensure that the nodes to be merged do not already exist + # in the graph that is to be merged into. This could be problematic if + # there are duplicates. + if not allow_overlaps: + # Attempt to induce a subgraph using the to be merged graphs nodes + # and see if any graph results. + overlaps = graph.subgraph(g.nodes_iter()) + if len(overlaps): + raise ValueError("Can not merge graph %s into %s since there " + "are %s overlapping nodes" (g, graph, + len(overlaps))) + # Keep the target graphs name. + name = graph.name + graph = nx.algorithms.compose(graph, g) + graph.name = name + return graph diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py index ec365648e..6b54d5639 100644 --- a/taskflow/utils/flow_utils.py +++ b/taskflow/utils/flow_utils.py @@ -16,13 +16,11 @@ import logging -import networkx as nx - from taskflow import exceptions from taskflow import flow from taskflow import retry from taskflow import task -from taskflow.utils import graph_utils as gu +from taskflow.types import graph as gr from taskflow.utils import misc @@ -80,7 +78,7 @@ class Flattener(object): graph.add_node(retry) # All graph nodes that have no predecessors should depend on its retry - nodes_to = [n for n in gu.get_no_predecessors(graph) if n != retry] + nodes_to = [n for n in graph.no_predecessors_iter() if n != retry] self._add_new_edges(graph, [retry], nodes_to, RETRY_EDGE_DATA) # Add link to retry for each node of subgraph that hasn't @@ -91,34 +89,37 @@ class Flattener(object): def _flatten_task(self, task): """Flattens a individual task.""" - graph = nx.DiGraph(name=task.name) + graph = gr.DiGraph(name=task.name) graph.add_node(task) return graph def _flatten_flow(self, flow): """Flattens a graph flow.""" - graph = nx.DiGraph(name=flow.name) + graph = gr.DiGraph(name=flow.name) + # Flatten all nodes into a single subgraph per node. subgraph_map = {} for item in flow: subgraph = self._flatten(item) subgraph_map[item] = subgraph - graph = gu.merge_graphs([graph, subgraph]) + graph = gr.merge_graphs([graph, subgraph]) # Reconnect all node edges to their corresponding subgraphs. for (u, v, attrs) in flow.iter_links(): + u_g = subgraph_map[u] + v_g = subgraph_map[v] if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')): # Connect nodes with no predecessors in v to nodes with # no successors in u (thus maintaining the edge dependency). self._add_new_edges(graph, - gu.get_no_successors(subgraph_map[u]), - gu.get_no_predecessors(subgraph_map[v]), + u_g.no_successors_iter(), + v_g.no_predecessors_iter(), edge_attrs=attrs) else: # This is dependency-only edge, connect corresponding # providers and consumers. - for provider in subgraph_map[u]: - for consumer in subgraph_map[v]: + for provider in u_g: + for consumer in v_g: reasons = provider.provides & consumer.requires if reasons: graph.add_edge(provider, consumer, reasons=reasons) @@ -143,7 +144,7 @@ class Flattener(object): # and not under all cases. if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Translated '%s' into a graph:", item) - for line in gu.pformat(graph).splitlines(): + for line in graph.pformat().splitlines(): # Indent it so that it's slightly offset from the above line. LOG.debug(" %s", line) @@ -168,10 +169,9 @@ class Flattener(object): self._pre_flatten() graph = self._flatten(self._root) self._post_flatten(graph) + self._graph = graph if self._freeze: - self._graph = nx.freeze(graph) - else: - self._graph = graph + self._graph.freeze() return self._graph diff --git a/taskflow/utils/graph_utils.py b/taskflow/utils/graph_utils.py deleted file mode 100644 index 7f18134cf..000000000 --- a/taskflow/utils/graph_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import networkx as nx -import six - - -def get_edge_attrs(graph, u, v): - """Gets the dictionary of edge attributes between u->v (or none).""" - if not graph.has_edge(u, v): - return None - return dict(graph.adj[u][v]) - - -def merge_graphs(graphs, allow_overlaps=False): - if not graphs: - return None - graph = graphs[0] - for g in graphs[1:]: - # This should ensure that the nodes to be merged do not already exist - # in the graph that is to be merged into. This could be problematic if - # there are duplicates. - if not allow_overlaps: - # Attempt to induce a subgraph using the to be merged graphs nodes - # and see if any graph results. - overlaps = graph.subgraph(g.nodes_iter()) - if len(overlaps): - raise ValueError("Can not merge graph %s into %s since there " - "are %s overlapping nodes" (g, graph, - len(overlaps))) - # Keep the target graphs name. - name = graph.name - graph = nx.algorithms.compose(graph, g) - graph.name = name - return graph - - -def get_no_successors(graph): - """Returns an iterator for all nodes with no successors.""" - for n in graph.nodes_iter(): - if not len(graph.successors(n)): - yield n - - -def get_no_predecessors(graph): - """Returns an iterator for all nodes with no predecessors.""" - for n in graph.nodes_iter(): - if not len(graph.predecessors(n)): - yield n - - -def pformat(graph): - """Pretty formats your graph into a string representation that includes - details about your graph, including; name, type, frozeness, node count, - nodes, edge count, edges, graph density and graph cycles (if any). - """ - lines = [] - lines.append("Name: %s" % graph.name) - lines.append("Type: %s" % type(graph).__name__) - lines.append("Frozen: %s" % nx.is_frozen(graph)) - lines.append("Nodes: %s" % graph.number_of_nodes()) - for n in graph.nodes_iter(): - lines.append(" - %s" % n) - lines.append("Edges: %s" % graph.number_of_edges()) - for (u, v, e_data) in graph.edges_iter(data=True): - if e_data: - lines.append(" %s -> %s (%s)" % (u, v, e_data)) - else: - lines.append(" %s -> %s" % (u, v)) - lines.append("Density: %0.3f" % nx.density(graph)) - cycles = list(nx.cycles.recursive_simple_cycles(graph)) - lines.append("Cycles: %s" % len(cycles)) - for cycle in cycles: - buf = six.StringIO() - buf.write(str(cycle[0])) - for i in range(1, len(cycle)): - buf.write(" --> %s" % (cycle[i])) - buf.write(" --> %s" % (cycle[0])) - lines.append(" %s" % buf.getvalue()) - return "\n".join(lines) - - -def export_graph_to_dot(graph): - """Exports the graph to a dot format (requires pydot library).""" - return nx.to_pydot(graph).to_string() diff --git a/tools/state_graph.py b/tools/state_graph.py index 4a2587c5e..f6a2057dd 100644 --- a/tools/state_graph.py +++ b/tools/state_graph.py @@ -11,10 +11,8 @@ import optparse import subprocess import tempfile -import networkx as nx - from taskflow import states -from taskflow.utils import graph_utils as gu +from taskflow.types import graph as gr def mini_exec(cmd, ok_codes=(0,)): @@ -31,7 +29,7 @@ def mini_exec(cmd, ok_codes=(0,)): def make_svg(graph, output_filename, output_format): # NOTE(harlowja): requires pydot! - gdot = gu.export_graph_to_dot(graph) + gdot = graph.export_to_dot() if output_format == 'dot': output = gdot elif output_format in ('svg', 'svgz', 'png'): @@ -62,7 +60,7 @@ def main(): if options.filename is None: options.filename = 'states.%s' % options.format - g = nx.DiGraph(name="State transitions") + g = gr.DiGraph(name="State transitions") if not options.tasks: source = states._ALLOWED_FLOW_TRANSITIONS else: