Allow concurrent tasks to complete after error

Add an option to allow concurrent tasks in a DependencyTaskGroup some time
to run to completion before being cancelled in the event of an error in
another task.

Change-Id: Ie3d63a0d5e8c9fa5f0fa74285b06f732685afcbc
This commit is contained in:
Zane Bitter 2014-08-26 18:37:10 -04:00
parent 73736991c1
commit 08a349ac02
2 changed files with 86 additions and 25 deletions

View File

@ -177,6 +177,7 @@ class TaskRunner(object):
raised inside the task. raised inside the task.
""" """
assert self._runner is None, "Task already started" assert self._runner is None, "Task already started"
assert not self._done, "Task already cancelled"
LOG.debug('%s starting' % str(self)) LOG.debug('%s starting' % str(self))
@ -321,7 +322,8 @@ class DependencyTaskGroup(object):
""" """
def __init__(self, dependencies, task=lambda o: o(), def __init__(self, dependencies, task=lambda o: o(),
reverse=False, name=None, aggregate_exceptions=False): reverse=False, name=None, error_wait_time=None,
aggregate_exceptions=False):
""" """
Initialise with the task dependencies and (optionally) a task to run on Initialise with the task dependencies and (optionally) a task to run on
each. each.
@ -330,12 +332,19 @@ class DependencyTaskGroup(object):
directly in the dependency tree. If a task is supplied, the object directly in the dependency tree. If a task is supplied, the object
stored in the dependency tree is passed as an argument. stored in the dependency tree is passed as an argument.
If aggregate_exceptions is set to True, then all the tasks will be run If an error_wait_time is specified, tasks that are already running at
and any raised exceptions will be stored to be re-raised after all the time of an error will continue to run for up to the specified
tasks are done. time before being cancelled. Once all remaining tasks are complete or
have been cancelled, the original exception is raised.
If aggregate_exceptions is True, then execution of parallel operations
will not be cancelled in the event of an error (operations downstream
of the error will be cancelled). Once all chains are complete, any
errors will be rolled up into an ExceptionGroup exception.
""" """
self._runners = dict((o, TaskRunner(task, o)) for o in dependencies) self._runners = dict((o, TaskRunner(task, o)) for o in dependencies)
self._graph = dependencies.graph(reverse=reverse) self._graph = dependencies.graph(reverse=reverse)
self.error_wait_time = error_wait_time
self.aggregate_exceptions = aggregate_exceptions self.aggregate_exceptions = aggregate_exceptions
if name is None: if name is None:
@ -351,29 +360,35 @@ class DependencyTaskGroup(object):
def __call__(self): def __call__(self):
"""Return a co-routine which runs the task group.""" """Return a co-routine which runs the task group."""
raised_exceptions = [] raised_exceptions = []
try: while any(self._runners.itervalues()):
while any(self._runners.itervalues()): try:
try: for k, r in self._ready():
for k, r in self._ready(): r.start()
r.start()
yield yield
for k, r in self._running(): for k, r in self._running():
if r.step(): if r.step():
del self._graph[k] del self._graph[k]
except Exception as e: except Exception:
exc_info = sys.exc_info()
if self.aggregate_exceptions:
self._cancel_recursively(k, r) self._cancel_recursively(k, r)
if not self.aggregate_exceptions: else:
raise for r in self._runners.itervalues():
raised_exceptions.append(e) r.cancel(grace_period=self.error_wait_time)
except: # noqa raised_exceptions.append(exc_info)
with excutils.save_and_reraise_exception(): except: # noqa
for r in self._runners.itervalues(): with excutils.save_and_reraise_exception():
r.cancel() for r in self._runners.itervalues():
r.cancel()
if raised_exceptions: if raised_exceptions:
raise ExceptionGroup(raised_exceptions) if self.aggregate_exceptions:
raise ExceptionGroup(v for t, v, tb in raised_exceptions)
else:
exc_type, exc_val, traceback = raised_exceptions[0]
raise exc_type, exc_val, traceback
def _cancel_recursively(self, key, runner): def _cancel_recursively(self, key, runner):
runner.cancel() runner.cancel()
@ -392,7 +407,7 @@ class DependencyTaskGroup(object):
for k, n in six.iteritems(self._graph): for k, n in six.iteritems(self._graph):
if not n: if not n:
runner = self._runners[k] runner = self._runners[k]
if not runner.started(): if runner and not runner.started():
yield k, runner yield k, runner
def _running(self): def _running(self):

View File

@ -195,6 +195,7 @@ class DependencyTaskGroupTest(HeatTestCase):
super(DependencyTaskGroupTest, self).setUp() super(DependencyTaskGroupTest, self).setUp()
self.addCleanup(self.m.VerifyAll) self.addCleanup(self.m.VerifyAll)
self.aggregate_exceptions = False self.aggregate_exceptions = False
self.error_wait_time = None
self.reverse_order = False self.reverse_order = False
@contextlib.contextmanager @contextlib.contextmanager
@ -204,8 +205,9 @@ class DependencyTaskGroupTest(HeatTestCase):
deps = dependencies.Dependencies(edges) deps = dependencies.Dependencies(edges)
tg = scheduler.DependencyTaskGroup( tg = scheduler.DependencyTaskGroup(
deps, dummy, aggregate_exceptions=self.aggregate_exceptions, deps, dummy, reverse=self.reverse_order,
reverse=self.reverse_order) error_wait_time=self.error_wait_time,
aggregate_exceptions=self.aggregate_exceptions)
self.m.StubOutWithMock(dummy, 'do_step') self.m.StubOutWithMock(dummy, 'do_step')
@ -399,6 +401,44 @@ class DependencyTaskGroupTest(HeatTestCase):
run_tasks_with_exceptions, e1) run_tasks_with_exceptions, e1)
self.assertEqual([e1], exc.exceptions) self.assertEqual([e1], exc.exceptions)
def test_exception_grace_period(self):
e1 = Exception('e1')
def run_tasks_with_exceptions():
self.error_wait_time = 5
tasks = (('A', None), ('B', None), ('C', 'A'))
with self._dep_test(*tasks) as dummy:
dummy.do_step(1, 'A').InAnyOrder('1')
dummy.do_step(1, 'B').InAnyOrder('1')
dummy.do_step(2, 'A').InAnyOrder('2').AndRaise(e1)
dummy.do_step(2, 'B').InAnyOrder('2')
dummy.do_step(3, 'B')
exc = self.assertRaises(Exception, run_tasks_with_exceptions)
self.assertEqual(e1, exc)
def test_exception_grace_period_expired(self):
e1 = Exception('e1')
def run_tasks_with_exceptions():
self.steps = 5
self.error_wait_time = 0.05
def sleep():
eventlet.sleep(self.error_wait_time)
tasks = (('A', None), ('B', None), ('C', 'A'))
with self._dep_test(*tasks) as dummy:
dummy.do_step(1, 'A').InAnyOrder('1')
dummy.do_step(1, 'B').InAnyOrder('1')
dummy.do_step(2, 'A').InAnyOrder('2').AndRaise(e1)
dummy.do_step(2, 'B').InAnyOrder('2')
dummy.do_step(3, 'B')
dummy.do_step(4, 'B').WithSideEffects(sleep)
exc = self.assertRaises(Exception, run_tasks_with_exceptions)
self.assertEqual(e1, exc)
class TaskTest(HeatTestCase): class TaskTest(HeatTestCase):
@ -570,6 +610,12 @@ class TaskTest(HeatTestCase):
runner.start() runner.start()
self.assertRaises(AssertionError, runner.start) self.assertRaises(AssertionError, runner.start)
def test_start_cancelled(self):
runner = scheduler.TaskRunner(DummyTask())
runner.cancel()
self.assertRaises(AssertionError, runner.start)
def test_call_double_start(self): def test_call_double_start(self):
runner = scheduler.TaskRunner(DummyTask()) runner = scheduler.TaskRunner(DummyTask())