Add the concept of 'tracked_events' to CoroutinePool and CoroutinePool.wait, the method you use to drain the events

This commit is contained in:
donovan
2008-05-30 10:56:32 -07:00
parent 75853995eb
commit 65c09922e6
2 changed files with 42 additions and 4 deletions

View File

@@ -260,8 +260,13 @@ class CoroutinePool(pools.Pool):
foo 4 foo 4
""" """
def __init__(self, min_size=0, max_size=4): def __init__(self, min_size=0, max_size=4, track_events=False):
self._greenlets = set() self._greenlets = set()
if track_events:
self._tracked_events = []
self._next_event = None
else:
self._tracked_events = None
super(CoroutinePool, self).__init__(min_size, max_size) super(CoroutinePool, self).__init__(min_size, max_size)
def _main_loop(self, sender): def _main_loop(self, sender):
@@ -286,6 +291,13 @@ class CoroutinePool(pools.Pool):
result = func(*args, **kw) result = func(*args, **kw)
if evt is not None: if evt is not None:
evt.send(result) evt.send(result)
if self._tracked_events is not None:
if self._next_event is None:
self._tracked_events.append(result)
else:
ne = self._next_event
self._next_event = None
ne.send(result)
except api.GreenletExit, e: except api.GreenletExit, e:
# we're printing this out to see if it ever happens # we're printing this out to see if it ever happens
# in practice # in practice
@@ -354,6 +366,22 @@ class CoroutinePool(pools.Pool):
""" """
self._execute(None, func, args, kw) self._execute(None, func, args, kw)
def wait(self):
"""Wait for the next execute in the pool to complete,
and return the result.
You must pass track_events=True to the CoroutinePool constructor
in order to use this method.
"""
assert self._tracked_events is not None, (
"Must pass track_events=True to the constructor to use CoroutinePool.wait()")
if self._next_event is None:
result = self._tracked_events.pop(0)
if not self._tracked_events:
self._next_event = event()
return result
return self._next_event.wait()
def killall(self): def killall(self):
for g in self._greenlets: for g in self._greenlets:
api.kill(g) api.kill(g)

View File

@@ -110,6 +110,7 @@ class TestEvent(tests.TestCase):
api.exc_after(0.001, api.TimeoutError) api.exc_after(0.001, api.TimeoutError)
self.assertRaises(api.TimeoutError, evt.wait) self.assertRaises(api.TimeoutError, evt.wait)
class TestCoroutinePool(tests.TestCase): class TestCoroutinePool(tests.TestCase):
mode = 'static' mode = 'static'
def setUp(self): def setUp(self):
@@ -206,14 +207,23 @@ class TestCoroutinePool(tests.TestCase):
t.cancel() t.cancel()
finally: finally:
sys.stderr = normal_err sys.stderr = normal_err
def test_track_events(self):
pool = coros.CoroutinePool(track_events=True)
for x in range(6):
pool.execute(lambda n: n, x)
t = api.exc_after(10, RuntimeError)
for y in range(6):
pool.wait()
t.cancel()
class IncrActor(coros.Actor): class IncrActor(coros.Actor):
def received(self, evt): def received(self, evt):
self.value = getattr(self, 'value', 0) + 1 self.value = getattr(self, 'value', 0) + 1
if evt: evt.send() if evt: evt.send()
class TestActor(tests.TestCase): class TestActor(tests.TestCase):
mode = 'static' mode = 'static'
def setUp(self): def setUp(self):
@@ -277,7 +287,6 @@ class TestActor(tests.TestCase):
for evt in waiters: for evt in waiters:
evt.wait() evt.wait()
self.assertEqual(msgs, [1,2,3,4,5]) self.assertEqual(msgs, [1,2,3,4,5])
def test_raising_received(self): def test_raising_received(self):
msgs = [] msgs = []
@@ -325,5 +334,6 @@ class TestActor(tests.TestCase):
self.assertEqual(total[0], 3) self.assertEqual(total[0], 3)
self.assertEqual(self.actor._pool.free(), 2) self.assertEqual(self.actor._pool.free(), 2)
if __name__ == '__main__': if __name__ == '__main__':
tests.main() tests.main()