Add a flow flattening util
Instead of recursively executing subflows which causes dead locks when they parent and subflows share the same executor we can instead flatten the parent and subflows into a single graph, composed with only tasks and run this instead, which will not have the issue of subflows dead locking, since after flattening there is no concept of a subflow. Fixes bug: 1225759 Change-Id: I79b9b194cd81e36ce75ba34a673e3e9d3e96c4cd
This commit is contained in:

committed by
Ivan A. Melnikov

parent
0417ebf956
commit
d736bdbfae
@@ -22,22 +22,16 @@ import threading
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.action_engine import graph_action
|
||||
from taskflow.engines.action_engine import parallel_action
|
||||
from taskflow.engines.action_engine import seq_action
|
||||
from taskflow.engines.action_engine import task_action
|
||||
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.patterns import unordered_flow as uf
|
||||
|
||||
from taskflow.persistence import utils as p_utils
|
||||
|
||||
from taskflow import decorators
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import states
|
||||
from taskflow import storage as t_storage
|
||||
from taskflow import task
|
||||
|
||||
from taskflow.utils import flow_utils
|
||||
from taskflow.utils import misc
|
||||
|
||||
|
||||
@@ -105,59 +99,21 @@ class ActionEngine(object):
|
||||
result=result)
|
||||
self.task_notifier.notify(state, details)
|
||||
|
||||
def _translate_flow_to_action(self):
|
||||
# Flatten the flow into just 1 graph.
|
||||
task_graph = flow_utils.flatten(self._flow)
|
||||
ga = graph_action.SequentialGraphAction(task_graph)
|
||||
for n in task_graph.nodes_iter():
|
||||
ga.add(n, task_action.TaskAction(n, self))
|
||||
return ga
|
||||
|
||||
@decorators.locked
|
||||
def compile(self):
|
||||
if self._root is None:
|
||||
translator = self.translator_cls(self)
|
||||
self._root = translator.translate(self._flow)
|
||||
|
||||
|
||||
class Translator(object):
|
||||
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def _factory_map(self):
|
||||
return []
|
||||
|
||||
def translate(self, pattern):
|
||||
"""Translates the pattern into an engine runnable action"""
|
||||
if isinstance(pattern, task.BaseTask):
|
||||
# Wrap the task into something more useful.
|
||||
return task_action.TaskAction(pattern, self.engine)
|
||||
|
||||
# Decompose the flow into something more useful:
|
||||
for cls, factory in self._factory_map():
|
||||
if isinstance(pattern, cls):
|
||||
return factory(pattern)
|
||||
|
||||
raise TypeError('Unknown pattern type: %s (type %s)'
|
||||
% (pattern, type(pattern)))
|
||||
|
||||
|
||||
class SingleThreadedTranslator(Translator):
|
||||
|
||||
def _factory_map(self):
|
||||
return [(lf.Flow, self._translate_sequential),
|
||||
(uf.Flow, self._translate_sequential),
|
||||
(gf.Flow, self._translate_graph)]
|
||||
|
||||
def _translate_sequential(self, pattern):
|
||||
action = seq_action.SequentialAction()
|
||||
for p in pattern:
|
||||
action.add(self.translate(p))
|
||||
return action
|
||||
|
||||
def _translate_graph(self, pattern):
|
||||
action = graph_action.SequentialGraphAction(pattern.graph)
|
||||
for p in pattern:
|
||||
action.add(p, self.translate(p))
|
||||
return action
|
||||
self._root = self._translate_flow_to_action()
|
||||
|
||||
|
||||
class SingleThreadedActionEngine(ActionEngine):
|
||||
translator_cls = SingleThreadedTranslator
|
||||
|
||||
def __init__(self, flow, flow_detail=None, book=None, backend=None):
|
||||
if flow_detail is None:
|
||||
flow_detail = p_utils.create_flow_detail(flow,
|
||||
@@ -167,37 +123,7 @@ class SingleThreadedActionEngine(ActionEngine):
|
||||
storage=t_storage.Storage(flow_detail, backend))
|
||||
|
||||
|
||||
class MultiThreadedTranslator(Translator):
|
||||
|
||||
def _factory_map(self):
|
||||
return [(lf.Flow, self._translate_sequential),
|
||||
# unordered can be run in parallel
|
||||
(uf.Flow, self._translate_parallel),
|
||||
(gf.Flow, self._translate_graph)]
|
||||
|
||||
def _translate_sequential(self, pattern):
|
||||
action = seq_action.SequentialAction()
|
||||
for p in pattern:
|
||||
action.add(self.translate(p))
|
||||
return action
|
||||
|
||||
def _translate_parallel(self, pattern):
|
||||
action = parallel_action.ParallelAction()
|
||||
for p in pattern:
|
||||
action.add(self.translate(p))
|
||||
return action
|
||||
|
||||
def _translate_graph(self, pattern):
|
||||
# TODO(akarpinska): replace with parallel graph later
|
||||
action = graph_action.SequentialGraphAction(pattern.graph)
|
||||
for p in pattern:
|
||||
action.add(p, self.translate(p))
|
||||
return action
|
||||
|
||||
|
||||
class MultiThreadedActionEngine(ActionEngine):
|
||||
translator_cls = MultiThreadedTranslator
|
||||
|
||||
def __init__(self, flow, flow_detail=None, book=None, backend=None,
|
||||
executor=None):
|
||||
if flow_detail is None:
|
||||
|
@@ -1,54 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.action_engine import base_action as base
|
||||
from taskflow.utils import misc
|
||||
|
||||
|
||||
class ParallelAction(base.Action):
|
||||
|
||||
def __init__(self):
|
||||
self._actions = []
|
||||
|
||||
def add(self, action):
|
||||
self._actions.append(action)
|
||||
|
||||
def _map(self, engine, fn):
|
||||
executor = engine.executor
|
||||
|
||||
def call_fn(action):
|
||||
try:
|
||||
fn(action)
|
||||
except Exception:
|
||||
return misc.Failure()
|
||||
else:
|
||||
return None
|
||||
|
||||
failures = []
|
||||
result_iter = executor.map(call_fn, self._actions)
|
||||
for result in result_iter:
|
||||
if isinstance(result, misc.Failure):
|
||||
failures.append(result)
|
||||
if failures:
|
||||
failures[0].reraise()
|
||||
|
||||
def execute(self, engine):
|
||||
self._map(engine, lambda action: action.execute(engine))
|
||||
|
||||
def revert(self, engine):
|
||||
self._map(engine, lambda action: action.revert(engine))
|
@@ -1,36 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from taskflow.engines.action_engine import base_action as base
|
||||
|
||||
|
||||
class SequentialAction(base.Action):
|
||||
|
||||
def __init__(self):
|
||||
self._actions = []
|
||||
|
||||
def add(self, action):
|
||||
self._actions.append(action)
|
||||
|
||||
def execute(self, engine):
|
||||
for action in self._actions:
|
||||
action.execute(engine) # raises on failure
|
||||
|
||||
def revert(self, engine):
|
||||
for action in reversed(self._actions):
|
||||
action.revert(engine)
|
@@ -27,3 +27,16 @@ class TestCase(unittest2.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super(TestCase, self).tearDown()
|
||||
|
||||
def assertIsSubset(self, super_set, sub_set, msg=None):
|
||||
missing_set = set()
|
||||
for e in sub_set:
|
||||
if e not in super_set:
|
||||
missing_set.add(e)
|
||||
if len(missing_set):
|
||||
if msg is not None:
|
||||
self.fail(msg)
|
||||
else:
|
||||
self.fail("Subset %s has %s elements which are not in the "
|
||||
"superset %s." % (sub_set, list(missing_set),
|
||||
list(super_set)))
|
||||
|
@@ -594,12 +594,14 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
|
||||
engine.run()
|
||||
result = set(self.values)
|
||||
self.assertEquals(result,
|
||||
set(['task1', 'task2',
|
||||
'task2 reverted(5)', 'task1 reverted(5)']))
|
||||
# NOTE(harlowja): task 1/2 may or may not have executed, even with the
|
||||
# sleeps due to the fact that the above is an unordered flow.
|
||||
possible_result = set(['task1', 'task2',
|
||||
'task2 reverted(5)', 'task1 reverted(5)'])
|
||||
self.assertIsSubset(possible_result, result)
|
||||
|
||||
def test_parallel_revert_exception_is_reraised_(self):
|
||||
flow = uf.Flow('p-r-reraise').add(
|
||||
flow = lf.Flow('p-r-reraise').add(
|
||||
TestTask(self.values, name='task1', sleep=0.01),
|
||||
NastyTask(),
|
||||
FailingTask(sleep=0.01),
|
||||
@@ -609,13 +611,13 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
|
||||
engine.run()
|
||||
result = set(self.values)
|
||||
self.assertEquals(result, set(['task1', 'task1 reverted(5)']))
|
||||
self.assertEquals(result, set(['task1']))
|
||||
|
||||
def test_nested_parallel_revert_exception_is_reraised(self):
|
||||
flow = uf.Flow('p-root').add(
|
||||
TestTask(self.values, name='task1'),
|
||||
TestTask(self.values, name='task2'),
|
||||
uf.Flow('p-inner').add(
|
||||
lf.Flow('p-inner').add(
|
||||
TestTask(self.values, name='task3', sleep=0.1),
|
||||
NastyTask(),
|
||||
FailingTask(sleep=0.01)
|
||||
@@ -625,9 +627,13 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
|
||||
engine.run()
|
||||
result = set(self.values)
|
||||
self.assertEquals(result, set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3', 'task3 reverted(5)']))
|
||||
# Task1, task2 may *not* have executed and also may have *not* reverted
|
||||
# since the above is an unordered flow so take that into account by
|
||||
# ensuring that the superset is matched.
|
||||
possible_result = set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3', 'task3 reverted(5)'])
|
||||
self.assertIsSubset(possible_result, result)
|
||||
|
||||
def test_parallel_revert_exception_do_not_revert_linear_tasks(self):
|
||||
flow = lf.Flow('l-root').add(
|
||||
@@ -640,11 +646,35 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
)
|
||||
)
|
||||
engine = self._make_engine(flow)
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
|
||||
# Depending on when (and if failing task) is executed the exception
|
||||
# raised could be either woot or gotcha since the above unordered
|
||||
# sub-flow does not guarantee that the ordering will be maintained,
|
||||
# even with sleeping.
|
||||
was_nasty = False
|
||||
try:
|
||||
engine.run()
|
||||
self.assertTrue(False)
|
||||
except RuntimeError as e:
|
||||
self.assertRegexpMatches(str(e), '^Gotcha|^Woot')
|
||||
if 'Gotcha!' in str(e):
|
||||
was_nasty = True
|
||||
result = set(self.values)
|
||||
self.assertEquals(result, set(['task1', 'task2',
|
||||
'task3', 'task3 reverted(5)']))
|
||||
possible_result = set(['task1', 'task2',
|
||||
'task3', 'task3 reverted(5)'])
|
||||
if not was_nasty:
|
||||
possible_result.update(['task1 reverted(5)', 'task2 reverted(5)'])
|
||||
self.assertIsSubset(possible_result, result)
|
||||
# If the nasty task killed reverting, then task1 and task2 should not
|
||||
# have reverted, but if the failing task stopped execution then task1
|
||||
# and task2 should have reverted.
|
||||
if was_nasty:
|
||||
must_not_have = ['task1 reverted(5)', 'task2 reverted(5)']
|
||||
for r in must_not_have:
|
||||
self.assertNotIn(r, result)
|
||||
else:
|
||||
must_have = ['task1 reverted(5)', 'task2 reverted(5)']
|
||||
for r in must_have:
|
||||
self.assertIn(r, result)
|
||||
|
||||
def test_parallel_nested_to_linear_revert(self):
|
||||
flow = lf.Flow('l-root').add(
|
||||
@@ -659,9 +689,18 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
|
||||
engine.run()
|
||||
result = set(self.values)
|
||||
self.assertEquals(result, set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3', 'task3 reverted(5)']))
|
||||
# Task3 may or may not have executed, depending on scheduling and
|
||||
# task ordering selection, so it may or may not exist in the result set
|
||||
possible_result = set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3', 'task3 reverted(5)'])
|
||||
self.assertIsSubset(possible_result, result)
|
||||
# These must exist, since the linearity of the linear flow ensures
|
||||
# that they were executed first.
|
||||
must_have = ['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)']
|
||||
for r in must_have:
|
||||
self.assertIn(r, result)
|
||||
|
||||
def test_linear_nested_to_parallel_revert(self):
|
||||
flow = uf.Flow('p-root').add(
|
||||
@@ -676,11 +715,14 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Woot'):
|
||||
engine.run()
|
||||
result = set(self.values)
|
||||
self.assertEquals(result,
|
||||
set(['task1', 'task1 reverted(5)',
|
||||
# Since this is an unordered flow we can not guarantee that task1 or
|
||||
# task2 will exist and be reverted, although they may exist depending
|
||||
# on how the OS thread scheduling and execution graph algorithm...
|
||||
possible_result = set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3', 'task3 reverted(5)',
|
||||
'fail reverted(Failure: RuntimeError: Woot!)']))
|
||||
'fail reverted(Failure: RuntimeError: Woot!)'])
|
||||
self.assertIsSubset(possible_result, result)
|
||||
|
||||
def test_linear_nested_to_parallel_revert_exception(self):
|
||||
flow = uf.Flow('p-root').add(
|
||||
@@ -696,6 +738,7 @@ class MultiThreadedEngineTest(EngineTaskTest,
|
||||
with self.assertRaisesRegexp(RuntimeError, '^Gotcha'):
|
||||
engine.run()
|
||||
result = set(self.values)
|
||||
self.assertEquals(result, set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3']))
|
||||
possible_result = set(['task1', 'task1 reverted(5)',
|
||||
'task2', 'task2 reverted(5)',
|
||||
'task3'])
|
||||
self.assertIsSubset(possible_result, result)
|
||||
|
169
taskflow/tests/unit/test_flattening.py
Normal file
169
taskflow/tests/unit/test_flattening.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import string
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.patterns import unordered_flow as uf
|
||||
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils as t_utils
|
||||
from taskflow.utils import flow_utils as f_utils
|
||||
from taskflow.utils import graph_utils as g_utils
|
||||
|
||||
|
||||
def _make_many(amount):
|
||||
assert amount <= len(string.ascii_lowercase), 'Not enough letters'
|
||||
tasks = []
|
||||
for i in range(0, amount):
|
||||
tasks.append(t_utils.DummyTask(name=string.ascii_lowercase[i]))
|
||||
return tasks
|
||||
|
||||
|
||||
class FlattenTest(test.TestCase):
|
||||
def test_linear_flatten(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = lf.Flow("test")
|
||||
flo.add(a, b, c)
|
||||
sflo = lf.Flow("sub-test")
|
||||
sflo.add(d)
|
||||
flo.add(sflo)
|
||||
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(4, len(g))
|
||||
|
||||
order = nx.topological_sort(g)
|
||||
self.assertEquals([a, b, c, d], order)
|
||||
self.assertTrue(g.has_edge(c, d))
|
||||
self.assertEquals([d], list(g_utils.get_no_successors(g)))
|
||||
self.assertEquals([a], list(g_utils.get_no_predecessors(g)))
|
||||
|
||||
def test_invalid_flatten(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = lf.Flow("test")
|
||||
flo.add(a, b, c)
|
||||
flo.add(flo)
|
||||
self.assertRaises(ValueError, f_utils.flatten, flo)
|
||||
|
||||
def test_unordered_flatten(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = uf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(4, len(g))
|
||||
self.assertEquals(0, g.number_of_edges())
|
||||
self.assertEquals(set([a, b, c, d]),
|
||||
set(g_utils.get_no_successors(g)))
|
||||
self.assertEquals(set([a, b, c, d]),
|
||||
set(g_utils.get_no_predecessors(g)))
|
||||
|
||||
def test_linear_nested_flatten(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = lf.Flow("test")
|
||||
flo.add(a, b)
|
||||
flo2 = uf.Flow("test2")
|
||||
flo2.add(c, d)
|
||||
flo.add(flo2)
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(4, len(g))
|
||||
|
||||
lb = g.subgraph([a, b])
|
||||
self.assertTrue(lb.has_edge(a, b))
|
||||
self.assertFalse(lb.has_edge(b, a))
|
||||
|
||||
ub = g.subgraph([c, d])
|
||||
self.assertEquals(0, ub.number_of_edges())
|
||||
|
||||
# This ensures that c and d do not start executing until after b.
|
||||
self.assertTrue(g.has_edge(b, c))
|
||||
self.assertTrue(g.has_edge(b, d))
|
||||
|
||||
def test_unordered_nested_flatten(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = uf.Flow("test")
|
||||
flo.add(a, b)
|
||||
flo2 = lf.Flow("test2")
|
||||
flo2.add(c, d)
|
||||
flo.add(flo2)
|
||||
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(4, len(g))
|
||||
for n in [a, b]:
|
||||
self.assertFalse(g.has_edge(n, c))
|
||||
self.assertFalse(g.has_edge(n, d))
|
||||
self.assertTrue(g.has_edge(c, d))
|
||||
self.assertFalse(g.has_edge(d, c))
|
||||
|
||||
ub = g.subgraph([a, b])
|
||||
self.assertEquals(0, ub.number_of_edges())
|
||||
lb = g.subgraph([c, d])
|
||||
self.assertEquals(1, lb.number_of_edges())
|
||||
|
||||
def test_graph_flatten(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = gf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(4, len(g))
|
||||
self.assertEquals(0, g.number_of_edges())
|
||||
|
||||
def test_graph_flatten_nested(self):
|
||||
a, b, c, d, e, f, g = _make_many(7)
|
||||
flo = gf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
|
||||
flo2 = lf.Flow('test2')
|
||||
flo2.add(e, f, g)
|
||||
flo.add(flo2)
|
||||
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(7, len(g))
|
||||
self.assertEquals(2, g.number_of_edges())
|
||||
|
||||
def test_graph_flatten_nested_graph(self):
|
||||
a, b, c, d, e, f, g = _make_many(7)
|
||||
flo = gf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
|
||||
flo2 = gf.Flow('test2')
|
||||
flo2.add(e, f, g)
|
||||
flo.add(flo2)
|
||||
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(7, len(g))
|
||||
self.assertEquals(0, g.number_of_edges())
|
||||
|
||||
def test_graph_flatten_links(self):
|
||||
a, b, c, d = _make_many(4)
|
||||
flo = gf.Flow("test")
|
||||
flo.add(a, b, c, d)
|
||||
flo.link(a, b)
|
||||
flo.link(b, c)
|
||||
flo.link(c, d)
|
||||
|
||||
g = f_utils.flatten(flo)
|
||||
self.assertEquals(4, len(g))
|
||||
self.assertEquals(3, g.number_of_edges())
|
||||
self.assertEquals(set([a]),
|
||||
set(g_utils.get_no_predecessors(g)))
|
||||
self.assertEquals(set([d]),
|
||||
set(g_utils.get_no_successors(g)))
|
116
taskflow/utils/flow_utils.py
Normal file
116
taskflow/utils/flow_utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2013 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from taskflow.patterns import graph_flow as gf
|
||||
from taskflow.patterns import linear_flow as lf
|
||||
from taskflow.patterns import unordered_flow as uf
|
||||
from taskflow import task
|
||||
from taskflow.utils import graph_utils as gu
|
||||
|
||||
|
||||
# Use the 'flatten' reason as the need to add an edge here, which is useful for
|
||||
# doing later analysis of the edges (to determine why the edges were created).
|
||||
FLATTEN_REASON = 'flatten'
|
||||
FLATTEN_EDGE_DATA = {
|
||||
'reason': FLATTEN_REASON,
|
||||
}
|
||||
|
||||
|
||||
def _graph_name(flow):
|
||||
return "F:%s:%s" % (flow.name, flow.uuid)
|
||||
|
||||
|
||||
def _flatten_linear(flow, flattened):
|
||||
graph = nx.DiGraph(name=_graph_name(flow))
|
||||
previous_nodes = []
|
||||
for f in flow:
|
||||
subgraph = _flatten(f, flattened)
|
||||
graph = gu.merge_graphs([graph, subgraph])
|
||||
# Find nodes that have no predecessor, make them have a predecessor of
|
||||
# the previous nodes so that the linearity ordering is maintained. Find
|
||||
# the ones with no successors and use this list to connect the next
|
||||
# subgraph (if any).
|
||||
for n in gu.get_no_predecessors(subgraph):
|
||||
graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA)
|
||||
for n2 in previous_nodes
|
||||
if not graph.has_edge(n2, n)))
|
||||
# There should always be someone without successors, otherwise we have
|
||||
# a cycle A -> B -> A situation, which should not be possible.
|
||||
previous_nodes = list(gu.get_no_successors(subgraph))
|
||||
return graph
|
||||
|
||||
|
||||
def _flatten_unordered(flow, flattened):
|
||||
graph = nx.DiGraph(name=_graph_name(flow))
|
||||
for f in flow:
|
||||
graph = gu.merge_graphs([graph, _flatten(f, flattened)])
|
||||
return graph
|
||||
|
||||
|
||||
def _flatten_task(task):
|
||||
graph = nx.DiGraph(name='T:%s' % (task))
|
||||
graph.add_node(task)
|
||||
return graph
|
||||
|
||||
|
||||
def _flatten_graph(flow, flattened):
|
||||
graph = nx.DiGraph(name=_graph_name(flow))
|
||||
subgraph_map = {}
|
||||
# Flatten all nodes
|
||||
for n in flow.graph.nodes_iter():
|
||||
subgraph = _flatten(n, flattened)
|
||||
subgraph_map[n] = subgraph
|
||||
graph = gu.merge_graphs([graph, subgraph])
|
||||
# Reconnect all nodes to there corresponding subgraphs
|
||||
for (u, v) in flow.graph.edges_iter():
|
||||
u_no_succ = list(gu.get_no_successors(subgraph_map[u]))
|
||||
# Connect the ones with no predecessors in v to the ones with no
|
||||
# successors in u (thus maintaining the edge dependency).
|
||||
for n in gu.get_no_predecessors(subgraph_map[v]):
|
||||
graph.add_edges_from(((n2, n, FLATTEN_EDGE_DATA)
|
||||
for n2 in u_no_succ
|
||||
if not graph.has_edge(n2, n)))
|
||||
return graph
|
||||
|
||||
|
||||
def _flatten(item, flattened):
|
||||
"""Flattens a item (task/flow+subflows) into an execution graph."""
|
||||
if item in flattened:
|
||||
raise ValueError("Already flattened item: %s" % (item))
|
||||
if isinstance(item, lf.Flow):
|
||||
f = _flatten_linear(item, flattened)
|
||||
elif isinstance(item, uf.Flow):
|
||||
f = _flatten_unordered(item, flattened)
|
||||
elif isinstance(item, gf.Flow):
|
||||
f = _flatten_graph(item, flattened)
|
||||
elif isinstance(item, task.BaseTask):
|
||||
f = _flatten_task(item)
|
||||
else:
|
||||
raise TypeError("Unknown item: %r, %s" % (type(item), item))
|
||||
flattened.add(item)
|
||||
return f
|
||||
|
||||
|
||||
def flatten(item, freeze=True):
|
||||
graph = _flatten(item, set())
|
||||
if freeze:
|
||||
# Frozen graph can't be modified...
|
||||
return nx.freeze(graph)
|
||||
return graph
|
@@ -16,65 +16,68 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import six
|
||||
|
||||
from taskflow import exceptions as exc
|
||||
import networkx as nx
|
||||
from networkx import algorithms
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
def merge_graphs(graphs, allow_overlaps=False):
|
||||
if not graphs:
|
||||
return None
|
||||
graph = graphs[0]
|
||||
for g in graphs[1:]:
|
||||
# This should ensure that the nodes to be merged do not already exist
|
||||
# in the graph that is to be merged into. This could be problematic if
|
||||
# there are duplicates.
|
||||
if not allow_overlaps:
|
||||
# Attempt to induce a subgraph using the to be merged graphs nodes
|
||||
# and see if any graph results.
|
||||
overlaps = graph.subgraph(g.nodes_iter())
|
||||
if len(overlaps):
|
||||
raise ValueError("Can not merge graph %s into %s since there "
|
||||
"are %s overlapping nodes" (g, graph,
|
||||
len(overlaps)))
|
||||
# Keep the target graphs name.
|
||||
name = graph.name
|
||||
graph = algorithms.compose(graph, g)
|
||||
graph.name = name
|
||||
return graph
|
||||
|
||||
|
||||
def connect(graph, infer_key='infer', auto_reason='auto', discard_func=None):
|
||||
"""Connects a graphs runners to other runners in the graph which provide
|
||||
outputs for each runners requirements.
|
||||
"""
|
||||
def get_no_successors(graph):
|
||||
"""Returns an iterator for all nodes with no successors"""
|
||||
for n in graph.nodes_iter():
|
||||
if not len(graph.successors(n)):
|
||||
yield n
|
||||
|
||||
if len(graph) == 0:
|
||||
return
|
||||
if discard_func:
|
||||
for (u, v, e_data) in graph.edges(data=True):
|
||||
if discard_func(u, v, e_data):
|
||||
graph.remove_edge(u, v)
|
||||
for (r, r_data) in graph.nodes_iter(data=True):
|
||||
requires = set(r.requires)
|
||||
|
||||
# Find the ones that have already been attached manually.
|
||||
manual_providers = {}
|
||||
if requires:
|
||||
incoming = [e[0] for e in graph.in_edges_iter([r])]
|
||||
for r2 in incoming:
|
||||
fulfills = requires & r2.provides
|
||||
if fulfills:
|
||||
LOG.debug("%s is a manual provider of %s for %s",
|
||||
r2, fulfills, r)
|
||||
for k in fulfills:
|
||||
manual_providers[k] = r2
|
||||
requires.remove(k)
|
||||
def get_no_predecessors(graph):
|
||||
"""Returns an iterator for all nodes with no predecessors"""
|
||||
for n in graph.nodes_iter():
|
||||
if not len(graph.predecessors(n)):
|
||||
yield n
|
||||
|
||||
# Anything leftover that we must find providers for??
|
||||
auto_providers = {}
|
||||
if requires and r_data.get(infer_key):
|
||||
for r2 in graph.nodes_iter():
|
||||
if r is r2:
|
||||
continue
|
||||
fulfills = requires & r2.provides
|
||||
if fulfills:
|
||||
graph.add_edge(r2, r, reason=auto_reason)
|
||||
LOG.debug("Connecting %s as a automatic provider for"
|
||||
" %s for %s", r2, fulfills, r)
|
||||
for k in fulfills:
|
||||
auto_providers[k] = r2
|
||||
requires.remove(k)
|
||||
if not requires:
|
||||
break
|
||||
|
||||
# Anything still leftover??
|
||||
if requires:
|
||||
# Ensure its in string format, since join will puke on
|
||||
# things that are not strings.
|
||||
missing = ", ".join(sorted([str(s) for s in requires]))
|
||||
raise exc.MissingDependencies(r, missing)
|
||||
else:
|
||||
r.providers = {}
|
||||
r.providers.update(auto_providers)
|
||||
r.providers.update(manual_providers)
|
||||
def pformat(graph):
|
||||
lines = []
|
||||
lines.append("Name: %s" % graph.name)
|
||||
lines.append("Type: %s" % type(graph).__name__)
|
||||
lines.append("Frozen: %s" % nx.is_frozen(graph))
|
||||
lines.append("Nodes: %s" % graph.number_of_nodes())
|
||||
for n in graph.nodes_iter():
|
||||
lines.append(" - %s" % n)
|
||||
lines.append("Edges: %s" % graph.number_of_edges())
|
||||
for (u, v, e_data) in graph.edges_iter(data=True):
|
||||
reason = e_data.get('reason', '??')
|
||||
lines.append(" %s -> %s (%s)" % (u, v, reason))
|
||||
cycles = list(nx.cycles.recursive_simple_cycles(graph))
|
||||
lines.append("Cycles: %s" % len(cycles))
|
||||
for cycle in cycles:
|
||||
buf = six.StringIO()
|
||||
buf.write(str(cycle[0]))
|
||||
for i in range(1, len(cycle)):
|
||||
buf.write(" --> %s" % (cycle[i]))
|
||||
buf.write(" --> %s" % (cycle[0]))
|
||||
lines.append(" %s" % buf.getvalue())
|
||||
return "\n".join(lines)
|
||||
|
Reference in New Issue
Block a user