diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 40b6e685..a7a205f5 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -18,6 +18,7 @@ import collections +import networkx as nx from networkx.algorithms import dag from networkx.classes import digraph @@ -36,25 +37,47 @@ class Flow(flow.Flow): def __init__(self, name, uuid=None): super(Flow, self).__init__(name, uuid) - self._graph = digraph.DiGraph() + self._graph = nx.freeze(digraph.DiGraph()) + + def _validate(self, graph=None): + if graph is None: + graph = self._graph + # Ensure that there is a valid topological ordering. + if not dag.is_directed_acyclic_graph(graph): + raise exc.DependencyFailure("No path through the items in the" + " graph produces an ordering that" + " will allow for correct dependency" + " resolution") def link(self, u, v): if not self._graph.has_node(u): raise ValueError('Item %s not found to link from' % (u)) if not self._graph.has_node(v): raise ValueError('Item %s not found to link to' % (v)) - self._graph.add_edge(u, v) + if self._graph.has_edge(u, v): + return self - # Ensure that there is a valid topological ordering. - if not dag.is_directed_acyclic_graph(self._graph): - self._graph.remove_edge(u, v) - raise exc.DependencyFailure("No path through the items in the" - " graph produces an ordering that" - " will allow for correct dependency" - " resolution") + # NOTE(harlowja): Add an edge to a temporary copy and only if that + # copy is valid then do we swap with the underlying graph. + tmp_graph = digraph.DiGraph(self._graph) + tmp_graph.add_edge(u, v) + self._swap(tmp_graph) + return self + + def _swap(self, replacement_graph): + """Validates the replacement graph and then swaps the underlying graph + with a frozen version of the replacement graph (this maintains the + invariant that the underlying graph is immutable). + """ + self._validate(replacement_graph) + self._graph = nx.freeze(replacement_graph) def add(self, *items): """Adds a given task/tasks/flow/flows to this flow.""" + items = [i for i in items if not self._graph.has_node(i)] + if not items: + return self + requirements = collections.defaultdict(list) provided = {} @@ -67,42 +90,42 @@ class Flow(flow.Flow): for value in node.provides: provided[value] = node - try: - for item in items: - self._graph.add_node(item) - update_requirements(item) - for value in item.provides: - if value in provided: - raise exc.DependencyFailure( - "%(item)s provides %(value)s but is already being" - " provided by %(flow)s and duplicate producers" - " are disallowed" - % dict(item=item.name, - flow=provided[value].name, - value=value)) - provided[value] = item + # NOTE(harlowja): Add items and edges to a temporary copy of the + # underlying graph and only if that is successful added to do we then + # swap with the underlying graph. + tmp_graph = digraph.DiGraph(self._graph) + for item in items: + tmp_graph.add_node(item) + update_requirements(item) + for value in item.provides: + if value in provided: + raise exc.DependencyFailure( + "%(item)s provides %(value)s but is already being" + " provided by %(flow)s and duplicate producers" + " are disallowed" + % dict(item=item.name, + flow=provided[value].name, + value=value)) + provided[value] = item - for value in item.requires: - if value in provided: - self.link(provided[value], item) + for value in item.requires: + if value in provided: + tmp_graph.add_edge(provided[value], item) - for value in item.provides: - if value in requirements: - for node in requirements[value]: - self.link(item, node) - - except Exception: - self._graph.remove_nodes_from(items) - raise + for value in item.provides: + if value in requirements: + for node in requirements[value]: + tmp_graph.add_edge(item, node) + self._swap(tmp_graph) return self def __len__(self): return self._graph.number_of_nodes() def __iter__(self): - for child in self._graph.nodes_iter(): - yield child + for n in self._graph.nodes_iter(): + yield n @property def provides(self):