Allow for revert to have a different argument list from execute
Also allows for people to create Atom's with a different rebind or requires structure for the revert method, if desired. Implements blueprint: seperate-revert-args Change-Id: Ie7d13c8000ef08ff303481d486d1ba1cfbdeea44
This commit is contained in:
@@ -31,6 +31,10 @@ from taskflow.utils import misc
|
||||
_sequence_types = (list, tuple, collections.Sequence)
|
||||
_set_types = (set, collections.Set)
|
||||
|
||||
# the default list of revert arguments to ignore when deriving
|
||||
# revert argument mapping from the revert method signature
|
||||
_default_revert_args = ('result', 'flow_failures')
|
||||
|
||||
|
||||
def _save_as_to_mapping(save_as):
|
||||
"""Convert save_as to mapping name => index.
|
||||
@@ -176,6 +180,17 @@ class Atom(object):
|
||||
injected into the atoms scope before the atom execution
|
||||
commences (this allows for providing atom *local* values
|
||||
that do not need to be provided by other atoms/dependents).
|
||||
:param rebind: A dict of key/value pairs used to define argument
|
||||
name conversions for inputs to this atom's ``execute``
|
||||
method.
|
||||
:param revert_rebind: The same as ``rebind`` but for the ``revert``
|
||||
method. If unpassed, ``rebind`` will be used
|
||||
instead.
|
||||
:param requires: A set or list of required inputs for this atom's
|
||||
``execute`` method.
|
||||
:param revert_requires: A set or list of required inputs for this atom's
|
||||
``revert`` method. If unpassed, ```requires`` will
|
||||
be used.
|
||||
:ivar version: An *immutable* version that associates version information
|
||||
with this atom. It can be useful in resuming older versions
|
||||
of atoms. Standard major, minor versioning concepts
|
||||
@@ -191,12 +206,17 @@ class Atom(object):
|
||||
the names that this atom expects (in a way this is like
|
||||
remapping a namespace of another atom into the namespace
|
||||
of this atom).
|
||||
:ivar revert_rebind: The same as ``rebind`` but for the revert method. This
|
||||
should only differ from ``rebind`` if the ``revert``
|
||||
method has a different signature from ``execute`` or
|
||||
a different ``revert_rebind`` value was received.
|
||||
:ivar inject: See parameter ``inject``.
|
||||
:ivar name: See parameter ``name``.
|
||||
:ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
|
||||
this atom requires to function.
|
||||
:ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
|
||||
that are optional for this atom to function.
|
||||
that are optional for this atom to ``execute``.
|
||||
:ivar revert_optional: The ``revert`` version of ``optional``.
|
||||
:ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs
|
||||
this atom produces.
|
||||
"""
|
||||
@@ -232,7 +252,7 @@ class Atom(object):
|
||||
|
||||
def __init__(self, name=None, provides=None, requires=None,
|
||||
auto_extract=True, rebind=None, inject=None,
|
||||
ignore_list=None):
|
||||
ignore_list=None, revert_rebind=None, revert_requires=None):
|
||||
|
||||
if provides is None:
|
||||
provides = self.default_provides
|
||||
@@ -241,14 +261,30 @@ class Atom(object):
|
||||
self.version = (1, 0)
|
||||
self.inject = inject
|
||||
self.save_as = _save_as_to_mapping(provides)
|
||||
self.requires = sets.OrderedSet()
|
||||
self.optional = sets.OrderedSet()
|
||||
self.provides = sets.OrderedSet(self.save_as)
|
||||
self.rebind = collections.OrderedDict()
|
||||
|
||||
self._build_arg_mapping(self.execute, requires=requires,
|
||||
rebind=rebind, auto_extract=auto_extract,
|
||||
ignore_list=ignore_list)
|
||||
if ignore_list is None:
|
||||
ignore_list = []
|
||||
|
||||
self.rebind, exec_requires, self.optional = self._build_arg_mapping(
|
||||
self.execute,
|
||||
requires=requires,
|
||||
rebind=rebind, auto_extract=auto_extract,
|
||||
ignore_list=ignore_list
|
||||
)
|
||||
|
||||
revert_ignore = ignore_list + list(_default_revert_args)
|
||||
revert_mapping = self._build_arg_mapping(
|
||||
self.revert,
|
||||
requires=revert_requires or requires,
|
||||
rebind=revert_rebind or rebind,
|
||||
auto_extract=auto_extract,
|
||||
ignore_list=revert_ignore
|
||||
)
|
||||
(self.revert_rebind, addl_requires,
|
||||
self.revert_optional) = revert_mapping
|
||||
|
||||
self.requires = exec_requires.union(addl_requires)
|
||||
|
||||
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
||||
auto_extract=True, ignore_list=None):
|
||||
@@ -263,13 +299,13 @@ class Atom(object):
|
||||
for (arg_name, bound_name) in itertools.chain(six.iteritems(required),
|
||||
six.iteritems(optional)):
|
||||
rebind.setdefault(arg_name, bound_name)
|
||||
self.rebind = rebind
|
||||
self.requires = sets.OrderedSet(six.itervalues(required))
|
||||
self.optional = sets.OrderedSet(six.itervalues(optional))
|
||||
requires = sets.OrderedSet(six.itervalues(required))
|
||||
optional = sets.OrderedSet(six.itervalues(optional))
|
||||
if self.inject:
|
||||
inject_keys = frozenset(six.iterkeys(self.inject))
|
||||
self.requires -= inject_keys
|
||||
self.optional -= inject_keys
|
||||
requires -= inject_keys
|
||||
optional -= inject_keys
|
||||
return rebind, requires, optional
|
||||
|
||||
def pre_execute(self):
|
||||
"""Code to be run prior to executing the atom.
|
||||
|
||||
@@ -30,12 +30,19 @@ class RetryAction(base.Action):
|
||||
super(RetryAction, self).__init__(storage, notifier)
|
||||
self._retry_executor = retry_executor
|
||||
|
||||
def _get_retry_args(self, retry, addons=None):
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
retry.rebind,
|
||||
atom_name=retry.name,
|
||||
optional_args=retry.optional
|
||||
)
|
||||
def _get_retry_args(self, retry, revert=False, addons=None):
|
||||
if revert:
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
retry.revert_rebind,
|
||||
atom_name=retry.name,
|
||||
optional_args=retry.revert_optional
|
||||
)
|
||||
else:
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
retry.rebind,
|
||||
atom_name=retry.name,
|
||||
optional_args=retry.optional
|
||||
)
|
||||
history = self._storage.get_retry_history(retry.name)
|
||||
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
|
||||
if addons:
|
||||
@@ -92,7 +99,7 @@ class RetryAction(base.Action):
|
||||
retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(),
|
||||
}
|
||||
return self._retry_executor.revert_retry(
|
||||
retry, self._get_retry_args(retry, addons=arg_addons))
|
||||
retry, self._get_retry_args(retry, addons=arg_addons, revert=True))
|
||||
|
||||
def on_failure(self, retry, atom, last_failure):
|
||||
self._storage.save_retry_failure(retry.name, atom.name, last_failure)
|
||||
|
||||
@@ -122,9 +122,9 @@ class TaskAction(base.Action):
|
||||
def schedule_reversion(self, task):
|
||||
self.change_state(task, states.REVERTING, progress=0.0)
|
||||
arguments = self._storage.fetch_mapped_args(
|
||||
task.rebind,
|
||||
task.revert_rebind,
|
||||
atom_name=task.name,
|
||||
optional_args=task.optional
|
||||
optional_args=task.revert_optional
|
||||
)
|
||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||
task_result = self._storage.get(task.name)
|
||||
|
||||
@@ -370,8 +370,12 @@ class ActionEngine(base.Engine):
|
||||
last_node = None
|
||||
missing_nodes = 0
|
||||
for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS):
|
||||
atom_missing = self.storage.fetch_unsatisfied_args(
|
||||
exec_missing = self.storage.fetch_unsatisfied_args(
|
||||
atom.name, atom.rebind, optional_args=atom.optional)
|
||||
revert_missing = self.storage.fetch_unsatisfied_args(
|
||||
atom.name, atom.revert_rebind,
|
||||
optional_args=atom.revert_optional)
|
||||
atom_missing = exec_missing.union(revert_missing)
|
||||
if atom_missing:
|
||||
cause = exc.MissingDependencies(atom,
|
||||
sorted(atom_missing),
|
||||
|
||||
@@ -60,12 +60,14 @@ class Task(atom.Atom):
|
||||
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
|
||||
|
||||
def __init__(self, name=None, provides=None, requires=None,
|
||||
auto_extract=True, rebind=None, inject=None):
|
||||
auto_extract=True, rebind=None, inject=None,
|
||||
ignore_list=None, revert_rebind=None, revert_requires=None):
|
||||
if name is None:
|
||||
name = reflection.get_class_name(self)
|
||||
super(Task, self).__init__(name, provides=provides, requires=requires,
|
||||
auto_extract=auto_extract, rebind=rebind,
|
||||
inject=inject)
|
||||
inject=inject, revert_rebind=revert_rebind,
|
||||
revert_requires=revert_requires)
|
||||
self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
|
||||
|
||||
@property
|
||||
@@ -137,7 +139,18 @@ class FunctorTask(Task):
|
||||
self._revert = revert
|
||||
if version is not None:
|
||||
self.version = version
|
||||
self._build_arg_mapping(execute, requires, rebind, auto_extract)
|
||||
mapping = self._build_arg_mapping(execute, requires, rebind,
|
||||
auto_extract)
|
||||
self.rebind, exec_requires, self.optional = mapping
|
||||
|
||||
if revert:
|
||||
revert_mapping = self._build_arg_mapping(revert, requires, rebind,
|
||||
auto_extract)
|
||||
else:
|
||||
revert_mapping = (self.rebind, exec_requires, self.optional)
|
||||
(self.revert_rebind, revert_requires,
|
||||
self.revert_optional) = revert_mapping
|
||||
self.requires = exec_requires.union(revert_requires)
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
return self._execute(*args, **kwargs)
|
||||
|
||||
@@ -163,6 +163,26 @@ class ArgumentsPassingTest(utils.EngineTestBase):
|
||||
'long_arg_name': 1, 'result': 1
|
||||
}, engine.storage.fetch_all())
|
||||
|
||||
def test_revert_rebound_args_required(self):
|
||||
flow = utils.TaskMultiArg(revert_rebind={'z': 'b'})
|
||||
engine = self._make_engine(flow)
|
||||
engine.storage.inject({'a': 1, 'y': 4, 'c': 9, 'x': 17})
|
||||
self.assertRaises(exc.MissingDependencies, engine.run)
|
||||
|
||||
def test_revert_required_args_required(self):
|
||||
flow = utils.TaskMultiArg(revert_requires=['a'])
|
||||
engine = self._make_engine(flow)
|
||||
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
|
||||
self.assertRaises(exc.MissingDependencies, engine.run)
|
||||
|
||||
def test_derived_revert_args_required(self):
|
||||
flow = utils.TaskRevertExtraArgs()
|
||||
engine = self._make_engine(flow)
|
||||
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
|
||||
self.assertRaises(exc.MissingDependencies, engine.run)
|
||||
engine.storage.inject({'revert_arg': None})
|
||||
self.assertRaises(exc.ExecutionFailure, engine.run)
|
||||
|
||||
|
||||
class SerialEngineTest(ArgumentsPassingTest, test.TestCase):
|
||||
|
||||
|
||||
@@ -49,6 +49,22 @@ class ProgressTask(task.Task):
|
||||
self.update_progress(value)
|
||||
|
||||
|
||||
class SeparateRevertTask(task.Task):
|
||||
def execute(self, execute_arg):
|
||||
pass
|
||||
|
||||
def revert(self, revert_arg, result, flow_failures):
|
||||
pass
|
||||
|
||||
|
||||
class SeparateRevertOptionalTask(task.Task):
|
||||
def execute(self, execute_arg=None):
|
||||
pass
|
||||
|
||||
def revert(self, result, flow_failures, revert_arg=None):
|
||||
pass
|
||||
|
||||
|
||||
class TaskTest(test.TestCase):
|
||||
|
||||
def test_passed_name(self):
|
||||
@@ -338,6 +354,26 @@ class TaskTest(test.TestCase):
|
||||
self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS]))
|
||||
self.assertEqual(0, len(a_task.notifier))
|
||||
|
||||
def test_separate_revert_args(self):
|
||||
task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',))
|
||||
self.assertEqual({'execute_arg': 'a'}, task.rebind)
|
||||
self.assertEqual({'revert_arg': 'b'}, task.revert_rebind)
|
||||
self.assertEqual(set(['a', 'b']),
|
||||
task.requires)
|
||||
|
||||
task = SeparateRevertTask(requires='execute_arg',
|
||||
revert_requires='revert_arg')
|
||||
|
||||
self.assertEqual({'execute_arg': 'execute_arg'}, task.rebind)
|
||||
self.assertEqual({'revert_arg': 'revert_arg'}, task.revert_rebind)
|
||||
self.assertEqual(set(['execute_arg', 'revert_arg']),
|
||||
task.requires)
|
||||
|
||||
def test_separate_revert_optional_args(self):
|
||||
task = SeparateRevertOptionalTask()
|
||||
self.assertEqual(set(['execute_arg']), task.optional)
|
||||
self.assertEqual(set(['revert_arg']), task.revert_optional)
|
||||
|
||||
|
||||
class FunctorTaskTest(test.TestCase):
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestWorker(test.MockTestCase):
|
||||
self.broker_url = 'test-url'
|
||||
self.exchange = 'test-exchange'
|
||||
self.topic = 'test-topic'
|
||||
self.endpoint_count = 26
|
||||
self.endpoint_count = 27
|
||||
|
||||
# patch classes
|
||||
self.executor_mock, self.executor_inst_mock = self.patchClass(
|
||||
|
||||
@@ -332,6 +332,14 @@ class NeverRunningTask(task.Task):
|
||||
assert False, 'This method should not be called'
|
||||
|
||||
|
||||
class TaskRevertExtraArgs(task.Task):
|
||||
def execute(self, **kwargs):
|
||||
raise exceptions.ExecutionFailure("We want to force a revert here")
|
||||
|
||||
def revert(self, revert_arg, flow_failures, result, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class EngineTestBase(object):
|
||||
def setUp(self):
|
||||
super(EngineTestBase, self).setUp()
|
||||
|
||||
Reference in New Issue
Block a user