diff --git a/taskflow/patterns/graph_workflow.py b/taskflow/patterns/graph_workflow.py index d25c2cd4..8f8f05ed 100644 --- a/taskflow/patterns/graph_workflow.py +++ b/taskflow/patterns/graph_workflow.py @@ -25,7 +25,7 @@ from networkx.algorithms import dag from networkx.classes import digraph from taskflow import exceptions as exc -from taskflow import ordered_workflow +from taskflow.patterns import ordered_workflow LOG = logging.getLogger(__name__) diff --git a/taskflow/patterns/ordered_workflow.py b/taskflow/patterns/ordered_workflow.py index ddc22a81..24813d8b 100644 --- a/taskflow/patterns/ordered_workflow.py +++ b/taskflow/patterns/ordered_workflow.py @@ -73,6 +73,7 @@ class Workflow(object): @abc.abstractmethod def add(self, task): + """Adds a given task to this workflow.""" raise NotImplementedError() def __str__(self): @@ -80,10 +81,46 @@ class Workflow(object): @abc.abstractmethod def order(self): + """Returns the order in which the tasks should be ran + as a iterable list.""" raise NotImplementedError() - def _fetch_inputs(self, task): - return {} + def _fetch_task_inputs(self, task): + """Retrieves and additional kwargs inputs to provide to the task when + said task is being applied.""" + return None + + def _perform_reconcilation(self, context, task, excp): + # Attempt to reconcile the given exception that occured while applying + # the given task and either reconcile said task and its associated + # failure, so that the workflow can continue or abort and perform + # some type of undo of the tasks already completed. + try: + self._change_state(context, states.REVERTING) + except Exception: + LOG.exception("Dropping exception catched when" + " changing state to reverting while performing" + " reconcilation on a tasks exception.") + cause = exc.TaskException(task, self, excp) + with excutils.save_and_reraise_exception(): + try: + self._on_task_error(context, task) + except Exception: + LOG.exception("Dropping exception catched when" + " notifying about existing task" + " exception.") + # The default strategy will be to rollback all the contained + # tasks by calling there reverting methods, and then calling + # any parent workflows rollbacks (and so-on). + try: + self.rollback(context, cause) + finally: + try: + self._change_state(context, states.FAILURE) + except Exception: + LOG.exception("Dropping exception catched when" + " changing state to failure while performing" + " reconcilation on a tasks exception.") def run(self, context, *args, **kwargs): if self.state != states.PENDING: @@ -96,26 +133,15 @@ class Workflow(object): result_fetcher = None self._change_state(context, states.STARTED) - - # TODO(harlowja): we can likely add in custom reconcilation strategies - # here or around here... - def do_rollback_for(task, ex): - self._change_state(context, states.REVERTING) - with excutils.save_and_reraise_exception(): - try: - self._on_task_error(context, task) - except Exception: - LOG.exception("Dropping exception catched when" - " notifying about existing task" - " exception.") - self.rollback(context, exc.TaskException(task, self, ex)) - self._change_state(context, states.FAILURE) - task_order = self.order() last_task = 0 + was_interrupted = False if result_fetcher: self._change_state(context, states.RESUMING) for (i, task) in enumerate(task_order): + if self.state == states.INTERRUPTED: + was_interrupted = True + break (has_result, result) = result_fetcher(self, task) if not has_result: break @@ -132,11 +158,13 @@ class Workflow(object): # result it returned and not a modified one. self.results.append((task, copy.deepcopy(result))) self._on_task_finish(context, task, result) - except Exception as ex: - do_rollback_for(task, ex) + except Exception as e: + self._perform_reconcilation(context, task, e) + + if was_interrupted: + return self._change_state(context, states.RUNNING) - was_interrupted = False for task in task_order[last_task:]: if self.state == states.INTERRUPTED: was_interrupted = True @@ -148,7 +176,9 @@ class Workflow(object): (has_result, result) = result_fetcher(self, task) self._on_task_start(context, task) if not has_result: - inputs = self._fetch_inputs(task) + inputs = self._fetch_task_inputs(task) + if inputs is None: + inputs = {} inputs.update(kwargs) result = task.apply(context, *args, **inputs) # Keep a pristine copy of the result @@ -158,8 +188,8 @@ class Workflow(object): # and not a modified one. self.results.append((task, copy.deepcopy(result))) self._on_task_finish(context, task, result) - except Exception as ex: - do_rollback_for(task, ex) + except Exception as e: + self._perform_reconcilation(context, task, e) if not was_interrupted: # Only gets here if everything went successfully. @@ -201,16 +231,30 @@ class Workflow(object): f(context, states.SUCCESS, self, task, result=result) def rollback(self, context, cause): + # Performs basic task by task rollback by going through the reverse + # order that tasks have finished and asking said task to undo whatever + # it has done. If this workflow has any parent workflows then they will + # also be called to rollback any tasks said parents contain. + # + # Note(harlowja): if a workflow can more simply revert a whole set of + # tasks via a simpler command then it can override this method to + # accomplish that. + # + # For example, if each task was creating a file in a directory, then + # it's easier to just remove the directory than to ask each task to + # delete its file individually. for (i, (task, result)) in enumerate(reversed(self._reversions)): try: task.revert(context, result, cause) except Exception: # Ex: WARN: Failed rolling back stage 1 (validate_request) of # chain validation due to Y exception. + log_f = LOG.warn + if not self.tolerant: + log_f = LOG.exception msg = ("Failed rolling back stage %(index)s (%(task)s)" " of workflow %(workflow)s, due to inner exception.") - LOG.warn(msg % {'index': (i + 1), 'task': task, - 'workflow': self}) + log_f(msg % {'index': (i + 1), 'task': task, 'workflow': self}) if not self.tolerant: # NOTE(harlowja): LOG a msg AND re-raise the exception if # the chain does not tolerate exceptions happening in the diff --git a/taskflow/tests/unit/test_linear_workflow.py b/taskflow/tests/unit/test_linear_workflow.py new file mode 100644 index 00000000..4ad37d72 --- /dev/null +++ b/taskflow/tests/unit/test_linear_workflow.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import unittest + +from taskflow import states +from taskflow import task +from taskflow import wrappers + +from taskflow.patterns import linear_workflow as lw + + +def null_functor(*args, **kwargs): + return None + + +class LinearWorkflowTest(unittest.TestCase): + def makeRevertingTask(self, token, blowup=False): + + def do_apply(token, context, *args, **kwargs): + context[token] = 'passed' + + def do_revert(token, context, *args, **kwargs): + context[token] = 'reverted' + + def blow_up(context, *args, **kwargs): + raise Exception("I blew up") + + if blowup: + return wrappers.FunctorTask('task-%s' % (token), + functools.partial(blow_up, token), + null_functor) + else: + return wrappers.FunctorTask('task-%s' % (token), + functools.partial(do_apply, token), + functools.partial(do_revert, token)) + + def makeInterruptTask(self, token, wf): + + def do_interrupt(token, context, *args, **kwargs): + wf.interrupt() + + return wrappers.FunctorTask('task-%s' % (token), + functools.partial(do_interrupt, token), + null_functor) + + def testHappyPath(self): + wf = lw.Workflow("the-test-action") + + for i in range(0, 10): + wf.add(self.makeRevertingTask(i)) + + run_context = {} + wf.run(run_context) + + self.assertEquals(10, len(run_context)) + for _k, v in run_context.items(): + self.assertEquals('passed', v) + + def testRevertingPath(self): + wf = lw.Workflow("the-test-action") + wf.add(self.makeRevertingTask(1)) + wf.add(self.makeRevertingTask(2, True)) + + run_context = {} + self.assertRaises(Exception, wf.run, run_context) + self.assertEquals('reverted', run_context[1]) + self.assertEquals(1, len(run_context)) + + def testInterruptPath(self): + wf = lw.Workflow("the-int-action") + + result_storage = {} + + # If we interrupt we need to know how to resume so attach the needed + # parts to do that... + + def result_fetcher(ctx, wf, task): + if task.name in result_storage: + return (True, result_storage.get(task.name)) + return (False, None) + + def task_listener(ctx, state, wf, task, result=None): + if state not in (states.SUCCESS, states.FAILURE,): + return + if task.name not in result_storage: + result_storage[task.name] = result + + wf.result_fetcher = result_fetcher + wf.task_listeners.append(task_listener) + + wf.add(self.makeRevertingTask(1)) + wf.add(self.makeInterruptTask(2, wf)) + wf.add(self.makeRevertingTask(3)) + + self.assertEquals(states.PENDING, wf.state) + context = {} + wf.run(context) + + # Interrupt should have been triggered after task 1 + self.assertEquals(1, len(context)) + self.assertEquals(states.INTERRUPTED, wf.state) + + # And now reset and resume. + wf.reset() + self.assertEquals(states.PENDING, wf.state) + wf.run(context) + self.assertEquals(2, len(context)) + + def testParentRevertingPath(self): + happy_wf = lw.Workflow("the-happy-action") + for i in range(0, 10): + happy_wf.add(self.makeRevertingTask(i)) + context = {} + happy_wf.run(context) + + for (_k, v) in context.items(): + self.assertEquals('passed', v) + + baddy_wf = lw.Workflow("the-bad-action", parents=[happy_wf]) + baddy_wf.add(self.makeRevertingTask(i + 1)) + baddy_wf.add(self.makeRevertingTask(i + 2, True)) + self.assertRaises(Exception, baddy_wf.run, context) + + for (_k, v) in context.items(): + self.assertEquals('reverted', v)