Initial implementation of hubless zeromq support using ZMQ_FD and ZMQ_EVENTS.
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"""
|
||||
__zmq__ = __import__('zmq')
|
||||
from eventlet import sleep
|
||||
from eventlet.hubs import trampoline, get_hub
|
||||
from eventlet.hubs import trampoline, _threadlocal
|
||||
|
||||
__patched__ = ['Context', 'Socket']
|
||||
globals().update(dict([(var, getattr(__zmq__, var))
|
||||
@@ -13,13 +13,6 @@ globals().update(dict([(var, getattr(__zmq__, var))
|
||||
]))
|
||||
|
||||
|
||||
def get_hub_name_from_instance(hub):
|
||||
"""Get the string name the eventlet uses to refer to hub
|
||||
|
||||
:param hub: An eventlet hub
|
||||
"""
|
||||
return hub.__class__.__module__.rsplit('.',1)[-1]
|
||||
|
||||
def Context(io_threads=1):
|
||||
"""Factory function replacement for :class:`zmq.core.context.Context`
|
||||
|
||||
@@ -31,11 +24,11 @@ def Context(io_threads=1):
|
||||
instance per thread. This is due to the way :class:`zmq.core.poll.Poller`
|
||||
works
|
||||
"""
|
||||
hub = get_hub()
|
||||
hub_name = get_hub_name_from_instance(hub)
|
||||
if hub_name != 'zeromq':
|
||||
raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name)
|
||||
return hub.get_context(io_threads)
|
||||
try:
|
||||
return _threadlocal.context
|
||||
except AttributeError:
|
||||
_threadlocal.context = _Context(io_threads)
|
||||
return _threadlocal.context
|
||||
|
||||
class _Context(__zmq__.Context):
|
||||
"""Internal subclass of :class:`zmq.core.context.Context`
|
||||
@@ -68,62 +61,65 @@ class Socket(__zmq__.Socket):
|
||||
``zmq.EAGAIN`` (retry) error is raised
|
||||
"""
|
||||
|
||||
def _sock_wait(self, read=False, write=False):
|
||||
"""
|
||||
First checks if there are events in the socket, to avoid
|
||||
edge trigger problems with race conditions. Then if there
|
||||
are none it will trampoline and when coming back check
|
||||
for the events.
|
||||
"""
|
||||
events = self.getsockopt(__zmq__.EVENTS)
|
||||
|
||||
def _send_message(self, msg, flags=0):
|
||||
if read and (events & __zmq__.POLLIN):
|
||||
return events
|
||||
elif write and (events & __zmq__.POLLOUT):
|
||||
return events
|
||||
else:
|
||||
# ONLY trampoline on read events for the zmq FD
|
||||
trampoline(self.getsockopt(__zmq__.FD), read=True)
|
||||
return self.getsockopt(__zmq__.EVENTS)
|
||||
|
||||
def send(self, msg, flags=0, copy=True, track=False):
|
||||
"""
|
||||
Override this instead of the internal _send_* methods
|
||||
since those change and it's not clear when/how they're
|
||||
called in real code.
|
||||
"""
|
||||
if flags & __zmq__.NOBLOCK:
|
||||
super(Socket, self)._send_message(msg, flags)
|
||||
super(Socket, self).send(msg, flags=flags, track=track, copy=copy)
|
||||
return
|
||||
|
||||
flags |= __zmq__.NOBLOCK
|
||||
|
||||
while True:
|
||||
try:
|
||||
super(Socket, self)._send_message(msg, flags)
|
||||
self._sock_wait(write=True)
|
||||
super(Socket, self).send(msg, flags=flags, track=track,
|
||||
copy=copy)
|
||||
return
|
||||
except __zmq__.ZMQError, e:
|
||||
if e.errno != EAGAIN:
|
||||
raise
|
||||
trampoline(self, write=True)
|
||||
|
||||
def _send_copy(self, msg, flags=0):
|
||||
def recv(self, flags=0, copy=True, track=False):
|
||||
"""
|
||||
Override this instead of the internal _recv_* methods
|
||||
since those change and it's not clear when/how they're
|
||||
called in real code.
|
||||
"""
|
||||
if flags & __zmq__.NOBLOCK:
|
||||
super(Socket, self)._send_copy(msg, flags)
|
||||
return
|
||||
return super(Socket, self).recv(flags=flags, track=track, copy=copy)
|
||||
|
||||
flags |= __zmq__.NOBLOCK
|
||||
|
||||
while True:
|
||||
try:
|
||||
super(Socket, self)._send_copy(msg, flags)
|
||||
return
|
||||
except __zmq__.ZMQError, e:
|
||||
if e.errno != EAGAIN:
|
||||
raise
|
||||
trampoline(self, write=True)
|
||||
|
||||
def _recv_message(self, flags=0, track=False):
|
||||
if flags & __zmq__.NOBLOCK:
|
||||
return super(Socket, self)._recv_message(flags, track)
|
||||
flags |= __zmq__.NOBLOCK
|
||||
while True:
|
||||
try:
|
||||
m = super(Socket, self)._recv_message(flags, track)
|
||||
self._sock_wait(read=True)
|
||||
m = super(Socket, self).recv(flags=flags, track=track, copy=copy)
|
||||
if m is not None:
|
||||
return m
|
||||
except __zmq__.ZMQError, e:
|
||||
if e.errno != EAGAIN:
|
||||
raise
|
||||
trampoline(self, read=True)
|
||||
|
||||
def _recv_copy(self, flags=0):
|
||||
if flags & __zmq__.NOBLOCK:
|
||||
return super(Socket, self)._recv_copy(flags)
|
||||
flags |= __zmq__.NOBLOCK
|
||||
while True:
|
||||
try:
|
||||
m = super(Socket, self)._recv_copy(flags)
|
||||
if m is not None:
|
||||
return m
|
||||
except __zmq__.ZMQError, e:
|
||||
if e.errno != EAGAIN:
|
||||
raise
|
||||
trampoline(self, read=True)
|
||||
|
||||
|
||||
|
||||
|
@@ -1,110 +0,0 @@
|
||||
from eventlet import patcher
|
||||
from eventlet.green import zmq
|
||||
from eventlet.hubs import _threadlocal
|
||||
from eventlet.hubs.hub import BaseHub, READ, WRITE, noop
|
||||
from eventlet.support import clear_sys_exc_info
|
||||
import sys
|
||||
|
||||
time = patcher.original('time')
|
||||
select = patcher.original('select')
|
||||
sleep = time.sleep
|
||||
|
||||
EXC_MASK = zmq.POLLERR
|
||||
READ_MASK = zmq.POLLIN
|
||||
WRITE_MASK = zmq.POLLOUT
|
||||
|
||||
class Hub(BaseHub):
|
||||
def __init__(self, clock=time.time):
|
||||
BaseHub.__init__(self, clock)
|
||||
self.poll = zmq.Poller()
|
||||
|
||||
def get_context(self, io_threads=1):
|
||||
"""zmq's Context must be unique within a hub
|
||||
|
||||
The zeromq API documentation states:
|
||||
All zmq sockets passed to the zmq_poll() function must share the same
|
||||
zmq context and must belong to the thread calling zmq_poll()
|
||||
|
||||
As zmq_poll is what's eventually being called then we need to insure
|
||||
that all sockets that are going to be passed to zmq_poll (via
|
||||
hub.do_poll) are in the same context
|
||||
"""
|
||||
try:
|
||||
return _threadlocal.context
|
||||
except AttributeError:
|
||||
_threadlocal.context = zmq._Context(io_threads)
|
||||
return _threadlocal.context
|
||||
|
||||
def add(self, evtype, fileno, cb):
|
||||
listener = super(Hub, self).add(evtype, fileno, cb)
|
||||
self.register(fileno, new=True)
|
||||
return listener
|
||||
|
||||
def remove(self, listener):
|
||||
super(Hub, self).remove(listener)
|
||||
self.register(listener.fileno)
|
||||
|
||||
def register(self, fileno, new=False):
|
||||
mask = 0
|
||||
if self.listeners[READ].get(fileno):
|
||||
mask |= READ_MASK
|
||||
if self.listeners[WRITE].get(fileno):
|
||||
mask |= WRITE_MASK
|
||||
if mask:
|
||||
self.poll.register(fileno, mask)
|
||||
else:
|
||||
self.poll.unregister(fileno)
|
||||
|
||||
def remove_descriptor(self, fileno):
|
||||
super(Hub, self).remove_descriptor(fileno)
|
||||
try:
|
||||
self.poll.unregister(fileno)
|
||||
except (KeyError, ValueError, IOError, OSError):
|
||||
# raised if we try to remove a fileno that was
|
||||
# already removed/invalid
|
||||
pass
|
||||
|
||||
def do_poll(self, seconds):
|
||||
# zmq.Poller.poll expects milliseconds
|
||||
return self.poll.poll(seconds * 1000.0)
|
||||
|
||||
def wait(self, seconds=None):
|
||||
readers = self.listeners[READ]
|
||||
writers = self.listeners[WRITE]
|
||||
|
||||
if not readers and not writers:
|
||||
if seconds:
|
||||
sleep(seconds)
|
||||
return
|
||||
try:
|
||||
presult = self.do_poll(seconds)
|
||||
except zmq.ZMQError, e:
|
||||
# In the poll hub this part exists to special case some exceptions
|
||||
# from socket. There may be some error numbers that wider use of
|
||||
# this hub will throw up as needing special treatment so leaving
|
||||
# this block and this comment as a remineder
|
||||
raise
|
||||
SYSTEM_EXCEPTIONS = self.SYSTEM_EXCEPTIONS
|
||||
|
||||
if self.debug_blocking:
|
||||
self.block_detect_pre()
|
||||
|
||||
for fileno, event in presult:
|
||||
try:
|
||||
if event & READ_MASK:
|
||||
readers.get(fileno, noop).cb(fileno)
|
||||
if event & WRITE_MASK:
|
||||
writers.get(fileno, noop).cb(fileno)
|
||||
if event & EXC_MASK:
|
||||
# zmq.POLLERR is returned for any error condition in the
|
||||
# underlying fd (as passed through to poll/epoll)
|
||||
readers.get(fileno, noop).cb(fileno)
|
||||
writers.get(fileno, noop).cb(fileno)
|
||||
except SYSTEM_EXCEPTIONS:
|
||||
raise
|
||||
except:
|
||||
self.squelch_exception(fileno, sys.exc_info())
|
||||
clear_sys_exc_info()
|
||||
|
||||
if self.debug_blocking:
|
||||
self.block_detect_post()
|
31
examples/zmq_simple.py
Normal file
31
examples/zmq_simple.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from eventlet.green import zmq
|
||||
import eventlet
|
||||
|
||||
CTX = zmq.Context(1)
|
||||
|
||||
def bob_client(ctx, count):
|
||||
print "STARTING BOB"
|
||||
bob = zmq.Socket(CTX, zmq.REQ)
|
||||
bob.connect("ipc:///tmp/test")
|
||||
|
||||
for i in range(0, count):
|
||||
print "BOB SENDING"
|
||||
bob.send("HI")
|
||||
print "BOB GOT:", bob.recv()
|
||||
|
||||
def alice_server(ctx, count):
|
||||
print "STARTING ALICE"
|
||||
alice = zmq.Socket(CTX, zmq.REP)
|
||||
alice.bind("ipc:///tmp/test")
|
||||
|
||||
print "ALICE READY"
|
||||
for i in range(0, count):
|
||||
print "ALICE GOT:", alice.recv()
|
||||
print "ALIC SENDING"
|
||||
alice.send("HI BACK")
|
||||
|
||||
alice = eventlet.spawn(alice_server, CTX, 10)
|
||||
bob = eventlet.spawn(bob_client, CTX, 10)
|
||||
|
||||
bob.wait()
|
||||
alice.wait()
|
@@ -94,8 +94,8 @@ def using_zmq(_f):
|
||||
import zmq
|
||||
except ImportError:
|
||||
return False
|
||||
from eventlet.hubs import get_hub
|
||||
return zmq and 'zeromq' in type(get_hub()).__module__
|
||||
|
||||
return True
|
||||
|
||||
def skip_unless_zmq(func):
|
||||
""" Decorator that skips a test if we're not using the zeromq hub."""
|
||||
|
@@ -5,12 +5,7 @@ from tests import mock, LimitedTestCase, skip_unless_zmq
|
||||
from unittest import TestCase
|
||||
|
||||
from threading import Thread
|
||||
try:
|
||||
from eventlet.green import zmq
|
||||
from eventlet.hubs.zeromq import Hub
|
||||
except ImportError:
|
||||
zmq = None
|
||||
Hub = None
|
||||
|
||||
|
||||
class TestUpstreamDownStream(LimitedTestCase):
|
||||
@@ -249,19 +244,15 @@ class TestThreadedContextAccess(TestCase):
|
||||
"""
|
||||
if zmq: # don't call decorators if zmq module unavailable
|
||||
@skip_unless_zmq
|
||||
@mock.patch('eventlet.green.zmq.get_hub_name_from_instance')
|
||||
@mock.patch('eventlet.green.zmq.get_hub', spec=Hub)
|
||||
def test_context_factory_funtion(self, get_hub_mock, hub_name_mock):
|
||||
hub_name_mock.return_value = 'zeromq'
|
||||
def test_context_factory_function(self):
|
||||
ctx = zmq.Context()
|
||||
self.assertTrue(get_hub_mock().get_context.called)
|
||||
self.assertTrue(ctx is not None)
|
||||
|
||||
@skip_unless_zmq
|
||||
def test_threadlocal_context(self):
|
||||
hub = get_hub()
|
||||
context = zmq.Context()
|
||||
self.assertEqual(context, _threadlocal.context)
|
||||
next_context = hub.get_context()
|
||||
next_context = zmq.Context()
|
||||
self.assertTrue(context is next_context)
|
||||
|
||||
@skip_unless_zmq
|
||||
@@ -269,33 +260,15 @@ class TestThreadedContextAccess(TestCase):
|
||||
context = zmq.Context()
|
||||
test_result = []
|
||||
def assert_different(ctx):
|
||||
hub = get_hub()
|
||||
try:
|
||||
this_thread_context = zmq.Context()
|
||||
except:
|
||||
test_result.append('fail')
|
||||
raise
|
||||
test_result.append(ctx is this_thread_context)
|
||||
|
||||
Thread(target=assert_different, args=(context,)).start()
|
||||
while not test_result:
|
||||
sleep(0.1)
|
||||
self.assertFalse(test_result[0])
|
||||
|
||||
|
||||
class TestCheckingForZMQHub(TestCase):
|
||||
|
||||
@skip_unless_zmq
|
||||
def setUp(self):
|
||||
self.orig_hub = zmq.get_hub_name_from_instance(get_hub())
|
||||
use_hub('selects')
|
||||
|
||||
def tearDown(self):
|
||||
use_hub(self.orig_hub)
|
||||
|
||||
def test_assertionerror_raise_by_context(self):
|
||||
self.assertRaises(RuntimeError, zmq.Context)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user