Compile lists of retry/task atoms at runtime compile time

Instead of recompiling and rebuilding this list every iteration
of the ``iterate_retries`` function we can just locally cache this
information in the runtime compile function and later just use
it directly.

Change-Id: I70e8409391d655730da61413300db05b25843350
This commit is contained in:
Joshua Harlow
2015-06-22 16:27:51 -07:00
parent 87c12603eb
commit f7e1524815
2 changed files with 24 additions and 7 deletions

View File

@@ -14,12 +14,12 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import functools
import itertools import itertools
from networkx.algorithms import traversal from networkx.algorithms import traversal
import six import six
from taskflow import retry as retry_atom
from taskflow import states as st from taskflow import states as st
@@ -92,6 +92,8 @@ class Analyzer(object):
self._execution_graph = runtime.compilation.execution_graph self._execution_graph = runtime.compilation.execution_graph
self._check_atom_transition = runtime.check_atom_transition self._check_atom_transition = runtime.check_atom_transition
self._fetch_edge_deciders = runtime.fetch_edge_deciders self._fetch_edge_deciders = runtime.fetch_edge_deciders
self._fetch_retries = functools.partial(
runtime.fetch_atoms_by_kind, 'retry')
def get_next_nodes(self, node=None): def get_next_nodes(self, node=None):
"""Get next nodes to run (originating from node or all nodes).""" """Get next nodes to run (originating from node or all nodes)."""
@@ -207,14 +209,13 @@ class Analyzer(object):
yield dst yield dst
def iterate_retries(self, state=None): def iterate_retries(self, state=None):
"""Iterates retry controllers that match the provided state. """Iterates retry atoms that match the provided state.
If no state is provided it will yield back all retry controllers. If no state is provided it will yield back all retry atoms.
""" """
for node in self._execution_graph.nodes_iter(): for atom in self._fetch_retries():
if isinstance(node, retry_atom.Retry): if not state or self.get_state(atom) == state:
if not state or self.get_state(node) == state: yield atom
yield node
def iterate_all_nodes(self): def iterate_all_nodes(self):
"""Yields back all nodes in the execution graph.""" """Yields back all nodes in the execution graph."""

View File

@@ -43,6 +43,7 @@ class Runtime(object):
self._storage = storage self._storage = storage
self._compilation = compilation self._compilation = compilation
self._atom_cache = {} self._atom_cache = {}
self._atoms_by_kind = {}
def compile(self): def compile(self):
"""Compiles & caches frequently used execution helper objects. """Compiles & caches frequently used execution helper objects.
@@ -63,6 +64,8 @@ class Runtime(object):
'task': self.task_scheduler, 'task': self.task_scheduler,
} }
execution_graph = self._compilation.execution_graph execution_graph = self._compilation.execution_graph
all_retry_atoms = []
all_task_atoms = []
for atom in self.analyzer.iterate_all_nodes(): for atom in self.analyzer.iterate_all_nodes():
metadata = {} metadata = {}
walker = sc.ScopeWalker(self.compilation, atom, names_only=True) walker = sc.ScopeWalker(self.compilation, atom, names_only=True)
@@ -70,10 +73,12 @@ class Runtime(object):
check_transition_handler = st.check_task_transition check_transition_handler = st.check_task_transition
change_state_handler = change_state_handlers['task'] change_state_handler = change_state_handlers['task']
scheduler = schedulers['task'] scheduler = schedulers['task']
all_task_atoms.append(atom)
else: else:
check_transition_handler = st.check_retry_transition check_transition_handler = st.check_retry_transition
change_state_handler = change_state_handlers['retry'] change_state_handler = change_state_handlers['retry']
scheduler = schedulers['retry'] scheduler = schedulers['retry']
all_retry_atoms.append(atom)
edge_deciders = {} edge_deciders = {}
for previous_atom in execution_graph.predecessors(atom): for previous_atom in execution_graph.predecessors(atom):
# If there is any link function that says if this connection # If there is any link function that says if this connection
@@ -89,6 +94,8 @@ class Runtime(object):
metadata['scheduler'] = scheduler metadata['scheduler'] = scheduler
metadata['edge_deciders'] = edge_deciders metadata['edge_deciders'] = edge_deciders
self._atom_cache[atom.name] = metadata self._atom_cache[atom.name] = metadata
self._atoms_by_kind['retry'] = all_retry_atoms
self._atoms_by_kind['task'] = all_task_atoms
@property @property
def compilation(self): def compilation(self):
@@ -150,6 +157,15 @@ class Runtime(object):
metadata = self._atom_cache[atom.name] metadata = self._atom_cache[atom.name]
return metadata['edge_deciders'] return metadata['edge_deciders']
def fetch_atoms_by_kind(self, kind):
"""Fetches all the atoms of a given kind.
NOTE(harlowja): Currently only ``task`` or ``retry`` are valid
kinds of atoms (requesting other kinds will just
return empty lists).
"""
return self._atoms_by_kind.get(kind, [])
def fetch_scheduler(self, atom): def fetch_scheduler(self, atom):
"""Fetches the cached specific scheduler for the given atom.""" """Fetches the cached specific scheduler for the given atom."""
# This does not check if the name exists (since this is only used # This does not check if the name exists (since this is only used