diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index a7a205f5..880bd9b2 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -19,11 +19,10 @@ import collections import networkx as nx -from networkx.algorithms import dag -from networkx.classes import digraph from taskflow import exceptions as exc from taskflow import flow +from taskflow.utils import graph_utils class Flow(flow.Flow): @@ -37,13 +36,13 @@ class Flow(flow.Flow): def __init__(self, name, uuid=None): super(Flow, self).__init__(name, uuid) - self._graph = nx.freeze(digraph.DiGraph()) + self._graph = nx.freeze(nx.DiGraph()) def _validate(self, graph=None): if graph is None: graph = self._graph # Ensure that there is a valid topological ordering. - if not dag.is_directed_acyclic_graph(graph): + if not nx.is_directed_acyclic_graph(graph): raise exc.DependencyFailure("No path through the items in the" " graph produces an ordering that" " will allow for correct dependency" @@ -54,15 +53,29 @@ class Flow(flow.Flow): raise ValueError('Item %s not found to link from' % (u)) if not self._graph.has_node(v): raise ValueError('Item %s not found to link to' % (v)) - if self._graph.has_edge(u, v): - return self + self._swap(self._link(u, v, manual=True)) + return self + def _link(self, u, v, graph=None, reason=None, manual=False): + mutable_graph = True + if graph is None: + graph = self._graph + mutable_graph = False # NOTE(harlowja): Add an edge to a temporary copy and only if that # copy is valid then do we swap with the underlying graph. - tmp_graph = digraph.DiGraph(self._graph) - tmp_graph.add_edge(u, v) - self._swap(tmp_graph) - return self + attrs = graph_utils.get_edge_attrs(graph, u, v) + if not attrs: + attrs = {} + if manual: + attrs['manual'] = True + if reason is not None: + if 'reasons' not in attrs: + attrs['reasons'] = set() + attrs['reasons'].add(reason) + if not mutable_graph: + graph = nx.DiGraph(graph) + graph.add_edge(u, v, **attrs) + return graph def _swap(self, replacement_graph): """Validates the replacement graph and then swaps the underlying graph @@ -93,7 +106,7 @@ class Flow(flow.Flow): # NOTE(harlowja): Add items and edges to a temporary copy of the # underlying graph and only if that is successful added to do we then # swap with the underlying graph. - tmp_graph = digraph.DiGraph(self._graph) + tmp_graph = nx.DiGraph(self._graph) for item in items: tmp_graph.add_node(item) update_requirements(item) @@ -110,12 +123,14 @@ class Flow(flow.Flow): for value in item.requires: if value in provided: - tmp_graph.add_edge(provided[value], item) + self._link(provided[value], item, + graph=tmp_graph, reason=value) for value in item.provides: if value in requirements: for node in requirements[value]: - tmp_graph.add_edge(item, node) + self._link(item, node, + graph=tmp_graph, reason=value) self._swap(tmp_graph) return self diff --git a/taskflow/tests/unit/test_graph_flow.py b/taskflow/tests/unit/test_graph_flow.py new file mode 100644 index 00000000..04d3fba2 --- /dev/null +++ b/taskflow/tests/unit/test_graph_flow.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections + +import taskflow.engines + +from taskflow import exceptions as exc +from taskflow.patterns import graph_flow as gw +from taskflow import states +from taskflow.utils import flow_utils as fu +from taskflow.utils import graph_utils as gu + +from taskflow import test +from taskflow.tests import utils + + +class GraphFlowTest(test.TestCase): + def _make_engine(self, flow): + return taskflow.engines.load(flow, store={'context': {}}) + + def _capture_states(self): + # TODO(harlowja): move function to shared helper + capture_where = collections.defaultdict(list) + + def do_capture(state, details): + task_uuid = details.get('task_uuid') + if not task_uuid: + return + capture_where[task_uuid].append(state) + + return (do_capture, capture_where) + + def test_ordering(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=[], + provides=set(['a', 'b'])) + test_2 = utils.ProvidesRequiresTask('test-2', + provides=['c'], + requires=['a', 'b']) + test_3 = utils.ProvidesRequiresTask('test-3', + provides=[], + requires=['c']) + wf.add(test_1, test_2, test_3) + self.assertTrue(wf.graph.has_edge(test_1, test_2)) + self.assertTrue(wf.graph.has_edge(test_2, test_3)) + self.assertEquals(3, len(wf.graph)) + self.assertEquals([test_1], list(gu.get_no_predecessors(wf.graph))) + self.assertEquals([test_3], list(gu.get_no_successors(wf.graph))) + + def test_invalid_add_simple(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=['a'], + provides=set(['a', 'b'])) + self.assertRaises(exc.DependencyFailure, wf.add, test_1) + self.assertEquals(0, len(wf)) + + def test_invalid_add_loop(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=['c'], + provides=set(['a', 'b'])) + test_2 = utils.ProvidesRequiresTask('test-2', + requires=['a', 'b'], + provides=set(['c'])) + wf.add(test_1) + self.assertRaises(exc.DependencyFailure, wf.add, test_2) + self.assertEquals(1, len(wf)) + + def test_basic_edge_reasons(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=[], + provides=set(['a', 'b'])) + test_2 = utils.ProvidesRequiresTask('test-2', + provides=['c'], + requires=['a', 'b']) + wf.add(test_1, test_2) + self.assertTrue(wf.graph.has_edge(test_1, test_2)) + + edge_attrs = gu.get_edge_attrs(wf.graph, test_1, test_2) + self.assertTrue(len(edge_attrs) > 0) + self.assertIn('reasons', edge_attrs) + self.assertEquals(set(['a', 'b']), edge_attrs['reasons']) + + # 2 -> 1 should not be linked, and therefore have no attrs + no_edge_attrs = gu.get_edge_attrs(wf.graph, test_2, test_1) + self.assertFalse(no_edge_attrs) + + def test_linked_edge_reasons(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=[], + provides=[]) + test_2 = utils.ProvidesRequiresTask('test-2', + provides=[], + requires=[]) + wf.add(test_1, test_2) + self.assertFalse(wf.graph.has_edge(test_1, test_2)) + wf.link(test_1, test_2) + self.assertTrue(wf.graph.has_edge(test_1, test_2)) + + edge_attrs = gu.get_edge_attrs(wf.graph, test_1, test_2) + self.assertTrue(len(edge_attrs) > 0) + self.assertTrue(edge_attrs.get('manual')) + + def test_flatten_attribute(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=[], + provides=[]) + test_2 = utils.ProvidesRequiresTask('test-2', + provides=[], + requires=[]) + wf.add(test_1, test_2) + wf.link(test_1, test_2) + g = fu.flatten(wf) + self.assertEquals(2, len(g)) + edge_attrs = gu.get_edge_attrs(g, test_1, test_2) + self.assertTrue(edge_attrs.get('manual')) + self.assertTrue(edge_attrs.get('flatten')) + + def test_graph_run(self): + wf = gw.Flow("the-test-action") + test_1 = utils.ProvidesRequiresTask('test-1', + requires=[], + provides=[]) + test_2 = utils.ProvidesRequiresTask('test-2', + provides=[], + requires=[]) + wf.add(test_1, test_2) + wf.link(test_1, test_2) + self.assertEquals(2, len(wf)) + + e = self._make_engine(wf) + capture_func, captured = self._capture_states() + e.task_notifier.register('*', capture_func) + e.run() + + self.assertEquals(2, len(captured)) + for (_uuid, t_states) in captured.items(): + self.assertEquals([states.RUNNING, states.SUCCESS], t_states) + + run_context = e.storage.fetch('context') + ordering = [o['name'] for o in run_context[utils.ORDER_KEY]] + self.assertEquals(['test-1', 'test-2'], ordering) diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py index 2abdb5a6..688a790b 100644 --- a/taskflow/utils/flow_utils.py +++ b/taskflow/utils/flow_utils.py @@ -16,6 +16,8 @@ # License for the specific language governing permissions and limitations # under the License. +import copy + import networkx as nx from taskflow import exceptions @@ -27,11 +29,11 @@ from taskflow.utils import graph_utils as gu from taskflow.utils import misc -# Use the 'flatten' reason as the need to add an edge here, which is useful for -# doing later analysis of the edges (to determine why the edges were created). -FLATTEN_REASON = 'flatten' +# Use the 'flatten' attribute as the need to add an edge here, which is useful +# for doing later analysis of the edges (to determine why the edges were +# created). FLATTEN_EDGE_DATA = { - 'reason': FLATTEN_REASON, + 'flatten': True, } @@ -50,7 +52,9 @@ def _flatten_linear(flow, flattened): # the ones with no successors and use this list to connect the next # subgraph (if any). for n in gu.get_no_predecessors(subgraph): - graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA) + # NOTE(harlowja): give each edge its own copy so that if its later + # modified that the same copy isn't modified. + graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA.copy()) for n2 in previous_nodes if not graph.has_edge(n2, n))) # There should always be someone without successors, otherwise we have @@ -82,11 +86,19 @@ def _flatten_graph(flow, flattened): graph = gu.merge_graphs([graph, subgraph]) # Reconnect all nodes to there corresponding subgraphs for (u, v) in flow.graph.edges_iter(): + # Retain and update the original edge attributes. + u_v_attrs = gu.get_edge_attrs(flow.graph, u, v) + if not u_v_attrs: + u_v_attrs = FLATTEN_EDGE_DATA.copy() + else: + u_v_attrs.update(FLATTEN_EDGE_DATA) u_no_succ = list(gu.get_no_successors(subgraph_map[u])) # Connect the ones with no predecessors in v to the ones with no # successors in u (thus maintaining the edge dependency). for n in gu.get_no_predecessors(subgraph_map[v]): - graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA) + # NOTE(harlowja): give each edge its own copy so that if its later + # modified that the same copy isn't modified. + graph.add_edges_from(((n2, n, copy.deepcopy(u_v_attrs)) for n2 in u_no_succ if not graph.has_edge(n2, n))) return graph diff --git a/taskflow/utils/graph_utils.py b/taskflow/utils/graph_utils.py index 856e08a7..40c151f5 100644 --- a/taskflow/utils/graph_utils.py +++ b/taskflow/utils/graph_utils.py @@ -22,6 +22,13 @@ import networkx as nx from networkx import algorithms +def get_edge_attrs(graph, u, v): + """Gets the dictionary of edge attributes between u->v (or none).""" + if not graph.has_edge(u, v): + return None + return dict(graph.adj[u][v]) + + def merge_graphs(graphs, allow_overlaps=False): if not graphs: return None