diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 3bdbfd74..bd95e98d 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -173,7 +173,7 @@ class ActionEngine(base.EngineBase): else: self.storage.ensure_task(node.name, version, node.save_as) if node.inject: - self.storage.inject_task_args(node.name, node.inject) + self.storage.inject_atom_args(node.name, node.inject) self._change_state(states.SUSPENDED) # does nothing in PENDING state @lock_utils.locked diff --git a/taskflow/engines/action_engine/retry_action.py b/taskflow/engines/action_engine/retry_action.py index eaedf04b..a1ca3abb 100644 --- a/taskflow/engines/action_engine/retry_action.py +++ b/taskflow/engines/action_engine/retry_action.py @@ -34,7 +34,7 @@ class RetryAction(object): def _get_retry_args(self, retry): kwargs = self._storage.fetch_mapped_args(retry.rebind, - task_name=retry.name) + atom_name=retry.name) kwargs['history'] = self._storage.get_retry_history(retry.name) return kwargs diff --git a/taskflow/engines/action_engine/task_action.py b/taskflow/engines/action_engine/task_action.py index 9ab8c460..c0d1daa5 100644 --- a/taskflow/engines/action_engine/task_action.py +++ b/taskflow/engines/action_engine/task_action.py @@ -66,7 +66,7 @@ class TaskAction(object): raise exceptions.InvalidState("Task %s is in invalid state and" " can't be executed" % task.name) kwargs = self._storage.fetch_mapped_args(task.rebind, - task_name=task.name) + atom_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) return self._task_executor.execute_task(task, task_uuid, kwargs, self._on_update_progress) @@ -83,7 +83,7 @@ class TaskAction(object): raise exceptions.InvalidState("Task %s is in invalid state and" " can't be reverted" % task.name) kwargs = self._storage.fetch_mapped_args(task.rebind, - task_name=task.name) + atom_name=task.name) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) failures = self._storage.get_failures() diff --git a/taskflow/storage.py b/taskflow/storage.py index e3a208a4..353d44f3 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -411,9 +411,22 @@ class Storage(object): if self._reset_atom(ad, state): self._with_connection(self._save_atom_detail, ad) - def inject_task_args(self, task_name, injected_args): - self._injected_args.setdefault(task_name, {}) - self._injected_args[task_name].update(injected_args) + def inject_atom_args(self, atom_name, pairs): + """Add *transient* values into storage for a specific atom only. + + This method injects a dictionary/pairs of arguments for an atom so that + when that atom is scheduled for execution it will have immediate access + to these arguments. + + NOTE(harlowja): injected atom arguments take precedence over arguments + provided by predecessor atoms or arguments provided by injecting into + the flow scope (using the inject() method). + """ + if atom_name not in self._atom_name_to_uuid: + raise exceptions.NotFound("Unknown atom name: %s" % atom_name) + with self._lock.write_lock(): + self._injected_args.setdefault(atom_name, {}) + self._injected_args[atom_name].update(pairs) def inject(self, pairs, transient=False): """Add values into storage. @@ -521,12 +534,12 @@ class Storage(object): pass return results - def fetch_mapped_args(self, args_mapping, task_name=None): + def fetch_mapped_args(self, args_mapping, atom_name=None): """Fetch arguments for an atom using an atoms arguments mapping.""" with self._lock.read_lock(): injected_args = {} - if task_name: - injected_args = self._injected_args.get(task_name, {}) + if atom_name: + injected_args = self._injected_args.get(atom_name, {}) mapped_args = {} for key, name in six.iteritems(args_mapping): if name in injected_args: