diff --git a/eventlet/green/zmq.py b/eventlet/green/zmq.py index 32681ba..97cf788 100644 --- a/eventlet/green/zmq.py +++ b/eventlet/green/zmq.py @@ -320,46 +320,46 @@ class Socket(_Socket): return _Socket_send_multipart(self, msg_parts, flags, copy, track) @_wraps(_Socket.send_string) - def send_string(self, msg_parts, flags=0, copy=True, track=False): + def send_string(self, u, flags=0, copy=True, encoding='utf-8'): """A send_string method that's safe to use when multiple greenthreads are calling send, send_string, recv and recv_string on the same socket. """ if flags & NOBLOCK: - return _Socket_send_string(self, msg_parts, flags, copy, track) + return _Socket_send_string(self, u, flags, copy, encoding) # acquire lock here so the subsequent calls to send for the # message parts after the first don't block with self._eventlet_send_lock: - return _Socket_send_string(self, msg_parts, flags, copy, track) + return _Socket_send_string(self, u, flags, copy, encoding) @_wraps(_Socket.send_pyobj) - def send_pyobj(self, msg_parts, flags=0, copy=True, track=False): + def send_pyobj(self, obj, flags=0, protocol=2): """A send_pyobj method that's safe to use when multiple greenthreads are calling send, send_pyobj, recv and recv_pyobj on the same socket. """ if flags & NOBLOCK: - return _Socket_send_pyobj(self, msg_parts, flags, copy, track) + return _Socket_send_pyobj(self, obj, flags, protocol) # acquire lock here so the subsequent calls to send for the # message parts after the first don't block with self._eventlet_send_lock: - return _Socket_send_pyobj(self, msg_parts, flags, copy, track) + return _Socket_send_pyobj(self, obj, flags, protocol) @_wraps(_Socket.send_json) - def send_json(self, msg_parts, flags=0, copy=True, track=False): + def send_json(self, obj, flags=0, **kwargs): """A send_json method that's safe to use when multiple greenthreads are calling send, send_json, recv and recv_json on the same socket. """ if flags & NOBLOCK: - return _Socket_send_json(self, msg_parts, flags, copy, track) + return _Socket_send_json(self, obj, flags, **kwargs) # acquire lock here so the subsequent calls to send for the # message parts after the first don't block with self._eventlet_send_lock: - return _Socket_send_json(self, msg_parts, flags, copy, track) + return _Socket_send_json(self, obj, flags, **kwargs) @_wraps(_Socket.recv) def recv(self, flags=0, copy=True, track=False): @@ -407,43 +407,43 @@ class Socket(_Socket): return _Socket_recv_multipart(self, flags, copy, track) @_wraps(_Socket.recv_string) - def recv_string(self, flags=0, copy=True, track=False): + def recv_string(self, flags=0, encoding='utf-8'): """A recv_string method that's safe to use when multiple greenthreads are calling send, send_string, recv and recv_string on the same socket. """ if flags & NOBLOCK: - return _Socket_recv_string(self, flags, copy, track) + return _Socket_recv_string(self, flags, encoding) # acquire lock here so the subsequent calls to recv for the # message parts after the first don't block with self._eventlet_recv_lock: - return _Socket_recv_string(self, flags, copy, track) - - @_wraps(_Socket.recv_pyobj) - def recv_pyobj(self, flags=0, copy=True, track=False): - """A recv_pyobj method that's safe to use when multiple - greenthreads are calling send, send_pyobj, recv and - recv_pyobj on the same socket. - """ - if flags & NOBLOCK: - return _Socket_recv_pyobj(self, flags, copy, track) - - # acquire lock here so the subsequent calls to recv for the - # message parts after the first don't block - with self._eventlet_recv_lock: - return _Socket_recv_pyobj(self, flags, copy, track) + return _Socket_recv_string(self, flags, encoding) @_wraps(_Socket.recv_json) - def recv_json(self, flags=0, copy=True, track=False): + def recv_json(self, flags=0, **kwargs): """A recv_json method that's safe to use when multiple greenthreads are calling send, send_json, recv and recv_json on the same socket. """ if flags & NOBLOCK: - return _Socket_recv_json(self, flags, copy, track) + return _Socket_recv_json(self, flags, **kwargs) # acquire lock here so the subsequent calls to recv for the # message parts after the first don't block with self._eventlet_recv_lock: - return _Socket_recv_json(self, flags, copy, track) + return _Socket_recv_json(self, flags, **kwargs) + + @_wraps(_Socket.recv_pyobj) + def recv_pyobj(self, flags=0): + """A recv_pyobj method that's safe to use when multiple + greenthreads are calling send, send_pyobj, recv and + recv_pyobj on the same socket. + """ + if flags & NOBLOCK: + return _Socket_recv_pyobj(self, flags) + + # acquire lock here so the subsequent calls to recv for the + # message parts after the first don't block + with self._eventlet_recv_lock: + return _Socket_recv_pyobj(self, flags) diff --git a/tests/zmq_test.py b/tests/zmq_test.py index 2910c11..5c1a25f 100644 --- a/tests/zmq_test.py +++ b/tests/zmq_test.py @@ -1,3 +1,5 @@ +import contextlib + try: from eventlet.green import zmq except ImportError: @@ -6,7 +8,6 @@ else: RECV_ON_CLOSED_SOCKET_ERRNOS = (zmq.ENOTSUP, zmq.ENOTSOCK) import eventlet -from eventlet import event, spawn, sleep, semaphore import tests @@ -80,15 +81,15 @@ class TestUpstreamDownStream(tests.LimitedTestCase): req, rep, port = self.create_bound_pair(zmq.PAIR, zmq.PAIR) # req.connect(ipc) # rep.bind(ipc) - sleep() + eventlet.sleep() msg = dict(res=None) - done = event.Event() + done = eventlet.Event() def rx(): msg['res'] = rep.recv() done.send('done') - spawn(rx) + eventlet.spawn(rx) req.send(b'test') done.wait() self.assertEqual(msg['res'], b'test') @@ -114,8 +115,8 @@ class TestUpstreamDownStream(tests.LimitedTestCase): @tests.skip_unless(zmq_supported) def test_send_1k_req_rep(self): req, rep, port = self.create_bound_pair(zmq.REQ, zmq.REP) - sleep() - done = event.Event() + eventlet.sleep() + done = eventlet.Event() def tx(): tx_i = 0 @@ -132,17 +133,17 @@ class TestUpstreamDownStream(tests.LimitedTestCase): rep.send(b'done') break rep.send(b'i') - spawn(tx) - spawn(rx) + eventlet.spawn(tx) + eventlet.spawn(rx) final_i = done.wait() self.assertEqual(final_i, 0) @tests.skip_unless(zmq_supported) def test_send_1k_push_pull(self): down, up, port = self.create_bound_pair(zmq.PUSH, zmq.PULL) - sleep() + eventlet.sleep() - done = event.Event() + done = eventlet.Event() def tx(): tx_i = 0 @@ -156,8 +157,8 @@ class TestUpstreamDownStream(tests.LimitedTestCase): if rx_i == b"1000": done.send(0) break - spawn(tx) - spawn(rx) + eventlet.spawn(tx) + eventlet.spawn(rx) final_i = done.wait() self.assertEqual(final_i, 0) @@ -174,17 +175,17 @@ class TestUpstreamDownStream(tests.LimitedTestCase): sub1.setsockopt(zmq.SUBSCRIBE, b'sub1') sub2.setsockopt(zmq.SUBSCRIBE, b'sub2') - sub_all_done = event.Event() - sub1_done = event.Event() - sub2_done = event.Event() + sub_all_done = eventlet.Event() + sub1_done = eventlet.Event() + sub2_done = eventlet.Event() - sleep(0.2) + eventlet.sleep(0.2) def rx(sock, done_evt, msg_count=10000): count = 0 while count < msg_count: msg = sock.recv() - sleep() + eventlet.sleep() if b'LAST' in msg: break count += 1 @@ -195,14 +196,14 @@ class TestUpstreamDownStream(tests.LimitedTestCase): for i in range(1, 1001): msg = ("sub%s %s" % ([2, 1][i % 2], i)).encode() sock.send(msg) - sleep() + eventlet.sleep() sock.send(b'sub1 LAST') sock.send(b'sub2 LAST') - spawn(rx, sub_all, sub_all_done) - spawn(rx, sub1, sub1_done) - spawn(rx, sub2, sub2_done) - spawn(tx, pub) + eventlet.spawn(rx, sub_all, sub_all_done) + eventlet.spawn(rx, sub1, sub1_done) + eventlet.spawn(rx, sub2, sub2_done) + eventlet.spawn(tx, pub) sub1_count = sub1_done.wait() sub2_count = sub2_done.wait() sub_all_count = sub_all_done.wait() @@ -216,14 +217,14 @@ class TestUpstreamDownStream(tests.LimitedTestCase): # of sporadic failures on Travis. pub, sub, port = self.create_bound_pair(zmq.PUB, zmq.SUB) sub.setsockopt(zmq.SUBSCRIBE, b'test') - sleep(0) - sub_ready = event.Event() - sub_last = event.Event() - sub_done = event.Event() + eventlet.sleep(0) + sub_ready = eventlet.Event() + sub_last = eventlet.Event() + sub_done = eventlet.Event() def rx(): while sub.recv() != b'test BEGIN': - sleep(0) + eventlet.sleep(0) sub_ready.send() count = 0 while True: @@ -234,7 +235,7 @@ class TestUpstreamDownStream(tests.LimitedTestCase): if msg == b'test LAST': sub.setsockopt(zmq.SUBSCRIBE, b'done') sub.setsockopt(zmq.UNSUBSCRIBE, b'test') - sleep(0) + eventlet.sleep(0) # In real application you should either sync # or tolerate loss of messages. sub_last.send() @@ -247,7 +248,7 @@ class TestUpstreamDownStream(tests.LimitedTestCase): # Sync receiver ready to avoid loss of first packets while not sub_ready.ready(): pub.send(b'test BEGIN') - sleep(0.005) + eventlet.sleep(0.005) for i in range(1, 101): msg = 'test {0}'.format(i).encode() if i != 50: @@ -256,8 +257,8 @@ class TestUpstreamDownStream(tests.LimitedTestCase): pub.send(b'test LAST') sub_last.wait() # XXX: putting a real delay of 1ms here fixes sporadic failures on Travis - # just yield sleep(0) doesn't cut it - sleep(0.001) + # just yield eventlet.sleep(0) doesn't cut it + eventlet.sleep(0.001) pub.send(b'done DONE') eventlet.spawn(rx) @@ -292,10 +293,10 @@ class TestUpstreamDownStream(tests.LimitedTestCase): @tests.skip_unless(zmq_supported) def test_send_during_recv(self): sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) - sleep() + eventlet.sleep() num_recvs = 30 - done_evts = [event.Event() for _ in range(num_recvs)] + done_evts = [eventlet.Event() for _ in range(num_recvs)] def slow_rx(done, msg): self.assertEqual(sender.recv(), msg) @@ -313,24 +314,24 @@ class TestUpstreamDownStream(tests.LimitedTestCase): if rx_i == b"1000": for i in range(num_recvs): receiver.send(('done%d' % i).encode()) - sleep() + eventlet.sleep() return for i in range(num_recvs): - spawn(slow_rx, done_evts[i], ("done%d" % i).encode()) + eventlet.spawn(slow_rx, done_evts[i], ("done%d" % i).encode()) - spawn(tx) - spawn(rx) + eventlet.spawn(tx) + eventlet.spawn(rx) for evt in done_evts: self.assertEqual(evt.wait(), 0) @tests.skip_unless(zmq_supported) def test_send_during_recv_multipart(self): sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) - sleep() + eventlet.sleep() num_recvs = 30 - done_evts = [event.Event() for _ in range(num_recvs)] + done_evts = [eventlet.Event() for _ in range(num_recvs)] def slow_rx(done, msg): self.assertEqual(sender.recv_multipart(), msg) @@ -349,15 +350,15 @@ class TestUpstreamDownStream(tests.LimitedTestCase): for i in range(num_recvs): receiver.send_multipart([ ('done%d' % i).encode(), b'a', b'b', b'c']) - sleep() + eventlet.sleep() return for i in range(num_recvs): - spawn(slow_rx, done_evts[i], [ + eventlet.spawn(slow_rx, done_evts[i], [ ("done%d" % i).encode(), b'a', b'b', b'c']) - spawn(tx) - spawn(rx) + eventlet.spawn(tx) + eventlet.spawn(rx) for i in range(num_recvs): final_i = done_evts[i].wait() self.assertEqual(final_i, 0) @@ -366,9 +367,9 @@ class TestUpstreamDownStream(tests.LimitedTestCase): @tests.skip_unless(zmq_supported) def test_recv_during_send(self): sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) - sleep() + eventlet.sleep() - done = event.Event() + done = eventlet.Event() try: SNDHWM = zmq.SNDHWM @@ -388,25 +389,25 @@ class TestUpstreamDownStream(tests.LimitedTestCase): tx_i += 1 done.send(0) - spawn(tx) + eventlet.spawn(tx) final_i = done.wait() self.assertEqual(final_i, 0) @tests.skip_unless(zmq_supported) def test_close_during_recv(self): sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ) - sleep() - done1 = event.Event() - done2 = event.Event() + eventlet.sleep() + done1 = eventlet.Event() + done2 = eventlet.Event() def rx(e): self.assertRaisesErrno(RECV_ON_CLOSED_SOCKET_ERRNOS, receiver.recv) e.send() - spawn(rx, done1) - spawn(rx, done2) + eventlet.spawn(rx, done1) + eventlet.spawn(rx, done2) - sleep() + eventlet.sleep() receiver.close() done1.wait() @@ -415,7 +416,7 @@ class TestUpstreamDownStream(tests.LimitedTestCase): @tests.skip_unless(zmq_supported) def test_getsockopt_events(self): sock1, sock2, _port = self.create_bound_pair(zmq.DEALER, zmq.DEALER) - sleep() + eventlet.sleep() poll_out = zmq.Poller() poll_out.register(sock1, zmq.POLLOUT) sock_map = poll_out.poll(100) @@ -447,7 +448,7 @@ class TestUpstreamDownStream(tests.LimitedTestCase): sock = self.context.socket(zmq.PUB) self.sockets.append(sock) sock.bind_to_random_port("tcp://127.0.0.1") - sleep() + eventlet.sleep() tests.check_idle_cpu_usage(0.2, 0.1) @tests.skip_unless(zmq_supported) @@ -458,12 +459,12 @@ class TestUpstreamDownStream(tests.LimitedTestCase): """ pub, sub, _port = self.create_bound_pair(zmq.PUB, zmq.SUB) sub.setsockopt(zmq.SUBSCRIBE, b"") - sleep() + eventlet.sleep() pub.send(b'test_send') tests.check_idle_cpu_usage(0.2, 0.1) sender, receiver, _port = self.create_bound_pair(zmq.DEALER, zmq.DEALER) - sleep() + eventlet.sleep() sender.send(b'test_recv') msg = receiver.recv() self.assertEqual(msg, b'test_recv') @@ -474,7 +475,7 @@ class TestQueueLock(tests.LimitedTestCase): @tests.skip_unless(zmq_supported) def test_queue_lock_order(self): q = zmq._QueueLock() - s = semaphore.Semaphore(0) + s = eventlet.Semaphore(0) results = [] def lock(x): @@ -484,12 +485,12 @@ class TestQueueLock(tests.LimitedTestCase): q.acquire() - spawn(lock, 1) - sleep() - spawn(lock, 2) - sleep() - spawn(lock, 3) - sleep() + eventlet.spawn(lock, 1) + eventlet.sleep() + eventlet.spawn(lock, 2) + eventlet.sleep() + eventlet.spawn(lock, 3) + eventlet.sleep() self.assertEqual(results, []) q.release() @@ -529,7 +530,7 @@ class TestQueueLock(tests.LimitedTestCase): q.acquire() q.acquire() - s = semaphore.Semaphore(0) + s = eventlet.Semaphore(0) results = [] def lock(x): @@ -537,11 +538,11 @@ class TestQueueLock(tests.LimitedTestCase): results.append(x) s.release() - spawn(lock, 1) - sleep() + eventlet.spawn(lock, 1) + eventlet.sleep() self.assertEqual(results, []) q.release() - sleep() + eventlet.sleep() self.assertEqual(results, []) self.assertTrue(q) q.release() @@ -554,16 +555,45 @@ class TestBlockedThread(tests.LimitedTestCase): @tests.skip_unless(zmq_supported) def test_block(self): e = zmq._BlockedThread() - done = event.Event() + done = eventlet.Event() self.assertFalse(e) def block(): e.block() done.send(1) - spawn(block) - sleep() + eventlet.spawn(block) + eventlet.sleep() self.assertFalse(done.has_result()) e.wake() done.wait() + + +@contextlib.contextmanager +def clean_context(): + ctx = zmq.Context() + eventlet.sleep() + yield ctx + ctx.destroy() + + +@contextlib.contextmanager +def clean_pair(type1, type2, interface='tcp://127.0.0.1'): + with clean_context() as ctx: + s1 = ctx.socket(type1) + port = s1.bind_to_random_port(interface) + s2 = ctx.socket(type2) + s2.connect('{0}:{1}'.format(interface, port)) + eventlet.sleep() + yield (s1, s2, port) + s1.close() + s2.close() + + +@tests.skip_unless(zmq_supported) +def test_recv_json_no_args(): + # https://github.com/eventlet/eventlet/issues/376 + with clean_pair(zmq.REQ, zmq.REP) as (s1, s2, _): + eventlet.spawn(s1.send_json, {}) + s2.recv_json()