diff --git a/taskflow/patterns/linear_flow.py b/taskflow/patterns/linear_flow.py index 14f69f0b..5f7e9a80 100644 --- a/taskflow/patterns/linear_flow.py +++ b/taskflow/patterns/linear_flow.py @@ -172,7 +172,7 @@ class Flow(flow.Flow): # Add the task to be rolled back *immediately* so that even if # the task fails while producing results it will be given a # chance to rollback. - rb = utils.RollbackTask(context, runner.task, result=None) + rb = utils.Rollback(context, runner, self, self.task_notifier) self._accumulator.add(rb) self.task_notifier.notify(states.STARTED, details={ 'context': context, @@ -198,13 +198,7 @@ class Flow(flow.Flow): " object: %s", result) result = exc.InvalidStateException() raise result - # Adjust the task result in the accumulator before - # notifying others that the task has finished to - # avoid the case where a listener might throw an - # exception. - rb.result = result - runner.result = result - self.results[runner.uuid] = result + self.results[runner.uuid] = runner.result self.task_notifier.notify(states.SUCCESS, details={ 'context': context, 'flow': self, diff --git a/taskflow/patterns/threaded_flow.py b/taskflow/patterns/threaded_flow.py index 38267388..0e4de140 100644 --- a/taskflow/patterns/threaded_flow.py +++ b/taskflow/patterns/threaded_flow.py @@ -423,7 +423,8 @@ class Flow(flow.Flow): accum = utils.RollbackAccumulator() for r in self._graph.nodes_iter(): if r.has_ran(): - accum.add(utils.RollbackTask(context, r.task, r.result)) + accum.add(utils.Rollback(context, r, + self, self.task_notifier)) try: self._change_state(context, states.REVERTING) accum.rollback(cause) diff --git a/taskflow/states.py b/taskflow/states.py index b3ff9293..aeb094e9 100644 --- a/taskflow/states.py +++ b/taskflow/states.py @@ -30,6 +30,7 @@ INTERRUPTED = 'INTERRUPTED' PENDING = 'PENDING' RESUMING = 'RESUMING' REVERTING = 'REVERTING' +REVERTED = 'REVERTED' RUNNING = RUNNING STARTED = 'STARTED' SUCCESS = SUCCESS @@ -42,3 +43,5 @@ STARTED = STARTED SUCCESS = SUCCESS TIMED_OUT = 'TIMED_OUT' CANCELLED = CANCELLED +REVERTED = REVERTED +REVERTING = REVERTING diff --git a/taskflow/tests/unit/test_linear_flow.py b/taskflow/tests/unit/test_linear_flow.py index ffe9339e..b14f4d46 100644 --- a/taskflow/tests/unit/test_linear_flow.py +++ b/taskflow/tests/unit/test_linear_flow.py @@ -16,6 +16,8 @@ # License for the specific language governing permissions and limitations # under the License. +import collections + from taskflow import decorators from taskflow import exceptions as exc from taskflow import states @@ -135,21 +137,43 @@ class LinearFlowTest(test.TestCase): wf.add(self.make_reverting_task(i)) run_context = {} + capture_func, captured = self._capture_states() + wf.task_notifier.register('*', capture_func) wf.run(run_context) self.assertEquals(10, len(run_context)) + self.assertEquals(10, len(captured)) for _k, v in run_context.items(): self.assertEquals('passed', v) + for _uuid, u_states in captured.items(): + self.assertEquals([states.STARTED, states.SUCCESS], u_states) + + def _capture_states(self): + capture_where = collections.defaultdict(list) + + def do_capture(state, details): + runner = details.get('runner') + if not runner: + return + capture_where[runner.uuid].append(state) + + return (do_capture, capture_where) def test_reverting_flow(self): wf = lw.Flow("the-test-action") - wf.add(self.make_reverting_task(1)) - wf.add(self.make_reverting_task(2, True)) + ok_uuid = wf.add(self.make_reverting_task(1)) + broke_uuid = wf.add(self.make_reverting_task(2, True)) + capture_func, captured = self._capture_states() + wf.task_notifier.register('*', capture_func) run_context = {} self.assertRaises(Exception, wf.run, run_context) self.assertEquals('reverted', run_context[1]) self.assertEquals(1, len(run_context)) + self.assertEquals([states.STARTED, states.SUCCESS, states.REVERTING, + states.REVERTED], captured[ok_uuid]) + self.assertEquals([states.STARTED, states.FAILURE, states.REVERTING, + states.REVERTED], captured[broke_uuid]) def test_not_satisfied_inputs_previous(self): wf = lw.Flow("the-test-action") diff --git a/taskflow/tests/unit/test_threaded_flow.py b/taskflow/tests/unit/test_threaded_flow.py index a6a70285..7bcd85d0 100644 --- a/taskflow/tests/unit/test_threaded_flow.py +++ b/taskflow/tests/unit/test_threaded_flow.py @@ -327,7 +327,7 @@ class ThreadedFlowTest(test.TestCase): context = {} self.assertRaises(IOError, flo.run, context) self.assertEquals(states.FAILURE, flo.state) - self.assertEquals(states.FAILURE, history[f_uuid][-1]) + self.assertEquals(states.REVERTED, history[f_uuid][-1]) self.assertTrue(context.get('reverted')) def test_failure_cancel_successors(self): @@ -351,7 +351,7 @@ class ThreadedFlowTest(test.TestCase): context = {} self.assertRaises(IOError, flo.run, context) self.assertEquals(states.FAILURE, flo.state) - self.assertEquals(states.FAILURE, history[fq][-1]) + self.assertEquals(states.REVERTED, history[fq][-1]) self.assertEquals(states.CANCELLED, history[af][-1]) self.assertEquals(states.CANCELLED, history[af2][-1]) diff --git a/taskflow/utils.py b/taskflow/utils.py index f82bdd47..6aa8d623 100644 --- a/taskflow/utils.py +++ b/taskflow/utils.py @@ -25,12 +25,14 @@ import logging import threading import time import types +import weakref import threading2 from distutils import version from taskflow.openstack.common import uuidutils +from taskflow import states TASK_FACTORY_ATTRIBUTE = '_TaskFlow_task_factory' LOG = logging.getLogger(__name__) @@ -263,32 +265,10 @@ class FlowFailure(object): return self.runner.exc_info[1] -class RollbackTask(object): - """A helper task that on being called will call the underlying callable - tasks revert method (if said method exists). - """ - - def __init__(self, context, task, result): - self.task = task - self.result = result - self.context = context - - def __str__(self): - return str(self.task) - - def __call__(self, cause): - if ((hasattr(self.task, "revert") and - isinstance(self.task.revert, collections.Callable))): - self.task.revert(self.context, self.result, cause) - - class Runner(object): """A helper class that wraps a task and can find the needed inputs for the task to run, as well as providing a uuid and other useful functionality for users of the task. - - TODO(harlowja): replace with the task details object or a subclass of - that??? """ def __init__(self, task, uuid=None): @@ -332,7 +312,9 @@ class Runner(object): @property def name(self): - return self.task.name + if hasattr(self.task, 'name'): + return self.task.name + return '?' def reset(self): self.result = None @@ -452,6 +434,43 @@ class TransitionNotifier(object): break +class Rollback(object): + """A helper functor object that on being called will call the underlying + runners tasks revert method (if said method exists) and do the appropriate + notification to signal to others that the reverting is underway. + """ + + def __init__(self, context, runner, flow, notifier): + self.runner = runner + self.context = context + self.notifier = notifier + # Use weak references to give the GC a break. + self.flow = weakref.proxy(flow) + + def __str__(self): + return "Rollback: %s" % (self.runner) + + def _fire_notify(self, has_reverted): + if self.notifier: + if has_reverted: + state = states.REVERTED + else: + state = states.REVERTING + self.notifier.notify(state, details={ + 'context': self.context, + 'flow': self.flow, + 'runner': self.runner, + }) + + def __call__(self, cause): + self._fire_notify(False) + task = self.runner.task + if ((hasattr(task, "revert") and + isinstance(task.revert, collections.Callable))): + task.revert(self.context, self.runner.result, cause) + self._fire_notify(True) + + class RollbackAccumulator(object): """A utility class that can help in organizing 'undo' like code so that said code be rolled back on failure (automatically or manually)