154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
from eventlet import event, spawn, sleep, patcher
|
|
from eventlet.hubs import use_hub, get_hub, _threadlocal
|
|
from eventlet.hubs.hub import READ, WRITE
|
|
from eventlet.green import zmq
|
|
from nose.tools import *
|
|
from tests import mock, LimitedTestCase, skip_unless
|
|
from unittest import TestCase
|
|
|
|
from threading import Thread
|
|
|
|
def using_zmq(_f):
|
|
return 'zeromq' in type(get_hub()).__module__
|
|
|
|
def skip_unless_zmq(func):
|
|
""" Decorator that skips a test if we're using the pyevent hub."""
|
|
return skip_unless(using_zmq)(func)
|
|
|
|
class TestUpstreamDownStream(LimitedTestCase):
|
|
|
|
def create_bound_pair(self, type1, type2, interface='tcp://127.0.0.1'):
|
|
"""Create a bound socket pair using a random port."""
|
|
self.context = context = get_hub().get_context()
|
|
s1 = context.socket(type1)
|
|
port = s1.bind_to_random_port(interface)
|
|
s2 = context.socket(type2)
|
|
s2.connect('%s:%s' % (interface, port))
|
|
return s1, s2
|
|
|
|
def assertRaisesErrno(self, errno, func, *args):
|
|
try:
|
|
func(*args)
|
|
except zmq.ZMQError, e:
|
|
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
|
|
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
|
|
else:
|
|
self.fail("Function did not raise any error")
|
|
|
|
@skip_unless_zmq
|
|
def test_recv_spawned_before_send_is_non_blocking(self):
|
|
req, rep = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
# req.connect(ipc)
|
|
# rep.bind(ipc)
|
|
sleep()
|
|
msg = dict(res=None)
|
|
done = event.Event()
|
|
def rx():
|
|
msg['res'] = rep.recv()
|
|
done.send('done')
|
|
spawn(rx)
|
|
req.send('test')
|
|
done.wait()
|
|
self.assertEqual(msg['res'], 'test')
|
|
|
|
@skip_unless_zmq
|
|
def test_close_socket_raises_enotsup(self):
|
|
req, rep = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
rep.close()
|
|
req.close()
|
|
self.assertRaisesErrno(zmq.ENOTSUP, rep.recv)
|
|
self.assertRaisesErrno(zmq.ENOTSUP, req.send, 'test')
|
|
|
|
@skip_unless_zmq
|
|
def test_send_1k_req_rep(self):
|
|
req, rep = self.create_bound_pair(zmq.REQ, zmq.REP)
|
|
sleep()
|
|
done = event.Event()
|
|
def tx():
|
|
tx_i = 0
|
|
req.send(str(tx_i))
|
|
while req.recv() != 'done':
|
|
tx_i += 1
|
|
req.send(str(tx_i))
|
|
def rx():
|
|
while True:
|
|
rx_i = rep.recv()
|
|
if rx_i == "1000":
|
|
rep.send('done')
|
|
sleep()
|
|
done.send(0)
|
|
break
|
|
rep.send('i')
|
|
spawn(tx)
|
|
spawn(rx)
|
|
final_i = done.wait()
|
|
self.assertEqual(final_i, 0)
|
|
|
|
@skip_unless_zmq
|
|
def test_send_1k_up_down(self):
|
|
down, up = self.create_bound_pair(zmq.DOWNSTREAM, zmq.UPSTREAM)
|
|
sleep()
|
|
done = event.Event()
|
|
def tx():
|
|
tx_i = 0
|
|
while tx_i <= 1000:
|
|
tx_i += 1
|
|
down.send(str(tx_i))
|
|
def rx():
|
|
while True:
|
|
rx_i = up.recv()
|
|
if rx_i == "1000":
|
|
done.send(0)
|
|
break
|
|
spawn(tx)
|
|
spawn(rx)
|
|
final_i = done.wait()
|
|
self.assertEqual(final_i, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestThreadedContextAccess(TestCase):
|
|
"""zmq's Context must be unique within a hub
|
|
|
|
The zeromq API documentation states:
|
|
All zmq sockets passed to the zmq_poll() function must share the same zmq
|
|
context and must belong to the thread calling zmq_poll()
|
|
|
|
As zmq_poll is what's eventually being called then we need to insure that
|
|
all sockets that are going to be passed to zmq_poll (via hub.do_poll) are
|
|
in the same context
|
|
"""
|
|
|
|
@skip_unless_zmq
|
|
def test_threadlocal_context(self):
|
|
hub = get_hub()
|
|
context = hub.get_context()
|
|
self.assertEqual(context, _threadlocal.context)
|
|
next_context = hub.get_context()
|
|
self.assertTrue(context is next_context)
|
|
|
|
@skip_unless_zmq
|
|
def test_different_context_in_different_thread(self):
|
|
context = get_hub().get_context()
|
|
test_result = []
|
|
def assert_different(ctx):
|
|
assert not hasattr(_threadlocal, 'hub')
|
|
hub = get_hub()
|
|
try:
|
|
this_thread_context = hub.get_context()
|
|
except:
|
|
test_result.append('fail')
|
|
test_result.append(ctx is this_thread_context)
|
|
Thread(target=assert_different, args=(context,)).start()
|
|
while not len(test_result):
|
|
pass
|
|
self.assertFalse(test_result[0])
|
|
|
|
|
|
|
|
|
|
|