Refactored the green version of zmq to use a factory function for Context
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
__zmq__ = __import__('zmq')
|
__zmq__ = __import__('zmq')
|
||||||
from eventlet import sleep
|
from eventlet import sleep
|
||||||
from eventlet.hubs import trampoline
|
from eventlet.hubs import trampoline, get_hub
|
||||||
|
|
||||||
__patched__ = ['Context', 'Socket']
|
__patched__ = ['Context', 'Socket']
|
||||||
globals().update(dict([(var, getattr(__zmq__, var))
|
globals().update(dict([(var, getattr(__zmq__, var))
|
||||||
@@ -10,7 +10,18 @@ globals().update(dict([(var, getattr(__zmq__, var))
|
|||||||
var in __patched__)
|
var in __patched__)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
class Context(__zmq__.Context):
|
|
||||||
|
def get_hub_name_from_instance(hub):
|
||||||
|
return hub.__class__.__module__.rsplit('.',1)[-1]
|
||||||
|
|
||||||
|
def Context(io_threads=1):
|
||||||
|
hub = get_hub()
|
||||||
|
hub_name = get_hub_name_from_instance(hub)
|
||||||
|
if hub_name != 'zeromq':
|
||||||
|
raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name)
|
||||||
|
return hub.get_context(io_threads)
|
||||||
|
|
||||||
|
class _Context(__zmq__.Context):
|
||||||
|
|
||||||
def socket(self, socket_type):
|
def socket(self, socket_type):
|
||||||
return Socket(self, socket_type)
|
return Socket(self, socket_type)
|
||||||
|
@@ -17,12 +17,11 @@ WRITE_MASK = zmq.POLLOUT
|
|||||||
class Hub(poll.Hub):
|
class Hub(poll.Hub):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, clock=time.time):
|
def __init__(self, clock=time.time):
|
||||||
BaseHub.__init__(self, clock)
|
BaseHub.__init__(self, clock)
|
||||||
self.poll = zmq.Poller()
|
self.poll = zmq.Poller()
|
||||||
|
|
||||||
def get_context(self):
|
def get_context(self, io_threads=1):
|
||||||
"""zmq's Context must be unique within a hub
|
"""zmq's Context must be unique within a hub
|
||||||
|
|
||||||
The zeromq API documentation states:
|
The zeromq API documentation states:
|
||||||
@@ -36,7 +35,7 @@ class Hub(poll.Hub):
|
|||||||
try:
|
try:
|
||||||
return _threadlocal.context
|
return _threadlocal.context
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
_threadlocal.context = zmq.Context()
|
_threadlocal.context = zmq._Context(io_threads)
|
||||||
return _threadlocal.context
|
return _threadlocal.context
|
||||||
|
|
||||||
def register(self, fileno, new=False):
|
def register(self, fileno, new=False):
|
||||||
|
@@ -24,8 +24,7 @@ from eventlet.hubs import get_hub, use_hub
|
|||||||
from uuid import uuid1
|
from uuid import uuid1
|
||||||
|
|
||||||
use_hub('zeromq')
|
use_hub('zeromq')
|
||||||
hub = get_hub()
|
ctx = zmq.Context()
|
||||||
ctx = hub.get_context()
|
|
||||||
|
|
||||||
class IDName(object):
|
class IDName(object):
|
||||||
|
|
||||||
|
@@ -1,12 +1,12 @@
|
|||||||
from eventlet import event, spawn, sleep, patcher
|
from eventlet import event, spawn, sleep, patcher
|
||||||
from eventlet.hubs import use_hub, get_hub, _threadlocal
|
from eventlet.hubs import get_hub, _threadlocal, use_hub
|
||||||
from eventlet.hubs.hub import READ, WRITE
|
|
||||||
from eventlet.green import zmq
|
from eventlet.green import zmq
|
||||||
from nose.tools import *
|
from nose.tools import *
|
||||||
from tests import mock, LimitedTestCase, skip_unless
|
from tests import mock, LimitedTestCase, skip_unless
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from eventlet.hubs.zeromq import Hub
|
||||||
|
|
||||||
def using_zmq(_f):
|
def using_zmq(_f):
|
||||||
return 'zeromq' in type(get_hub()).__module__
|
return 'zeromq' in type(get_hub()).__module__
|
||||||
@@ -17,13 +17,15 @@ def skip_unless_zmq(func):
|
|||||||
|
|
||||||
class TestUpstreamDownStream(LimitedTestCase):
|
class TestUpstreamDownStream(LimitedTestCase):
|
||||||
|
|
||||||
|
sockets = []
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.clear_up_sockets()
|
self.clear_up_sockets()
|
||||||
super(TestUpstreamDownStream, self).tearDown()
|
super(TestUpstreamDownStream, self).tearDown()
|
||||||
|
|
||||||
def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
|
def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
|
||||||
"""Create a bound socket pair using a random port."""
|
"""Create a bound socket pair using a random port."""
|
||||||
self.context = context = get_hub().get_context()
|
self.context = context = zmq.Context()
|
||||||
s1 = context.socket(type1)
|
s1 = context.socket(type1)
|
||||||
port = s1.bind_to_random_port(interface)
|
port = s1.bind_to_random_port(interface)
|
||||||
s2 = context.socket(type2)
|
s2 = context.socket(type2)
|
||||||
@@ -212,17 +214,25 @@ class TestThreadedContextAccess(TestCase):
|
|||||||
in the same context
|
in the same context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@skip_unless_zmq
|
||||||
|
@mock.patch('eventlet.green.zmq.get_hub_name_from_instance')
|
||||||
|
@mock.patch('eventlet.green.zmq.get_hub', spec=Hub)
|
||||||
|
def test_context_factory_funtion(self, get_hub_mock, hub_name_mock):
|
||||||
|
hub_name_mock.return_value = 'zeromq'
|
||||||
|
ctx = zmq.Context()
|
||||||
|
self.assertTrue(get_hub_mock().get_context.called)
|
||||||
|
|
||||||
@skip_unless_zmq
|
@skip_unless_zmq
|
||||||
def test_threadlocal_context(self):
|
def test_threadlocal_context(self):
|
||||||
hub = get_hub()
|
hub = get_hub()
|
||||||
context = hub.get_context()
|
context = zmq.Context()
|
||||||
self.assertEqual(context, _threadlocal.context)
|
self.assertEqual(context, _threadlocal.context)
|
||||||
next_context = hub.get_context()
|
next_context = hub.get_context()
|
||||||
self.assertTrue(context is next_context)
|
self.assertTrue(context is next_context)
|
||||||
|
|
||||||
@skip_unless_zmq
|
@skip_unless_zmq
|
||||||
def test_different_context_in_different_thread(self):
|
def test_different_context_in_different_thread(self):
|
||||||
context = get_hub().get_context()
|
context = zmq.Context()
|
||||||
test_result = []
|
test_result = []
|
||||||
def assert_different(ctx):
|
def assert_different(ctx):
|
||||||
# assert not hasattr(_threadlocal, 'hub')
|
# assert not hasattr(_threadlocal, 'hub')
|
||||||
@@ -230,7 +240,7 @@ class TestThreadedContextAccess(TestCase):
|
|||||||
# os.environ['EVENTLET_HUB'] = 'zeromq'
|
# os.environ['EVENTLET_HUB'] = 'zeromq'
|
||||||
hub = get_hub()
|
hub = get_hub()
|
||||||
try:
|
try:
|
||||||
this_thread_context = hub.get_context()
|
this_thread_context = zmq.Context()
|
||||||
except:
|
except:
|
||||||
test_result.append('fail')
|
test_result.append('fail')
|
||||||
raise
|
raise
|
||||||
@@ -240,6 +250,18 @@ class TestThreadedContextAccess(TestCase):
|
|||||||
sleep(0.1)
|
sleep(0.1)
|
||||||
self.assertFalse(test_result[0])
|
self.assertFalse(test_result[0])
|
||||||
|
|
||||||
|
class TestCheckingForZMQHub(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.orig_hub = zmq.get_hub_name_from_instance(get_hub())
|
||||||
|
use_hub('poll')
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
use_hub(self.orig_hub)
|
||||||
|
|
||||||
|
def test_assertionerror_raise_by_context(self):
|
||||||
|
self.assertRaises(RuntimeError, zmq.Context)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user