diff --git a/taskflow/engines/action_engine/runtime.py b/taskflow/engines/action_engine/runtime.py index 395ce750..a04808da 100644 --- a/taskflow/engines/action_engine/runtime.py +++ b/taskflow/engines/action_engine/runtime.py @@ -38,7 +38,7 @@ class Runtime(object): self._task_executor = task_executor self._storage = storage self._compilation = compilation - self._scopes = {} + self._walkers_to_names = {} @property def compilation(self): @@ -76,9 +76,9 @@ class Runtime(object): self._task_executor) def fetch_scopes_for(self, atom_name): - """Fetches a tuple of the visible scopes for the given atom.""" + """Fetches a walker of the visible scopes for the given atom.""" try: - return self._scopes[atom_name] + return self._walkers_to_names[atom_name] except KeyError: atom = None for node in self.analyzer.iterate_all_nodes(): @@ -88,10 +88,10 @@ class Runtime(object): if atom is not None: walker = sc.ScopeWalker(self.compilation, atom, names_only=True) - self._scopes[atom_name] = visible_to = tuple(walker) + self._walkers_to_names[atom_name] = walker else: - visible_to = tuple([]) - return visible_to + walker = None + return walker # Various helper methods used by the runtime components; not for public # consumption... diff --git a/taskflow/engines/action_engine/scopes.py b/taskflow/engines/action_engine/scopes.py index c55305d0..99e1578b 100644 --- a/taskflow/engines/action_engine/scopes.py +++ b/taskflow/engines/action_engine/scopes.py @@ -56,9 +56,14 @@ class ScopeWalker(object): if self._node is None: raise ValueError("Unable to find atom '%s' in compilation" " hierarchy" % atom) + self._level_cache = {} self._atom = atom self._graph = compilation.execution_graph self._names_only = names_only + self._predecessors = None + + #: Function that extracts the *associated* atoms of a given tree node. + _extract_atoms = staticmethod(_extract_atoms) def __iter__(self): """Iterates over the visible scopes. @@ -95,27 +100,34 @@ class ScopeWalker(object): nodes (aka we have reached the top of the tree) or we run out of predecessors. """ - predecessors = set(self._graph.bfs_predecessors_iter(self._atom)) + if self._predecessors is None: + pred_iter = self._graph.bfs_predecessors_iter(self._atom) + self._predecessors = set(pred_iter) + predecessors = self._predecessors.copy() last = self._node - for parent in self._node.path_iter(include_self=False): + for lvl, parent in enumerate(self._node.path_iter(include_self=False)): if not predecessors: break last_idx = parent.index(last.item) - visible = [] - for a in _extract_atoms(parent, idx=last_idx): - if a in predecessors: - predecessors.remove(a) - if not self._names_only: + try: + visible, removals = self._level_cache[lvl] + predecessors = predecessors - removals + except KeyError: + visible = [] + removals = set() + for a in self._extract_atoms(parent, idx=last_idx): + if a in predecessors: + predecessors.remove(a) + removals.add(a) visible.append(a) - else: - visible.append(a.name) - if LOG.isEnabledFor(logging.BLATHER): - if not self._names_only: + self._level_cache[lvl] = (visible, removals) + if LOG.isEnabledFor(logging.BLATHER): visible_names = [a.name for a in visible] - else: - visible_names = visible - LOG.blather("Scope visible to '%s' (limited by parent '%s'" - " index < %s) is: %s", self._atom, - parent.item.name, last_idx, visible_names) - yield visible + LOG.blather("Scope visible to '%s' (limited by parent '%s'" + " index < %s) is: %s", self._atom, + parent.item.name, last_idx, visible_names) + if self._names_only: + yield [a.name for a in visible] + else: + yield visible last = parent