Allow for revert to have a different argument list from execute
Also allows for people to create Atom's with a different rebind or requires structure for the revert method, if desired. Implements blueprint: seperate-revert-args Change-Id: Ie7d13c8000ef08ff303481d486d1ba1cfbdeea44
This commit is contained in:
@@ -31,6 +31,10 @@ from taskflow.utils import misc
|
|||||||
_sequence_types = (list, tuple, collections.Sequence)
|
_sequence_types = (list, tuple, collections.Sequence)
|
||||||
_set_types = (set, collections.Set)
|
_set_types = (set, collections.Set)
|
||||||
|
|
||||||
|
# the default list of revert arguments to ignore when deriving
|
||||||
|
# revert argument mapping from the revert method signature
|
||||||
|
_default_revert_args = ('result', 'flow_failures')
|
||||||
|
|
||||||
|
|
||||||
def _save_as_to_mapping(save_as):
|
def _save_as_to_mapping(save_as):
|
||||||
"""Convert save_as to mapping name => index.
|
"""Convert save_as to mapping name => index.
|
||||||
@@ -176,6 +180,17 @@ class Atom(object):
|
|||||||
injected into the atoms scope before the atom execution
|
injected into the atoms scope before the atom execution
|
||||||
commences (this allows for providing atom *local* values
|
commences (this allows for providing atom *local* values
|
||||||
that do not need to be provided by other atoms/dependents).
|
that do not need to be provided by other atoms/dependents).
|
||||||
|
:param rebind: A dict of key/value pairs used to define argument
|
||||||
|
name conversions for inputs to this atom's ``execute``
|
||||||
|
method.
|
||||||
|
:param revert_rebind: The same as ``rebind`` but for the ``revert``
|
||||||
|
method. If unpassed, ``rebind`` will be used
|
||||||
|
instead.
|
||||||
|
:param requires: A set or list of required inputs for this atom's
|
||||||
|
``execute`` method.
|
||||||
|
:param revert_requires: A set or list of required inputs for this atom's
|
||||||
|
``revert`` method. If unpassed, ```requires`` will
|
||||||
|
be used.
|
||||||
:ivar version: An *immutable* version that associates version information
|
:ivar version: An *immutable* version that associates version information
|
||||||
with this atom. It can be useful in resuming older versions
|
with this atom. It can be useful in resuming older versions
|
||||||
of atoms. Standard major, minor versioning concepts
|
of atoms. Standard major, minor versioning concepts
|
||||||
@@ -191,12 +206,17 @@ class Atom(object):
|
|||||||
the names that this atom expects (in a way this is like
|
the names that this atom expects (in a way this is like
|
||||||
remapping a namespace of another atom into the namespace
|
remapping a namespace of another atom into the namespace
|
||||||
of this atom).
|
of this atom).
|
||||||
|
:ivar revert_rebind: The same as ``rebind`` but for the revert method. This
|
||||||
|
should only differ from ``rebind`` if the ``revert``
|
||||||
|
method has a different signature from ``execute`` or
|
||||||
|
a different ``revert_rebind`` value was received.
|
||||||
:ivar inject: See parameter ``inject``.
|
:ivar inject: See parameter ``inject``.
|
||||||
:ivar name: See parameter ``name``.
|
:ivar name: See parameter ``name``.
|
||||||
:ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
|
:ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
|
||||||
this atom requires to function.
|
this atom requires to function.
|
||||||
:ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
|
:ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
|
||||||
that are optional for this atom to function.
|
that are optional for this atom to ``execute``.
|
||||||
|
:ivar revert_optional: The ``revert`` version of ``optional``.
|
||||||
:ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs
|
:ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs
|
||||||
this atom produces.
|
this atom produces.
|
||||||
"""
|
"""
|
||||||
@@ -232,7 +252,7 @@ class Atom(object):
|
|||||||
|
|
||||||
def __init__(self, name=None, provides=None, requires=None,
|
def __init__(self, name=None, provides=None, requires=None,
|
||||||
auto_extract=True, rebind=None, inject=None,
|
auto_extract=True, rebind=None, inject=None,
|
||||||
ignore_list=None):
|
ignore_list=None, revert_rebind=None, revert_requires=None):
|
||||||
|
|
||||||
if provides is None:
|
if provides is None:
|
||||||
provides = self.default_provides
|
provides = self.default_provides
|
||||||
@@ -241,14 +261,30 @@ class Atom(object):
|
|||||||
self.version = (1, 0)
|
self.version = (1, 0)
|
||||||
self.inject = inject
|
self.inject = inject
|
||||||
self.save_as = _save_as_to_mapping(provides)
|
self.save_as = _save_as_to_mapping(provides)
|
||||||
self.requires = sets.OrderedSet()
|
|
||||||
self.optional = sets.OrderedSet()
|
|
||||||
self.provides = sets.OrderedSet(self.save_as)
|
self.provides = sets.OrderedSet(self.save_as)
|
||||||
self.rebind = collections.OrderedDict()
|
|
||||||
|
|
||||||
self._build_arg_mapping(self.execute, requires=requires,
|
if ignore_list is None:
|
||||||
|
ignore_list = []
|
||||||
|
|
||||||
|
self.rebind, exec_requires, self.optional = self._build_arg_mapping(
|
||||||
|
self.execute,
|
||||||
|
requires=requires,
|
||||||
rebind=rebind, auto_extract=auto_extract,
|
rebind=rebind, auto_extract=auto_extract,
|
||||||
ignore_list=ignore_list)
|
ignore_list=ignore_list
|
||||||
|
)
|
||||||
|
|
||||||
|
revert_ignore = ignore_list + list(_default_revert_args)
|
||||||
|
revert_mapping = self._build_arg_mapping(
|
||||||
|
self.revert,
|
||||||
|
requires=revert_requires or requires,
|
||||||
|
rebind=revert_rebind or rebind,
|
||||||
|
auto_extract=auto_extract,
|
||||||
|
ignore_list=revert_ignore
|
||||||
|
)
|
||||||
|
(self.revert_rebind, addl_requires,
|
||||||
|
self.revert_optional) = revert_mapping
|
||||||
|
|
||||||
|
self.requires = exec_requires.union(addl_requires)
|
||||||
|
|
||||||
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
def _build_arg_mapping(self, executor, requires=None, rebind=None,
|
||||||
auto_extract=True, ignore_list=None):
|
auto_extract=True, ignore_list=None):
|
||||||
@@ -263,13 +299,13 @@ class Atom(object):
|
|||||||
for (arg_name, bound_name) in itertools.chain(six.iteritems(required),
|
for (arg_name, bound_name) in itertools.chain(six.iteritems(required),
|
||||||
six.iteritems(optional)):
|
six.iteritems(optional)):
|
||||||
rebind.setdefault(arg_name, bound_name)
|
rebind.setdefault(arg_name, bound_name)
|
||||||
self.rebind = rebind
|
requires = sets.OrderedSet(six.itervalues(required))
|
||||||
self.requires = sets.OrderedSet(six.itervalues(required))
|
optional = sets.OrderedSet(six.itervalues(optional))
|
||||||
self.optional = sets.OrderedSet(six.itervalues(optional))
|
|
||||||
if self.inject:
|
if self.inject:
|
||||||
inject_keys = frozenset(six.iterkeys(self.inject))
|
inject_keys = frozenset(six.iterkeys(self.inject))
|
||||||
self.requires -= inject_keys
|
requires -= inject_keys
|
||||||
self.optional -= inject_keys
|
optional -= inject_keys
|
||||||
|
return rebind, requires, optional
|
||||||
|
|
||||||
def pre_execute(self):
|
def pre_execute(self):
|
||||||
"""Code to be run prior to executing the atom.
|
"""Code to be run prior to executing the atom.
|
||||||
|
|||||||
@@ -30,7 +30,14 @@ class RetryAction(base.Action):
|
|||||||
super(RetryAction, self).__init__(storage, notifier)
|
super(RetryAction, self).__init__(storage, notifier)
|
||||||
self._retry_executor = retry_executor
|
self._retry_executor = retry_executor
|
||||||
|
|
||||||
def _get_retry_args(self, retry, addons=None):
|
def _get_retry_args(self, retry, revert=False, addons=None):
|
||||||
|
if revert:
|
||||||
|
arguments = self._storage.fetch_mapped_args(
|
||||||
|
retry.revert_rebind,
|
||||||
|
atom_name=retry.name,
|
||||||
|
optional_args=retry.revert_optional
|
||||||
|
)
|
||||||
|
else:
|
||||||
arguments = self._storage.fetch_mapped_args(
|
arguments = self._storage.fetch_mapped_args(
|
||||||
retry.rebind,
|
retry.rebind,
|
||||||
atom_name=retry.name,
|
atom_name=retry.name,
|
||||||
@@ -92,7 +99,7 @@ class RetryAction(base.Action):
|
|||||||
retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(),
|
retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(),
|
||||||
}
|
}
|
||||||
return self._retry_executor.revert_retry(
|
return self._retry_executor.revert_retry(
|
||||||
retry, self._get_retry_args(retry, addons=arg_addons))
|
retry, self._get_retry_args(retry, addons=arg_addons, revert=True))
|
||||||
|
|
||||||
def on_failure(self, retry, atom, last_failure):
|
def on_failure(self, retry, atom, last_failure):
|
||||||
self._storage.save_retry_failure(retry.name, atom.name, last_failure)
|
self._storage.save_retry_failure(retry.name, atom.name, last_failure)
|
||||||
|
|||||||
@@ -122,9 +122,9 @@ class TaskAction(base.Action):
|
|||||||
def schedule_reversion(self, task):
|
def schedule_reversion(self, task):
|
||||||
self.change_state(task, states.REVERTING, progress=0.0)
|
self.change_state(task, states.REVERTING, progress=0.0)
|
||||||
arguments = self._storage.fetch_mapped_args(
|
arguments = self._storage.fetch_mapped_args(
|
||||||
task.rebind,
|
task.revert_rebind,
|
||||||
atom_name=task.name,
|
atom_name=task.name,
|
||||||
optional_args=task.optional
|
optional_args=task.revert_optional
|
||||||
)
|
)
|
||||||
task_uuid = self._storage.get_atom_uuid(task.name)
|
task_uuid = self._storage.get_atom_uuid(task.name)
|
||||||
task_result = self._storage.get(task.name)
|
task_result = self._storage.get(task.name)
|
||||||
|
|||||||
@@ -370,8 +370,12 @@ class ActionEngine(base.Engine):
|
|||||||
last_node = None
|
last_node = None
|
||||||
missing_nodes = 0
|
missing_nodes = 0
|
||||||
for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS):
|
for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS):
|
||||||
atom_missing = self.storage.fetch_unsatisfied_args(
|
exec_missing = self.storage.fetch_unsatisfied_args(
|
||||||
atom.name, atom.rebind, optional_args=atom.optional)
|
atom.name, atom.rebind, optional_args=atom.optional)
|
||||||
|
revert_missing = self.storage.fetch_unsatisfied_args(
|
||||||
|
atom.name, atom.revert_rebind,
|
||||||
|
optional_args=atom.revert_optional)
|
||||||
|
atom_missing = exec_missing.union(revert_missing)
|
||||||
if atom_missing:
|
if atom_missing:
|
||||||
cause = exc.MissingDependencies(atom,
|
cause = exc.MissingDependencies(atom,
|
||||||
sorted(atom_missing),
|
sorted(atom_missing),
|
||||||
|
|||||||
@@ -60,12 +60,14 @@ class Task(atom.Atom):
|
|||||||
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
|
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
|
||||||
|
|
||||||
def __init__(self, name=None, provides=None, requires=None,
|
def __init__(self, name=None, provides=None, requires=None,
|
||||||
auto_extract=True, rebind=None, inject=None):
|
auto_extract=True, rebind=None, inject=None,
|
||||||
|
ignore_list=None, revert_rebind=None, revert_requires=None):
|
||||||
if name is None:
|
if name is None:
|
||||||
name = reflection.get_class_name(self)
|
name = reflection.get_class_name(self)
|
||||||
super(Task, self).__init__(name, provides=provides, requires=requires,
|
super(Task, self).__init__(name, provides=provides, requires=requires,
|
||||||
auto_extract=auto_extract, rebind=rebind,
|
auto_extract=auto_extract, rebind=rebind,
|
||||||
inject=inject)
|
inject=inject, revert_rebind=revert_rebind,
|
||||||
|
revert_requires=revert_requires)
|
||||||
self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
|
self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -137,7 +139,18 @@ class FunctorTask(Task):
|
|||||||
self._revert = revert
|
self._revert = revert
|
||||||
if version is not None:
|
if version is not None:
|
||||||
self.version = version
|
self.version = version
|
||||||
self._build_arg_mapping(execute, requires, rebind, auto_extract)
|
mapping = self._build_arg_mapping(execute, requires, rebind,
|
||||||
|
auto_extract)
|
||||||
|
self.rebind, exec_requires, self.optional = mapping
|
||||||
|
|
||||||
|
if revert:
|
||||||
|
revert_mapping = self._build_arg_mapping(revert, requires, rebind,
|
||||||
|
auto_extract)
|
||||||
|
else:
|
||||||
|
revert_mapping = (self.rebind, exec_requires, self.optional)
|
||||||
|
(self.revert_rebind, revert_requires,
|
||||||
|
self.revert_optional) = revert_mapping
|
||||||
|
self.requires = exec_requires.union(revert_requires)
|
||||||
|
|
||||||
def execute(self, *args, **kwargs):
|
def execute(self, *args, **kwargs):
|
||||||
return self._execute(*args, **kwargs)
|
return self._execute(*args, **kwargs)
|
||||||
|
|||||||
@@ -163,6 +163,26 @@ class ArgumentsPassingTest(utils.EngineTestBase):
|
|||||||
'long_arg_name': 1, 'result': 1
|
'long_arg_name': 1, 'result': 1
|
||||||
}, engine.storage.fetch_all())
|
}, engine.storage.fetch_all())
|
||||||
|
|
||||||
|
def test_revert_rebound_args_required(self):
|
||||||
|
flow = utils.TaskMultiArg(revert_rebind={'z': 'b'})
|
||||||
|
engine = self._make_engine(flow)
|
||||||
|
engine.storage.inject({'a': 1, 'y': 4, 'c': 9, 'x': 17})
|
||||||
|
self.assertRaises(exc.MissingDependencies, engine.run)
|
||||||
|
|
||||||
|
def test_revert_required_args_required(self):
|
||||||
|
flow = utils.TaskMultiArg(revert_requires=['a'])
|
||||||
|
engine = self._make_engine(flow)
|
||||||
|
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
|
||||||
|
self.assertRaises(exc.MissingDependencies, engine.run)
|
||||||
|
|
||||||
|
def test_derived_revert_args_required(self):
|
||||||
|
flow = utils.TaskRevertExtraArgs()
|
||||||
|
engine = self._make_engine(flow)
|
||||||
|
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
|
||||||
|
self.assertRaises(exc.MissingDependencies, engine.run)
|
||||||
|
engine.storage.inject({'revert_arg': None})
|
||||||
|
self.assertRaises(exc.ExecutionFailure, engine.run)
|
||||||
|
|
||||||
|
|
||||||
class SerialEngineTest(ArgumentsPassingTest, test.TestCase):
|
class SerialEngineTest(ArgumentsPassingTest, test.TestCase):
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,22 @@ class ProgressTask(task.Task):
|
|||||||
self.update_progress(value)
|
self.update_progress(value)
|
||||||
|
|
||||||
|
|
||||||
|
class SeparateRevertTask(task.Task):
|
||||||
|
def execute(self, execute_arg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def revert(self, revert_arg, result, flow_failures):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SeparateRevertOptionalTask(task.Task):
|
||||||
|
def execute(self, execute_arg=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def revert(self, result, flow_failures, revert_arg=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TaskTest(test.TestCase):
|
class TaskTest(test.TestCase):
|
||||||
|
|
||||||
def test_passed_name(self):
|
def test_passed_name(self):
|
||||||
@@ -338,6 +354,26 @@ class TaskTest(test.TestCase):
|
|||||||
self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS]))
|
self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS]))
|
||||||
self.assertEqual(0, len(a_task.notifier))
|
self.assertEqual(0, len(a_task.notifier))
|
||||||
|
|
||||||
|
def test_separate_revert_args(self):
|
||||||
|
task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',))
|
||||||
|
self.assertEqual({'execute_arg': 'a'}, task.rebind)
|
||||||
|
self.assertEqual({'revert_arg': 'b'}, task.revert_rebind)
|
||||||
|
self.assertEqual(set(['a', 'b']),
|
||||||
|
task.requires)
|
||||||
|
|
||||||
|
task = SeparateRevertTask(requires='execute_arg',
|
||||||
|
revert_requires='revert_arg')
|
||||||
|
|
||||||
|
self.assertEqual({'execute_arg': 'execute_arg'}, task.rebind)
|
||||||
|
self.assertEqual({'revert_arg': 'revert_arg'}, task.revert_rebind)
|
||||||
|
self.assertEqual(set(['execute_arg', 'revert_arg']),
|
||||||
|
task.requires)
|
||||||
|
|
||||||
|
def test_separate_revert_optional_args(self):
|
||||||
|
task = SeparateRevertOptionalTask()
|
||||||
|
self.assertEqual(set(['execute_arg']), task.optional)
|
||||||
|
self.assertEqual(set(['revert_arg']), task.revert_optional)
|
||||||
|
|
||||||
|
|
||||||
class FunctorTaskTest(test.TestCase):
|
class FunctorTaskTest(test.TestCase):
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class TestWorker(test.MockTestCase):
|
|||||||
self.broker_url = 'test-url'
|
self.broker_url = 'test-url'
|
||||||
self.exchange = 'test-exchange'
|
self.exchange = 'test-exchange'
|
||||||
self.topic = 'test-topic'
|
self.topic = 'test-topic'
|
||||||
self.endpoint_count = 26
|
self.endpoint_count = 27
|
||||||
|
|
||||||
# patch classes
|
# patch classes
|
||||||
self.executor_mock, self.executor_inst_mock = self.patchClass(
|
self.executor_mock, self.executor_inst_mock = self.patchClass(
|
||||||
|
|||||||
@@ -332,6 +332,14 @@ class NeverRunningTask(task.Task):
|
|||||||
assert False, 'This method should not be called'
|
assert False, 'This method should not be called'
|
||||||
|
|
||||||
|
|
||||||
|
class TaskRevertExtraArgs(task.Task):
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
raise exceptions.ExecutionFailure("We want to force a revert here")
|
||||||
|
|
||||||
|
def revert(self, revert_arg, flow_failures, result, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EngineTestBase(object):
|
class EngineTestBase(object):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(EngineTestBase, self).setUp()
|
super(EngineTestBase, self).setUp()
|
||||||
|
|||||||
Reference in New Issue
Block a user