diff --git a/taskflow/graph_utils.py b/taskflow/graph_utils.py new file mode 100644 index 00000000..18e7f617 --- /dev/null +++ b/taskflow/graph_utils.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging + +from taskflow import exceptions as exc + + +LOG = logging.getLogger(__name__) + + +def connect(graph, infer_key='infer', auto_reason='auto', discard_func=None): + """Connects a graphs runners to other runners in the graph which provide + outputs for each runners requirements.""" + + if len(graph) == 0: + return + if discard_func: + for (u, v, e_data) in graph.edges(data=True): + if discard_func(u, v, e_data): + graph.remove_edge(u, v) + for (r, r_data) in graph.nodes_iter(data=True): + requires = set(r.requires) + + # Find the ones that have already been attached manually. + manual_providers = {} + if requires: + incoming = [e[0] for e in graph.in_edges_iter([r])] + for r2 in incoming: + fulfills = requires & r2.provides + if fulfills: + LOG.debug("%s is a manual provider of %s for %s", + r2, fulfills, r) + for k in fulfills: + manual_providers[k] = r2 + requires.remove(k) + + # Anything leftover that we must find providers for?? + auto_providers = {} + if requires and r_data.get(infer_key): + for r2 in graph.nodes_iter(): + if r is r2: + continue + fulfills = requires & r2.provides + if fulfills: + graph.add_edge(r2, r, reason=auto_reason) + LOG.debug("Connecting %s as a automatic provider for" + " %s for %s", r2, fulfills, r) + for k in fulfills: + auto_providers[k] = r2 + requires.remove(k) + if not requires: + break + + # Anything still leftover?? + if requires: + # Ensure its in string format, since join will puke on + # things that are not strings. + missing = ", ".join(sorted([str(s) for s in requires])) + raise exc.MissingDependencies(r, missing) + else: + r.providers = {} + r.providers.update(auto_providers) + r.providers.update(manual_providers) diff --git a/taskflow/patterns/base.py b/taskflow/patterns/base.py index 4060cf77..fb4ed1da 100644 --- a/taskflow/patterns/base.py +++ b/taskflow/patterns/base.py @@ -120,13 +120,16 @@ class Flow(object): before and after it is ran.""" raise NotImplementedError() - @abc.abstractmethod + @decorators.locked def add_many(self, tasks): """Adds many tasks to this flow. Returns a list of uuids (one for each task added). """ - raise NotImplementedError() + uuids = [] + for t in tasks: + uuids.append(self.add(t)) + return uuids def interrupt(self): """Attempts to interrupt the current flow and any tasks that are diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 9bdcf091..46c7394a 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -25,6 +25,7 @@ from networkx import exception as g_exc from taskflow import decorators from taskflow import exceptions as exc +from taskflow import graph_utils from taskflow.patterns import linear_flow from taskflow import utils @@ -41,20 +42,42 @@ class Flow(linear_flow.Flow): self._graph = digraph.DiGraph() @decorators.locked - def add(self, task): + def add(self, task, infer=True): # Only insert the node to start, connect all the edges # together later after all nodes have been added since if we try # to infer the edges at this stage we likely will fail finding # dependencies from nodes that don't exist. assert isinstance(task, collections.Callable) r = utils.Runner(task) - self._graph.add_node(r, uuid=r.uuid) + self._graph.add_node(r, uuid=r.uuid, infer=infer) self._reset_internals() return r.uuid - def _add_dependency(self, provider, requirer): - if not self._graph.has_edge(provider, requirer): - self._graph.add_edge(provider, requirer) + def _find_uuid(self, uuid): + runner = None + for r in self._graph.nodes_iter(): + if r.uuid == uuid: + runner = r + break + return runner + + @decorators.locked + def add_dependency(self, provider_uuid, requirer_uuid): + """Connects provider to requirer where provider will now be required + to run before requirer does.""" + if provider_uuid == requirer_uuid: + raise ValueError("Unable to link %s to itself" % provider_uuid) + provider = self._find_uuid(provider_uuid) + if not provider: + raise ValueError("No provider found with uuid %s" % provider_uuid) + requirer = self._find_uuid(requirer_uuid) + if not requirer: + raise ValueError("No requirer found with uuid %s" % requirer_uuid) + self._add_dependency(provider, requirer, reason='manual') + self._reset_internals() + + def _add_dependency(self, provider, requirer, reason): + self._graph.add_edge(provider, requirer, reason=reason) def __str__(self): lines = ["GraphFlow: %s" % (self.name)] @@ -71,13 +94,9 @@ class Flow(linear_flow.Flow): @decorators.locked def remove(self, uuid): - runner = None - for r in self._graph.nodes_iter(): - if r.uuid == uuid: - runner = r - break + runner = self._find_uuid(uuid) if not runner: - raise ValueError("No runner found with uuid %s" % (uuid)) + raise ValueError("No uuid %s found" % (uuid)) else: self._graph.remove_node(runner) self._reset_internals() @@ -100,36 +119,16 @@ class Flow(linear_flow.Flow): if self._connected: return self._runners - # Clear out all edges (since we want to do a fresh connection) - for (u, v) in self._graph.edges(): - self._graph.remove_edge(u, v) + # Clear out all automatically added edges since we want to do a fresh + # connections. Leave the manually connected ones intact so that users + # still retain the dependencies they established themselves. + def discard_edge_func(u, v, e_data): + if e_data and e_data.get('reason') != 'manual': + return True + return False # Link providers to requirers. - # - # TODO(harlowja): allow for developers to manually establish these - # connections instead of automatically doing it for them?? - for n in self._graph.nodes_iter(): - n_providers = {} - n_requires = n.requires - if n_requires: - LOG.debug("Finding providers of %s for %s", n_requires, n) - for p in self._graph.nodes_iter(): - if n is p: - continue - p_provides = p.provides - p_satisfies = n_requires & p_provides - if p_satisfies: - # P produces for N so thats why we link P->N - # and not N->P - self._add_dependency(p, n) - for k in p_satisfies: - n_providers[k] = p - LOG.debug("Found provider of %s from %s", - p_satisfies, p) - n_requires = n_requires - p_satisfies - if n_requires: - raise exc.MissingDependencies(n, sorted(n_requires)) - n.providers = n_providers + graph_utils.connect(self._graph, discard_func=discard_edge_func) # Now figure out the order so that we can give the runners there # optional item providers as well as figure out the topological run diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 6f622bb4..c64635d8 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -57,13 +57,6 @@ class Flow(base.Flow): # The resumption strategy to use. self.resumer = None - @decorators.locked - def add_many(self, tasks): - uuids = [] - for t in tasks: - uuids.append(self.add(t)) - return uuids - @decorators.locked def add(self, task): """Adds a given task to this flow.""" diff --git a/taskflow/tests/unit/test_graph_flow.py b/taskflow/tests/unit/test_graph_flow.py index 43edd547..8cfde4e3 100644 --- a/taskflow/tests/unit/test_graph_flow.py +++ b/taskflow/tests/unit/test_graph_flow.py @@ -128,6 +128,50 @@ class GraphFlowTest(unittest2.TestCase): self.assertRaises(excp.InvalidStateException, flo.run, {}) + def test_manual_dependencies(self): + flo = gw.Flow("test-flow") + run_order = [] + + def run1(context): # pylint: disable=W0613,C0103 + run_order.append('ran1') + + def run2(context): # pylint: disable=W0613,C0103 + run_order.append('ran2') + + def run3(context): # pylint: disable=W0613,C0103 + run_order.append('ran3') + + (uuid1, uuid2, uuid3) = flo.add_many([run1, run2, run3]) + flo.add_dependency(uuid3, uuid2) + flo.add_dependency(uuid2, uuid1) + self.assertRaises(ValueError, flo.add_dependency, uuid2, uuid2) + self.assertRaises(ValueError, flo.add_dependency, + uuid2 + "blah", uuid3) + + flo.run({}) + self.assertEquals(['ran3', 'ran2', 'ran1'], run_order) + + def test_manual_providing_dependencies(self): + flo = gw.Flow("test-flow") + + @decorators.task(provides=['a']) + def run1(context): + return { + 'a': 2, + } + + @decorators.task + def run2(context, a): + pass + + uuid1 = flo.add(run1) + uuid2 = flo.add(run2, infer=False) + self.assertRaises(excp.MissingDependencies, + flo.run, {}) + flo.reset() + flo.add_dependency(uuid1, uuid2) + flo.run({}) + def test_happy_flow(self): flo = gw.Flow("test-flow")