diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 3dda9125..7c275ba6 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -215,8 +215,9 @@ class ActionEngine(base.Engine): """Ensure all contained atoms exist in the storage unit.""" transient = strutils.bool_from_string( self._options.get('inject_transient', True)) + self.storage.ensure_atoms( + self._compilation.execution_graph.nodes_iter()) for node in self._compilation.execution_graph.nodes_iter(): - self.storage.ensure_atom(node) if node.inject: self.storage.inject_atom_args(node.name, node.inject, diff --git a/taskflow/storage.py b/taskflow/storage.py index d50d8771..8eb19c09 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -186,10 +186,20 @@ class Storage(object): with contextlib.closing(self._backend.get_connection()) as conn: return functor(conn, *args, **kwargs) - def ensure_atom(self, atom): - """Ensure that there is an atomdetail in storage for the given atom. + def ensure_atoms(self, atoms_iter): + """Ensure there is an atomdetail for **each** of the given atoms. - Returns uuid for the atomdetail that is/was created. + Returns list of atomdetail uuids for each atom processed. + """ + atom_ids = [] + for atom in atoms_iter: + atom_ids.append(self.ensure_atom(atom)) + return atom_ids + + def ensure_atom(self, atom): + """Ensure there is an atomdetail for the **given** atom. + + Returns the uuid for the atomdetail that corresponds to the given atom. """ match = misc.match_type(atom, self._ensure_matchers) if not match: