diff --git a/eventlet/api.py b/eventlet/api.py index 2985c9e..f3e79fe 100644 --- a/eventlet/api.py +++ b/eventlet/api.py @@ -129,43 +129,67 @@ def _spawn_startup(cb, args, kw, cancel=None): return cb(*args, **kw) -class ResultGreenlet(Greenlet): - def __init__(self): - Greenlet.__init__(self, self.main) +class GreenThread(Greenlet): + def __init__(self, parent): + Greenlet.__init__(self, self.main, parent) from eventlet import coros - self._exit_event = coros.event() + self._exit_event = coros.Event() def wait(self): return self._exit_event.wait() - def link(self, func): - self._exit_funcs = getattr(self, '_exit_funcs', []) - self._exit_funcs.append(func) + def link(self, func, *curried_args, **curried_kwargs): + """ Set up a function to be called with the results of the GreenThread. - def main(self, *a): - function, args, kwargs = a + The function must have the following signature: + def f(result=None, exc=None, [curried args/kwargs]): + """ + self._exit_funcs = getattr(self, '_exit_funcs', []) + self._exit_funcs.append((func, curried_args, curried_kwargs)) + + def main(self, function, args, kwargs): try: result = function(*args, **kwargs) except: self._exit_event.send_exception(*sys.exc_info()) - for f in getattr(self, '_exit_funcs', []): - f(self, exc=sys.exc_info()) + # ca and ckw are the curried function arguments + for f, ca, ckw in getattr(self, '_exit_funcs', []): + f(exc=sys.exc_info(), *ca, **ckw) + raise else: self._exit_event.send(result) - for f in getattr(self, '_exit_funcs', []): - f(self, result) + for f, ca, ckw in getattr(self, '_exit_funcs', []): + f(result, *ca, **ckw) def spawn(func, *args, **kwargs): - """ Create a coroutine to run func(*args, **kwargs) without any - way to retrieve the results. Returns the greenlet object. + """Create a green thread to run func(*args, **kwargs). Returns a GreenThread + object which you can use to get the results of the call. """ - g = ResultGreenlet() - hub = get_hub() - g.parent = hub.greenlet + hub = get_hub_() + g = GreenThread(hub.greenlet) hub.schedule_call_global(0, g.switch, func, args, kwargs) return g + +def _main_wrapper(func, args, kwargs): + # function that gets around the fact that greenlet.switch + # doesn't accept keyword arguments + return func(*args, **kwargs) + +def spawn_n(func, *args, **kwargs): + """Same as spawn, but returns a greenlet object from which it is not possible + to retrieve the results. This is slightly faster than spawn; it is fastest + if there are no keyword arguments.""" + hub = get_hub_() + if kwargs: + g = Greenlet(_main_wrapper, parent=hub.greenlet) + hub.schedule_call_global(0, g.switch, func, args, kwargs) + else: + g = Greenlet(func, parent=hub.greenlet) + hub.schedule_call_global(0, g.switch, *args) + return g + def kill(g, *throw_args): get_hub_().schedule_call_global(0, g.throw, *throw_args) diff --git a/eventlet/parallel.py b/eventlet/parallel.py index 2247e92..180177e 100644 --- a/eventlet/parallel.py +++ b/eventlet/parallel.py @@ -1,4 +1,4 @@ -from eventlet.coros import Semaphore, Queue +from eventlet.coros import Semaphore, Queue, Event from eventlet.api import spawn, getcurrent import sys @@ -11,8 +11,9 @@ class Parallel(object): self.max_size = max_size self.coroutines_running = set() self.sem = Semaphore(max_size) + self.no_coros_running = Event() self._results = Queue() - + def resize(self, new_max_size): """ Change the max number of coroutines doing work at any given time. @@ -36,19 +37,27 @@ class Parallel(object): """ Returns the number of coroutines available for use.""" return self.sem.counter - def _coro_done(self, coro, result, exc=None): - self.sem.release() - self.coroutines_running.remove(coro) - self._results.send(result) - # if done processing (no more work is being done), - # send StopIteration so that the queue knows it's done - if self.sem.balance == self.max_size: - self._results.send_exception(StopIteration) - def spawn(self, func, *args, **kwargs): - """ Create a coroutine to run func(*args, **kwargs). Returns a - Coro object that can be used to retrieve the results of the function. + """Run func(*args, **kwargs) in its own green thread. """ + return self._spawn(False, func, *args, **kwargs) + + def spawn_q(self, func, *args, **kwargs): + """Run func(*args, **kwargs) in its own green thread. + + The results of func are stuck in the results() iterator. + """ + self._spawn(True, func, *args, **kwargs) + + def spawn_n(self, func, *args, **kwargs): + """ Create a coroutine to run func(*args, **kwargs). + + Returns None; the results of the function are not retrievable. + The results of the function are not put into the results() iterator. + """ + self._spawn(False, func, *args, **kwargs) + + def _spawn(self, send_result, func, *args, **kwargs): # if reentering an empty pool, don't try to wait on a coroutine freeing # itself -- instead, just execute in the current coroutine current = getcurrent() @@ -57,11 +66,28 @@ class Parallel(object): else: self.sem.acquire() p = spawn(func, *args, **kwargs) + if not self.coroutines_running: + self.no_coros_running = Event() self.coroutines_running.add(p) - p.link(self._coro_done) - + p.link(self._spawn_done, send_result=send_result, coro=p) return p - + + def waitall(self): + """Waits until all coroutines in the pool are finished working.""" + self.no_coros_running.wait() + + def _spawn_done(self, result=None, exc=None, send_result=False, coro=None): + self.sem.release() + self.coroutines_running.remove(coro) + if send_result: + self._results.send(result) + # if done processing (no more work is waiting for processing), + # send StopIteration so that the queue knows it's done + if self.sem.balance == self.max_size: + if send_result: + self._results.send_exception(StopIteration) + self.no_coros_running.send(None) + def wait(self): """Wait for the next execute in the pool to complete, and return the result.""" @@ -73,10 +99,11 @@ class Parallel(object): def _do_spawn_all(self, func, iterable): for i in iterable: - if not isinstance(i, tuple): - self.spawn(func, i) + # if the list is composed of single arguments, use those + if not isinstance(i, (tuple, list)): + self.spawn_q(func, i) else: - self.spawn(func, *i) + self.spawn_q(func, *i) def spawn_all(self, func, iterable): """ Applies *func* over every item in *iterable* using the concurrency diff --git a/tests/parallel_test.py b/tests/parallel_test.py index 0096288..497bcf3 100644 --- a/tests/parallel_test.py +++ b/tests/parallel_test.py @@ -17,7 +17,7 @@ class Parallel(unittest.TestCase): def test_parallel(self): p = parallel.Parallel(4) for i in xrange(10): - p.spawn(passthru, i) + p.spawn_q(passthru, i) result_list = list(p.results()) self.assertEquals(result_list, range(10)) @@ -25,3 +25,15 @@ class Parallel(unittest.TestCase): p = parallel.Parallel(4) result_list = list(p.spawn_all(passthru, xrange(10))) self.assertEquals(result_list, range(10)) + + def test_spawn_n(self): + p = parallel.Parallel(4) + results_closure = [] + def do_something(a): + api.sleep(0.01) + results_closure.append(a) + for i in xrange(10): + p.spawn(do_something, i) + p.waitall() + self.assertEquals(results_closure, range(10)) +