diff --git a/heat/engine/dependencies.py b/heat/engine/dependencies.py index 9ac3f038bc..819b4dbe6b 100644 --- a/heat/engine/dependencies.py +++ b/heat/engine/dependencies.py @@ -128,7 +128,9 @@ class Graph(collections.defaultdict): node = self[key] for src in node.required_by(): - self[src] -= key + src_node = self[src] + if key in src_node: + src_node -= key return super(Graph, self).__delitem__(key) diff --git a/heat/engine/scheduler.py b/heat/engine/scheduler.py index c609fe7214..111a47e976 100644 --- a/heat/engine/scheduler.py +++ b/heat/engine/scheduler.py @@ -75,6 +75,29 @@ class Timeout(BaseException): return wallclock() > self._endtime +class ExceptionGroup(Exception): + ''' + Container for multiple exceptions. + + This exception is used by DependencyTaskGroup when the flag + aggregate_exceptions is set to True and it's re-raised again when all tasks + are finished. This way it can be caught later on so that the individual + exceptions can be acted upon. + ''' + + def __init__(self, exceptions=None): + if exceptions is None: + exceptions = list() + + self.exceptions = list(exceptions) + + def __str__(self): + return str(map(str, self.exceptions)) + + def __unicode__(self): + return unicode(map(str, self.exceptions)) + + class TaskRunner(object): """ Wrapper for a resumable task (co-routine). @@ -182,11 +205,14 @@ class TaskRunner(object): self._sleep(wait_time) def cancel(self): - """Cancel the task if it is running.""" - if self.started() and not self.done(): + """Cancel the task and mark it as done.""" + if not self.done(): logger.debug(_('%s cancelled') % str(self)) - self._runner.close() - self._done = True + try: + if self.started(): + self._runner.close() + finally: + self._done = True def started(self): """Return True if the task has been started.""" @@ -268,7 +294,7 @@ class DependencyTaskGroup(object): """ def __init__(self, dependencies, task=lambda o: o(), - reverse=False, name=None): + reverse=False, name=None, aggregate_exceptions=False): """ Initialise with the task dependencies and (optionally) a task to run on each. @@ -276,9 +302,14 @@ class DependencyTaskGroup(object): If no task is supplied, it is assumed that the tasks are stored directly in the dependency tree. If a task is supplied, the object stored in the dependency tree is passed as an argument. + + If aggregate_exceptions is set to True, then all the tasks will be run + and any raised exceptions will be stored to be re-raised after all + tasks are done. """ self._runners = dict((o, TaskRunner(task, o)) for o in dependencies) self._graph = dependencies.graph(reverse=reverse) + self.aggregate_exceptions = aggregate_exceptions if name is None: name = '(%s) %s' % (getattr(task, '__name__', @@ -292,21 +323,40 @@ class DependencyTaskGroup(object): def __call__(self): """Return a co-routine which runs the task group.""" + raised_exceptions = [] try: while any(self._runners.itervalues()): - for k, r in self._ready(): - r.start() + try: + for k, r in self._ready(): + r.start() - yield + yield - for k, r in self._running(): - if r.step(): - del self._graph[k] + for k, r in self._running(): + if r.step(): + del self._graph[k] + except Exception as e: + self._cancel_recursively(k, r) + if not self.aggregate_exceptions: + raise + raised_exceptions.append(e) except: with excutils.save_and_reraise_exception(): for r in self._runners.itervalues(): r.cancel() + if raised_exceptions: + raise ExceptionGroup(raised_exceptions) + + def _cancel_recursively(self, key, runner): + runner.cancel() + node = self._graph[key] + for dependent_node in node.required_by(): + node_runner = self._runners[dependent_node] + self._cancel_recursively(dependent_node, node_runner) + + del self._graph[key] + def _ready(self): """ Iterate over all subtasks that are ready to start - i.e. all their diff --git a/heat/tests/test_scheduler.py b/heat/tests/test_scheduler.py index c146fa396b..a6a0c0cf6d 100644 --- a/heat/tests/test_scheduler.py +++ b/heat/tests/test_scheduler.py @@ -163,11 +163,41 @@ class PollingTaskGroupTest(HeatTestCase): dummy.do_step(1, i, i * i) -class DependencyTaskGroupTest(HeatTestCase): +class ExceptionGroupTest(HeatTestCase): + def test_contains_exceptions(self): + exception_group = scheduler.ExceptionGroup() + self.assertIsInstance(exception_group.exceptions, list) + + def test_can_be_initialized_with_a_list_of_exceptions(self): + ex1 = Exception("ex 1") + ex2 = Exception("ex 2") + + exception_group = scheduler.ExceptionGroup([ex1, ex2]) + self.assertIn(ex1, exception_group.exceptions) + self.assertIn(ex2, exception_group.exceptions) + + def test_can_add_exceptions_after_init(self): + ex = Exception() + exception_group = scheduler.ExceptionGroup() + + exception_group.exceptions.append(ex) + self.assertIn(ex, exception_group.exceptions) + + def test_str_representation_aggregates_all_exceptions(self): + ex1 = Exception("ex 1") + ex2 = Exception("ex 2") + + exception_group = scheduler.ExceptionGroup([ex1, ex2]) + self.assertEqual("['ex 1', 'ex 2']", str(exception_group)) + + +class DependencyTaskGroupTest(HeatTestCase): def setUp(self): super(DependencyTaskGroupTest, self).setUp() self.addCleanup(self.m.VerifyAll) + self.aggregate_exceptions = False + self.reverse_order = False @contextlib.contextmanager def _dep_test(self, *edges): @@ -175,7 +205,9 @@ class DependencyTaskGroupTest(HeatTestCase): deps = dependencies.Dependencies(edges) - tg = scheduler.DependencyTaskGroup(deps, dummy) + tg = scheduler.DependencyTaskGroup( + deps, dummy, aggregate_exceptions=self.aggregate_exceptions, + reverse=self.reverse_order) self.m.StubOutWithMock(dummy, 'do_step') @@ -320,6 +352,54 @@ class DependencyTaskGroupTest(HeatTestCase): self.assertRaises(dependencies.CircularDependencyException, scheduler.DependencyTaskGroup, d) + def test_aggregate_exceptions_raises_all_at_the_end(self): + def run_tasks_with_exceptions(e1=None, e2=None): + self.aggregate_exceptions = True + tasks = (('A', None), ('B', None), ('C', None)) + with self._dep_test(*tasks) as dummy: + dummy.do_step(1, 'A').InAnyOrder('1') + dummy.do_step(1, 'B').InAnyOrder('1') + dummy.do_step(1, 'C').InAnyOrder('1').AndRaise(e1) + + dummy.do_step(2, 'A').InAnyOrder('2') + dummy.do_step(2, 'B').InAnyOrder('2').AndRaise(e2) + + dummy.do_step(3, 'A').InAnyOrder('3') + + e1 = Exception('e1') + e2 = Exception('e2') + + exc = self.assertRaises(scheduler.ExceptionGroup, + run_tasks_with_exceptions, e1, e2) + self.assertEqual(set([e1, e2]), set(exc.exceptions)) + + def test_aggregate_exceptions_cancels_dependent_tasks_recursively(self): + def run_tasks_with_exceptions(e1=None, e2=None): + self.aggregate_exceptions = True + tasks = (('A', None), ('B', 'A'), ('C', 'B')) + with self._dep_test(*tasks) as dummy: + dummy.do_step(1, 'A').AndRaise(e1) + + e1 = Exception('e1') + + exc = self.assertRaises(scheduler.ExceptionGroup, + run_tasks_with_exceptions, e1) + self.assertEqual([e1], exc.exceptions) + + def test_aggregate_exceptions_cancels_tasks_in_reverse_order(self): + def run_tasks_with_exceptions(e1=None, e2=None): + self.reverse_order = True + self.aggregate_exceptions = True + tasks = (('A', None), ('B', 'A'), ('C', 'B')) + with self._dep_test(*tasks) as dummy: + dummy.do_step(1, 'C').AndRaise(e1) + + e1 = Exception('e1') + + exc = self.assertRaises(scheduler.ExceptionGroup, + run_tasks_with_exceptions, e1) + self.assertEqual([e1], exc.exceptions) + class TaskTest(HeatTestCase):