Allow for revert to have a different argument list from execute

Also allows for people to create Atom's with a different rebind
or requires structure for the revert method, if desired.

Implements blueprint: seperate-revert-args

Change-Id: Ie7d13c8000ef08ff303481d486d1ba1cfbdeea44
This commit is contained in:
Greg Hill
2016-01-21 08:20:58 -06:00
parent 3c2896aac0
commit e8d78e7aab
9 changed files with 151 additions and 27 deletions

View File

@@ -31,6 +31,10 @@ from taskflow.utils import misc
_sequence_types = (list, tuple, collections.Sequence) _sequence_types = (list, tuple, collections.Sequence)
_set_types = (set, collections.Set) _set_types = (set, collections.Set)
# the default list of revert arguments to ignore when deriving
# revert argument mapping from the revert method signature
_default_revert_args = ('result', 'flow_failures')
def _save_as_to_mapping(save_as): def _save_as_to_mapping(save_as):
"""Convert save_as to mapping name => index. """Convert save_as to mapping name => index.
@@ -176,6 +180,17 @@ class Atom(object):
injected into the atoms scope before the atom execution injected into the atoms scope before the atom execution
commences (this allows for providing atom *local* values commences (this allows for providing atom *local* values
that do not need to be provided by other atoms/dependents). that do not need to be provided by other atoms/dependents).
:param rebind: A dict of key/value pairs used to define argument
name conversions for inputs to this atom's ``execute``
method.
:param revert_rebind: The same as ``rebind`` but for the ``revert``
method. If unpassed, ``rebind`` will be used
instead.
:param requires: A set or list of required inputs for this atom's
``execute`` method.
:param revert_requires: A set or list of required inputs for this atom's
``revert`` method. If unpassed, ```requires`` will
be used.
:ivar version: An *immutable* version that associates version information :ivar version: An *immutable* version that associates version information
with this atom. It can be useful in resuming older versions with this atom. It can be useful in resuming older versions
of atoms. Standard major, minor versioning concepts of atoms. Standard major, minor versioning concepts
@@ -191,12 +206,17 @@ class Atom(object):
the names that this atom expects (in a way this is like the names that this atom expects (in a way this is like
remapping a namespace of another atom into the namespace remapping a namespace of another atom into the namespace
of this atom). of this atom).
:ivar revert_rebind: The same as ``rebind`` but for the revert method. This
should only differ from ``rebind`` if the ``revert``
method has a different signature from ``execute`` or
a different ``revert_rebind`` value was received.
:ivar inject: See parameter ``inject``. :ivar inject: See parameter ``inject``.
:ivar name: See parameter ``name``. :ivar name: See parameter ``name``.
:ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs :ivar requires: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
this atom requires to function. this atom requires to function.
:ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs :ivar optional: A :py:class:`~taskflow.types.sets.OrderedSet` of inputs
that are optional for this atom to function. that are optional for this atom to ``execute``.
:ivar revert_optional: The ``revert`` version of ``optional``.
:ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs :ivar provides: A :py:class:`~taskflow.types.sets.OrderedSet` of outputs
this atom produces. this atom produces.
""" """
@@ -232,7 +252,7 @@ class Atom(object):
def __init__(self, name=None, provides=None, requires=None, def __init__(self, name=None, provides=None, requires=None,
auto_extract=True, rebind=None, inject=None, auto_extract=True, rebind=None, inject=None,
ignore_list=None): ignore_list=None, revert_rebind=None, revert_requires=None):
if provides is None: if provides is None:
provides = self.default_provides provides = self.default_provides
@@ -241,14 +261,30 @@ class Atom(object):
self.version = (1, 0) self.version = (1, 0)
self.inject = inject self.inject = inject
self.save_as = _save_as_to_mapping(provides) self.save_as = _save_as_to_mapping(provides)
self.requires = sets.OrderedSet()
self.optional = sets.OrderedSet()
self.provides = sets.OrderedSet(self.save_as) self.provides = sets.OrderedSet(self.save_as)
self.rebind = collections.OrderedDict()
self._build_arg_mapping(self.execute, requires=requires, if ignore_list is None:
ignore_list = []
self.rebind, exec_requires, self.optional = self._build_arg_mapping(
self.execute,
requires=requires,
rebind=rebind, auto_extract=auto_extract, rebind=rebind, auto_extract=auto_extract,
ignore_list=ignore_list) ignore_list=ignore_list
)
revert_ignore = ignore_list + list(_default_revert_args)
revert_mapping = self._build_arg_mapping(
self.revert,
requires=revert_requires or requires,
rebind=revert_rebind or rebind,
auto_extract=auto_extract,
ignore_list=revert_ignore
)
(self.revert_rebind, addl_requires,
self.revert_optional) = revert_mapping
self.requires = exec_requires.union(addl_requires)
def _build_arg_mapping(self, executor, requires=None, rebind=None, def _build_arg_mapping(self, executor, requires=None, rebind=None,
auto_extract=True, ignore_list=None): auto_extract=True, ignore_list=None):
@@ -263,13 +299,13 @@ class Atom(object):
for (arg_name, bound_name) in itertools.chain(six.iteritems(required), for (arg_name, bound_name) in itertools.chain(six.iteritems(required),
six.iteritems(optional)): six.iteritems(optional)):
rebind.setdefault(arg_name, bound_name) rebind.setdefault(arg_name, bound_name)
self.rebind = rebind requires = sets.OrderedSet(six.itervalues(required))
self.requires = sets.OrderedSet(six.itervalues(required)) optional = sets.OrderedSet(six.itervalues(optional))
self.optional = sets.OrderedSet(six.itervalues(optional))
if self.inject: if self.inject:
inject_keys = frozenset(six.iterkeys(self.inject)) inject_keys = frozenset(six.iterkeys(self.inject))
self.requires -= inject_keys requires -= inject_keys
self.optional -= inject_keys optional -= inject_keys
return rebind, requires, optional
def pre_execute(self): def pre_execute(self):
"""Code to be run prior to executing the atom. """Code to be run prior to executing the atom.

View File

@@ -30,7 +30,14 @@ class RetryAction(base.Action):
super(RetryAction, self).__init__(storage, notifier) super(RetryAction, self).__init__(storage, notifier)
self._retry_executor = retry_executor self._retry_executor = retry_executor
def _get_retry_args(self, retry, addons=None): def _get_retry_args(self, retry, revert=False, addons=None):
if revert:
arguments = self._storage.fetch_mapped_args(
retry.revert_rebind,
atom_name=retry.name,
optional_args=retry.revert_optional
)
else:
arguments = self._storage.fetch_mapped_args( arguments = self._storage.fetch_mapped_args(
retry.rebind, retry.rebind,
atom_name=retry.name, atom_name=retry.name,
@@ -92,7 +99,7 @@ class RetryAction(base.Action):
retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(), retry_atom.REVERT_FLOW_FAILURES: self._storage.get_failures(),
} }
return self._retry_executor.revert_retry( return self._retry_executor.revert_retry(
retry, self._get_retry_args(retry, addons=arg_addons)) retry, self._get_retry_args(retry, addons=arg_addons, revert=True))
def on_failure(self, retry, atom, last_failure): def on_failure(self, retry, atom, last_failure):
self._storage.save_retry_failure(retry.name, atom.name, last_failure) self._storage.save_retry_failure(retry.name, atom.name, last_failure)

View File

@@ -122,9 +122,9 @@ class TaskAction(base.Action):
def schedule_reversion(self, task): def schedule_reversion(self, task):
self.change_state(task, states.REVERTING, progress=0.0) self.change_state(task, states.REVERTING, progress=0.0)
arguments = self._storage.fetch_mapped_args( arguments = self._storage.fetch_mapped_args(
task.rebind, task.revert_rebind,
atom_name=task.name, atom_name=task.name,
optional_args=task.optional optional_args=task.revert_optional
) )
task_uuid = self._storage.get_atom_uuid(task.name) task_uuid = self._storage.get_atom_uuid(task.name)
task_result = self._storage.get(task.name) task_result = self._storage.get(task.name)

View File

@@ -370,8 +370,12 @@ class ActionEngine(base.Engine):
last_node = None last_node = None
missing_nodes = 0 missing_nodes = 0
for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS): for atom in self._runtime.analyzer.iterate_nodes(compiler.ATOMS):
atom_missing = self.storage.fetch_unsatisfied_args( exec_missing = self.storage.fetch_unsatisfied_args(
atom.name, atom.rebind, optional_args=atom.optional) atom.name, atom.rebind, optional_args=atom.optional)
revert_missing = self.storage.fetch_unsatisfied_args(
atom.name, atom.revert_rebind,
optional_args=atom.revert_optional)
atom_missing = exec_missing.union(revert_missing)
if atom_missing: if atom_missing:
cause = exc.MissingDependencies(atom, cause = exc.MissingDependencies(atom,
sorted(atom_missing), sorted(atom_missing),

View File

@@ -60,12 +60,14 @@ class Task(atom.Atom):
TASK_EVENTS = (EVENT_UPDATE_PROGRESS,) TASK_EVENTS = (EVENT_UPDATE_PROGRESS,)
def __init__(self, name=None, provides=None, requires=None, def __init__(self, name=None, provides=None, requires=None,
auto_extract=True, rebind=None, inject=None): auto_extract=True, rebind=None, inject=None,
ignore_list=None, revert_rebind=None, revert_requires=None):
if name is None: if name is None:
name = reflection.get_class_name(self) name = reflection.get_class_name(self)
super(Task, self).__init__(name, provides=provides, requires=requires, super(Task, self).__init__(name, provides=provides, requires=requires,
auto_extract=auto_extract, rebind=rebind, auto_extract=auto_extract, rebind=rebind,
inject=inject) inject=inject, revert_rebind=revert_rebind,
revert_requires=revert_requires)
self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS) self._notifier = notifier.RestrictedNotifier(self.TASK_EVENTS)
@property @property
@@ -137,7 +139,18 @@ class FunctorTask(Task):
self._revert = revert self._revert = revert
if version is not None: if version is not None:
self.version = version self.version = version
self._build_arg_mapping(execute, requires, rebind, auto_extract) mapping = self._build_arg_mapping(execute, requires, rebind,
auto_extract)
self.rebind, exec_requires, self.optional = mapping
if revert:
revert_mapping = self._build_arg_mapping(revert, requires, rebind,
auto_extract)
else:
revert_mapping = (self.rebind, exec_requires, self.optional)
(self.revert_rebind, revert_requires,
self.revert_optional) = revert_mapping
self.requires = exec_requires.union(revert_requires)
def execute(self, *args, **kwargs): def execute(self, *args, **kwargs):
return self._execute(*args, **kwargs) return self._execute(*args, **kwargs)

View File

@@ -163,6 +163,26 @@ class ArgumentsPassingTest(utils.EngineTestBase):
'long_arg_name': 1, 'result': 1 'long_arg_name': 1, 'result': 1
}, engine.storage.fetch_all()) }, engine.storage.fetch_all())
def test_revert_rebound_args_required(self):
flow = utils.TaskMultiArg(revert_rebind={'z': 'b'})
engine = self._make_engine(flow)
engine.storage.inject({'a': 1, 'y': 4, 'c': 9, 'x': 17})
self.assertRaises(exc.MissingDependencies, engine.run)
def test_revert_required_args_required(self):
flow = utils.TaskMultiArg(revert_requires=['a'])
engine = self._make_engine(flow)
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
self.assertRaises(exc.MissingDependencies, engine.run)
def test_derived_revert_args_required(self):
flow = utils.TaskRevertExtraArgs()
engine = self._make_engine(flow)
engine.storage.inject({'y': 4, 'z': 9, 'x': 17})
self.assertRaises(exc.MissingDependencies, engine.run)
engine.storage.inject({'revert_arg': None})
self.assertRaises(exc.ExecutionFailure, engine.run)
class SerialEngineTest(ArgumentsPassingTest, test.TestCase): class SerialEngineTest(ArgumentsPassingTest, test.TestCase):

View File

@@ -49,6 +49,22 @@ class ProgressTask(task.Task):
self.update_progress(value) self.update_progress(value)
class SeparateRevertTask(task.Task):
def execute(self, execute_arg):
pass
def revert(self, revert_arg, result, flow_failures):
pass
class SeparateRevertOptionalTask(task.Task):
def execute(self, execute_arg=None):
pass
def revert(self, result, flow_failures, revert_arg=None):
pass
class TaskTest(test.TestCase): class TaskTest(test.TestCase):
def test_passed_name(self): def test_passed_name(self):
@@ -338,6 +354,26 @@ class TaskTest(test.TestCase):
self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS])) self.assertEqual(2, len(listeners[task.EVENT_UPDATE_PROGRESS]))
self.assertEqual(0, len(a_task.notifier)) self.assertEqual(0, len(a_task.notifier))
def test_separate_revert_args(self):
task = SeparateRevertTask(rebind=('a',), revert_rebind=('b',))
self.assertEqual({'execute_arg': 'a'}, task.rebind)
self.assertEqual({'revert_arg': 'b'}, task.revert_rebind)
self.assertEqual(set(['a', 'b']),
task.requires)
task = SeparateRevertTask(requires='execute_arg',
revert_requires='revert_arg')
self.assertEqual({'execute_arg': 'execute_arg'}, task.rebind)
self.assertEqual({'revert_arg': 'revert_arg'}, task.revert_rebind)
self.assertEqual(set(['execute_arg', 'revert_arg']),
task.requires)
def test_separate_revert_optional_args(self):
task = SeparateRevertOptionalTask()
self.assertEqual(set(['execute_arg']), task.optional)
self.assertEqual(set(['revert_arg']), task.revert_optional)
class FunctorTaskTest(test.TestCase): class FunctorTaskTest(test.TestCase):

View File

@@ -33,7 +33,7 @@ class TestWorker(test.MockTestCase):
self.broker_url = 'test-url' self.broker_url = 'test-url'
self.exchange = 'test-exchange' self.exchange = 'test-exchange'
self.topic = 'test-topic' self.topic = 'test-topic'
self.endpoint_count = 26 self.endpoint_count = 27
# patch classes # patch classes
self.executor_mock, self.executor_inst_mock = self.patchClass( self.executor_mock, self.executor_inst_mock = self.patchClass(

View File

@@ -332,6 +332,14 @@ class NeverRunningTask(task.Task):
assert False, 'This method should not be called' assert False, 'This method should not be called'
class TaskRevertExtraArgs(task.Task):
def execute(self, **kwargs):
raise exceptions.ExecutionFailure("We want to force a revert here")
def revert(self, revert_arg, flow_failures, result, **kwargs):
pass
class EngineTestBase(object): class EngineTestBase(object):
def setUp(self): def setUp(self):
super(EngineTestBase, self).setUp() super(EngineTestBase, self).setUp()