diff --git a/eventlet/coros.py b/eventlet/coros.py index 89e69f6..04aad61 100644 --- a/eventlet/coros.py +++ b/eventlet/coros.py @@ -260,8 +260,13 @@ class CoroutinePool(pools.Pool): foo 4 """ - def __init__(self, min_size=0, max_size=4): + def __init__(self, min_size=0, max_size=4, track_events=False): self._greenlets = set() + if track_events: + self._tracked_events = [] + self._next_event = None + else: + self._tracked_events = None super(CoroutinePool, self).__init__(min_size, max_size) def _main_loop(self, sender): @@ -286,6 +291,13 @@ class CoroutinePool(pools.Pool): result = func(*args, **kw) if evt is not None: evt.send(result) + if self._tracked_events is not None: + if self._next_event is None: + self._tracked_events.append(result) + else: + ne = self._next_event + self._next_event = None + ne.send(result) except api.GreenletExit, e: # we're printing this out to see if it ever happens # in practice @@ -354,6 +366,22 @@ class CoroutinePool(pools.Pool): """ self._execute(None, func, args, kw) + def wait(self): + """Wait for the next execute in the pool to complete, + and return the result. + + You must pass track_events=True to the CoroutinePool constructor + in order to use this method. + """ + assert self._tracked_events is not None, ( + "Must pass track_events=True to the constructor to use CoroutinePool.wait()") + if self._next_event is None: + result = self._tracked_events.pop(0) + if not self._tracked_events: + self._next_event = event() + return result + return self._next_event.wait() + def killall(self): for g in self._greenlets: api.kill(g) diff --git a/eventlet/coros_test.py b/eventlet/coros_test.py index ae83724..fa08761 100644 --- a/eventlet/coros_test.py +++ b/eventlet/coros_test.py @@ -110,6 +110,7 @@ class TestEvent(tests.TestCase): api.exc_after(0.001, api.TimeoutError) self.assertRaises(api.TimeoutError, evt.wait) + class TestCoroutinePool(tests.TestCase): mode = 'static' def setUp(self): @@ -206,14 +207,23 @@ class TestCoroutinePool(tests.TestCase): t.cancel() finally: sys.stderr = normal_err - - + + def test_track_events(self): + pool = coros.CoroutinePool(track_events=True) + for x in range(6): + pool.execute(lambda n: n, x) + t = api.exc_after(10, RuntimeError) + for y in range(6): + pool.wait() + t.cancel() + class IncrActor(coros.Actor): def received(self, evt): self.value = getattr(self, 'value', 0) + 1 if evt: evt.send() + class TestActor(tests.TestCase): mode = 'static' def setUp(self): @@ -277,7 +287,6 @@ class TestActor(tests.TestCase): for evt in waiters: evt.wait() self.assertEqual(msgs, [1,2,3,4,5]) - def test_raising_received(self): msgs = [] @@ -325,5 +334,6 @@ class TestActor(tests.TestCase): self.assertEqual(total[0], 3) self.assertEqual(self.actor._pool.free(), 2) + if __name__ == '__main__': tests.main()