diff --git a/heat/engine/scheduler.py b/heat/engine/scheduler.py index 7ad19a6858..b4f750695a 100644 --- a/heat/engine/scheduler.py +++ b/heat/engine/scheduler.py @@ -215,10 +215,18 @@ def wrappertask(task): def wrapper(*args, **kwargs): parent = task(*args, **kwargs) - for subtask in parent: + subtask = next(parent) + + while True: try: if subtask is not None: - for step in subtask: + subtask_running = True + try: + step = next(subtask) + except StopIteration: + subtask_running = False + + while subtask_running: try: yield step except GeneratorExit as exit: @@ -226,19 +234,23 @@ def wrappertask(task): raise exit except: try: - subtask.throw(*sys.exc_info()) + step = subtask.throw(*sys.exc_info()) except StopIteration: - break + subtask_running = False + else: + try: + step = next(subtask) + except StopIteration: + subtask_running = False else: yield except GeneratorExit as exit: parent.close() raise exit except: - try: - parent.throw(*sys.exc_info()) - except StopIteration: - break + subtask = parent.throw(*sys.exc_info()) + else: + subtask = next(parent) return wrapper diff --git a/heat/tests/test_scheduler.py b/heat/tests/test_scheduler.py index f331085b4e..95b0caebb1 100644 --- a/heat/tests/test_scheduler.py +++ b/heat/tests/test_scheduler.py @@ -691,6 +691,143 @@ class WrapperTaskTest(mox.MoxTestBase): task.next() task.next() + def test_child_exception_swallow_next(self): + class MyException(Exception): + pass + + def child_task(): + yield + + raise MyException() + + dummy = DummyTask() + + @scheduler.wrappertask + def parent_task(): + try: + yield child_task() + except MyException: + pass + else: + self.fail('No exception raised in parent_task') + + yield dummy() + + task = parent_task() + task.next() + + self.mox.StubOutWithMock(dummy, 'do_step') + for i in range(1, dummy.num_steps + 1): + dummy.do_step(i).AndReturn(None) + self.mox.ReplayAll() + + for i in range(1, dummy.num_steps + 1): + task.next() + self.assertRaises(StopIteration, task.next) + + def test_thrown_exception_swallow_next(self): + class MyException(Exception): + pass + + dummy = DummyTask() + + @scheduler.wrappertask + def child_task(): + try: + yield + except MyException: + yield dummy() + else: + self.fail('No exception raised in child_task') + + @scheduler.wrappertask + def parent_task(): + yield child_task() + + task = parent_task() + + self.mox.StubOutWithMock(dummy, 'do_step') + for i in range(1, dummy.num_steps + 1): + dummy.do_step(i).AndReturn(None) + self.mox.ReplayAll() + + next(task) + task.throw(MyException) + + for i in range(2, dummy.num_steps + 1): + task.next() + self.assertRaises(StopIteration, task.next) + + def test_thrown_exception_raise(self): + class MyException(Exception): + pass + + dummy = DummyTask() + + @scheduler.wrappertask + def child_task(): + try: + yield + except MyException: + raise + else: + self.fail('No exception raised in child_task') + + @scheduler.wrappertask + def parent_task(): + try: + yield child_task() + except MyException: + yield dummy() + + task = parent_task() + + self.mox.StubOutWithMock(dummy, 'do_step') + for i in range(1, dummy.num_steps + 1): + dummy.do_step(i).AndReturn(None) + self.mox.ReplayAll() + + next(task) + task.throw(MyException) + + for i in range(2, dummy.num_steps + 1): + task.next() + self.assertRaises(StopIteration, task.next) + + def test_thrown_exception_exit(self): + class MyException(Exception): + pass + + dummy = DummyTask() + + @scheduler.wrappertask + def child_task(): + try: + yield + except MyException: + return + else: + self.fail('No exception raised in child_task') + + @scheduler.wrappertask + def parent_task(): + yield child_task() + yield dummy() + + task = parent_task() + + self.mox.StubOutWithMock(dummy, 'do_step') + for i in range(1, dummy.num_steps + 1): + dummy.do_step(i).AndReturn(None) + self.mox.ReplayAll() + + next(task) + task.throw(MyException) + + for i in range(2, dummy.num_steps + 1): + task.next() + self.assertRaises(StopIteration, task.next) + def test_parent_exception(self): class MyException(Exception): pass