diff --git a/taskflow/patterns/graph_flow.py b/taskflow/patterns/graph_flow.py index 50a4d61d..9e255871 100644 --- a/taskflow/patterns/graph_flow.py +++ b/taskflow/patterns/graph_flow.py @@ -69,7 +69,20 @@ class Flow(flow.Flow): _unsatisfied_requires = staticmethod(_unsatisfied_requires) def link(self, u, v, decider=None): - """Link existing node u as a runtime dependency of existing node v.""" + """Link existing node u as a runtime dependency of existing node v. + + :param u: task or flow to create a link from (must exist already) + :param v: task or flow to create a link to (must exist already) + :param decider: A callback function that will be expected to decide + at runtime whether ``v`` should be allowed to + execute (or whether the execution of ``v`` should be + ignored, and therefore not executed). It is expected + to take as single keyword argument ``history`` which + will be the execution results of all ``u`` decideable + links that have ``v`` as a target. It is expected to + return a single boolean (``True`` to allow ``v`` + execution or ``False`` to not). + """ if not self._graph.has_node(u): raise ValueError("Node '%s' not found to link from" % (u)) if not self._graph.has_node(v): @@ -251,6 +264,18 @@ class Flow(flow.Flow): return frozenset(requires) +def _reset_cached_subgraph(func): + """Resets cached subgraph after execution, in case it was affected.""" + + @six.wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + self._subgraph = None + return result + + return wrapper + + class TargetedFlow(Flow): """Graph flow with a target. @@ -282,19 +307,9 @@ class TargetedFlow(Flow): self._target = None self._subgraph = None - def add(self, *nodes): - """Adds a given task/tasks/flow/flows to this flow.""" - super(TargetedFlow, self).add(*nodes) - # reset cached subgraph, in case it was affected - self._subgraph = None - return self + add = _reset_cached_subgraph(Flow.add) - def link(self, u, v, decider=None): - """Link existing node u as a runtime dependency of existing node v.""" - super(TargetedFlow, self).link(u, v, decider=decider) - # reset cached subgraph, in case it was affected - self._subgraph = None - return self + link = _reset_cached_subgraph(Flow.link) def _get_subgraph(self): if self._subgraph is not None: