diff --git a/taskflow/exceptions.py b/taskflow/exceptions.py index 29307bbb..9f1b4fa0 100644 --- a/taskflow/exceptions.py +++ b/taskflow/exceptions.py @@ -98,6 +98,10 @@ class DependencyFailure(TaskFlowException): """Raised when some type of dependency problem occurs.""" +class AmbiguousDependency(DependencyFailure): + """Raised when some type of ambiguous dependency problem occurs.""" + + class MissingDependencies(DependencyFailure): """Raised when a entity has dependencies that can not be satisfied.""" MESSAGE_TPL = ("%(who)s requires %(requirements)s but no other entity" diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 658e051c..f07e7435 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -21,6 +21,22 @@ from taskflow import flow from taskflow.types import graph as gr +def _unsatisfied_requires(node, graph, *additional_provided): + """Extracts the unsatisified symbol requirements of a single node.""" + requires = set(node.requires) + if not requires: + return requires + for provided in additional_provided: + requires = requires - provided + if not requires: + return requires + for pred in graph.bfs_predecessors_iter(node): + requires = requires - pred.provides + if not requires: + return requires + return requires + + class Flow(flow.Flow): """Graph flow pattern. @@ -80,33 +96,59 @@ class Flow(flow.Flow): if not graph.is_directed_acyclic(): raise exc.DependencyFailure("No path through the items in the" " graph produces an ordering that" - " will allow for correct dependency" - " resolution") - self._graph = graph - self._graph.freeze() + " will allow for logical" + " edge traversal") + self._graph = graph.freeze() - def add(self, *items): - """Adds a given task/tasks/flow/flows to this flow.""" + def add(self, *items, **kwargs): + """Adds a given task/tasks/flow/flows to this flow. + + :param items: items to add to the flow + :param kwargs: keyword arguments, the two keyword arguments + currently processed are: + + * ``resolve_requires`` a boolean that when true (the + default) implies that when items are added their + symbol requirements will be matched to existing items + and links will be automatically made to those + providers. If multiple possible providers exist + then a AmbiguousDependency exception will be raised. + * ``resolve_existing``, a boolean that when true (the + default) implies that on addition of a new item that + existing items will have their requirements scanned + for symbols that this newly added item can provide. + If a match is found a link is automatically created + from the newly added item to the requiree. + """ items = [i for i in items if not self._graph.has_node(i)] if not items: return self - requirements = collections.defaultdict(list) - provided = {} + # This syntax will *hopefully* be better in future versions of python. + # + # See: http://legacy.python.org/dev/peps/pep-3102/ (python 3.0+) + resolve_requires = bool(kwargs.get('resolve_requires', True)) + resolve_existing = bool(kwargs.get('resolve_existing', True)) - def update_requirements(node): - for value in node.requires: - requirements[value].append(node) + # Figure out what the existing nodes *still* require and what they + # provide so we can do this lookup later when inferring. + required = collections.defaultdict(list) + provided = collections.defaultdict(list) - for node in self: - update_requirements(node) - for value in node.provides: - provided[value] = node + retry_provides = set() + if self._retry is not None: + for value in self._retry.requires: + required[value].append(self._retry) + for value in self._retry.provides: + retry_provides.add(value) + provided[value].append(self._retry) - if self.retry: - update_requirements(self.retry) - provided.update(dict((k, self.retry) - for k in self.retry.provides)) + for item in self._graph.nodes_iter(): + for value in _unsatisfied_requires(item, self._graph, + retry_provides): + required[value].append(item) + for value in item.provides: + provided[value].append(item) # NOTE(harlowja): Add items and edges to a temporary copy of the # underlying graph and only if that is successful added to do we then @@ -114,37 +156,41 @@ class Flow(flow.Flow): tmp_graph = gr.DiGraph(self._graph) for item in items: tmp_graph.add_node(item) - update_requirements(item) - for value in item.provides: - if value in provided: - raise exc.DependencyFailure( - "%(item)s provides %(value)s but is already being" - " provided by %(flow)s and duplicate producers" - " are disallowed" - % dict(item=item.name, - flow=provided[value].name, - value=value)) - if self.retry and value in self.retry.requires: - raise exc.DependencyFailure( - "Flows retry controller %(retry)s requires %(value)s " - "but item %(item)s being added to the flow produces " - "that item, this creates a cyclic dependency and is " - "disallowed" - % dict(item=item.name, - retry=self.retry.name, - value=value)) - provided[value] = item - for value in item.requires: - if value in provided: - self._link(provided[value], item, - graph=tmp_graph, reason=value) + # Try to find a valid provider. + if resolve_requires: + for value in _unsatisfied_requires(item, tmp_graph, + retry_provides): + if value in provided: + providers = provided[value] + if len(providers) > 1: + provider_names = [n.name for n in providers] + raise exc.AmbiguousDependency( + "Resolution error detected when" + " adding %(item)s, multiple" + " providers %(providers)s found for" + " required symbol '%(value)s'" + % dict(item=item.name, + providers=sorted(provider_names), + value=value)) + else: + self._link(providers[0], item, + graph=tmp_graph, reason=value) + else: + required[value].append(item) for value in item.provides: - if value in requirements: - for node in requirements[value]: - self._link(item, node, - graph=tmp_graph, reason=value) + provided[value].append(item) + + # See if what we provide fulfills any existing requiree. + if resolve_existing: + for value in item.provides: + if value in required: + for requiree in list(required[value]): + if requiree is not item: + self._link(item, requiree, + graph=tmp_graph, reason=value) + required[value].remove(requiree) self._swap(tmp_graph) return self @@ -177,15 +223,7 @@ class Flow(flow.Flow): retry_provides.update(self._retry.provides) g = self._get_subgraph() for item in g.nodes_iter(): - item_requires = item.requires - retry_provides - # Now scan predecessors to see if they provide what we want. - if item_requires: - for pred_item in g.bfs_predecessors_iter(item): - item_requires = item_requires - pred_item.provides - if not item_requires: - break - if item_requires: - requires.update(item_requires) + requires.update(_unsatisfied_requires(item, g, retry_provides)) return frozenset(requires) diff --git a/taskflow/tests/unit/patterns/test_graph_flow.py b/taskflow/tests/unit/patterns/test_graph_flow.py index c7dad38e..62dbc287 100644 --- a/taskflow/tests/unit/patterns/test_graph_flow.py +++ b/taskflow/tests/unit/patterns/test_graph_flow.py @@ -97,7 +97,40 @@ class GraphFlowTest(test.TestCase): task1 = _task(name='task1', provides=['a', 'b']) task2 = _task(name='task2', provides=['a', 'c']) f = gf.Flow('test') - self.assertRaises(exc.DependencyFailure, f.add, task2, task1) + f.add(task2, task1) + self.assertEqual(set(['a', 'b', 'c']), f.provides) + + def test_graph_flow_ambiguous_provides(self): + task1 = _task(name='task1', provides=['a', 'b']) + task2 = _task(name='task2', provides=['a']) + f = gf.Flow('test') + f.add(task1, task2) + self.assertEqual(set(['a', 'b']), f.provides) + task3 = _task(name='task3', requires=['a']) + self.assertRaises(exc.AmbiguousDependency, f.add, task3) + + def test_graph_flow_no_resolve_requires(self): + task1 = _task(name='task1', provides=['a', 'b', 'c']) + task2 = _task(name='task2', requires=['a', 'b']) + f = gf.Flow('test') + f.add(task1, task2, resolve_requires=False) + self.assertEqual(set(['a', 'b']), f.requires) + + def test_graph_flow_no_resolve_existing(self): + task1 = _task(name='task1', requires=['a', 'b']) + task2 = _task(name='task2', provides=['a', 'b']) + f = gf.Flow('test') + f.add(task1) + f.add(task2, resolve_existing=False) + self.assertEqual(set(['a', 'b']), f.requires) + + def test_graph_flow_resolve_existing(self): + task1 = _task(name='task1', requires=['a', 'b']) + task2 = _task(name='task2', provides=['a', 'b']) + f = gf.Flow('test') + f.add(task1) + f.add(task2, resolve_existing=True) + self.assertEqual(set([]), f.requires) def test_graph_flow_with_retry(self): ret = retry.AlwaysRevert(requires=['a'], provides=['b']) diff --git a/taskflow/tests/unit/test_flow_dependencies.py b/taskflow/tests/unit/test_flow_dependencies.py index 47604e81..69f4a8fe 100644 --- a/taskflow/tests/unit/test_flow_dependencies.py +++ b/taskflow/tests/unit/test_flow_dependencies.py @@ -218,9 +218,8 @@ class FlowDependenciesTest(test.TestCase): def test_graph_flow_provides_provided_value_other_call(self): flow = gf.Flow('gf') flow.add(utils.TaskOneReturn('task1', provides='x')) - self.assertRaises(exceptions.DependencyFailure, - flow.add, - utils.TaskOneReturn('task2', provides='x')) + flow.add(utils.TaskOneReturn('task2', provides='x')) + self.assertEqual(set(['x']), flow.provides) def test_graph_flow_multi_provides_and_requires_values(self): flow = gf.Flow('gf').add( @@ -367,17 +366,16 @@ class FlowDependenciesTest(test.TestCase): self.assertEqual(flow.requires, set(['x', 'y', 'c'])) self.assertEqual(flow.provides, set(['a', 'b', 'z'])) - def test_graph_flow_retry_and_task_dependency_conflict(self): + def test_graph_flow_retry_and_task_dependency_provide_require(self): flow = gf.Flow('gf', retry.AlwaysRevert('rt', requires=['x'])) - self.assertRaises(exceptions.DependencyFailure, - flow.add, - utils.TaskOneReturn(provides=['x'])) + flow.add(utils.TaskOneReturn(provides=['x'])) + self.assertEqual(set(['x']), flow.provides) + self.assertEqual(set(['x']), flow.requires) def test_graph_flow_retry_and_task_provide_same_value(self): flow = gf.Flow('gf', retry.AlwaysRevert('rt', provides=['x'])) - self.assertRaises(exceptions.DependencyFailure, - flow.add, - utils.TaskOneReturn('t1', provides=['x'])) + flow.add(utils.TaskOneReturn('t1', provides=['x'])) + self.assertEqual(set(['x']), flow.provides) def test_builtin_retry_args(self):