diff --git a/taskflow/engines/action_engine/completer.py b/taskflow/engines/action_engine/completer.py index 140785ad..d5c43a58 100644 --- a/taskflow/engines/action_engine/completer.py +++ b/taskflow/engines/action_engine/completer.py @@ -26,7 +26,6 @@ from taskflow.engines.action_engine import executor as ex from taskflow import logging from taskflow import retry as retry_atom from taskflow import states as st -from taskflow import task as task_atom from taskflow.types import failure LOG = logging.getLogger(__name__) @@ -110,26 +109,10 @@ class Completer(object): self._runtime = weakref.proxy(runtime) self._analyzer = runtime.analyzer self._storage = runtime.storage - self._task_action = runtime.task_action - self._retry_action = runtime.retry_action self._undefined_resolver = RevertAll(self._runtime) self._defer_reverts = strutils.bool_from_string( self._runtime.options.get('defer_reverts', False)) - def _complete_task(self, task, outcome, result): - """Completes the given task, processes task failure.""" - if outcome == ex.EXECUTED: - self._task_action.complete_execution(task, result) - else: - self._task_action.complete_reversion(task, result) - - def _complete_retry(self, retry, outcome, result): - """Completes the given retry, processes retry failure.""" - if outcome == ex.EXECUTED: - self._retry_action.complete_execution(retry, result) - else: - self._retry_action.complete_reversion(retry, result) - def resume(self): """Resumes atoms in the contained graph. @@ -165,10 +148,11 @@ class Completer(object): Returns whether the result should be saved into an accumulator of failures or whether this should not be done. """ - if isinstance(node, task_atom.BaseTask): - self._complete_task(node, outcome, result) + handler = self._runtime.fetch_action(node) + if outcome == ex.EXECUTED: + handler.complete_execution(node, result) else: - self._complete_retry(node, outcome, result) + handler.complete_reversion(node, result) if isinstance(result, failure.Failure): if outcome == ex.EXECUTED: self._process_atom_failure(node, result) @@ -182,7 +166,8 @@ class Completer(object): retry = self._analyzer.find_retry(atom) if retry is not None: # Ask retry controller what to do in case of failure. - strategy = self._retry_action.on_failure(retry, atom, failure) + handler = self._runtime.fetch_action(retry) + strategy = handler.on_failure(retry, atom, failure) if strategy == retry_atom.RETRY: return RevertAndRetry(self._runtime, retry) elif strategy == retry_atom.REVERT: diff --git a/taskflow/engines/action_engine/runtime.py b/taskflow/engines/action_engine/runtime.py index a79d9143..41dbd779 100644 --- a/taskflow/engines/action_engine/runtime.py +++ b/taskflow/engines/action_engine/runtime.py @@ -101,6 +101,10 @@ class Runtime(object): com.TASK: st.check_task_transition, com.RETRY: st.check_retry_transition, } + actions = { + com.TASK: self.task_action, + com.RETRY: self.retry_action, + } graph = self._compilation.execution_graph for node, node_data in graph.nodes_iter(data=True): node_kind = node_data['kind'] @@ -110,6 +114,7 @@ class Runtime(object): check_transition_handler = check_transition_handlers[node_kind] change_state_handler = change_state_handlers[node_kind] scheduler = schedulers[node_kind] + action = actions[node_kind] else: raise exc.CompilationFailure("Unknown node kind '%s'" " encountered" % node_kind) @@ -121,6 +126,7 @@ class Runtime(object): metadata['change_state_handler'] = change_state_handler metadata['scheduler'] = scheduler metadata['edge_deciders'] = tuple(deciders_it) + metadata['action'] = action self._atom_cache[node.name] = metadata @property @@ -197,6 +203,11 @@ class Runtime(object): # not exist and therefore doesn't need to handle that case). return self._fetch_atom_metadata_entry(atom.name, 'scheduler') + def fetch_action(self, atom): + """Fetches the cached action handler for the given atom.""" + metadata = self._atom_cache[atom.name] + return metadata['action'] + def fetch_scopes_for(self, atom_name): """Fetches a walker of the visible scopes for the given atom.""" try: