diff --git a/eventlet/db_pool.py b/eventlet/db_pool.py index 9874c2f..e68513f 100644 --- a/eventlet/db_pool.py +++ b/eventlet/db_pool.py @@ -4,7 +4,9 @@ import time from eventlet.pools import Pool from eventlet import timeout -from eventlet import greenthread +from eventlet import hubs +from eventlet.hubs.timer import Timer +from eventlet.greenthread import GreenThread class ConnectTimeout(Exception): @@ -88,8 +90,9 @@ class BaseConnectionPool(Pool): if next_delay > 0: # set up a continuous self-calling loop - self._expiration_timer = greenthread.spawn_after(next_delay, - self._schedule_expiration) + self._expiration_timer = Timer(next_delay, GreenThread(hubs.get_hub().greenlet).switch, + self._schedule_expiration, [], {}) + self._expiration_timer.schedule() def _expire_old_connections(self, now): """ Iterates through the open connections contained in the pool, closing @@ -103,8 +106,6 @@ class BaseConnectionPool(Pool): conn for last_used, created_at, conn in self.free_items if self._is_expired(now, last_used, created_at)] - for conn in expired: - self._safe_close(conn, quiet=True) new_free = [ (last_used, created_at, conn) @@ -117,6 +118,9 @@ class BaseConnectionPool(Pool): # connections self.current_size -= original_count - len(self.free_items) + for conn in expired: + self._safe_close(conn, quiet=True) + def _is_expired(self, now, last_used, created_at): """ Returns true and closes the connection if it's expired.""" if ( self.max_idle <= 0 @@ -228,7 +232,9 @@ class BaseConnectionPool(Pool): if self._expiration_timer: self._expiration_timer.cancel() free_items, self.free_items = self.free_items, deque() - for _last_used, _created_at, conn in free_items: + for item in free_items: + # Free items created using min_size>0 are not tuples. + conn = item[2] if isinstance(item, tuple) else item self._safe_close(conn, quiet=True) def __del__(self): @@ -296,6 +302,7 @@ class GenericConnectionWrapper(object): def errno(self,*args, **kwargs): return self._base.errno(*args, **kwargs) def error(self,*args, **kwargs): return self._base.error(*args, **kwargs) def errorhandler(self, *args, **kwargs): return self._base.errorhandler(*args, **kwargs) + def insert_id(self, *args, **kwargs): return self._base.insert_id(*args, **kwargs) def literal(self, *args, **kwargs): return self._base.literal(*args, **kwargs) def set_character_set(self, *args, **kwargs): return self._base.set_character_set(*args, **kwargs) def set_sql_mode(self, *args, **kwargs): return self._base.set_sql_mode(*args, **kwargs) diff --git a/eventlet/green/zmq.py b/eventlet/green/zmq.py index 187e0c4..5f379e7 100644 --- a/eventlet/green/zmq.py +++ b/eventlet/green/zmq.py @@ -1,36 +1,103 @@ """The :mod:`zmq` module wraps the :class:`Socket` and :class:`Context` found in :mod:`pyzmq ` to be non blocking """ __zmq__ = __import__('zmq') -from eventlet import sleep -from eventlet.hubs import trampoline, _threadlocal +from eventlet import hubs from eventlet.patcher import slurp_properties +from eventlet.support import greenlets as greenlet __patched__ = ['Context', 'Socket'] slurp_properties(__zmq__, globals(), ignore=__patched__) +from collections import deque -def Context(io_threads=1): - """Factory function replacement for :class:`zmq.core.context.Context` +class _QueueLock(object): + """A Lock that can be acquired by at most one thread. Any other + thread calling acquire will be blocked in a queue. When release + is called, the threads are awoken in the order they blocked, + one at a time. This lock can be required recursively by the same + thread.""" + def __init__(self): + self._waiters = deque() + self._count = 0 + self._holder = None + self._hub = hubs.get_hub() - This factory ensures the :class:`zeromq hub ` - is the active hub, and defers creation (or retreival) of the ``Context`` - to the hub's :meth:`~eventlet.hubs.zeromq.Hub.get_context` method - - It's a factory function due to the fact that there can only be one :class:`_Context` - instance per thread. This is due to the way :class:`zmq.core.poll.Poller` - works - """ - try: - return _threadlocal.context - except AttributeError: - _threadlocal.context = _Context(io_threads) - return _threadlocal.context + def __nonzero__(self): + return self._count -class _Context(__zmq__.Context): - """Internal subclass of :class:`zmq.core.context.Context` + def __enter__(self): + self.acquire() + + def __exit__(self, type, value, traceback): + self.release() - .. warning:: Do not grab one of these yourself, use the factory function - :func:`eventlet.green.zmq.Context` + def acquire(self): + current = greenlet.getcurrent() + if (self._waiters or self._count > 0) and self._holder is not current: + # block until lock is free + self._waiters.append(current) + self._hub.switch() + w = self._waiters.popleft() + + assert w is current, 'Waiting threads woken out of order' + assert self._count == 0, 'After waking a thread, the lock must be unacquired' + + self._holder = current + self._count += 1 + + def release(self): + if self._count <= 0: + raise Exception("Cannot release unacquired lock") + + self._count -= 1 + if self._count == 0: + self._holder = None + if self._waiters: + # wake next + self._hub.schedule_call_global(0, self._waiters[0].switch) + +class _BlockedThread(object): + """Is either empty, or represents a single blocked thread that + blocked itself by calling the block() method. The thread can be + awoken by calling wake(). Wake() can be called multiple times and + all but the first call will have no effect.""" + + def __init__(self): + self._blocked_thread = None + self._wakeupper = None + self._hub = hubs.get_hub() + + def __nonzero__(self): + return self._blocked_thread is not None + + def block(self): + if self._blocked_thread is not None: + raise Exception("Cannot block more than one thread on one BlockedThread") + self._blocked_thread = greenlet.getcurrent() + + try: + self._hub.switch() + finally: + self._blocked_thread = None + # cleanup the wakeup task + if self._wakeupper is not None: + # Important to cancel the wakeup task so it doesn't + # spuriously wake this greenthread later on. + self._wakeupper.cancel() + self._wakeupper = None + + def wake(self): + """Schedules the blocked thread to be awoken and return + True. If wake has already been called or if there is no + blocked thread, then this call has no effect and returns + False.""" + if self._blocked_thread is not None and self._wakeupper is None: + self._wakeupper = self._hub.schedule_call_global(0, self._blocked_thread.switch) + return True + return False + +class Context(__zmq__.Context): + """Subclass of :class:`zmq.core.context.Context` """ def socket(self, socket_type): @@ -40,82 +107,216 @@ class _Context(__zmq__.Context): that a :class:`Socket` with all of its send and recv methods set to be non-blocking is returned """ + if self.closed: + raise ZMQError(ENOTSUP) return Socket(self, socket_type) -class Socket(__zmq__.Socket): +def _wraps(source_fn): + """A decorator that copies the __name__ and __doc__ from the given + function + """ + def wrapper(dest_fn): + dest_fn.__name__ = source_fn.__name__ + dest_fn.__doc__ = source_fn.__doc__ + return dest_fn + return wrapper + +# Implementation notes: Each socket in 0mq contains a pipe that the +# background IO threads use to communicate with the socket. These +# events are important because they tell the socket when it is able to +# send and when it has messages waiting to be received. The read end +# of the events pipe is the same FD that getsockopt(zmq.FD) returns. +# +# Events are read from the socket's event pipe only on the thread that +# the 0mq context is associated with, which is the native thread the +# greenthreads are running on, and the only operations that cause the +# events to be read and processed are send(), recv() and +# getsockopt(zmq.EVENTS). This means that after doing any of these +# three operations, the ability of the socket to send or receive a +# message without blocking may have changed, but after the events are +# read the FD is no longer readable so the hub may not signal our +# listener. +# +# If we understand that after calling send() a message might be ready +# to be received and that after calling recv() a message might be able +# to be sent, what should we do next? There are two approaches: +# +# 1. Always wake the other thread if there is one waiting. This +# wakeup may be spurious because the socket might not actually be +# ready for a send() or recv(). However, if a thread is in a +# tight-loop successfully calling send() or recv() then the wakeups +# are naturally batched and there's very little cost added to each +# send/recv call. +# +# or +# +# 2. Call getsockopt(zmq.EVENTS) and explicitly check if the other +# thread should be woken up. This avoids spurious wake-ups but may +# add overhead because getsockopt will cause all events to be +# processed, whereas send and recv throttle processing +# events. Admittedly, all of the events will need to be processed +# eventually, but it is likely faster to batch the processing. +# +# Which approach is better? I have no idea. +# +# TODO: +# - Support MessageTrackers and make MessageTracker.wait green + +_Socket = __zmq__.Socket +_Socket_recv = _Socket.recv +_Socket_send = _Socket.send +_Socket_send_multipart = _Socket.send_multipart +_Socket_recv_multipart = _Socket.recv_multipart +_Socket_getsockopt = _Socket.getsockopt + +class Socket(_Socket): """Green version of :class:`zmq.core.socket.Socket - The following four methods are overridden: - - * _send_message - * _send_copy - * _recv_message - * _recv_copy - + The following three methods are always overridden: + * send + * recv + * getsockopt To ensure that the ``zmq.NOBLOCK`` flag is set and that sending or recieving is deferred to the hub (using :func:`eventlet.hubs.trampoline`) if a ``zmq.EAGAIN`` (retry) error is raised + + For some socket types, the following methods are also overridden: + * send_multipart + * recv_multipart """ + def __init__(self, context, socket_type): + super(Socket, self).__init__(context, socket_type) - def _sock_wait(self, read=False, write=False): - """ - First checks if there are events in the socket, to avoid - edge trigger problems with race conditions. Then if there - are none it will trampoline and when coming back check - for the events. - """ - events = self.getsockopt(__zmq__.EVENTS) + self._eventlet_send_event = _BlockedThread() + self._eventlet_recv_event = _BlockedThread() + self._eventlet_send_lock = _QueueLock() + self._eventlet_recv_lock = _QueueLock() - if read and (events & __zmq__.POLLIN): - return events - elif write and (events & __zmq__.POLLOUT): - return events - else: - # ONLY trampoline on read events for the zmq FD - trampoline(self.getsockopt(__zmq__.FD), read=True) - return self.getsockopt(__zmq__.EVENTS) + def event(fd): + # Some events arrived at the zmq socket. This may mean + # there's a message that can be read or there's space for + # a message to be written. + self._eventlet_send_event.wake() + self._eventlet_recv_event.wake() + hub = hubs.get_hub() + self._eventlet_listener = hub.add(hub.READ, self.getsockopt(FD), event) + + @_wraps(_Socket.close) + def close(self): + _Socket.close(self) + if self._eventlet_listener is not None: + hubs.get_hub().remove(self._eventlet_listener) + self._eventlet_listener = None + # wake any blocked threads + self._eventlet_send_event.wake() + self._eventlet_recv_event.wake() + + @_wraps(_Socket.getsockopt) + def getsockopt(self, option): + result = _Socket_getsockopt(self, option) + if option == EVENTS: + # Getting the events causes the zmq socket to process + # events which may mean a msg can be sent or received. If + # there is a greenthread blocked and waiting for events, + # it will miss the edge-triggered read event, so wake it + # up. + if (result & POLLOUT): + self._send_evt.wake() + if (result & POLLIN): + self._recv_evt.wake() + return result + + @_wraps(_Socket.send) def send(self, msg, flags=0, copy=True, track=False): + """A send method that's safe to use when multiple greenthreads + are calling send, send_multipart, recv and recv_multipart on + the same socket. """ - Override this instead of the internal _send_* methods - since those change and it's not clear when/how they're - called in real code. + if flags & NOBLOCK: + result = _Socket_send(self, msg, flags, copy, track) + # Instead of calling both wake methods, could call + # self.getsockopt(EVENTS) which would trigger wakeups if + # needed. + self._eventlet_send_event.wake() + self._eventlet_recv_event.wake() + return result + + # TODO: pyzmq will copy the message buffer and create Message + # objects under some circumstances. We could do that work here + # once to avoid doing it every time the send is retried. + flags |= NOBLOCK + with self._eventlet_send_lock: + while True: + try: + return _Socket_send(self, msg, flags, copy, track) + except ZMQError, e: + if e.errno == EAGAIN: + self._eventlet_send_event.block() + else: + raise + finally: + # The call to send processes 0mq events and may + # make the socket ready to recv. Wake the next + # receiver. (Could check EVENTS for POLLIN here) + self._eventlet_recv_event.wake() + + + @_wraps(_Socket.send_multipart) + def send_multipart(self, msg_parts, flags=0, copy=True, track=False): + """A send_multipart method that's safe to use when multiple + greenthreads are calling send, send_multipart, recv and + recv_multipart on the same socket. """ - if flags & __zmq__.NOBLOCK: - super(Socket, self).send(msg, flags=flags, track=track, copy=copy) - return + if flags & NOBLOCK: + return _Socket_send_multipart(self, msg_parts, flags, copy, track) - flags |= __zmq__.NOBLOCK - - while True: - try: - self._sock_wait(write=True) - super(Socket, self).send(msg, flags=flags, track=track, - copy=copy) - return - except __zmq__.ZMQError, e: - if e.errno != EAGAIN: - raise + # acquire lock here so the subsequent calls to send for the + # message parts after the first don't block + with self._eventlet_send_lock: + return _Socket_send_multipart(self, msg_parts, flags, copy, track) + @_wraps(_Socket.recv) def recv(self, flags=0, copy=True, track=False): + """A recv method that's safe to use when multiple greenthreads + are calling send, send_multipart, recv and recv_multipart on + the same socket. """ - Override this instead of the internal _recv_* methods - since those change and it's not clear when/how they're - called in real code. + if flags & NOBLOCK: + msg = _Socket_recv(self, flags, copy, track) + # Instead of calling both wake methods, could call + # self.getsockopt(EVENTS) which would trigger wakeups if + # needed. + self._eventlet_send_event.wake() + self._eventlet_recv_event.wake() + return msg + + flags |= NOBLOCK + with self._eventlet_recv_lock: + while True: + try: + return _Socket_recv(self, flags, copy, track) + except ZMQError, e: + if e.errno == EAGAIN: + self._eventlet_recv_event.block() + else: + raise + finally: + # The call to recv processes 0mq events and may + # make the socket ready to send. Wake the next + # receiver. (Could check EVENTS for POLLOUT here) + self._eventlet_send_event.wake() + + @_wraps(_Socket.recv_multipart) + def recv_multipart(self, flags=0, copy=True, track=False): + """A recv_multipart method that's safe to use when multiple + greenthreads are calling send, send_multipart, recv and + recv_multipart on the same socket. """ - if flags & __zmq__.NOBLOCK: - return super(Socket, self).recv(flags=flags, track=track, copy=copy) - - flags |= __zmq__.NOBLOCK - - while True: - try: - self._sock_wait(read=True) - m = super(Socket, self).recv(flags=flags, track=track, copy=copy) - if m is not None: - return m - except __zmq__.ZMQError, e: - if e.errno != EAGAIN: - raise - + if flags & NOBLOCK: + return _Socket_recv_multipart(self, flags, copy, track) + # acquire lock here so the subsequent calls to recv for the + # message parts after the first don't block + with self._eventlet_recv_lock: + return _Socket_recv_multipart(self, flags, copy, track) diff --git a/eventlet/wsgi.py b/eventlet/wsgi.py index 6003664..3e04cb9 100644 --- a/eventlet/wsgi.py +++ b/eventlet/wsgi.py @@ -429,14 +429,16 @@ class HttpProtocol(BaseHTTPServer.BaseHTTPRequestHandler): for hook, args, kwargs in self.environ['eventlet.posthooks']: hook(self.environ, *args, **kwargs) - - self.server.log_message(self.server.log_format % dict( - client_ip=self.get_client_ip(), - date_time=self.log_date_time_string(), - request_line=self.requestline, - status_code=status_code[0], - body_length=length[0], - wall_seconds=finish - start)) + + if self.server.log_output: + + self.server.log_message(self.server.log_format % dict( + client_ip=self.get_client_ip(), + date_time=self.log_date_time_string(), + request_line=self.requestline, + status_code=status_code[0], + body_length=length[0], + wall_seconds=finish - start)) def get_client_ip(self): client_ip = self.client_address[0] @@ -517,6 +519,7 @@ class Server(BaseHTTPServer.HTTPServer): minimum_chunk_size=None, log_x_forwarded_for=True, keepalive=True, + log_output=True, log_format=DEFAULT_LOG_FORMAT, debug=True): @@ -536,6 +539,7 @@ class Server(BaseHTTPServer.HTTPServer): if minimum_chunk_size is not None: protocol.minimum_chunk_size = minimum_chunk_size self.log_x_forwarded_for = log_x_forwarded_for + self.log_output = log_output self.log_format = log_format self.debug = debug @@ -582,7 +586,8 @@ def server(sock, site, minimum_chunk_size=None, log_x_forwarded_for=True, custom_pool=None, - keepalive=True, + keepalive=True, + log_output=True, log_format=DEFAULT_LOG_FORMAT, debug=True): """ Start up a wsgi server handling requests from the supplied server @@ -602,6 +607,7 @@ def server(sock, site, :param log_x_forwarded_for: If True (the default), logs the contents of the x-forwarded-for header in addition to the actual client ip address in the 'client_ip' field of the log line. :param custom_pool: A custom GreenPool instance which is used to spawn client green threads. If this is supplied, max_size is ignored. :param keepalive: If set to False, disables keepalives on the server; all connections will be closed after serving one request. + :param log_output: A Boolean indicating if the server will log data or not. :param log_format: A python format string that is used as the template to generate log lines. The following values can be formatted into it: client_ip, date_time, request_line, status_code, body_length, wall_seconds. The default is a good example of how to use it. :param debug: True if the server should send exception tracebacks to the clients on 500 errors. If False, the server will respond with empty bodies. """ @@ -613,6 +619,7 @@ def server(sock, site, minimum_chunk_size=minimum_chunk_size, log_x_forwarded_for=log_x_forwarded_for, keepalive=keepalive, + log_output=log_output, log_format=log_format, debug=debug) if server_event is not None: diff --git a/tests/db_pool_test.py b/tests/db_pool_test.py index ed6f09b..2a316b0 100644 --- a/tests/db_pool_test.py +++ b/tests/db_pool_test.py @@ -247,6 +247,12 @@ class DBConnectionPool(DBTester): self.pool.clear() self.assertEqual(len(self.pool.free_items), 0) + def test_clear_warmup(self): + """Clear implicitly created connections (min_size > 0)""" + self.pool = self.create_pool(min_size=1) + self.pool.clear() + self.assertEqual(len(self.pool.free_items), 0) + def test_unwrap_connection(self): self.assert_(isinstance(self.connection, db_pool.GenericConnectionWrapper)) @@ -438,12 +444,12 @@ class RaisingDBModule(object): class TpoolConnectionPool(DBConnectionPool): __test__ = False # so that nose doesn't try to execute this directly - def create_pool(self, max_size=1, max_idle=10, max_age=10, + def create_pool(self, min_size=0, max_size=1, max_idle=10, max_age=10, connect_timeout=0.5, module=None): if module is None: module = self._dbmodule return db_pool.TpooledConnectionPool(module, - min_size=0, max_size=max_size, + min_size=min_size, max_size=max_size, max_idle=max_idle, max_age=max_age, connect_timeout = connect_timeout, **self._auth) @@ -462,12 +468,12 @@ class TpoolConnectionPool(DBConnectionPool): class RawConnectionPool(DBConnectionPool): __test__ = False # so that nose doesn't try to execute this directly - def create_pool(self, max_size=1, max_idle=10, max_age=10, + def create_pool(self, min_size=0, max_size=1, max_idle=10, max_age=10, connect_timeout=0.5, module=None): if module is None: module = self._dbmodule return db_pool.RawConnectionPool(module, - min_size=0, max_size=max_size, + min_size=min_size, max_size=max_size, max_idle=max_idle, max_age=max_age, connect_timeout=connect_timeout, **self._auth) diff --git a/tests/zmq_test.py b/tests/zmq_test.py index 26d0bda..107a339 100644 --- a/tests/zmq_test.py +++ b/tests/zmq_test.py @@ -1,4 +1,4 @@ -from eventlet import event, spawn, sleep, patcher +from eventlet import event, spawn, sleep, patcher, semaphore from eventlet.hubs import get_hub, _threadlocal, use_hub from nose.tools import * from tests import mock, LimitedTestCase, using_pyevent, skip_unless @@ -19,7 +19,9 @@ def zmq_supported(_): class TestUpstreamDownStream(LimitedTestCase): - sockets = [] + def setUp(self): + super(TestUpstreamDownStream, self).setUp() + self.sockets = [] def tearDown(self): self.clear_up_sockets() @@ -32,12 +34,14 @@ class TestUpstreamDownStream(LimitedTestCase): port = s1.bind_to_random_port(interface) s2 = context.socket(type2) s2.connect('%s:%s' % (interface, port)) - self.sockets = [s1, s2] + self.sockets.append(s1) + self.sockets.append(s2) return s1, s2, port def clear_up_sockets(self): for sock in self.sockets: sock.close() + self.sockets = None def assertRaisesErrno(self, errno, func, *args): try: @@ -75,6 +79,15 @@ got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) self.assertRaisesErrno(zmq.ENOTSUP, rep.recv) self.assertRaisesErrno(zmq.ENOTSUP, req.send, 'test') + @skip_unless(zmq_supported) + def test_close_xsocket_raises_enotsup(self): + req, rep, port = self.create_bound_pair(zmq.XREQ, zmq.XREP) + + rep.close() + req.close() + self.assertRaisesErrno(zmq.ENOTSUP, rep.recv) + self.assertRaisesErrno(zmq.ENOTSUP, req.send, 'test') + @skip_unless(zmq_supported) def test_send_1k_req_rep(self): req, rep, port = self.create_bound_pair(zmq.REQ, zmq.REP) @@ -87,14 +100,13 @@ got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) while req.recv() != 'done': tx_i += 1 req.send(str(tx_i)) + done.send(0) def rx(): while True: rx_i = rep.recv() if rx_i == "1000": rep.send('done') - sleep() - done.send(0) break rep.send('i') spawn(tx) @@ -238,46 +250,218 @@ got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) self.assertRaisesErrno(zmq.EAGAIN, rep.recv, zmq.NOBLOCK) self.assertRaisesErrno(zmq.EAGAIN, rep.recv, zmq.NOBLOCK, True) + @skip_unless(zmq_supported) + def test_send_during_recv(self): + sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) + sleep() + + num_recvs = 30 + done_evts = [event.Event() for _ in range(num_recvs)] + + def slow_rx(done, msg): + self.assertEqual(sender.recv(), msg) + done.send(0) + + def tx(): + tx_i = 0 + while tx_i <= 1000: + sender.send(str(tx_i)) + tx_i += 1 + + def rx(): + while True: + rx_i = receiver.recv() + if rx_i == "1000": + for i in range(num_recvs): + receiver.send('done%d' % i) + sleep() + return + + for i in range(num_recvs): + spawn(slow_rx, done_evts[i], "done%d" % i) + + spawn(tx) + spawn(rx) + for evt in done_evts: + self.assertEqual(evt.wait(), 0) -class TestThreadedContextAccess(TestCase): - """zmq's Context must be unique within a hub + @skip_unless(zmq_supported) + def test_send_during_recv_multipart(self): + sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) + sleep() - The zeromq API documentation states: - All zmq sockets passed to the zmq_poll() function must share the same zmq - context and must belong to the thread calling zmq_poll() + num_recvs = 30 + done_evts = [event.Event() for _ in range(num_recvs)] - As zmq_poll is what's eventually being called then we need to ensure that - all sockets that are going to be passed to zmq_poll (via hub.do_poll) are - in the same context - """ - if zmq: # don't call decorators if zmq module unavailable - @skip_unless(zmq_supported) - def test_context_factory_function(self): - ctx = zmq.Context() - self.assertTrue(ctx is not None) + def slow_rx(done, msg): + self.assertEqual(sender.recv_multipart(), msg) + done.send(0) - @skip_unless(zmq_supported) - def test_threadlocal_context(self): - context = zmq.Context() - self.assertEqual(context, _threadlocal.context) - next_context = zmq.Context() - self.assertTrue(context is next_context) + def tx(): + tx_i = 0 + while tx_i <= 1000: + sender.send_multipart([str(tx_i), '1', '2', '3']) + tx_i += 1 - @skip_unless(zmq_supported) - def test_different_context_in_different_thread(self): - context = zmq.Context() - test_result = [] - def assert_different(ctx): - try: - this_thread_context = zmq.Context() - except: - test_result.append('fail') - raise - test_result.append(ctx is this_thread_context) + def rx(): + while True: + rx_i = receiver.recv_multipart() + if rx_i == ["1000", '1', '2', '3']: + for i in range(num_recvs): + receiver.send_multipart(['done%d' % i, 'a', 'b', 'c']) + sleep() + return - Thread(target=assert_different, args=(context,)).start() - while not test_result: - sleep(0.1) - self.assertFalse(test_result[0]) + for i in range(num_recvs): + spawn(slow_rx, done_evts[i], ["done%d" % i, 'a', 'b', 'c']) + spawn(tx) + spawn(rx) + for i in range(num_recvs): + final_i = done_evts[i].wait() + self.assertEqual(final_i, 0) + + + # Need someway to ensure a thread is blocked on send... This isn't working + @skip_unless(zmq_supported) + def test_recv_during_send(self): + sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) + sleep() + + num_recvs = 30 + done = event.Event() + + sender.setsockopt(zmq.HWM, 10) + sender.setsockopt(zmq.SNDBUF, 10) + + receiver.setsockopt(zmq.RCVBUF, 10) + + def tx(): + tx_i = 0 + while tx_i <= 1000: + sender.send(str(tx_i)) + tx_i += 1 + done.send(0) + + spawn(tx) + final_i = done.wait() + self.assertEqual(final_i, 0) + + @skip_unless(zmq_supported) + def test_close_during_recv(self): + sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) + sleep() + done1 = event.Event() + done2 = event.Event() + + def rx(e): + self.assertRaisesErrno(zmq.ENOTSUP, receiver.recv) + e.send() + + spawn(rx, done1) + spawn(rx, done2) + + sleep() + receiver.close() + + done1.wait() + done2.wait() + +class TestQueueLock(LimitedTestCase): + @skip_unless(zmq_supported) + def test_queue_lock_order(self): + q = zmq._QueueLock() + s = semaphore.Semaphore(0) + results = [] + + def lock(x): + with q: + results.append(x) + s.release() + + q.acquire() + + spawn(lock, 1) + sleep() + spawn(lock, 2) + sleep() + spawn(lock, 3) + sleep() + + self.assertEquals(results, []) + q.release() + s.acquire() + s.acquire() + s.acquire() + self.assertEquals(results, [1,2,3]) + + @skip_unless(zmq_supported) + def test_count(self): + q = zmq._QueueLock() + self.assertFalse(q) + q.acquire() + self.assertTrue(q) + q.release() + self.assertFalse(q) + + with q: + self.assertTrue(q) + self.assertFalse(q) + + @skip_unless(zmq_supported) + def test_errors(self): + q = zmq._QueueLock() + + with self.assertRaises(Exception): + q.release() + + q.acquire() + q.release() + + with self.assertRaises(Exception): + q.release() + + @skip_unless(zmq_supported) + def test_nested_acquire(self): + q = zmq._QueueLock() + self.assertFalse(q) + q.acquire() + q.acquire() + + s = semaphore.Semaphore(0) + results = [] + def lock(x): + with q: + results.append(x) + s.release() + + spawn(lock, 1) + sleep() + self.assertEquals(results, []) + q.release() + sleep() + self.assertEquals(results, []) + self.assertTrue(q) + q.release() + + s.acquire() + self.assertEquals(results, [1]) + +class TestBlockedThread(LimitedTestCase): + @skip_unless(zmq_supported) + def test_block(self): + e = zmq._BlockedThread() + done = event.Event() + self.assertFalse(e) + + def block(): + e.block() + done.send(1) + + spawn(block) + sleep() + + self.assertFalse(done.has_result()) + e.wake() + done.wait()