diff --git a/doc/source/utils.rst b/doc/source/utils.rst index b0bd38156..b847e07f8 100644 --- a/doc/source/utils.rst +++ b/doc/source/utils.rst @@ -17,16 +17,8 @@ The following classes and modules are *recommended* for external usage: .. autoclass:: taskflow.utils.eventlet_utils.GreenExecutor :members: -.. autofunction:: taskflow.utils.graph_utils.pformat - -.. autofunction:: taskflow.utils.graph_utils.export_graph_to_dot - .. autofunction:: taskflow.utils.persistence_utils.temporary_log_book .. autofunction:: taskflow.utils.persistence_utils.temporary_flow_detail .. autofunction:: taskflow.utils.persistence_utils.pformat - -.. autofunction:: taskflow.utils.persistence_utils.pformat_flow_detail - -.. autofunction:: taskflow.utils.persistence_utils.pformat_atom_detail diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index ccbcf413e..68691996d 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -16,12 +16,11 @@ import collections -import networkx as nx from networkx.algorithms import traversal from taskflow import exceptions as exc from taskflow import flow -from taskflow.utils import graph_utils +from taskflow.types import graph as gr class Flow(flow.Flow): @@ -39,7 +38,8 @@ class Flow(flow.Flow): def __init__(self, name, retry=None): super(Flow, self).__init__(name, retry) - self._graph = nx.freeze(nx.DiGraph()) + self._graph = gr.DiGraph() + self._graph.freeze() def link(self, u, v): """Link existing node u as a runtime dependency of existing node v.""" @@ -57,7 +57,7 @@ class Flow(flow.Flow): mutable_graph = False # NOTE(harlowja): Add an edge to a temporary copy and only if that # copy is valid then do we swap with the underlying graph. - attrs = graph_utils.get_edge_attrs(graph, u, v) + attrs = graph.get_edge_data(u, v) if not attrs: attrs = {} if manual: @@ -67,21 +67,22 @@ class Flow(flow.Flow): attrs['reasons'] = set() attrs['reasons'].add(reason) if not mutable_graph: - graph = nx.DiGraph(graph) + graph = gr.DiGraph(graph) graph.add_edge(u, v, **attrs) return graph - def _swap(self, replacement_graph): + def _swap(self, graph): """Validates the replacement graph and then swaps the underlying graph with a frozen version of the replacement graph (this maintains the invariant that the underlying graph is immutable). """ - if not nx.is_directed_acyclic_graph(replacement_graph): + if not graph.is_directed_acyclic(): raise exc.DependencyFailure("No path through the items in the" " graph produces an ordering that" " will allow for correct dependency" " resolution") - self._graph = nx.freeze(replacement_graph) + self._graph = graph + self._graph.freeze() def add(self, *items): """Adds a given task/tasks/flow/flows to this flow.""" @@ -109,7 +110,7 @@ class Flow(flow.Flow): # NOTE(harlowja): Add items and edges to a temporary copy of the # underlying graph and only if that is successful added to do we then # swap with the underlying graph. - tmp_graph = nx.DiGraph(self._graph) + tmp_graph = gr.DiGraph(self._graph) for item in items: tmp_graph.add_node(item) update_requirements(item) @@ -237,5 +238,6 @@ class TargetedFlow(Flow): nodes = [self._target] nodes.extend(dst for _src, dst in traversal.dfs_edges(self._graph.reverse(), self._target)) - self._subgraph = nx.freeze(self._graph.subgraph(nodes)) + self._subgraph = self._graph.subgraph(nodes) + self._subgraph.freeze() return self._subgraph diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index d2401f436..d711a1c2f 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -15,7 +15,6 @@ # under the License. import contextlib -import networkx import testtools import threading @@ -36,6 +35,7 @@ from taskflow import states from taskflow import task from taskflow import test from taskflow.tests import utils +from taskflow.types import graph as gr from taskflow.utils import eventlet_utils as eu from taskflow.utils import misc @@ -466,7 +466,7 @@ class EngineGraphFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) engine.compile() graph = engine.execution_graph - self.assertIsInstance(graph, networkx.DiGraph) + self.assertIsInstance(graph, gr.DiGraph) def test_task_graph_property_for_one_task(self): flow = utils.TaskNoRequiresNoReturns(name='task1') @@ -474,7 +474,7 @@ class EngineGraphFlowTest(utils.EngineTestBase): engine = self._make_engine(flow) engine.compile() graph = engine.execution_graph - self.assertIsInstance(graph, networkx.DiGraph) + self.assertIsInstance(graph, gr.DiGraph) class EngineCheckingTaskTest(utils.EngineTestBase): diff --git a/taskflow/tests/unit/test_flattening.py b/taskflow/tests/unit/test_flattening.py index 9d56c1111..600a000a8 100644 --- a/taskflow/tests/unit/test_flattening.py +++ b/taskflow/tests/unit/test_flattening.py @@ -16,8 +16,6 @@ import string -import networkx as nx - from taskflow import exceptions as exc from taskflow.patterns import graph_flow as gf from taskflow.patterns import linear_flow as lf @@ -27,7 +25,6 @@ from taskflow import retry from taskflow import test from taskflow.tests import utils as t_utils from taskflow.utils import flow_utils as f_utils -from taskflow.utils import graph_utils as g_utils def _make_many(amount): @@ -66,13 +63,13 @@ class FlattenTest(test.TestCase): g = f_utils.flatten(flo) self.assertEqual(4, len(g)) - order = nx.topological_sort(g) + order = g.topological_sort() self.assertEqual([a, b, c, d], order) self.assertTrue(g.has_edge(c, d)) self.assertEqual(g.get_edge_data(c, d), {'invariant': True}) - self.assertEqual([d], list(g_utils.get_no_successors(g))) - self.assertEqual([a], list(g_utils.get_no_predecessors(g))) + self.assertEqual([d], list(g.no_successors_iter())) + self.assertEqual([a], list(g.no_predecessors_iter())) def test_invalid_flatten(self): a, b, c = _make_many(3) @@ -89,9 +86,9 @@ class FlattenTest(test.TestCase): self.assertEqual(4, len(g)) self.assertEqual(0, g.number_of_edges()) self.assertEqual(set([a, b, c, d]), - set(g_utils.get_no_successors(g))) + set(g.no_successors_iter())) self.assertEqual(set([a, b, c, d]), - set(g_utils.get_no_predecessors(g))) + set(g.no_predecessors_iter())) def test_linear_nested_flatten(self): a, b, c, d = _make_many(4) @@ -206,8 +203,8 @@ class FlattenTest(test.TestCase): (b, c, {'manual': True}), (c, d, {'manual': True}), ]) - self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([d], g_utils.get_no_successors(g)) + self.assertItemsEqual([a], g.no_predecessors_iter()) + self.assertItemsEqual([d], g.no_successors_iter()) def test_graph_flatten_dependencies(self): a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[]) @@ -219,8 +216,8 @@ class FlattenTest(test.TestCase): self.assertItemsEqual(g.edges(data=True), [ (a, b, {'reasons': set(['x'])}) ]) - self.assertItemsEqual([a], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([b], g_utils.get_no_successors(g)) + self.assertItemsEqual([a], g.no_predecessors_iter()) + self.assertItemsEqual([b], g.no_successors_iter()) def test_graph_flatten_nested_requires(self): a = t_utils.ProvidesRequiresTask('a', provides=['x'], requires=[]) @@ -237,8 +234,8 @@ class FlattenTest(test.TestCase): (a, c, {'reasons': set(['x'])}), (b, c, {'invariant': True}) ]) - self.assertItemsEqual([a, b], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([c], g_utils.get_no_successors(g)) + self.assertItemsEqual([a, b], g.no_predecessors_iter()) + self.assertItemsEqual([c], g.no_successors_iter()) def test_graph_flatten_nested_provides(self): a = t_utils.ProvidesRequiresTask('a', provides=[], requires=['x']) @@ -255,8 +252,8 @@ class FlattenTest(test.TestCase): (b, c, {'invariant': True}), (b, a, {'reasons': set(['x'])}) ]) - self.assertItemsEqual([b], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a, c], g_utils.get_no_successors(g)) + self.assertItemsEqual([b], g.no_predecessors_iter()) + self.assertItemsEqual([a, c], g.no_successors_iter()) def test_flatten_checks_for_dups(self): flo = gf.Flow("test").add( @@ -304,8 +301,8 @@ class FlattenTest(test.TestCase): (c1, c2, {'retry': True}) ]) self.assertIs(c1, g.node[c2]['retry']) - self.assertItemsEqual([c1], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([c2], g_utils.get_no_successors(g)) + self.assertItemsEqual([c1], g.no_predecessors_iter()) + self.assertItemsEqual([c2], g.no_successors_iter()) def test_flatten_retry_in_linear_flow_with_tasks(self): c = retry.AlwaysRevert("c") @@ -318,8 +315,8 @@ class FlattenTest(test.TestCase): (c, a, {'retry': True}) ]) - self.assertItemsEqual([c], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([b], g_utils.get_no_successors(g)) + self.assertItemsEqual([c], g.no_predecessors_iter()) + self.assertItemsEqual([b], g.no_successors_iter()) self.assertIs(c, g.node[a]['retry']) self.assertIs(c, g.node[b]['retry']) @@ -334,8 +331,8 @@ class FlattenTest(test.TestCase): (c, b, {'retry': True}) ]) - self.assertItemsEqual([c], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a, b], g_utils.get_no_successors(g)) + self.assertItemsEqual([c], g.no_predecessors_iter()) + self.assertItemsEqual([a, b], g.no_successors_iter()) self.assertIs(c, g.node[a]['retry']) self.assertIs(c, g.node[b]['retry']) @@ -352,8 +349,8 @@ class FlattenTest(test.TestCase): (b, c, {'manual': True}) ]) - self.assertItemsEqual([r], g_utils.get_no_predecessors(g)) - self.assertItemsEqual([a, c], g_utils.get_no_successors(g)) + self.assertItemsEqual([r], g.no_predecessors_iter()) + self.assertItemsEqual([a, c], g.no_successors_iter()) self.assertIs(r, g.node[a]['retry']) self.assertIs(r, g.node[b]['retry']) self.assertIs(r, g.node[c]['retry']) diff --git a/taskflow/types/__init__.py b/taskflow/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/taskflow/types/graph.py b/taskflow/types/graph.py new file mode 100644 index 000000000..f67591274 --- /dev/null +++ b/taskflow/types/graph.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import networkx as nx +import six + + +class DiGraph(nx.DiGraph): + """A directed graph subclass with useful utility functions.""" + def __init__(self, data=None, name=''): + super(DiGraph, self).__init__(name=name, data=data) + self.frozen = False + + def freeze(self): + """Freezes the graph so that no more mutations can occur.""" + if not self.frozen: + nx.freeze(self) + return self + + def get_edge_data(self, u, v, default=None): + """Returns a *copy* of the attribute dictionary associated with edges + between (u, v). + + NOTE(harlowja): this differs from the networkx get_edge_data() as that + function does not return a copy (but returns a reference to the actual + edge data). + """ + try: + return dict(self.adj[u][v]) + except KeyError: + return default + + def topological_sort(self): + """Return a list of nodes in this graph in topological sort order.""" + return nx.topological_sort(self) + + def pformat(self): + """Pretty formats your graph into a string representation that includes + details about your graph, including; name, type, frozeness, node count, + nodes, edge count, edges, graph density and graph cycles (if any). + """ + lines = [] + lines.append("Name: %s" % self.name) + lines.append("Type: %s" % type(self).__name__) + lines.append("Frozen: %s" % nx.is_frozen(self)) + lines.append("Nodes: %s" % self.number_of_nodes()) + for n in self.nodes_iter(): + lines.append(" - %s" % n) + lines.append("Edges: %s" % self.number_of_edges()) + for (u, v, e_data) in self.edges_iter(data=True): + if e_data: + lines.append(" %s -> %s (%s)" % (u, v, e_data)) + else: + lines.append(" %s -> %s" % (u, v)) + lines.append("Density: %0.3f" % nx.density(self)) + cycles = list(nx.cycles.recursive_simple_cycles(self)) + lines.append("Cycles: %s" % len(cycles)) + for cycle in cycles: + buf = six.StringIO() + buf.write("%s" % (cycle[0])) + for i in range(1, len(cycle)): + buf.write(" --> %s" % (cycle[i])) + buf.write(" --> %s" % (cycle[0])) + lines.append(" %s" % buf.getvalue()) + return "\n".join(lines) + + def export_to_dot(self): + """Exports the graph to a dot format (requires pydot library).""" + return nx.to_pydot(self).to_string() + + def is_directed_acyclic(self): + """Returns if this graph is a DAG or not.""" + return nx.is_directed_acyclic_graph(self) + + def no_successors_iter(self): + """Returns an iterator for all nodes with no successors.""" + for n in self.nodes_iter(): + if not len(self.successors(n)): + yield n + + def no_predecessors_iter(self): + """Returns an iterator for all nodes with no predecessors.""" + for n in self.nodes_iter(): + if not len(self.predecessors(n)): + yield n + + +def merge_graphs(graphs, allow_overlaps=False): + """Merges a bunch of graphs into a single graph.""" + if not graphs: + return None + graph = graphs[0] + for g in graphs[1:]: + # This should ensure that the nodes to be merged do not already exist + # in the graph that is to be merged into. This could be problematic if + # there are duplicates. + if not allow_overlaps: + # Attempt to induce a subgraph using the to be merged graphs nodes + # and see if any graph results. + overlaps = graph.subgraph(g.nodes_iter()) + if len(overlaps): + raise ValueError("Can not merge graph %s into %s since there " + "are %s overlapping nodes" (g, graph, + len(overlaps))) + # Keep the target graphs name. + name = graph.name + graph = nx.algorithms.compose(graph, g) + graph.name = name + return graph diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py index ec365648e..6b54d5639 100644 --- a/taskflow/utils/flow_utils.py +++ b/taskflow/utils/flow_utils.py @@ -16,13 +16,11 @@ import logging -import networkx as nx - from taskflow import exceptions from taskflow import flow from taskflow import retry from taskflow import task -from taskflow.utils import graph_utils as gu +from taskflow.types import graph as gr from taskflow.utils import misc @@ -80,7 +78,7 @@ class Flattener(object): graph.add_node(retry) # All graph nodes that have no predecessors should depend on its retry - nodes_to = [n for n in gu.get_no_predecessors(graph) if n != retry] + nodes_to = [n for n in graph.no_predecessors_iter() if n != retry] self._add_new_edges(graph, [retry], nodes_to, RETRY_EDGE_DATA) # Add link to retry for each node of subgraph that hasn't @@ -91,34 +89,37 @@ class Flattener(object): def _flatten_task(self, task): """Flattens a individual task.""" - graph = nx.DiGraph(name=task.name) + graph = gr.DiGraph(name=task.name) graph.add_node(task) return graph def _flatten_flow(self, flow): """Flattens a graph flow.""" - graph = nx.DiGraph(name=flow.name) + graph = gr.DiGraph(name=flow.name) + # Flatten all nodes into a single subgraph per node. subgraph_map = {} for item in flow: subgraph = self._flatten(item) subgraph_map[item] = subgraph - graph = gu.merge_graphs([graph, subgraph]) + graph = gr.merge_graphs([graph, subgraph]) # Reconnect all node edges to their corresponding subgraphs. for (u, v, attrs) in flow.iter_links(): + u_g = subgraph_map[u] + v_g = subgraph_map[v] if any(attrs.get(k) for k in ('invariant', 'manual', 'retry')): # Connect nodes with no predecessors in v to nodes with # no successors in u (thus maintaining the edge dependency). self._add_new_edges(graph, - gu.get_no_successors(subgraph_map[u]), - gu.get_no_predecessors(subgraph_map[v]), + u_g.no_successors_iter(), + v_g.no_predecessors_iter(), edge_attrs=attrs) else: # This is dependency-only edge, connect corresponding # providers and consumers. - for provider in subgraph_map[u]: - for consumer in subgraph_map[v]: + for provider in u_g: + for consumer in v_g: reasons = provider.provides & consumer.requires if reasons: graph.add_edge(provider, consumer, reasons=reasons) @@ -143,7 +144,7 @@ class Flattener(object): # and not under all cases. if LOG.isEnabledFor(logging.DEBUG): LOG.debug("Translated '%s' into a graph:", item) - for line in gu.pformat(graph).splitlines(): + for line in graph.pformat().splitlines(): # Indent it so that it's slightly offset from the above line. LOG.debug(" %s", line) @@ -168,10 +169,9 @@ class Flattener(object): self._pre_flatten() graph = self._flatten(self._root) self._post_flatten(graph) + self._graph = graph if self._freeze: - self._graph = nx.freeze(graph) - else: - self._graph = graph + self._graph.freeze() return self._graph diff --git a/taskflow/utils/graph_utils.py b/taskflow/utils/graph_utils.py deleted file mode 100644 index 7f18134cf..000000000 --- a/taskflow/utils/graph_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import networkx as nx -import six - - -def get_edge_attrs(graph, u, v): - """Gets the dictionary of edge attributes between u->v (or none).""" - if not graph.has_edge(u, v): - return None - return dict(graph.adj[u][v]) - - -def merge_graphs(graphs, allow_overlaps=False): - if not graphs: - return None - graph = graphs[0] - for g in graphs[1:]: - # This should ensure that the nodes to be merged do not already exist - # in the graph that is to be merged into. This could be problematic if - # there are duplicates. - if not allow_overlaps: - # Attempt to induce a subgraph using the to be merged graphs nodes - # and see if any graph results. - overlaps = graph.subgraph(g.nodes_iter()) - if len(overlaps): - raise ValueError("Can not merge graph %s into %s since there " - "are %s overlapping nodes" (g, graph, - len(overlaps))) - # Keep the target graphs name. - name = graph.name - graph = nx.algorithms.compose(graph, g) - graph.name = name - return graph - - -def get_no_successors(graph): - """Returns an iterator for all nodes with no successors.""" - for n in graph.nodes_iter(): - if not len(graph.successors(n)): - yield n - - -def get_no_predecessors(graph): - """Returns an iterator for all nodes with no predecessors.""" - for n in graph.nodes_iter(): - if not len(graph.predecessors(n)): - yield n - - -def pformat(graph): - """Pretty formats your graph into a string representation that includes - details about your graph, including; name, type, frozeness, node count, - nodes, edge count, edges, graph density and graph cycles (if any). - """ - lines = [] - lines.append("Name: %s" % graph.name) - lines.append("Type: %s" % type(graph).__name__) - lines.append("Frozen: %s" % nx.is_frozen(graph)) - lines.append("Nodes: %s" % graph.number_of_nodes()) - for n in graph.nodes_iter(): - lines.append(" - %s" % n) - lines.append("Edges: %s" % graph.number_of_edges()) - for (u, v, e_data) in graph.edges_iter(data=True): - if e_data: - lines.append(" %s -> %s (%s)" % (u, v, e_data)) - else: - lines.append(" %s -> %s" % (u, v)) - lines.append("Density: %0.3f" % nx.density(graph)) - cycles = list(nx.cycles.recursive_simple_cycles(graph)) - lines.append("Cycles: %s" % len(cycles)) - for cycle in cycles: - buf = six.StringIO() - buf.write(str(cycle[0])) - for i in range(1, len(cycle)): - buf.write(" --> %s" % (cycle[i])) - buf.write(" --> %s" % (cycle[0])) - lines.append(" %s" % buf.getvalue()) - return "\n".join(lines) - - -def export_graph_to_dot(graph): - """Exports the graph to a dot format (requires pydot library).""" - return nx.to_pydot(graph).to_string() diff --git a/tools/state_graph.py b/tools/state_graph.py index 4a2587c5e..f6a2057dd 100644 --- a/tools/state_graph.py +++ b/tools/state_graph.py @@ -11,10 +11,8 @@ import optparse import subprocess import tempfile -import networkx as nx - from taskflow import states -from taskflow.utils import graph_utils as gu +from taskflow.types import graph as gr def mini_exec(cmd, ok_codes=(0,)): @@ -31,7 +29,7 @@ def mini_exec(cmd, ok_codes=(0,)): def make_svg(graph, output_filename, output_format): # NOTE(harlowja): requires pydot! - gdot = gu.export_graph_to_dot(graph) + gdot = graph.export_to_dot() if output_format == 'dot': output = gdot elif output_format in ('svg', 'svgz', 'png'): @@ -62,7 +60,7 @@ def main(): if options.filename is None: options.filename = 'states.%s' % options.format - g = nx.DiGraph(name="State transitions") + g = gr.DiGraph(name="State transitions") if not options.tasks: source = states._ALLOWED_FLOW_TRANSITIONS else: