diff --git a/taskflow/atom.py b/taskflow/atom.py index 60a6c427..a1e9b28f 100644 --- a/taskflow/atom.py +++ b/taskflow/atom.py @@ -31,6 +31,10 @@ from taskflow.utils import misc _sequence_types = (list, tuple, collections.Sequence) _set_types = (set, collections.Set) +# the default list of revert arguments to ignore when deriving +# revert argument mapping from the revert method signature +_default_revert_args = ('result', 'flow_failures') + def _save_as_to_mapping(save_as): """Convert save_as to mapping name => index. @@ -176,6 +180,17 @@ class Atom(object): injected into the atoms scope before the atom execution commences (this allows for providing atom *local* values that do not need to be provided by other atoms/dependents). + :param rebind: A dict of key/value pairs used to define argument + name conversions for inputs to this atom's ``execute`` + method. + :param revert_rebind: The same as ``rebind`` but for the ``revert`` + method. If unpassed, ``rebind`` will be used + instead. + :param requires: A set or list of required inputs for this atom's + ``execute`` method. + :param revert_requires: A set or list of required inputs for this atom's + ``revert`` method. If unpassed, ```requires`` will + be used. :ivar version: An *immutable* version that associates version information with this atom. It can be useful in resuming older versions of atoms. Standard major, minor versioning concepts @@ -191,12 +206,17 @@ class Atom(object): the names that this atom expects (in a way this is like remapping a namespace of another atom into the namespace of this atom). + :ivar revert_rebind: The same as ``rebind`` but for the revert method. This + should only differ from ``rebind`` if the ``revert`` + method has a different signature from ``execute`` or + a different ``revert_rebind`` value was received. :ivar inject: See parameter ``inject``. :ivar name: See parameter ``name``. :ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs this atom requires to function. :ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs - that are optional for this atom to function. + that are optional for this atom to ``execute``. + :ivar revert_optional: The ``revert`` version of ``optional``. :ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs this atom produces. """ @@ -232,7 +252,7 @@ class Atom(object): def __init__(self, name=None, provides=None, requires=None, auto_extract=True, rebind=None, inject=None, - ignore_list=None): + ignore_list=None, revert_rebind=None, revert_requires=None): if provides is None: provides = self.default_provides @@ -241,14 +261,30 @@ class Atom(object): self.version = (1, 0) self.inject = inject self.save_as = _save_as_to_mapping(provides) - self.requires = sets.OrderedSet() - self.optional = sets.OrderedSet() self.provides = sets.OrderedSet(self.save_as) - self.rebind = collections.OrderedDict() - self._build_arg_mapping(self.execute, requires=requires, - rebind=rebind, auto_extract=auto_extract, - ignore_list=ignore_list) + if ignore_list is None: + ignore_list = [] + + self.rebind, exec_requires, self.optional = self._build_arg_mapping( + self.execute, + requires=requires, + rebind=rebind, auto_extract=auto_extract, + ignore_list=ignore_list + ) + + revert_ignore = ignore_list + list(_default_revert_args) + revert_mapping = self._build_arg_mapping( + self.revert, + requires=revert_requires or requires, + rebind=revert_rebind or rebind, + auto_extract=auto_extract, + ignore_list=revert_ignore + ) + (self.revert_rebind, addl_requires, + self.revert_optional) = revert_mapping + + self.requires = exec_requires.union(addl_requires) def _build_arg_mapping(self, executor, requires=None, rebind=None, auto_extract=True, ignore_list=None): @@ -263,13 +299,13 @@ class Atom(object): for (arg_name, bound_name) in itertools.chain(six.iteritems(required), six.iteritems(optional)): rebind.setdefault(arg_name, bound_name) - self.rebind = rebind - self.requires = sets.OrderedSet(six.itervalues(required)) - self.optional = sets.OrderedSet(six.itervalues(optional)) + requires = sets.OrderedSet(six.itervalues(required)) + optional = sets.OrderedSet(six.itervalues(optional)) if self.inject: inject_keys = frozenset(six.iterkeys(self.inject)) - self.requires -= inject_keys - self.optional -= inject_keys + requires -= inject_keys + optional -= inject_keys + return rebind, requires, optional def pre_execute(self): """Code to be run prior to executing the atom. diff --git a/taskflow/engines/action_engine/actions/retry.py b/taskflow/engines/action_engine/actions/retry.py index 126b9038..49dd94ac 100644 --- a/taskflow/engines/action_engine/actions/retry.py +++ b/taskflow/engines/action_engine/actions/retry.py @@ -30,12 +30,19 @@ class RetryAction(base.Action): super(RetryAction, self).__init__(storage, notifier) self._retry_executor = retry_executor - def _get_retry_args(self, retry, addons=None): - arguments = self._storage.fetch_mapped_args( - retry.rebind, - atom_name=retry.name, - optional_args=retry.optional - ) + def _get_retry_args(self, retry, revert=False, addons=None): + if revert: + arguments = self._storage.fetch_mapped_args( + retry.revert_rebind, + atom_name=retry.name, + optional_args=retry.revert_optional + ) + else: + arguments = self._storage.fetch_mapped_args( + retry.rebind, + atom_name=retry.name, + optional_args=retry.optional + ) history = self._storage.get_retry_history(retry.name) arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history if addons: @@ -92,7 +99,7 @@ class RetryAction(base.Action): retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(), } return self._retry_executor.revert_retry( - retry, self._get_retry_args(retry, addons=arg_addons)) + retry, self._get_retry_args(retry, addons=arg_addons, revert=True)) def on_failure(self, retry, atom, last_failure): self._storage.save_retry_failure(retry.name, atom.name, last_failure) diff --git a/taskflow/engines/action_engine/actions/task.py b/taskflow/engines/action_engine/actions/task.py index ac117e1c..6d0981d8 100644 --- a/taskflow/engines/action_engine/actions/task.py +++ b/taskflow/engines/action_engine/actions/task.py @@ -122,9 +122,9 @@ class TaskAction(base.Action): def schedule_reversion(self, task): self.change_state(task, states.REVERTING, progress=0.0) arguments = self._storage.fetch_mapped_args( - task.rebind, + task.revert_rebind, atom_name=task.name, - optional_args=task.optional + optional_args=task.revert_optional ) task_uuid = self._storage.get_atom_uuid(task.name) task_result = self._storage.get(task.name) diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 8973c1c6..479a04b5 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -370,8 +370,12 @@ class ActionEngine(base.Engine): last_node = None missing_nodes = 0 for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS): - atom_missing = self.storage.fetch_unsatisfied_args( + exec_missing = self.storage.fetch_unsatisfied_args( atom.name, atom.rebind, optional_args=atom.optional) + revert_missing = self.storage.fetch_unsatisfied_args( + atom.name, atom.revert_rebind, + optional_args=atom.revert_optional) + atom_missing = exec_missing.union(revert_missing) if atom_missing: cause = exc.MissingDependencies(atom, sorted(atom_missing), diff --git a/taskflow/task.py b/taskflow/task.py index 1a83afe7..ba4fb328 100644 --- a/taskflow/task.py +++ b/taskflow/task.py @@ -60,12 +60,14 @@ class Task(atom.Atom): TASK_EVENTS = (EVENT_UPDATE_PROGRESS,) def __init__(self, name=None, provides=None, requires=None, - auto_extract=True, rebind=None, inject=None): + auto_extract=True, rebind=None, inject=None, + ignore_list=None, revert_rebind=None, revert_requires=None): if name is None: name = reflection.get_class_name(self) super(Task, self).__init__(name, provides=provides, requires=requires, auto_extract=auto_extract, rebind=rebind, - inject=inject) + inject=inject, revert_rebind=revert_rebind, + revert_requires=revert_requires) self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS) @property @@ -137,7 +139,18 @@ class FunctorTask(Task): self._revert = revert if version is not None: self.version = version - self._build_arg_mapping(execute, requires, rebind, auto_extract) + mapping = self._build_arg_mapping(execute, requires, rebind, + auto_extract) + self.rebind, exec_requires, self.optional = mapping + + if revert: + revert_mapping = self._build_arg_mapping(revert, requires, rebind, + auto_extract) + else: + revert_mapping = (self.rebind, exec_requires, self.optional) + (self.revert_rebind, revert_requires, + self.revert_optional) = revert_mapping + self.requires = exec_requires.union(revert_requires) def execute(self, *args, **kwargs): return self._execute(*args, **kwargs) diff --git a/taskflow/tests/unit/test_arguments_passing.py b/taskflow/tests/unit/test_arguments_passing.py index 8676412d..0db93009 100644 --- a/taskflow/tests/unit/test_arguments_passing.py +++ b/taskflow/tests/unit/test_arguments_passing.py @@ -163,6 +163,26 @@ class ArgumentsPassingTest(utils.EngineTestBase): 'long_arg_name': 1, 'result': 1 }, engine.storage.fetch_all()) + def test_revert_rebound_args_required(self): + flow = utils.TaskMultiArg(revert_rebind={'z': 'b'}) + engine = self._make_engine(flow) + engine.storage.inject({'a': 1, 'y': 4, 'c': 9, 'x': 17}) + self.assertRaises(exc.MissingDependencies, engine.run) + + def test_revert_required_args_required(self): + flow = utils.TaskMultiArg(revert_requires=['a']) + engine = self._make_engine(flow) + engine.storage.inject({'y': 4, 'z': 9, 'x': 17}) + self.assertRaises(exc.MissingDependencies, engine.run) + + def test_derived_revert_args_required(self): + flow = utils.TaskRevertExtraArgs() + engine = self._make_engine(flow) + engine.storage.inject({'y': 4, 'z': 9, 'x': 17}) + self.assertRaises(exc.MissingDependencies, engine.run) + engine.storage.inject({'revert_arg': None}) + self.assertRaises(exc.ExecutionFailure, engine.run) + class SerialEngineTest(ArgumentsPassingTest, test.TestCase): diff --git a/taskflow/tests/unit/test_task.py b/taskflow/tests/unit/test_task.py index 30474af7..f7668059 100644 --- a/taskflow/tests/unit/test_task.py +++ b/taskflow/tests/unit/test_task.py @@ -49,6 +49,22 @@ class ProgressTask(task.Task): self.update_progress(value) +class SeparateRevertTask(task.Task): + def execute(self, execute_arg): + pass + + def revert(self, revert_arg, result, flow_failures): + pass + + +class SeparateRevertOptionalTask(task.Task): + def execute(self, execute_arg=None): + pass + + def revert(self, result, flow_failures, revert_arg=None): + pass + + class TaskTest(test.TestCase): def test_passed_name(self): @@ -338,6 +354,26 @@ class TaskTest(test.TestCase): self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS])) self.assertEqual(0, len(a_task.notifier)) + def test_separate_revert_args(self): + task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',)) + self.assertEqual({'execute_arg': 'a'}, task.rebind) + self.assertEqual({'revert_arg': 'b'}, task.revert_rebind) + self.assertEqual(set(['a', 'b']), + task.requires) + + task = SeparateRevertTask(requires='execute_arg', + revert_requires='revert_arg') + + self.assertEqual({'execute_arg': 'execute_arg'}, task.rebind) + self.assertEqual({'revert_arg': 'revert_arg'}, task.revert_rebind) + self.assertEqual(set(['execute_arg', 'revert_arg']), + task.requires) + + def test_separate_revert_optional_args(self): + task = SeparateRevertOptionalTask() + self.assertEqual(set(['execute_arg']), task.optional) + self.assertEqual(set(['revert_arg']), task.revert_optional) + class FunctorTaskTest(test.TestCase): diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 10521485..da26fa3b 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -33,7 +33,7 @@ class TestWorker(test.MockTestCase): self.broker_url = 'test-url' self.exchange = 'test-exchange' self.topic = 'test-topic' - self.endpoint_count = 26 + self.endpoint_count = 27 # patch classes self.executor_mock, self.executor_inst_mock = self.patchClass( diff --git a/taskflow/tests/utils.py b/taskflow/tests/utils.py index f4654676..23eeeb6e 100644 --- a/taskflow/tests/utils.py +++ b/taskflow/tests/utils.py @@ -332,6 +332,14 @@ class NeverRunningTask(task.Task): assert False, 'This method should not be called' +class TaskRevertExtraArgs(task.Task): + def execute(self, **kwargs): + raise exceptions.ExecutionFailure("We want to force a revert here") + + def revert(self, revert_arg, flow_failures, result, **kwargs): + pass + + class EngineTestBase(object): def setUp(self): super(EngineTestBase, self).setUp()