Allow concurrent tasks to complete after error

Add an option to allow concurrent tasks in a DependencyTaskGroup some time
to run to completion before being cancelled in the event of an error in
another task.

Change-Id: Ie3d63a0d5e8c9fa5f0fa74285b06f732685afcbc
This commit is contained in:
Zane Bitter 2014-08-26 18:37:10 -04:00
parent 73736991c1
commit 08a349ac02
2 changed files with 86 additions and 25 deletions

View File

@ -177,6 +177,7 @@ class TaskRunner(object):
raised inside the task.
"""
assert self._runner is None, "Task already started"
assert not self._done, "Task already cancelled"
LOG.debug('%s starting' % str(self))
@ -321,7 +322,8 @@ class DependencyTaskGroup(object):
"""
def __init__(self, dependencies, task=lambda o: o(),
reverse=False, name=None, aggregate_exceptions=False):
reverse=False, name=None, error_wait_time=None,
aggregate_exceptions=False):
"""
Initialise with the task dependencies and (optionally) a task to run on
each.
@ -330,12 +332,19 @@ class DependencyTaskGroup(object):
directly in the dependency tree. If a task is supplied, the object
stored in the dependency tree is passed as an argument.
If aggregate_exceptions is set to True, then all the tasks will be run
and any raised exceptions will be stored to be re-raised after all
tasks are done.
If an error_wait_time is specified, tasks that are already running at
the time of an error will continue to run for up to the specified
time before being cancelled. Once all remaining tasks are complete or
have been cancelled, the original exception is raised.
If aggregate_exceptions is True, then execution of parallel operations
will not be cancelled in the event of an error (operations downstream
of the error will be cancelled). Once all chains are complete, any
errors will be rolled up into an ExceptionGroup exception.
"""
self._runners = dict((o, TaskRunner(task, o)) for o in dependencies)
self._graph = dependencies.graph(reverse=reverse)
self.error_wait_time = error_wait_time
self.aggregate_exceptions = aggregate_exceptions
if name is None:
@ -351,29 +360,35 @@ class DependencyTaskGroup(object):
def __call__(self):
"""Return a co-routine which runs the task group."""
raised_exceptions = []
try:
while any(self._runners.itervalues()):
try:
for k, r in self._ready():
r.start()
while any(self._runners.itervalues()):
try:
for k, r in self._ready():
r.start()
yield
yield
for k, r in self._running():
if r.step():
del self._graph[k]
except Exception as e:
for k, r in self._running():
if r.step():
del self._graph[k]
except Exception:
exc_info = sys.exc_info()
if self.aggregate_exceptions:
self._cancel_recursively(k, r)
if not self.aggregate_exceptions:
raise
raised_exceptions.append(e)
except: # noqa
with excutils.save_and_reraise_exception():
for r in self._runners.itervalues():
r.cancel()
else:
for r in self._runners.itervalues():
r.cancel(grace_period=self.error_wait_time)
raised_exceptions.append(exc_info)
except: # noqa
with excutils.save_and_reraise_exception():
for r in self._runners.itervalues():
r.cancel()
if raised_exceptions:
raise ExceptionGroup(raised_exceptions)
if self.aggregate_exceptions:
raise ExceptionGroup(v for t, v, tb in raised_exceptions)
else:
exc_type, exc_val, traceback = raised_exceptions[0]
raise exc_type, exc_val, traceback
def _cancel_recursively(self, key, runner):
runner.cancel()
@ -392,7 +407,7 @@ class DependencyTaskGroup(object):
for k, n in six.iteritems(self._graph):
if not n:
runner = self._runners[k]
if not runner.started():
if runner and not runner.started():
yield k, runner
def _running(self):

View File

@ -195,6 +195,7 @@ class DependencyTaskGroupTest(HeatTestCase):
super(DependencyTaskGroupTest, self).setUp()
self.addCleanup(self.m.VerifyAll)
self.aggregate_exceptions = False
self.error_wait_time = None
self.reverse_order = False
@contextlib.contextmanager
@ -204,8 +205,9 @@ class DependencyTaskGroupTest(HeatTestCase):
deps = dependencies.Dependencies(edges)
tg = scheduler.DependencyTaskGroup(
deps, dummy, aggregate_exceptions=self.aggregate_exceptions,
reverse=self.reverse_order)
deps, dummy, reverse=self.reverse_order,
error_wait_time=self.error_wait_time,
aggregate_exceptions=self.aggregate_exceptions)
self.m.StubOutWithMock(dummy, 'do_step')
@ -399,6 +401,44 @@ class DependencyTaskGroupTest(HeatTestCase):
run_tasks_with_exceptions, e1)
self.assertEqual([e1], exc.exceptions)
def test_exception_grace_period(self):
e1 = Exception('e1')
def run_tasks_with_exceptions():
self.error_wait_time = 5
tasks = (('A', None), ('B', None), ('C', 'A'))
with self._dep_test(*tasks) as dummy:
dummy.do_step(1, 'A').InAnyOrder('1')
dummy.do_step(1, 'B').InAnyOrder('1')
dummy.do_step(2, 'A').InAnyOrder('2').AndRaise(e1)
dummy.do_step(2, 'B').InAnyOrder('2')
dummy.do_step(3, 'B')
exc = self.assertRaises(Exception, run_tasks_with_exceptions)
self.assertEqual(e1, exc)
def test_exception_grace_period_expired(self):
e1 = Exception('e1')
def run_tasks_with_exceptions():
self.steps = 5
self.error_wait_time = 0.05
def sleep():
eventlet.sleep(self.error_wait_time)
tasks = (('A', None), ('B', None), ('C', 'A'))
with self._dep_test(*tasks) as dummy:
dummy.do_step(1, 'A').InAnyOrder('1')
dummy.do_step(1, 'B').InAnyOrder('1')
dummy.do_step(2, 'A').InAnyOrder('2').AndRaise(e1)
dummy.do_step(2, 'B').InAnyOrder('2')
dummy.do_step(3, 'B')
dummy.do_step(4, 'B').WithSideEffects(sleep)
exc = self.assertRaises(Exception, run_tasks_with_exceptions)
self.assertEqual(e1, exc)
class TaskTest(HeatTestCase):
@ -570,6 +610,12 @@ class TaskTest(HeatTestCase):
runner.start()
self.assertRaises(AssertionError, runner.start)
def test_start_cancelled(self):
runner = scheduler.TaskRunner(DummyTask())
runner.cancel()
self.assertRaises(AssertionError, runner.start)
def test_call_double_start(self):
runner = scheduler.TaskRunner(DummyTask())