Allow for revert to have a different argument list from execute

Also allows for people to create Atom's with a different rebind
or requires structure for the revert method, if desired.

Implements blueprint: seperate-revert-args

Change-Id: Ie7d13c8000ef08ff303481d486d1ba1cfbdeea44
This commit is contained in:
Greg Hill
2016-01-21 08:20:58 -06:00
parent 3c2896aac0
commit e8d78e7aab
9 changed files with 151 additions and 27 deletions

View File

@@ -31,6 +31,10 @@ from taskflow.utils import misc
_sequence_types = (list, tuple, collections.Sequence)
_set_types = (set, collections.Set)
# the default list of revert arguments to ignore when deriving
# revert argument mapping from the revert method signature
_default_revert_args = ('result', 'flow_failures')
def _save_as_to_mapping(save_as):
"""Convert save_as to mapping name => index.
@@ -176,6 +180,17 @@ class Atom(object):
injected into the atoms scope before the atom execution
commences (this allows for providing atom *local* values
that do not need to be provided by other atoms/dependents).
:param rebind: A dict of key/value pairs used to define argument
name conversions for inputs to this atom's ``execute``
method.
:param revert_rebind: The same as ``rebind`` but for the ``revert``
method. If unpassed, ``rebind`` will be used
instead.
:param requires: A set or list of required inputs for this atom's
``execute`` method.
:param revert_requires: A set or list of required inputs for this atom's
``revert`` method. If unpassed, ```requires`` will
be used.
:ivar version: An *immutable* version that associates version information
with this atom. It can be useful in resuming older versions
of atoms. Standard major, minor versioning concepts
@@ -191,12 +206,17 @@ class Atom(object):
the names that this atom expects (in a way this is like
remapping a namespace of another atom into the namespace
of this atom).
:ivar revert_rebind: The same as ``rebind`` but for the revert method. This
should only differ from ``rebind`` if the ``revert``
method has a different signature from ``execute`` or
a different ``revert_rebind`` value was received.
:ivar inject: See parameter ``inject``.
:ivar name: See parameter ``name``.
:ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
this atom requires to function.
:ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
that are optional for this atom to function.
that are optional for this atom to ``execute``.
:ivar revert_optional: The ``revert`` version of ``optional``.
:ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs
this atom produces.
"""
@@ -232,7 +252,7 @@ class Atom(object):
def __init__(self, name=None, provides=None, requires=None,
auto_extract=True, rebind=None, inject=None,
ignore_list=None):
ignore_list=None, revert_rebind=None, revert_requires=None):
if provides is None:
provides = self.default_provides
@@ -241,14 +261,30 @@ class Atom(object):
self.version = (1, 0)
self.inject = inject
self.save_as = _save_as_to_mapping(provides)
self.requires = sets.OrderedSet()
self.optional = sets.OrderedSet()
self.provides = sets.OrderedSet(self.save_as)
self.rebind = collections.OrderedDict()
self._build_arg_mapping(self.execute, requires=requires,
rebind=rebind, auto_extract=auto_extract,
ignore_list=ignore_list)
if ignore_list is None:
ignore_list = []
self.rebind, exec_requires, self.optional = self._build_arg_mapping(
self.execute,
requires=requires,
rebind=rebind, auto_extract=auto_extract,
ignore_list=ignore_list
)
revert_ignore = ignore_list + list(_default_revert_args)
revert_mapping = self._build_arg_mapping(
self.revert,
requires=revert_requires or requires,
rebind=revert_rebind or rebind,
auto_extract=auto_extract,
ignore_list=revert_ignore
)
(self.revert_rebind, addl_requires,
self.revert_optional) = revert_mapping
self.requires = exec_requires.union(addl_requires)
def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None):
@@ -263,13 +299,13 @@ class Atom(object):
for (arg_name, bound_name) in itertools.chain(six.iteritems(required),
six.iteritems(optional)):
rebind.setdefault(arg_name, bound_name)
self.rebind = rebind
self.requires = sets.OrderedSet(six.itervalues(required))
self.optional = sets.OrderedSet(six.itervalues(optional))
requires = sets.OrderedSet(six.itervalues(required))
optional = sets.OrderedSet(six.itervalues(optional))
if self.inject:
inject_keys = frozenset(six.iterkeys(self.inject))
self.requires -= inject_keys
self.optional -= inject_keys
requires -= inject_keys
optional -= inject_keys
return rebind, requires, optional
def pre_execute(self):
"""Code to be run prior to executing the atom.

View File

@@ -30,12 +30,19 @@ class RetryAction(base.Action):
super(RetryAction, self).__init__(storage, notifier)
self._retry_executor = retry_executor
def _get_retry_args(self, retry, addons=None):
arguments = self._storage.fetch_mapped_args(
retry.rebind,
atom_name=retry.name,
optional_args=retry.optional
)
def _get_retry_args(self, retry, revert=False, addons=None):
if revert:
arguments = self._storage.fetch_mapped_args(
retry.revert_rebind,
atom_name=retry.name,
optional_args=retry.revert_optional
)
else:
arguments = self._storage.fetch_mapped_args(
retry.rebind,
atom_name=retry.name,
optional_args=retry.optional
)
history = self._storage.get_retry_history(retry.name)
arguments[retry_atom.EXECUTE_REVERT_HISTORY] = history
if addons:
@@ -92,7 +99,7 @@ class RetryAction(base.Action):
retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(),
}
return self._retry_executor.revert_retry(
retry, self._get_retry_args(retry, addons=arg_addons))
retry, self._get_retry_args(retry, addons=arg_addons, revert=True))
def on_failure(self, retry, atom, last_failure):
self._storage.save_retry_failure(retry.name, atom.name, last_failure)

View File

@@ -122,9 +122,9 @@ class TaskAction(base.Action):
def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0)
arguments = self._storage.fetch_mapped_args(
task.rebind,
task.revert_rebind,
atom_name=task.name,
optional_args=task.optional
optional_args=task.revert_optional
)
task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name)

View File

@@ -370,8 +370,12 @@ class ActionEngine(base.Engine):
last_node = None
missing_nodes = 0
for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS):
atom_missing = self.storage.fetch_unsatisfied_args(
exec_missing = self.storage.fetch_unsatisfied_args(
atom.name, atom.rebind, optional_args=atom.optional)
revert_missing = self.storage.fetch_unsatisfied_args(
atom.name, atom.revert_rebind,
optional_args=atom.revert_optional)
atom_missing = exec_missing.union(revert_missing)
if atom_missing:
cause = exc.MissingDependencies(atom,
sorted(atom_missing),

View File

@@ -60,12 +60,14 @@ class Task(atom.Atom):
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
def __init__(self, name=None, provides=None, requires=None,
auto_extract=True, rebind=None, inject=None):
auto_extract=True, rebind=None, inject=None,
ignore_list=None, revert_rebind=None, revert_requires=None):
if name is None:
name = reflection.get_class_name(self)
super(Task, self).__init__(name, provides=provides, requires=requires,
auto_extract=auto_extract, rebind=rebind,
inject=inject)
inject=inject, revert_rebind=revert_rebind,
revert_requires=revert_requires)
self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
@property
@@ -137,7 +139,18 @@ class FunctorTask(Task):
self._revert = revert
if version is not None:
self.version = version
self._build_arg_mapping(execute, requires, rebind, auto_extract)
mapping = self._build_arg_mapping(execute, requires, rebind,
auto_extract)
self.rebind, exec_requires, self.optional = mapping
if revert:
revert_mapping = self._build_arg_mapping(revert, requires, rebind,
auto_extract)
else:
revert_mapping = (self.rebind, exec_requires, self.optional)
(self.revert_rebind, revert_requires,
self.revert_optional) = revert_mapping
self.requires = exec_requires.union(revert_requires)
def execute(self, *args, **kwargs):
return self._execute(*args, **kwargs)

View File

@@ -163,6 +163,26 @@ class ArgumentsPassingTest(utils.EngineTestBase):
'long_arg_name': 1, 'result': 1
}, engine.storage.fetch_all())
def test_revert_rebound_args_required(self):
flow = utils.TaskMultiArg(revert_rebind={'z': 'b'})
engine = self._make_engine(flow)
engine.storage.inject({'a': 1, 'y': 4, 'c': 9, 'x': 17})
self.assertRaises(exc.MissingDependencies, engine.run)
def test_revert_required_args_required(self):
flow = utils.TaskMultiArg(revert_requires=['a'])
engine = self._make_engine(flow)
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
self.assertRaises(exc.MissingDependencies, engine.run)
def test_derived_revert_args_required(self):
flow = utils.TaskRevertExtraArgs()
engine = self._make_engine(flow)
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
self.assertRaises(exc.MissingDependencies, engine.run)
engine.storage.inject({'revert_arg': None})
self.assertRaises(exc.ExecutionFailure, engine.run)
class SerialEngineTest(ArgumentsPassingTest, test.TestCase):

View File

@@ -49,6 +49,22 @@ class ProgressTask(task.Task):
self.update_progress(value)
class SeparateRevertTask(task.Task):
def execute(self, execute_arg):
pass
def revert(self, revert_arg, result, flow_failures):
pass
class SeparateRevertOptionalTask(task.Task):
def execute(self, execute_arg=None):
pass
def revert(self, result, flow_failures, revert_arg=None):
pass
class TaskTest(test.TestCase):
def test_passed_name(self):
@@ -338,6 +354,26 @@ class TaskTest(test.TestCase):
self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS]))
self.assertEqual(0, len(a_task.notifier))
def test_separate_revert_args(self):
task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',))
self.assertEqual({'execute_arg': 'a'}, task.rebind)
self.assertEqual({'revert_arg': 'b'}, task.revert_rebind)
self.assertEqual(set(['a', 'b']),
task.requires)
task = SeparateRevertTask(requires='execute_arg',
revert_requires='revert_arg')
self.assertEqual({'execute_arg': 'execute_arg'}, task.rebind)
self.assertEqual({'revert_arg': 'revert_arg'}, task.revert_rebind)
self.assertEqual(set(['execute_arg', 'revert_arg']),
task.requires)
def test_separate_revert_optional_args(self):
task = SeparateRevertOptionalTask()
self.assertEqual(set(['execute_arg']), task.optional)
self.assertEqual(set(['revert_arg']), task.revert_optional)
class FunctorTaskTest(test.TestCase):

View File

@@ -33,7 +33,7 @@ class TestWorker(test.MockTestCase):
self.broker_url = 'test-url'
self.exchange = 'test-exchange'
self.topic = 'test-topic'
self.endpoint_count = 26
self.endpoint_count = 27
# patch classes
self.executor_mock, self.executor_inst_mock = self.patchClass(

View File

@@ -332,6 +332,14 @@ class NeverRunningTask(task.Task):
assert False, 'This method should not be called'
class TaskRevertExtraArgs(task.Task):
def execute(self, **kwargs):
raise exceptions.ExecutionFailure("We want to force a revert here")
def revert(self, revert_arg, flow_failures, result, **kwargs):
pass
class EngineTestBase(object):
def setUp(self):
super(EngineTestBase, self).setUp()