Allow (or disallow) multiple providers of items.
Incases where a item is produced by 2+ tasks we should allow the task that depends on that item to get the 2+ items produced as a list of items. We can also disallow this type of production via a new boolean option.
This commit is contained in:
@@ -34,10 +34,12 @@ class Flow(ordered_flow.Flow):
|
||||
determine who provides said input and order the task so that said providing
|
||||
task will be ran before."""
|
||||
|
||||
def __init__(self, name, tolerant=False, parents=None):
|
||||
def __init__(self, name, tolerant=False, parents=None,
|
||||
allow_same_inputs=True):
|
||||
super(Flow, self).__init__(name, tolerant, parents)
|
||||
self._graph = digraph.DiGraph()
|
||||
self._connected = False
|
||||
self._allow_same_inputs = allow_same_inputs
|
||||
|
||||
def add(self, task):
|
||||
# Do something with the task, either store it for later
|
||||
@@ -49,18 +51,25 @@ class Flow(ordered_flow.Flow):
|
||||
self._connected = False
|
||||
|
||||
def _fetch_task_inputs(self, task):
|
||||
inputs = {}
|
||||
inputs = collections.defaultdict(list)
|
||||
|
||||
for n in task.requires():
|
||||
for (them, there_result) in self.results:
|
||||
if (not self._graph.has_edge(them, task) or
|
||||
not n in them.provides() or not there_result):
|
||||
not n in them.provides()):
|
||||
continue
|
||||
if n in there_result:
|
||||
# NOTE(harlowja): later results overwrite
|
||||
# prior results for the same keys, which is
|
||||
# typically desired.
|
||||
inputs[n] = there_result[n]
|
||||
return inputs
|
||||
if there_result and n in there_result:
|
||||
inputs[n].append(there_result[n])
|
||||
else:
|
||||
inputs[n].append(None)
|
||||
|
||||
def collapse_functor(k_v):
|
||||
(k, v) = k_v
|
||||
if len(v) == 1:
|
||||
v = v[0]
|
||||
return (k, v)
|
||||
|
||||
return dict(map(collapse_functor, inputs.iteritems()))
|
||||
|
||||
def order(self):
|
||||
self.connect()
|
||||
@@ -86,6 +95,14 @@ class Flow(ordered_flow.Flow):
|
||||
for p in t.provides():
|
||||
provides_what[p].append(t)
|
||||
|
||||
def get_providers(node, want_what):
|
||||
providers = []
|
||||
for (producer, me) in self._graph.in_edges_iter(node):
|
||||
providing_what = self._graph.get_edge_data(producer, me)
|
||||
if want_what in providing_what:
|
||||
providers.append(producer)
|
||||
return providers
|
||||
|
||||
# Link providers to consumers of items.
|
||||
for (want_what, who_wants) in requires_what.iteritems():
|
||||
who_provided = 0
|
||||
@@ -95,10 +112,13 @@ class Flow(ordered_flow.Flow):
|
||||
if p is n:
|
||||
# No self-referencing allowed.
|
||||
continue
|
||||
why = {
|
||||
if (len(get_providers(n, want_what)) and not
|
||||
self._allow_same_inputs):
|
||||
msg = "Multiple providers of %s not allowed."
|
||||
raise exc.InvalidStateException(msg % (want_what))
|
||||
self._graph.add_edge(p, n, attr_dict={
|
||||
want_what: True,
|
||||
}
|
||||
self._graph.add_edge(p, n, why)
|
||||
})
|
||||
who_provided += 1
|
||||
if not who_provided:
|
||||
who_wants = ", ".join([str(a) for a in who_wants])
|
||||
|
||||
@@ -57,6 +57,39 @@ class GraphFlowTest(unittest.TestCase):
|
||||
self.assertEquals(states.FAILURE, flo.state)
|
||||
self.assertEquals(['run1'], reverted)
|
||||
|
||||
def test_multi_provider_disallowed(self):
|
||||
flo = gw.Flow("test-flow", allow_same_inputs=False)
|
||||
flo.add(utils.ProvidesRequiresTask('test6',
|
||||
provides=['y'],
|
||||
requires=[]))
|
||||
flo.add(utils.ProvidesRequiresTask('test7',
|
||||
provides=['y'],
|
||||
requires=[]))
|
||||
flo.add(utils.ProvidesRequiresTask('test8',
|
||||
provides=[],
|
||||
requires=['y']))
|
||||
self.assertEquals(states.PENDING, flo.state)
|
||||
self.assertRaises(excp.InvalidStateException, flo.run, {})
|
||||
self.assertEquals(states.FAILURE, flo.state)
|
||||
|
||||
def test_multi_provider_allowed(self):
|
||||
flo = gw.Flow("test-flow", allow_same_inputs=True)
|
||||
flo.add(utils.ProvidesRequiresTask('test6',
|
||||
provides=['y', 'z'],
|
||||
requires=[]))
|
||||
flo.add(utils.ProvidesRequiresTask('test7',
|
||||
provides=['y'],
|
||||
requires=['z']))
|
||||
flo.add(utils.ProvidesRequiresTask('test8',
|
||||
provides=[],
|
||||
requires=['y', 'z']))
|
||||
ctx = {}
|
||||
flo.run(ctx)
|
||||
self.assertEquals(['test6', 'test7', 'test8'], ctx[utils.ORDER_KEY])
|
||||
(_task, results) = flo.results[2]
|
||||
self.assertEquals([True, True], results[utils.KWARGS_KEY]['y'])
|
||||
self.assertEquals(True, results[utils.KWARGS_KEY]['z'])
|
||||
|
||||
def test_no_requires_provider(self):
|
||||
flo = gw.Flow("test-flow")
|
||||
flo.add(utils.ProvidesRequiresTask('test1',
|
||||
|
||||
Reference in New Issue
Block a user