diff --git a/taskflow/storage.py b/taskflow/storage.py index 14fd3180..35ba7d0e 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -51,6 +51,7 @@ class Storage(object): self._backend = backend self._flowdetail = flow_detail self._lock = self._lock_cls() + self._transients = {} # NOTE(imelnikov): failure serialization looses information, # so we cache failures here, in atom name -> failure mapping. @@ -99,6 +100,8 @@ class Storage(object): Returns uuid for the task details corresponding to the task with given name. """ + if not task_name: + raise ValueError("Task name must be non-empty") with self._lock.write_lock(): try: task_id = self._atom_name_to_uuid[task_name] @@ -127,6 +130,8 @@ class Storage(object): Returns uuid for the retry details corresponding to the retry with given name. """ + if not retry_name: + raise ValueError("Retry name must be non-empty") with self._lock.write_lock(): try: retry_id = self._atom_name_to_uuid[retry_name] @@ -405,17 +410,21 @@ class Storage(object): if self._reset_atom(ad, state): self._with_connection(self._save_atom_detail, ad) - def inject(self, pairs): + def inject(self, pairs, transient=False): """Add values into storage. This method should be used to put flow parameters (requirements that are not satisfied by any task in the flow) into storage. + + :param: transient save the data in-memory only instead of persisting + the data to backend storage (useful for resource-like objects + or similar objects which should *not* be persisted) """ - with self._lock.write_lock(): + + def save_persistent(): try: - ad = self._atomdetail_by_name( - self.injector_name, - expected_type=logbook.TaskDetail) + ad = self._atomdetail_by_name(self.injector_name, + expected_type=logbook.TaskDetail) except exceptions.NotFound: uuid = uuidutils.generate_uuid() self._create_atom_detail(logbook.TaskDetail, @@ -427,8 +436,21 @@ class Storage(object): else: ad.results.update(pairs) self._with_connection(self._save_atom_detail, ad) - names = six.iterkeys(ad.results) - self._set_result_mapping(self.injector_name, + return (self.injector_name, six.iterkeys(ad.results)) + + def save_transient(): + self._transients.update(pairs) + # NOTE(harlowja): none is not a valid atom name, so that means + # we can use it internally to reference all of our transient + # variables. + return (None, six.iterkeys(self._transients)) + + with self._lock.write_lock(): + if transient: + (atom_name, names) = save_transient() + else: + (atom_name, names) = save_persistent() + self._set_result_mapping(atom_name, dict((name, name) for name in names)) def _set_result_mapping(self, atom_name, mapping): @@ -470,8 +492,11 @@ class Storage(object): raise exceptions.NotFound("Name %r is not mapped" % name) # Return the first one that is found. for (atom_name, index) in reversed(indexes): - try: + if not atom_name: + results = self._transients + else: results = self._get(atom_name, only_last=True) + try: return misc.item_from(results, index, name) except exceptions.NotFound: pass diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index eb088190..001cba97 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -371,6 +371,35 @@ class StorageTestMixin(object): s.ensure_task('my task') self.assertTrue(uuidutils.is_uuid_like(s.get_atom_uuid('my task'))) + def test_transient_storage_fetch_all(self): + s = self._get_storage() + s.inject([("a", "b")], transient=True) + s.inject([("b", "c")]) + + results = s.fetch_all() + self.assertEqual({"a": "b", "b": "c"}, results) + + def test_transient_storage_fetch_mapped(self): + s = self._get_storage() + s.inject([("a", "b")], transient=True) + s.inject([("b", "c")]) + desired = { + 'y': 'a', + 'z': 'b', + } + args = s.fetch_mapped_args(desired) + self.assertEqual({'y': 'b', 'z': 'c'}, args) + + def test_transient_storage_restore(self): + _lb, flow_detail = p_utils.temporary_flow_detail(self.backend) + s = self._get_storage(flow_detail=flow_detail) + s.inject([("a", "b")], transient=True) + s.inject([("b", "c")]) + + s2 = self._get_storage(flow_detail=flow_detail) + results = s2.fetch_all() + self.assertEqual({"b": "c"}, results) + def test_unknown_task_by_name(self): s = self._get_storage() self.assertRaisesRegexp(exceptions.NotFound,