diff --git a/eventlet/coros.py b/eventlet/coros.py index 8b89281..cd05163 100644 --- a/eventlet/coros.py +++ b/eventlet/coros.py @@ -33,6 +33,13 @@ from eventlet import channel from eventlet import pools from eventlet import greenlib + +try: + set +except NameError: # python 2.3 compatibility + from sets import Set as set + + class Cancelled(RuntimeError): pass @@ -245,27 +252,48 @@ class CoroutinePool(pools.Pool): foo 4 """ + def __init__(self, min_size=0, max_size=4): + self._greenlets = set() + super(CoroutinePool, self).__init__(min_size, max_size) + def _main_loop(self, sender): + """ Private, infinite loop run by a pooled coroutine. """ while True: recvd = sender.wait() sender.reset() (evt, func, args, kw) = recvd - try: - result = func(*args, **kw) - if evt is not None: - evt.send(result) - except api.GreenletExit, e: - # we're printing this out to see if it ever happens - # in practice - print "GreenletExit raised in coroutine pool", e - if evt is not None: - evt.send(e) # sent as a return value, not an exception - except Exception, e: - traceback.print_exc() - if evt is not None: - evt.send(exc=e) + self._safe_apply(evt, func, args, kw) api.get_hub().runloop.cancel_timers(api.getcurrent()) self.put(sender) + + def _safe_apply(self, evt, func, args, kw): + """ Private method that runs the function, catches exceptions, and + passes back the return value in the event.""" + try: + result = func(*args, **kw) + if evt is not None: + evt.send(result) + except api.GreenletExit, e: + # we're printing this out to see if it ever happens + # in practice + print "GreenletExit raised in coroutine pool", e + if evt is not None: + evt.send(e) # sent as a return value, not an exception + except Exception, e: + traceback.print_exc() + if evt is not None: + evt.send(exc=e) + + def _execute(self, evt, func, args, kw): + """ Private implementation of the execute methods. + """ + # if reentering an empty pool, don't try to wait on a coroutine freeing + # itself -- instead, just execute in the current coroutine + if self.free() == 0 and api.getcurrent() in self._greenlets: + self._safe_apply(evt, func, args, kw) + else: + sender = self.get() + sender.send((evt, func, args, kw)) def create(self): """Private implementation of eventlet.pools.Pool @@ -275,9 +303,9 @@ class CoroutinePool(pools.Pool): new coroutine, to be executed. """ sender = event() - api.spawn(self._main_loop, sender) + self._greenlets.add(api.spawn(self._main_loop, sender)) return sender - + def execute(self, func, *args, **kw): """Execute func in one of the coroutines maintained by the pool, when one is free. @@ -291,9 +319,8 @@ class CoroutinePool(pools.Pool): >>> evt.wait() ('foo', 1) """ - sender = self.get() receiver = event() - sender.send((receiver, func, args, kw)) + self._execute(receiver, func, args, kw) return receiver def execute_async(self, func, *args, **kw): @@ -310,8 +337,7 @@ class CoroutinePool(pools.Pool): >>> api.sleep(0) foo 1 """ - sender = self.get() - sender.send((None, func, args, kw)) + self._execute(None, func, args, kw) class pipe(object): diff --git a/eventlet/coros_test.py b/eventlet/coros_test.py index 7a28d3e..30e6cfc 100644 --- a/eventlet/coros_test.py +++ b/eventlet/coros_test.py @@ -112,7 +112,7 @@ class TestCoroutinePool(tests.TestCase): mode = 'static' def setUp(self): # raise an exception if we're waiting forever - self._cancel_timeout = api.exc_after(1, RuntimeError()) + self._cancel_timeout = api.exc_after(1, api.TimeoutError) def tearDown(self): self._cancel_timeout.cancel() @@ -161,6 +161,24 @@ class TestCoroutinePool(tests.TestCase): t = worker.wait() api.sleep(0) self.assertEquals(t.cancelled, True) + + def test_reentrant(self): + pool = coros.CoroutinePool(0,1) + def reenter(): + waiter = pool.execute(lambda a: a, 'reenter') + self.assertEqual('reenter', waiter.wait()) + + outer_waiter = pool.execute(reenter) + outer_waiter.wait() + + evt = coros.event() + def reenter_async(): + pool.execute_async(lambda a: a, 'reenter') + evt.send('done') + + pool.execute_async(reenter_async) + evt.wait() + class IncrActor(coros.Actor): def received(self, message): @@ -229,10 +247,10 @@ class TestActor(tests.TestCase): if message == 'fail': raise RuntimeError() else: - print "appending" msgs.append(message) self.actor.received = received + self.actor.excepted = lambda x: None self.actor.cast('fail') api.sleep(0)