From d736bdbfaef6a0b907213baf2d197ef47999d073 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Sun, 15 Sep 2013 22:16:02 -0700 Subject: [PATCH] Add a flow flattening util Instead of recursively executing subflows which causes dead locks when they parent and subflows share the same executor we can instead flatten the parent and subflows into a single graph, composed with only tasks and run this instead, which will not have the issue of subflows dead locking, since after flattening there is no concept of a subflow. Fixes bug: 1225759 Change-Id: I79b9b194cd81e36ce75ba34a673e3e9d3e96c4cd --- taskflow/engines/action_engine/engine.py | 94 ++-------- .../engines/action_engine/parallel_action.py | 54 ------ taskflow/engines/action_engine/seq_action.py | 36 ---- taskflow/test.py | 13 ++ taskflow/tests/unit/test_action_engine.py | 85 ++++++--- taskflow/tests/unit/test_flattening.py | 169 ++++++++++++++++++ taskflow/utils/flow_utils.py | 116 ++++++++++++ taskflow/utils/graph_utils.py | 109 +++++------ 8 files changed, 428 insertions(+), 248 deletions(-) delete mode 100644 taskflow/engines/action_engine/parallel_action.py delete mode 100644 taskflow/engines/action_engine/seq_action.py create mode 100644 taskflow/tests/unit/test_flattening.py create mode 100644 taskflow/utils/flow_utils.py diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index d60c974c..bc0e544c 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -22,22 +22,16 @@ import threading from concurrent import futures from taskflow.engines.action_engine import graph_action -from taskflow.engines.action_engine import parallel_action -from taskflow.engines.action_engine import seq_action from taskflow.engines.action_engine import task_action -from taskflow.patterns import graph_flow as gf -from taskflow.patterns import linear_flow as lf -from taskflow.patterns import unordered_flow as uf - from taskflow.persistence import utils as p_utils from taskflow import decorators from taskflow import exceptions as exc from taskflow import states from taskflow import storage as t_storage -from taskflow import task +from taskflow.utils import flow_utils from taskflow.utils import misc @@ -105,59 +99,21 @@ class ActionEngine(object): result=result) self.task_notifier.notify(state, details) + def _translate_flow_to_action(self): + # Flatten the flow into just 1 graph. + task_graph = flow_utils.flatten(self._flow) + ga = graph_action.SequentialGraphAction(task_graph) + for n in task_graph.nodes_iter(): + ga.add(n, task_action.TaskAction(n, self)) + return ga + @decorators.locked def compile(self): if self._root is None: - translator = self.translator_cls(self) - self._root = translator.translate(self._flow) - - -class Translator(object): - - def __init__(self, engine): - self.engine = engine - - def _factory_map(self): - return [] - - def translate(self, pattern): - """Translates the pattern into an engine runnable action""" - if isinstance(pattern, task.BaseTask): - # Wrap the task into something more useful. - return task_action.TaskAction(pattern, self.engine) - - # Decompose the flow into something more useful: - for cls, factory in self._factory_map(): - if isinstance(pattern, cls): - return factory(pattern) - - raise TypeError('Unknown pattern type: %s (type %s)' - % (pattern, type(pattern))) - - -class SingleThreadedTranslator(Translator): - - def _factory_map(self): - return [(lf.Flow, self._translate_sequential), - (uf.Flow, self._translate_sequential), - (gf.Flow, self._translate_graph)] - - def _translate_sequential(self, pattern): - action = seq_action.SequentialAction() - for p in pattern: - action.add(self.translate(p)) - return action - - def _translate_graph(self, pattern): - action = graph_action.SequentialGraphAction(pattern.graph) - for p in pattern: - action.add(p, self.translate(p)) - return action + self._root = self._translate_flow_to_action() class SingleThreadedActionEngine(ActionEngine): - translator_cls = SingleThreadedTranslator - def __init__(self, flow, flow_detail=None, book=None, backend=None): if flow_detail is None: flow_detail = p_utils.create_flow_detail(flow, @@ -167,37 +123,7 @@ class SingleThreadedActionEngine(ActionEngine): storage=t_storage.Storage(flow_detail, backend)) -class MultiThreadedTranslator(Translator): - - def _factory_map(self): - return [(lf.Flow, self._translate_sequential), - # unordered can be run in parallel - (uf.Flow, self._translate_parallel), - (gf.Flow, self._translate_graph)] - - def _translate_sequential(self, pattern): - action = seq_action.SequentialAction() - for p in pattern: - action.add(self.translate(p)) - return action - - def _translate_parallel(self, pattern): - action = parallel_action.ParallelAction() - for p in pattern: - action.add(self.translate(p)) - return action - - def _translate_graph(self, pattern): - # TODO(akarpinska): replace with parallel graph later - action = graph_action.SequentialGraphAction(pattern.graph) - for p in pattern: - action.add(p, self.translate(p)) - return action - - class MultiThreadedActionEngine(ActionEngine): - translator_cls = MultiThreadedTranslator - def __init__(self, flow, flow_detail=None, book=None, backend=None, executor=None): if flow_detail is None: diff --git a/taskflow/engines/action_engine/parallel_action.py b/taskflow/engines/action_engine/parallel_action.py deleted file mode 100644 index 5c4762d8..00000000 --- a/taskflow/engines/action_engine/parallel_action.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- - -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from taskflow.engines.action_engine import base_action as base -from taskflow.utils import misc - - -class ParallelAction(base.Action): - - def __init__(self): - self._actions = [] - - def add(self, action): - self._actions.append(action) - - def _map(self, engine, fn): - executor = engine.executor - - def call_fn(action): - try: - fn(action) - except Exception: - return misc.Failure() - else: - return None - - failures = [] - result_iter = executor.map(call_fn, self._actions) - for result in result_iter: - if isinstance(result, misc.Failure): - failures.append(result) - if failures: - failures[0].reraise() - - def execute(self, engine): - self._map(engine, lambda action: action.execute(engine)) - - def revert(self, engine): - self._map(engine, lambda action: action.revert(engine)) diff --git a/taskflow/engines/action_engine/seq_action.py b/taskflow/engines/action_engine/seq_action.py deleted file mode 100644 index 782176b1..00000000 --- a/taskflow/engines/action_engine/seq_action.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- - -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from taskflow.engines.action_engine import base_action as base - - -class SequentialAction(base.Action): - - def __init__(self): - self._actions = [] - - def add(self, action): - self._actions.append(action) - - def execute(self, engine): - for action in self._actions: - action.execute(engine) # raises on failure - - def revert(self, engine): - for action in reversed(self._actions): - action.revert(engine) diff --git a/taskflow/test.py b/taskflow/test.py index 2f1ec49b..b1ea8013 100644 --- a/taskflow/test.py +++ b/taskflow/test.py @@ -27,3 +27,16 @@ class TestCase(unittest2.TestCase): def tearDown(self): super(TestCase, self).tearDown() + + def assertIsSubset(self, super_set, sub_set, msg=None): + missing_set = set() + for e in sub_set: + if e not in super_set: + missing_set.add(e) + if len(missing_set): + if msg is not None: + self.fail(msg) + else: + self.fail("Subset %s has %s elements which are not in the " + "superset %s." % (sub_set, list(missing_set), + list(super_set))) diff --git a/taskflow/tests/unit/test_action_engine.py b/taskflow/tests/unit/test_action_engine.py index 9afa2c10..95dee871 100644 --- a/taskflow/tests/unit/test_action_engine.py +++ b/taskflow/tests/unit/test_action_engine.py @@ -594,12 +594,14 @@ class MultiThreadedEngineTest(EngineTaskTest, with self.assertRaisesRegexp(RuntimeError, '^Woot'): engine.run() result = set(self.values) - self.assertEquals(result, - set(['task1', 'task2', - 'task2 reverted(5)', 'task1 reverted(5)'])) + # NOTE(harlowja): task 1/2 may or may not have executed, even with the + # sleeps due to the fact that the above is an unordered flow. + possible_result = set(['task1', 'task2', + 'task2 reverted(5)', 'task1 reverted(5)']) + self.assertIsSubset(possible_result, result) def test_parallel_revert_exception_is_reraised_(self): - flow = uf.Flow('p-r-reraise').add( + flow = lf.Flow('p-r-reraise').add( TestTask(self.values, name='task1', sleep=0.01), NastyTask(), FailingTask(sleep=0.01), @@ -609,13 +611,13 @@ class MultiThreadedEngineTest(EngineTaskTest, with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): engine.run() result = set(self.values) - self.assertEquals(result, set(['task1', 'task1 reverted(5)'])) + self.assertEquals(result, set(['task1'])) def test_nested_parallel_revert_exception_is_reraised(self): flow = uf.Flow('p-root').add( TestTask(self.values, name='task1'), TestTask(self.values, name='task2'), - uf.Flow('p-inner').add( + lf.Flow('p-inner').add( TestTask(self.values, name='task3', sleep=0.1), NastyTask(), FailingTask(sleep=0.01) @@ -625,9 +627,13 @@ class MultiThreadedEngineTest(EngineTaskTest, with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): engine.run() result = set(self.values) - self.assertEquals(result, set(['task1', 'task1 reverted(5)', - 'task2', 'task2 reverted(5)', - 'task3', 'task3 reverted(5)'])) + # Task1, task2 may *not* have executed and also may have *not* reverted + # since the above is an unordered flow so take that into account by + # ensuring that the superset is matched. + possible_result = set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3', 'task3 reverted(5)']) + self.assertIsSubset(possible_result, result) def test_parallel_revert_exception_do_not_revert_linear_tasks(self): flow = lf.Flow('l-root').add( @@ -640,11 +646,35 @@ class MultiThreadedEngineTest(EngineTaskTest, ) ) engine = self._make_engine(flow) - with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): + # Depending on when (and if failing task) is executed the exception + # raised could be either woot or gotcha since the above unordered + # sub-flow does not guarantee that the ordering will be maintained, + # even with sleeping. + was_nasty = False + try: engine.run() + self.assertTrue(False) + except RuntimeError as e: + self.assertRegexpMatches(str(e), '^Gotcha|^Woot') + if 'Gotcha!' in str(e): + was_nasty = True result = set(self.values) - self.assertEquals(result, set(['task1', 'task2', - 'task3', 'task3 reverted(5)'])) + possible_result = set(['task1', 'task2', + 'task3', 'task3 reverted(5)']) + if not was_nasty: + possible_result.update(['task1 reverted(5)', 'task2 reverted(5)']) + self.assertIsSubset(possible_result, result) + # If the nasty task killed reverting, then task1 and task2 should not + # have reverted, but if the failing task stopped execution then task1 + # and task2 should have reverted. + if was_nasty: + must_not_have = ['task1 reverted(5)', 'task2 reverted(5)'] + for r in must_not_have: + self.assertNotIn(r, result) + else: + must_have = ['task1 reverted(5)', 'task2 reverted(5)'] + for r in must_have: + self.assertIn(r, result) def test_parallel_nested_to_linear_revert(self): flow = lf.Flow('l-root').add( @@ -659,9 +689,18 @@ class MultiThreadedEngineTest(EngineTaskTest, with self.assertRaisesRegexp(RuntimeError, '^Woot'): engine.run() result = set(self.values) - self.assertEquals(result, set(['task1', 'task1 reverted(5)', - 'task2', 'task2 reverted(5)', - 'task3', 'task3 reverted(5)'])) + # Task3 may or may not have executed, depending on scheduling and + # task ordering selection, so it may or may not exist in the result set + possible_result = set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3', 'task3 reverted(5)']) + self.assertIsSubset(possible_result, result) + # These must exist, since the linearity of the linear flow ensures + # that they were executed first. + must_have = ['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)'] + for r in must_have: + self.assertIn(r, result) def test_linear_nested_to_parallel_revert(self): flow = uf.Flow('p-root').add( @@ -676,11 +715,14 @@ class MultiThreadedEngineTest(EngineTaskTest, with self.assertRaisesRegexp(RuntimeError, '^Woot'): engine.run() result = set(self.values) - self.assertEquals(result, - set(['task1', 'task1 reverted(5)', + # Since this is an unordered flow we can not guarantee that task1 or + # task2 will exist and be reverted, although they may exist depending + # on how the OS thread scheduling and execution graph algorithm... + possible_result = set(['task1', 'task1 reverted(5)', 'task2', 'task2 reverted(5)', 'task3', 'task3 reverted(5)', - 'fail reverted(Failure: RuntimeError: Woot!)'])) + 'fail reverted(Failure: RuntimeError: Woot!)']) + self.assertIsSubset(possible_result, result) def test_linear_nested_to_parallel_revert_exception(self): flow = uf.Flow('p-root').add( @@ -696,6 +738,7 @@ class MultiThreadedEngineTest(EngineTaskTest, with self.assertRaisesRegexp(RuntimeError, '^Gotcha'): engine.run() result = set(self.values) - self.assertEquals(result, set(['task1', 'task1 reverted(5)', - 'task2', 'task2 reverted(5)', - 'task3'])) + possible_result = set(['task1', 'task1 reverted(5)', + 'task2', 'task2 reverted(5)', + 'task3']) + self.assertIsSubset(possible_result, result) diff --git a/taskflow/tests/unit/test_flattening.py b/taskflow/tests/unit/test_flattening.py new file mode 100644 index 00000000..c7ed613b --- /dev/null +++ b/taskflow/tests/unit/test_flattening.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import string + +import networkx as nx + +from taskflow.patterns import graph_flow as gf +from taskflow.patterns import linear_flow as lf +from taskflow.patterns import unordered_flow as uf + +from taskflow import test +from taskflow.tests import utils as t_utils +from taskflow.utils import flow_utils as f_utils +from taskflow.utils import graph_utils as g_utils + + +def _make_many(amount): + assert amount <= len(string.ascii_lowercase), 'Not enough letters' + tasks = [] + for i in range(0, amount): + tasks.append(t_utils.DummyTask(name=string.ascii_lowercase[i])) + return tasks + + +class FlattenTest(test.TestCase): + def test_linear_flatten(self): + a, b, c, d = _make_many(4) + flo = lf.Flow("test") + flo.add(a, b, c) + sflo = lf.Flow("sub-test") + sflo.add(d) + flo.add(sflo) + + g = f_utils.flatten(flo) + self.assertEquals(4, len(g)) + + order = nx.topological_sort(g) + self.assertEquals([a, b, c, d], order) + self.assertTrue(g.has_edge(c, d)) + self.assertEquals([d], list(g_utils.get_no_successors(g))) + self.assertEquals([a], list(g_utils.get_no_predecessors(g))) + + def test_invalid_flatten(self): + a, b, c, d = _make_many(4) + flo = lf.Flow("test") + flo.add(a, b, c) + flo.add(flo) + self.assertRaises(ValueError, f_utils.flatten, flo) + + def test_unordered_flatten(self): + a, b, c, d = _make_many(4) + flo = uf.Flow("test") + flo.add(a, b, c, d) + g = f_utils.flatten(flo) + self.assertEquals(4, len(g)) + self.assertEquals(0, g.number_of_edges()) + self.assertEquals(set([a, b, c, d]), + set(g_utils.get_no_successors(g))) + self.assertEquals(set([a, b, c, d]), + set(g_utils.get_no_predecessors(g))) + + def test_linear_nested_flatten(self): + a, b, c, d = _make_many(4) + flo = lf.Flow("test") + flo.add(a, b) + flo2 = uf.Flow("test2") + flo2.add(c, d) + flo.add(flo2) + g = f_utils.flatten(flo) + self.assertEquals(4, len(g)) + + lb = g.subgraph([a, b]) + self.assertTrue(lb.has_edge(a, b)) + self.assertFalse(lb.has_edge(b, a)) + + ub = g.subgraph([c, d]) + self.assertEquals(0, ub.number_of_edges()) + + # This ensures that c and d do not start executing until after b. + self.assertTrue(g.has_edge(b, c)) + self.assertTrue(g.has_edge(b, d)) + + def test_unordered_nested_flatten(self): + a, b, c, d = _make_many(4) + flo = uf.Flow("test") + flo.add(a, b) + flo2 = lf.Flow("test2") + flo2.add(c, d) + flo.add(flo2) + + g = f_utils.flatten(flo) + self.assertEquals(4, len(g)) + for n in [a, b]: + self.assertFalse(g.has_edge(n, c)) + self.assertFalse(g.has_edge(n, d)) + self.assertTrue(g.has_edge(c, d)) + self.assertFalse(g.has_edge(d, c)) + + ub = g.subgraph([a, b]) + self.assertEquals(0, ub.number_of_edges()) + lb = g.subgraph([c, d]) + self.assertEquals(1, lb.number_of_edges()) + + def test_graph_flatten(self): + a, b, c, d = _make_many(4) + flo = gf.Flow("test") + flo.add(a, b, c, d) + + g = f_utils.flatten(flo) + self.assertEquals(4, len(g)) + self.assertEquals(0, g.number_of_edges()) + + def test_graph_flatten_nested(self): + a, b, c, d, e, f, g = _make_many(7) + flo = gf.Flow("test") + flo.add(a, b, c, d) + + flo2 = lf.Flow('test2') + flo2.add(e, f, g) + flo.add(flo2) + + g = f_utils.flatten(flo) + self.assertEquals(7, len(g)) + self.assertEquals(2, g.number_of_edges()) + + def test_graph_flatten_nested_graph(self): + a, b, c, d, e, f, g = _make_many(7) + flo = gf.Flow("test") + flo.add(a, b, c, d) + + flo2 = gf.Flow('test2') + flo2.add(e, f, g) + flo.add(flo2) + + g = f_utils.flatten(flo) + self.assertEquals(7, len(g)) + self.assertEquals(0, g.number_of_edges()) + + def test_graph_flatten_links(self): + a, b, c, d = _make_many(4) + flo = gf.Flow("test") + flo.add(a, b, c, d) + flo.link(a, b) + flo.link(b, c) + flo.link(c, d) + + g = f_utils.flatten(flo) + self.assertEquals(4, len(g)) + self.assertEquals(3, g.number_of_edges()) + self.assertEquals(set([a]), + set(g_utils.get_no_predecessors(g))) + self.assertEquals(set([d]), + set(g_utils.get_no_successors(g))) diff --git a/taskflow/utils/flow_utils.py b/taskflow/utils/flow_utils.py new file mode 100644 index 00000000..18bab12a --- /dev/null +++ b/taskflow/utils/flow_utils.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2013 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import networkx as nx + +from taskflow.patterns import graph_flow as gf +from taskflow.patterns import linear_flow as lf +from taskflow.patterns import unordered_flow as uf +from taskflow import task +from taskflow.utils import graph_utils as gu + + +# Use the 'flatten' reason as the need to add an edge here, which is useful for +# doing later analysis of the edges (to determine why the edges were created). +FLATTEN_REASON = 'flatten' +FLATTEN_EDGE_DATA = { + 'reason': FLATTEN_REASON, +} + + +def _graph_name(flow): + return "F:%s:%s" % (flow.name, flow.uuid) + + +def _flatten_linear(flow, flattened): + graph = nx.DiGraph(name=_graph_name(flow)) + previous_nodes = [] + for f in flow: + subgraph = _flatten(f, flattened) + graph = gu.merge_graphs([graph, subgraph]) + # Find nodes that have no predecessor, make them have a predecessor of + # the previous nodes so that the linearity ordering is maintained. Find + # the ones with no successors and use this list to connect the next + # subgraph (if any). + for n in gu.get_no_predecessors(subgraph): + graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA) + for n2 in previous_nodes + if not graph.has_edge(n2, n))) + # There should always be someone without successors, otherwise we have + # a cycle A -> B -> A situation, which should not be possible. + previous_nodes = list(gu.get_no_successors(subgraph)) + return graph + + +def _flatten_unordered(flow, flattened): + graph = nx.DiGraph(name=_graph_name(flow)) + for f in flow: + graph = gu.merge_graphs([graph, _flatten(f, flattened)]) + return graph + + +def _flatten_task(task): + graph = nx.DiGraph(name='T:%s' % (task)) + graph.add_node(task) + return graph + + +def _flatten_graph(flow, flattened): + graph = nx.DiGraph(name=_graph_name(flow)) + subgraph_map = {} + # Flatten all nodes + for n in flow.graph.nodes_iter(): + subgraph = _flatten(n, flattened) + subgraph_map[n] = subgraph + graph = gu.merge_graphs([graph, subgraph]) + # Reconnect all nodes to there corresponding subgraphs + for (u, v) in flow.graph.edges_iter(): + u_no_succ = list(gu.get_no_successors(subgraph_map[u])) + # Connect the ones with no predecessors in v to the ones with no + # successors in u (thus maintaining the edge dependency). + for n in gu.get_no_predecessors(subgraph_map[v]): + graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA) + for n2 in u_no_succ + if not graph.has_edge(n2, n))) + return graph + + +def _flatten(item, flattened): + """Flattens a item (task/flow+subflows) into an execution graph.""" + if item in flattened: + raise ValueError("Already flattened item: %s" % (item)) + if isinstance(item, lf.Flow): + f = _flatten_linear(item, flattened) + elif isinstance(item, uf.Flow): + f = _flatten_unordered(item, flattened) + elif isinstance(item, gf.Flow): + f = _flatten_graph(item, flattened) + elif isinstance(item, task.BaseTask): + f = _flatten_task(item) + else: + raise TypeError("Unknown item: %r, %s" % (type(item), item)) + flattened.add(item) + return f + + +def flatten(item, freeze=True): + graph = _flatten(item, set()) + if freeze: + # Frozen graph can't be modified... + return nx.freeze(graph) + return graph diff --git a/taskflow/utils/graph_utils.py b/taskflow/utils/graph_utils.py index b56559bd..4390803a 100644 --- a/taskflow/utils/graph_utils.py +++ b/taskflow/utils/graph_utils.py @@ -16,65 +16,68 @@ # License for the specific language governing permissions and limitations # under the License. -import logging +import six -from taskflow import exceptions as exc +import networkx as nx +from networkx import algorithms -LOG = logging.getLogger(__name__) +def merge_graphs(graphs, allow_overlaps=False): + if not graphs: + return None + graph = graphs[0] + for g in graphs[1:]: + # This should ensure that the nodes to be merged do not already exist + # in the graph that is to be merged into. This could be problematic if + # there are duplicates. + if not allow_overlaps: + # Attempt to induce a subgraph using the to be merged graphs nodes + # and see if any graph results. + overlaps = graph.subgraph(g.nodes_iter()) + if len(overlaps): + raise ValueError("Can not merge graph %s into %s since there " + "are %s overlapping nodes" (g, graph, + len(overlaps))) + # Keep the target graphs name. + name = graph.name + graph = algorithms.compose(graph, g) + graph.name = name + return graph -def connect(graph, infer_key='infer', auto_reason='auto', discard_func=None): - """Connects a graphs runners to other runners in the graph which provide - outputs for each runners requirements. - """ +def get_no_successors(graph): + """Returns an iterator for all nodes with no successors""" + for n in graph.nodes_iter(): + if not len(graph.successors(n)): + yield n - if len(graph) == 0: - return - if discard_func: - for (u, v, e_data) in graph.edges(data=True): - if discard_func(u, v, e_data): - graph.remove_edge(u, v) - for (r, r_data) in graph.nodes_iter(data=True): - requires = set(r.requires) - # Find the ones that have already been attached manually. - manual_providers = {} - if requires: - incoming = [e[0] for e in graph.in_edges_iter([r])] - for r2 in incoming: - fulfills = requires & r2.provides - if fulfills: - LOG.debug("%s is a manual provider of %s for %s", - r2, fulfills, r) - for k in fulfills: - manual_providers[k] = r2 - requires.remove(k) +def get_no_predecessors(graph): + """Returns an iterator for all nodes with no predecessors""" + for n in graph.nodes_iter(): + if not len(graph.predecessors(n)): + yield n - # Anything leftover that we must find providers for?? - auto_providers = {} - if requires and r_data.get(infer_key): - for r2 in graph.nodes_iter(): - if r is r2: - continue - fulfills = requires & r2.provides - if fulfills: - graph.add_edge(r2, r, reason=auto_reason) - LOG.debug("Connecting %s as a automatic provider for" - " %s for %s", r2, fulfills, r) - for k in fulfills: - auto_providers[k] = r2 - requires.remove(k) - if not requires: - break - # Anything still leftover?? - if requires: - # Ensure its in string format, since join will puke on - # things that are not strings. - missing = ", ".join(sorted([str(s) for s in requires])) - raise exc.MissingDependencies(r, missing) - else: - r.providers = {} - r.providers.update(auto_providers) - r.providers.update(manual_providers) +def pformat(graph): + lines = [] + lines.append("Name: %s" % graph.name) + lines.append("Type: %s" % type(graph).__name__) + lines.append("Frozen: %s" % nx.is_frozen(graph)) + lines.append("Nodes: %s" % graph.number_of_nodes()) + for n in graph.nodes_iter(): + lines.append(" - %s" % n) + lines.append("Edges: %s" % graph.number_of_edges()) + for (u, v, e_data) in graph.edges_iter(data=True): + reason = e_data.get('reason', '??') + lines.append(" %s -> %s (%s)" % (u, v, reason)) + cycles = list(nx.cycles.recursive_simple_cycles(graph)) + lines.append("Cycles: %s" % len(cycles)) + for cycle in cycles: + buf = six.StringIO() + buf.write(str(cycle[0])) + for i in range(1, len(cycle)): + buf.write(" --> %s" % (cycle[i])) + buf.write(" --> %s" % (cycle[0])) + lines.append(" %s" % buf.getvalue()) + return "\n".join(lines)