From b4f592c5248fd8eda03a5183085ca634abde8b4f Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Wed, 22 May 2013 21:06:37 -0700 Subject: [PATCH] Allow (or disallow) multiple providers of items. Incases where a item is produced by 2+ tasks we should allow the task that depends on that item to get the 2+ items produced as a list of items. We can also disallow this type of production via a new boolean option. --- taskflow/patterns/graph_flow.py | 44 +++++++++++++++++++------- taskflow/tests/unit/test_graph_flow.py | 33 +++++++++++++++++++ 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index c280170a..e6992159 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -34,10 +34,12 @@ class Flow(ordered_flow.Flow): determine who provides said input and order the task so that said providing task will be ran before.""" - def __init__(self, name, tolerant=False, parents=None): + def __init__(self, name, tolerant=False, parents=None, + allow_same_inputs=True): super(Flow, self).__init__(name, tolerant, parents) self._graph = digraph.DiGraph() self._connected = False + self._allow_same_inputs = allow_same_inputs def add(self, task): # Do something with the task, either store it for later @@ -49,18 +51,25 @@ class Flow(ordered_flow.Flow): self._connected = False def _fetch_task_inputs(self, task): - inputs = {} + inputs = collections.defaultdict(list) + for n in task.requires(): for (them, there_result) in self.results: if (not self._graph.has_edge(them, task) or - not n in them.provides() or not there_result): + not n in them.provides()): continue - if n in there_result: - # NOTE(harlowja): later results overwrite - # prior results for the same keys, which is - # typically desired. - inputs[n] = there_result[n] - return inputs + if there_result and n in there_result: + inputs[n].append(there_result[n]) + else: + inputs[n].append(None) + + def collapse_functor(k_v): + (k, v) = k_v + if len(v) == 1: + v = v[0] + return (k, v) + + return dict(map(collapse_functor, inputs.iteritems())) def order(self): self.connect() @@ -86,6 +95,14 @@ class Flow(ordered_flow.Flow): for p in t.provides(): provides_what[p].append(t) + def get_providers(node, want_what): + providers = [] + for (producer, me) in self._graph.in_edges_iter(node): + providing_what = self._graph.get_edge_data(producer, me) + if want_what in providing_what: + providers.append(producer) + return providers + # Link providers to consumers of items. for (want_what, who_wants) in requires_what.iteritems(): who_provided = 0 @@ -95,10 +112,13 @@ class Flow(ordered_flow.Flow): if p is n: # No self-referencing allowed. continue - why = { + if (len(get_providers(n, want_what)) and not + self._allow_same_inputs): + msg = "Multiple providers of %s not allowed." + raise exc.InvalidStateException(msg % (want_what)) + self._graph.add_edge(p, n, attr_dict={ want_what: True, - } - self._graph.add_edge(p, n, why) + }) who_provided += 1 if not who_provided: who_wants = ", ".join([str(a) for a in who_wants]) diff --git a/taskflow/tests/unit/test_graph_flow.py b/taskflow/tests/unit/test_graph_flow.py index 015344f2..13cc6265 100644 --- a/taskflow/tests/unit/test_graph_flow.py +++ b/taskflow/tests/unit/test_graph_flow.py @@ -57,6 +57,39 @@ class GraphFlowTest(unittest.TestCase): self.assertEquals(states.FAILURE, flo.state) self.assertEquals(['run1'], reverted) + def test_multi_provider_disallowed(self): + flo = gw.Flow("test-flow", allow_same_inputs=False) + flo.add(utils.ProvidesRequiresTask('test6', + provides=['y'], + requires=[])) + flo.add(utils.ProvidesRequiresTask('test7', + provides=['y'], + requires=[])) + flo.add(utils.ProvidesRequiresTask('test8', + provides=[], + requires=['y'])) + self.assertEquals(states.PENDING, flo.state) + self.assertRaises(excp.InvalidStateException, flo.run, {}) + self.assertEquals(states.FAILURE, flo.state) + + def test_multi_provider_allowed(self): + flo = gw.Flow("test-flow", allow_same_inputs=True) + flo.add(utils.ProvidesRequiresTask('test6', + provides=['y', 'z'], + requires=[])) + flo.add(utils.ProvidesRequiresTask('test7', + provides=['y'], + requires=['z'])) + flo.add(utils.ProvidesRequiresTask('test8', + provides=[], + requires=['y', 'z'])) + ctx = {} + flo.run(ctx) + self.assertEquals(['test6', 'test7', 'test8'], ctx[utils.ORDER_KEY]) + (_task, results) = flo.results[2] + self.assertEquals([True, True], results[utils.KWARGS_KEY]['y']) + self.assertEquals(True, results[utils.KWARGS_KEY]['z']) + def test_no_requires_provider(self): flo = gw.Flow("test-flow") flo.add(utils.ProvidesRequiresTask('test1',