diff --git a/eventlet/green/zmq.py b/eventlet/green/zmq.py index 2da0a6b..8f3c824 100644 --- a/eventlet/green/zmq.py +++ b/eventlet/green/zmq.py @@ -1,6 +1,6 @@ __zmq__ = __import__('zmq') from eventlet import sleep -from eventlet.hubs import trampoline +from eventlet.hubs import trampoline, get_hub __patched__ = ['Context', 'Socket'] globals().update(dict([(var, getattr(__zmq__, var)) @@ -10,7 +10,18 @@ globals().update(dict([(var, getattr(__zmq__, var)) var in __patched__) ])) -class Context(__zmq__.Context): + +def get_hub_name_from_instance(hub): + return hub.__class__.__module__.rsplit('.',1)[-1] + +def Context(io_threads=1): + hub = get_hub() + hub_name = get_hub_name_from_instance(hub) + if hub_name != 'zeromq': + raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name) + return hub.get_context(io_threads) + +class _Context(__zmq__.Context): def socket(self, socket_type): return Socket(self, socket_type) diff --git a/eventlet/hubs/zeromq.py b/eventlet/hubs/zeromq.py index 9ce6503..2433bb9 100644 --- a/eventlet/hubs/zeromq.py +++ b/eventlet/hubs/zeromq.py @@ -17,12 +17,11 @@ WRITE_MASK = zmq.POLLOUT class Hub(poll.Hub): - def __init__(self, clock=time.time): BaseHub.__init__(self, clock) self.poll = zmq.Poller() - def get_context(self): + def get_context(self, io_threads=1): """zmq's Context must be unique within a hub The zeromq API documentation states: @@ -36,7 +35,7 @@ class Hub(poll.Hub): try: return _threadlocal.context except AttributeError: - _threadlocal.context = zmq.Context() + _threadlocal.context = zmq._Context(io_threads) return _threadlocal.context def register(self, fileno, new=False): diff --git a/examples/distributed_websocket_chat.py b/examples/distributed_websocket_chat.py index aa2ed85..7ad0483 100644 --- a/examples/distributed_websocket_chat.py +++ b/examples/distributed_websocket_chat.py @@ -24,8 +24,7 @@ from eventlet.hubs import get_hub, use_hub from uuid import uuid1 use_hub('zeromq') -hub = get_hub() -ctx = hub.get_context() +ctx = zmq.Context() class IDName(object): diff --git a/tests/zmq_test.py b/tests/zmq_test.py index ecf3c4c..04b54cd 100644 --- a/tests/zmq_test.py +++ b/tests/zmq_test.py @@ -1,12 +1,12 @@ from eventlet import event, spawn, sleep, patcher -from eventlet.hubs import use_hub, get_hub, _threadlocal -from eventlet.hubs.hub import READ, WRITE +from eventlet.hubs import get_hub, _threadlocal, use_hub from eventlet.green import zmq from nose.tools import * from tests import mock, LimitedTestCase, skip_unless from unittest import TestCase from threading import Thread +from eventlet.hubs.zeromq import Hub def using_zmq(_f): return 'zeromq' in type(get_hub()).__module__ @@ -17,13 +17,15 @@ def skip_unless_zmq(func): class TestUpstreamDownStream(LimitedTestCase): + sockets = [] + def tearDown(self): self.clear_up_sockets() super(TestUpstreamDownStream, self).tearDown() def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'): """Create a bound socket pair using a random port.""" - self.context = context = get_hub().get_context() + self.context = context = zmq.Context() s1 = context.socket(type1) port = s1.bind_to_random_port(interface) s2 = context.socket(type2) @@ -212,17 +214,25 @@ class TestThreadedContextAccess(TestCase): in the same context """ + @skip_unless_zmq + @mock.patch('eventlet.green.zmq.get_hub_name_from_instance') + @mock.patch('eventlet.green.zmq.get_hub', spec=Hub) + def test_context_factory_funtion(self, get_hub_mock, hub_name_mock): + hub_name_mock.return_value = 'zeromq' + ctx = zmq.Context() + self.assertTrue(get_hub_mock().get_context.called) + @skip_unless_zmq def test_threadlocal_context(self): hub = get_hub() - context = hub.get_context() + context = zmq.Context() self.assertEqual(context, _threadlocal.context) next_context = hub.get_context() self.assertTrue(context is next_context) @skip_unless_zmq def test_different_context_in_different_thread(self): - context = get_hub().get_context() + context = zmq.Context() test_result = [] def assert_different(ctx): # assert not hasattr(_threadlocal, 'hub') @@ -230,7 +240,7 @@ class TestThreadedContextAccess(TestCase): # os.environ['EVENTLET_HUB'] = 'zeromq' hub = get_hub() try: - this_thread_context = hub.get_context() + this_thread_context = zmq.Context() except: test_result.append('fail') raise @@ -240,6 +250,18 @@ class TestThreadedContextAccess(TestCase): sleep(0.1) self.assertFalse(test_result[0]) +class TestCheckingForZMQHub(TestCase): + + def setUp(self): + self.orig_hub = zmq.get_hub_name_from_instance(get_hub()) + use_hub('poll') + + def tearDown(self): + use_hub(self.orig_hub) + + def test_assertionerror_raise_by_context(self): + self.assertRaises(RuntimeError, zmq.Context) +