diff --git a/eventlet/green/zmq.py b/eventlet/green/zmq.py index 68b8625..726d2af 100644 --- a/eventlet/green/zmq.py +++ b/eventlet/green/zmq.py @@ -2,7 +2,7 @@ """ __zmq__ = __import__('zmq') from eventlet import sleep -from eventlet.hubs import trampoline, get_hub +from eventlet.hubs import trampoline, _threadlocal __patched__ = ['Context', 'Socket'] globals().update(dict([(var, getattr(__zmq__, var)) @@ -13,13 +13,6 @@ globals().update(dict([(var, getattr(__zmq__, var)) ])) -def get_hub_name_from_instance(hub): - """Get the string name the eventlet uses to refer to hub - - :param hub: An eventlet hub - """ - return hub.__class__.__module__.rsplit('.',1)[-1] - def Context(io_threads=1): """Factory function replacement for :class:`zmq.core.context.Context` @@ -31,11 +24,11 @@ def Context(io_threads=1): instance per thread. This is due to the way :class:`zmq.core.poll.Poller` works """ - hub = get_hub() - hub_name = get_hub_name_from_instance(hub) - if hub_name != 'zeromq': - raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name) - return hub.get_context(io_threads) + try: + return _threadlocal.context + except AttributeError: + _threadlocal.context = _Context(io_threads) + return _threadlocal.context class _Context(__zmq__.Context): """Internal subclass of :class:`zmq.core.context.Context` @@ -68,62 +61,65 @@ class Socket(__zmq__.Socket): ``zmq.EAGAIN`` (retry) error is raised """ + def _sock_wait(self, read=False, write=False): + """ + First checks if there are events in the socket, to avoid + edge trigger problems with race conditions. Then if there + are none it will trampoline and when coming back check + for the events. + """ + events = self.getsockopt(__zmq__.EVENTS) - def _send_message(self, msg, flags=0): + if read and (events & __zmq__.POLLIN): + return events + elif write and (events & __zmq__.POLLOUT): + return events + else: + # ONLY trampoline on read events for the zmq FD + trampoline(self.getsockopt(__zmq__.FD), read=True) + return self.getsockopt(__zmq__.EVENTS) + + def send(self, msg, flags=0, copy=True, track=False): + """ + Override this instead of the internal _send_* methods + since those change and it's not clear when/how they're + called in real code. + """ if flags & __zmq__.NOBLOCK: - super(Socket, self)._send_message(msg, flags) + super(Socket, self).send(msg, flags=flags, track=track, copy=copy) return + flags |= __zmq__.NOBLOCK + while True: try: - super(Socket, self)._send_message(msg, flags) + self._sock_wait(write=True) + super(Socket, self).send(msg, flags=flags, track=track, + copy=copy) return except __zmq__.ZMQError, e: if e.errno != EAGAIN: raise - trampoline(self, write=True) - def _send_copy(self, msg, flags=0): + def recv(self, flags=0, copy=True, track=False): + """ + Override this instead of the internal _recv_* methods + since those change and it's not clear when/how they're + called in real code. + """ if flags & __zmq__.NOBLOCK: - super(Socket, self)._send_copy(msg, flags) - return + return super(Socket, self).recv(flags=flags, track=track, copy=copy) + flags |= __zmq__.NOBLOCK + while True: try: - super(Socket, self)._send_copy(msg, flags) - return - except __zmq__.ZMQError, e: - if e.errno != EAGAIN: - raise - trampoline(self, write=True) - - def _recv_message(self, flags=0, track=False): - if flags & __zmq__.NOBLOCK: - return super(Socket, self)._recv_message(flags, track) - flags |= __zmq__.NOBLOCK - while True: - try: - m = super(Socket, self)._recv_message(flags, track) + self._sock_wait(read=True) + m = super(Socket, self).recv(flags=flags, track=track, copy=copy) if m is not None: return m except __zmq__.ZMQError, e: if e.errno != EAGAIN: raise - trampoline(self, read=True) - - def _recv_copy(self, flags=0): - if flags & __zmq__.NOBLOCK: - return super(Socket, self)._recv_copy(flags) - flags |= __zmq__.NOBLOCK - while True: - try: - m = super(Socket, self)._recv_copy(flags) - if m is not None: - return m - except __zmq__.ZMQError, e: - if e.errno != EAGAIN: - raise - trampoline(self, read=True) - diff --git a/eventlet/hubs/zeromq.py b/eventlet/hubs/zeromq.py deleted file mode 100644 index 686c974..0000000 --- a/eventlet/hubs/zeromq.py +++ /dev/null @@ -1,110 +0,0 @@ -from eventlet import patcher -from eventlet.green import zmq -from eventlet.hubs import _threadlocal -from eventlet.hubs.hub import BaseHub, READ, WRITE, noop -from eventlet.support import clear_sys_exc_info -import sys - -time = patcher.original('time') -select = patcher.original('select') -sleep = time.sleep - -EXC_MASK = zmq.POLLERR -READ_MASK = zmq.POLLIN -WRITE_MASK = zmq.POLLOUT - -class Hub(BaseHub): - def __init__(self, clock=time.time): - BaseHub.__init__(self, clock) - self.poll = zmq.Poller() - - def get_context(self, io_threads=1): - """zmq's Context must be unique within a hub - - The zeromq API documentation states: - All zmq sockets passed to the zmq_poll() function must share the same - zmq context and must belong to the thread calling zmq_poll() - - As zmq_poll is what's eventually being called then we need to insure - that all sockets that are going to be passed to zmq_poll (via - hub.do_poll) are in the same context - """ - try: - return _threadlocal.context - except AttributeError: - _threadlocal.context = zmq._Context(io_threads) - return _threadlocal.context - - def add(self, evtype, fileno, cb): - listener = super(Hub, self).add(evtype, fileno, cb) - self.register(fileno, new=True) - return listener - - def remove(self, listener): - super(Hub, self).remove(listener) - self.register(listener.fileno) - - def register(self, fileno, new=False): - mask = 0 - if self.listeners[READ].get(fileno): - mask |= READ_MASK - if self.listeners[WRITE].get(fileno): - mask |= WRITE_MASK - if mask: - self.poll.register(fileno, mask) - else: - self.poll.unregister(fileno) - - def remove_descriptor(self, fileno): - super(Hub, self).remove_descriptor(fileno) - try: - self.poll.unregister(fileno) - except (KeyError, ValueError, IOError, OSError): - # raised if we try to remove a fileno that was - # already removed/invalid - pass - - def do_poll(self, seconds): - # zmq.Poller.poll expects milliseconds - return self.poll.poll(seconds * 1000.0) - - def wait(self, seconds=None): - readers = self.listeners[READ] - writers = self.listeners[WRITE] - - if not readers and not writers: - if seconds: - sleep(seconds) - return - try: - presult = self.do_poll(seconds) - except zmq.ZMQError, e: - # In the poll hub this part exists to special case some exceptions - # from socket. There may be some error numbers that wider use of - # this hub will throw up as needing special treatment so leaving - # this block and this comment as a remineder - raise - SYSTEM_EXCEPTIONS = self.SYSTEM_EXCEPTIONS - - if self.debug_blocking: - self.block_detect_pre() - - for fileno, event in presult: - try: - if event & READ_MASK: - readers.get(fileno, noop).cb(fileno) - if event & WRITE_MASK: - writers.get(fileno, noop).cb(fileno) - if event & EXC_MASK: - # zmq.POLLERR is returned for any error condition in the - # underlying fd (as passed through to poll/epoll) - readers.get(fileno, noop).cb(fileno) - writers.get(fileno, noop).cb(fileno) - except SYSTEM_EXCEPTIONS: - raise - except: - self.squelch_exception(fileno, sys.exc_info()) - clear_sys_exc_info() - - if self.debug_blocking: - self.block_detect_post() diff --git a/examples/zmq_simple.py b/examples/zmq_simple.py new file mode 100644 index 0000000..ebc3e71 --- /dev/null +++ b/examples/zmq_simple.py @@ -0,0 +1,31 @@ +from eventlet.green import zmq +import eventlet + +CTX = zmq.Context(1) + +def bob_client(ctx, count): + print "STARTING BOB" + bob = zmq.Socket(CTX, zmq.REQ) + bob.connect("ipc:///tmp/test") + + for i in range(0, count): + print "BOB SENDING" + bob.send("HI") + print "BOB GOT:", bob.recv() + +def alice_server(ctx, count): + print "STARTING ALICE" + alice = zmq.Socket(CTX, zmq.REP) + alice.bind("ipc:///tmp/test") + + print "ALICE READY" + for i in range(0, count): + print "ALICE GOT:", alice.recv() + print "ALIC SENDING" + alice.send("HI BACK") + +alice = eventlet.spawn(alice_server, CTX, 10) +bob = eventlet.spawn(bob_client, CTX, 10) + +bob.wait() +alice.wait() diff --git a/tests/__init__.py b/tests/__init__.py index 789d132..26b79fa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -94,8 +94,8 @@ def using_zmq(_f): import zmq except ImportError: return False - from eventlet.hubs import get_hub - return zmq and 'zeromq' in type(get_hub()).__module__ + + return True def skip_unless_zmq(func): """ Decorator that skips a test if we're not using the zeromq hub.""" diff --git a/tests/zmq_test.py b/tests/zmq_test.py index 4ca53c6..0eadcb1 100644 --- a/tests/zmq_test.py +++ b/tests/zmq_test.py @@ -5,12 +5,7 @@ from tests import mock, LimitedTestCase, skip_unless_zmq from unittest import TestCase from threading import Thread -try: - from eventlet.green import zmq - from eventlet.hubs.zeromq import Hub -except ImportError: - zmq = None - Hub = None +from eventlet.green import zmq class TestUpstreamDownStream(LimitedTestCase): @@ -47,8 +42,8 @@ got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) @skip_unless_zmq def test_recv_spawned_before_send_is_non_blocking(self): req, rep, port = self.create_bound_pair(zmq.PAIR, zmq.PAIR) -# req.connect(ipc) -# rep.bind(ipc) +# req.connect(ipc) +# rep.bind(ipc) sleep() msg = dict(res=None) done = event.Event() @@ -249,19 +244,15 @@ class TestThreadedContextAccess(TestCase): """ if zmq: # don't call decorators if zmq module unavailable @skip_unless_zmq - @mock.patch('eventlet.green.zmq.get_hub_name_from_instance') - @mock.patch('eventlet.green.zmq.get_hub', spec=Hub) - def test_context_factory_funtion(self, get_hub_mock, hub_name_mock): - hub_name_mock.return_value = 'zeromq' + def test_context_factory_function(self): ctx = zmq.Context() - self.assertTrue(get_hub_mock().get_context.called) + self.assertTrue(ctx is not None) @skip_unless_zmq def test_threadlocal_context(self): - hub = get_hub() context = zmq.Context() self.assertEqual(context, _threadlocal.context) - next_context = hub.get_context() + next_context = zmq.Context() self.assertTrue(context is next_context) @skip_unless_zmq @@ -269,33 +260,15 @@ class TestThreadedContextAccess(TestCase): context = zmq.Context() test_result = [] def assert_different(ctx): - hub = get_hub() try: this_thread_context = zmq.Context() except: test_result.append('fail') raise test_result.append(ctx is this_thread_context) + Thread(target=assert_different, args=(context,)).start() while not test_result: sleep(0.1) self.assertFalse(test_result[0]) - -class TestCheckingForZMQHub(TestCase): - - @skip_unless_zmq - def setUp(self): - self.orig_hub = zmq.get_hub_name_from_instance(get_hub()) - use_hub('selects') - - def tearDown(self): - use_hub(self.orig_hub) - - def test_assertionerror_raise_by_context(self): - self.assertRaises(RuntimeError, zmq.Context) - - - - -