Refactored the green version of zmq to use a factory function for Context

This commit is contained in:
Ben Ford
2010-10-12 07:50:26 +01:00
parent 69a73a843e
commit e1cfcbb574
4 changed files with 44 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
__zmq__ = __import__('zmq')
from eventlet import sleep
from eventlet.hubs import trampoline
from eventlet.hubs import trampoline, get_hub
__patched__ = ['Context', 'Socket']
globals().update(dict([(var, getattr(__zmq__, var))
@@ -10,7 +10,18 @@ globals().update(dict([(var, getattr(__zmq__, var))
var in __patched__)
]))
class Context(__zmq__.Context):
def get_hub_name_from_instance(hub):
return hub.__class__.__module__.rsplit('.',1)[-1]
def Context(io_threads=1):
hub = get_hub()
hub_name = get_hub_name_from_instance(hub)
if hub_name != 'zeromq':
raise RuntimeError("Hub must be 'zeromq', got '%s'" % hub_name)
return hub.get_context(io_threads)
class _Context(__zmq__.Context):
def socket(self, socket_type):
return Socket(self, socket_type)

View File

@@ -17,12 +17,11 @@ WRITE_MASK = zmq.POLLOUT
class Hub(poll.Hub):
def __init__(self, clock=time.time):
BaseHub.__init__(self, clock)
self.poll = zmq.Poller()
def get_context(self):
def get_context(self, io_threads=1):
"""zmq's Context must be unique within a hub
The zeromq API documentation states:
@@ -36,7 +35,7 @@ class Hub(poll.Hub):
try:
return _threadlocal.context
except AttributeError:
_threadlocal.context = zmq.Context()
_threadlocal.context = zmq._Context(io_threads)
return _threadlocal.context
def register(self, fileno, new=False):

View File

@@ -24,8 +24,7 @@ from eventlet.hubs import get_hub, use_hub
from uuid import uuid1
use_hub('zeromq')
hub = get_hub()
ctx = hub.get_context()
ctx = zmq.Context()
class IDName(object):

View File

@@ -1,12 +1,12 @@
from eventlet import event, spawn, sleep, patcher
from eventlet.hubs import use_hub, get_hub, _threadlocal
from eventlet.hubs.hub import READ, WRITE
from eventlet.hubs import get_hub, _threadlocal, use_hub
from eventlet.green import zmq
from nose.tools import *
from tests import mock, LimitedTestCase, skip_unless
from unittest import TestCase
from threading import Thread
from eventlet.hubs.zeromq import Hub
def using_zmq(_f):
return 'zeromq' in type(get_hub()).__module__
@@ -17,13 +17,15 @@ def skip_unless_zmq(func):
class TestUpstreamDownStream(LimitedTestCase):
sockets = []
def tearDown(self):
self.clear_up_sockets()
super(TestUpstreamDownStream, self).tearDown()
def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
"""Create a bound socket pair using a random port."""
self.context = context = get_hub().get_context()
self.context = context = zmq.Context()
s1 = context.socket(type1)
port = s1.bind_to_random_port(interface)
s2 = context.socket(type2)
@@ -212,17 +214,25 @@ class TestThreadedContextAccess(TestCase):
in the same context
"""
@skip_unless_zmq
@mock.patch('eventlet.green.zmq.get_hub_name_from_instance')
@mock.patch('eventlet.green.zmq.get_hub', spec=Hub)
def test_context_factory_funtion(self, get_hub_mock, hub_name_mock):
hub_name_mock.return_value = 'zeromq'
ctx = zmq.Context()
self.assertTrue(get_hub_mock().get_context.called)
@skip_unless_zmq
def test_threadlocal_context(self):
hub = get_hub()
context = hub.get_context()
context = zmq.Context()
self.assertEqual(context, _threadlocal.context)
next_context = hub.get_context()
self.assertTrue(context is next_context)
@skip_unless_zmq
def test_different_context_in_different_thread(self):
context = get_hub().get_context()
context = zmq.Context()
test_result = []
def assert_different(ctx):
# assert not hasattr(_threadlocal, 'hub')
@@ -230,7 +240,7 @@ class TestThreadedContextAccess(TestCase):
# os.environ['EVENTLET_HUB'] = 'zeromq'
hub = get_hub()
try:
this_thread_context = hub.get_context()
this_thread_context = zmq.Context()
except:
test_result.append('fail')
raise
@@ -240,6 +250,18 @@ class TestThreadedContextAccess(TestCase):
sleep(0.1)
self.assertFalse(test_result[0])
class TestCheckingForZMQHub(TestCase):
def setUp(self):
self.orig_hub = zmq.get_hub_name_from_instance(get_hub())
use_hub('poll')
def tearDown(self):
use_hub(self.orig_hub)
def test_assertionerror_raise_by_context(self):
self.assertRaises(RuntimeError, zmq.Context)