Refactored the green version of zmq to use a factory function for Context
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
__zmq__ = __import__('zmq')
|
||||
from eventlet import sleep
|
||||
from eventlet.hubs import trampoline
|
||||
from eventlet.hubs import trampoline, get_hub
|
||||
|
||||
__patched__ = ['Context', 'Socket']
|
||||
globals().update(dict([(var, getattr(__zmq__, var))
|
||||
@@ -10,7 +10,18 @@ globals().update(dict([(var, getattr(__zmq__, var))
|
||||
var in __patched__)
|
||||
]))
|
||||
|
||||
class Context(__zmq__.Context):
|
||||
|
||||
def get_hub_name_from_instance(hub):
|
||||
return hub.__class__.__module__.rsplit('.',1)[-1]
|
||||
|
||||
def Context(io_threads=1):
|
||||
hub = get_hub()
|
||||
hub_name = get_hub_name_from_instance(hub)
|
||||
if hub_name != 'zeromq':
|
||||
raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name)
|
||||
return hub.get_context(io_threads)
|
||||
|
||||
class _Context(__zmq__.Context):
|
||||
|
||||
def socket(self, socket_type):
|
||||
return Socket(self, socket_type)
|
||||
|
@@ -17,12 +17,11 @@ WRITE_MASK = zmq.POLLOUT
|
||||
class Hub(poll.Hub):
|
||||
|
||||
|
||||
|
||||
def __init__(self, clock=time.time):
|
||||
BaseHub.__init__(self, clock)
|
||||
self.poll = zmq.Poller()
|
||||
|
||||
def get_context(self):
|
||||
def get_context(self, io_threads=1):
|
||||
"""zmq's Context must be unique within a hub
|
||||
|
||||
The zeromq API documentation states:
|
||||
@@ -36,7 +35,7 @@ class Hub(poll.Hub):
|
||||
try:
|
||||
return _threadlocal.context
|
||||
except AttributeError:
|
||||
_threadlocal.context = zmq.Context()
|
||||
_threadlocal.context = zmq._Context(io_threads)
|
||||
return _threadlocal.context
|
||||
|
||||
def register(self, fileno, new=False):
|
||||
|
@@ -24,8 +24,7 @@ from eventlet.hubs import get_hub, use_hub
|
||||
from uuid import uuid1
|
||||
|
||||
use_hub('zeromq')
|
||||
hub = get_hub()
|
||||
ctx = hub.get_context()
|
||||
ctx = zmq.Context()
|
||||
|
||||
class IDName(object):
|
||||
|
||||
|
@@ -1,12 +1,12 @@
|
||||
from eventlet import event, spawn, sleep, patcher
|
||||
from eventlet.hubs import use_hub, get_hub, _threadlocal
|
||||
from eventlet.hubs.hub import READ, WRITE
|
||||
from eventlet.hubs import get_hub, _threadlocal, use_hub
|
||||
from eventlet.green import zmq
|
||||
from nose.tools import *
|
||||
from tests import mock, LimitedTestCase, skip_unless
|
||||
from unittest import TestCase
|
||||
|
||||
from threading import Thread
|
||||
from eventlet.hubs.zeromq import Hub
|
||||
|
||||
def using_zmq(_f):
|
||||
return 'zeromq' in type(get_hub()).__module__
|
||||
@@ -17,13 +17,15 @@ def skip_unless_zmq(func):
|
||||
|
||||
class TestUpstreamDownStream(LimitedTestCase):
|
||||
|
||||
sockets = []
|
||||
|
||||
def tearDown(self):
|
||||
self.clear_up_sockets()
|
||||
super(TestUpstreamDownStream, self).tearDown()
|
||||
|
||||
def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
|
||||
"""Create a bound socket pair using a random port."""
|
||||
self.context = context = get_hub().get_context()
|
||||
self.context = context = zmq.Context()
|
||||
s1 = context.socket(type1)
|
||||
port = s1.bind_to_random_port(interface)
|
||||
s2 = context.socket(type2)
|
||||
@@ -212,17 +214,25 @@ class TestThreadedContextAccess(TestCase):
|
||||
in the same context
|
||||
"""
|
||||
|
||||
@skip_unless_zmq
|
||||
@mock.patch('eventlet.green.zmq.get_hub_name_from_instance')
|
||||
@mock.patch('eventlet.green.zmq.get_hub', spec=Hub)
|
||||
def test_context_factory_funtion(self, get_hub_mock, hub_name_mock):
|
||||
hub_name_mock.return_value = 'zeromq'
|
||||
ctx = zmq.Context()
|
||||
self.assertTrue(get_hub_mock().get_context.called)
|
||||
|
||||
@skip_unless_zmq
|
||||
def test_threadlocal_context(self):
|
||||
hub = get_hub()
|
||||
context = hub.get_context()
|
||||
context = zmq.Context()
|
||||
self.assertEqual(context, _threadlocal.context)
|
||||
next_context = hub.get_context()
|
||||
self.assertTrue(context is next_context)
|
||||
|
||||
@skip_unless_zmq
|
||||
def test_different_context_in_different_thread(self):
|
||||
context = get_hub().get_context()
|
||||
context = zmq.Context()
|
||||
test_result = []
|
||||
def assert_different(ctx):
|
||||
# assert not hasattr(_threadlocal, 'hub')
|
||||
@@ -230,7 +240,7 @@ class TestThreadedContextAccess(TestCase):
|
||||
# os.environ['EVENTLET_HUB'] = 'zeromq'
|
||||
hub = get_hub()
|
||||
try:
|
||||
this_thread_context = hub.get_context()
|
||||
this_thread_context = zmq.Context()
|
||||
except:
|
||||
test_result.append('fail')
|
||||
raise
|
||||
@@ -240,6 +250,18 @@ class TestThreadedContextAccess(TestCase):
|
||||
sleep(0.1)
|
||||
self.assertFalse(test_result[0])
|
||||
|
||||
class TestCheckingForZMQHub(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.orig_hub = zmq.get_hub_name_from_instance(get_hub())
|
||||
use_hub('poll')
|
||||
|
||||
def tearDown(self):
|
||||
use_hub(self.orig_hub)
|
||||
|
||||
def test_assertionerror_raise_by_context(self):
|
||||
self.assertRaises(RuntimeError, zmq.Context)
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user