Refactored the green version of zmq to use a factory function for Context

This commit is contained in:
Ben Ford
2010-10-12 07:50:26 +01:00
parent 69a73a843e
commit e1cfcbb574
4 changed files with 44 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
__zmq__ = __import__('zmq') __zmq__ = __import__('zmq')
from eventlet import sleep from eventlet import sleep
from eventlet.hubs import trampoline from eventlet.hubs import trampoline, get_hub
__patched__ = ['Context', 'Socket'] __patched__ = ['Context', 'Socket']
globals().update(dict([(var, getattr(__zmq__, var)) globals().update(dict([(var, getattr(__zmq__, var))
@@ -10,7 +10,18 @@ globals().update(dict([(var, getattr(__zmq__, var))
var in __patched__) var in __patched__)
])) ]))
class Context(__zmq__.Context):
def get_hub_name_from_instance(hub):
return hub.__class__.__module__.rsplit('.',1)[-1]
def Context(io_threads=1):
hub = get_hub()
hub_name = get_hub_name_from_instance(hub)
if hub_name != 'zeromq':
raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name)
return hub.get_context(io_threads)
class _Context(__zmq__.Context):
def socket(self, socket_type): def socket(self, socket_type):
return Socket(self, socket_type) return Socket(self, socket_type)

View File

@@ -17,12 +17,11 @@ WRITE_MASK = zmq.POLLOUT
class Hub(poll.Hub): class Hub(poll.Hub):
def __init__(self, clock=time.time): def __init__(self, clock=time.time):
BaseHub.__init__(self, clock) BaseHub.__init__(self, clock)
self.poll = zmq.Poller() self.poll = zmq.Poller()
def get_context(self): def get_context(self, io_threads=1):
"""zmq's Context must be unique within a hub """zmq's Context must be unique within a hub
The zeromq API documentation states: The zeromq API documentation states:
@@ -36,7 +35,7 @@ class Hub(poll.Hub):
try: try:
return _threadlocal.context return _threadlocal.context
except AttributeError: except AttributeError:
_threadlocal.context = zmq.Context() _threadlocal.context = zmq._Context(io_threads)
return _threadlocal.context return _threadlocal.context
def register(self, fileno, new=False): def register(self, fileno, new=False):

View File

@@ -24,8 +24,7 @@ from eventlet.hubs import get_hub, use_hub
from uuid import uuid1 from uuid import uuid1
use_hub('zeromq') use_hub('zeromq')
hub = get_hub() ctx = zmq.Context()
ctx = hub.get_context()
class IDName(object): class IDName(object):

View File

@@ -1,12 +1,12 @@
from eventlet import event, spawn, sleep, patcher from eventlet import event, spawn, sleep, patcher
from eventlet.hubs import use_hub, get_hub, _threadlocal from eventlet.hubs import get_hub, _threadlocal, use_hub
from eventlet.hubs.hub import READ, WRITE
from eventlet.green import zmq from eventlet.green import zmq
from nose.tools import * from nose.tools import *
from tests import mock, LimitedTestCase, skip_unless from tests import mock, LimitedTestCase, skip_unless
from unittest import TestCase from unittest import TestCase
from threading import Thread from threading import Thread
from eventlet.hubs.zeromq import Hub
def using_zmq(_f): def using_zmq(_f):
return 'zeromq' in type(get_hub()).__module__ return 'zeromq' in type(get_hub()).__module__
@@ -17,13 +17,15 @@ def skip_unless_zmq(func):
class TestUpstreamDownStream(LimitedTestCase): class TestUpstreamDownStream(LimitedTestCase):
sockets = []
def tearDown(self): def tearDown(self):
self.clear_up_sockets() self.clear_up_sockets()
super(TestUpstreamDownStream, self).tearDown() super(TestUpstreamDownStream, self).tearDown()
def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'): def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
"""Create a bound socket pair using a random port.""" """Create a bound socket pair using a random port."""
self.context = context = get_hub().get_context() self.context = context = zmq.Context()
s1 = context.socket(type1) s1 = context.socket(type1)
port = s1.bind_to_random_port(interface) port = s1.bind_to_random_port(interface)
s2 = context.socket(type2) s2 = context.socket(type2)
@@ -212,17 +214,25 @@ class TestThreadedContextAccess(TestCase):
in the same context in the same context
""" """
@skip_unless_zmq
@mock.patch('eventlet.green.zmq.get_hub_name_from_instance')
@mock.patch('eventlet.green.zmq.get_hub', spec=Hub)
def test_context_factory_funtion(self, get_hub_mock, hub_name_mock):
hub_name_mock.return_value = 'zeromq'
ctx = zmq.Context()
self.assertTrue(get_hub_mock().get_context.called)
@skip_unless_zmq @skip_unless_zmq
def test_threadlocal_context(self): def test_threadlocal_context(self):
hub = get_hub() hub = get_hub()
context = hub.get_context() context = zmq.Context()
self.assertEqual(context, _threadlocal.context) self.assertEqual(context, _threadlocal.context)
next_context = hub.get_context() next_context = hub.get_context()
self.assertTrue(context is next_context) self.assertTrue(context is next_context)
@skip_unless_zmq @skip_unless_zmq
def test_different_context_in_different_thread(self): def test_different_context_in_different_thread(self):
context = get_hub().get_context() context = zmq.Context()
test_result = [] test_result = []
def assert_different(ctx): def assert_different(ctx):
# assert not hasattr(_threadlocal, 'hub') # assert not hasattr(_threadlocal, 'hub')
@@ -230,7 +240,7 @@ class TestThreadedContextAccess(TestCase):
# os.environ['EVENTLET_HUB'] = 'zeromq' # os.environ['EVENTLET_HUB'] = 'zeromq'
hub = get_hub() hub = get_hub()
try: try:
this_thread_context = hub.get_context() this_thread_context = zmq.Context()
except: except:
test_result.append('fail') test_result.append('fail')
raise raise
@@ -240,6 +250,18 @@ class TestThreadedContextAccess(TestCase):
sleep(0.1) sleep(0.1)
self.assertFalse(test_result[0]) self.assertFalse(test_result[0])
class TestCheckingForZMQHub(TestCase):
def setUp(self):
self.orig_hub = zmq.get_hub_name_from_instance(get_hub())
use_hub('poll')
def tearDown(self):
use_hub(self.orig_hub)
def test_assertionerror_raise_by_context(self):
self.assertRaises(RuntimeError, zmq.Context)