fixed missed read events on zmq socked FD

This commit is contained in:
Geoff Salmon
2011-09-05 14:38:04 -04:00
parent 77157811d4
commit 291282e785
2 changed files with 152 additions and 123 deletions

View File

@@ -46,8 +46,15 @@ class _Context(__zmq__.Context):
# see http://api.zeromq.org/2-1:zmq-socket for explanation of socket types
_multi_reader_types = set([__zmq__.XREP, __zmq__.XREQ, __zmq__.SUB, __zmq__.PULL, __zmq__.PAIR])
_multi_writer_types = set([__zmq__.XREP, __zmq__.XREQ, __zmq__.PUB, __zmq__.PUSH, __zmq__.PAIR])
_multi_reader_types = set([__zmq__.SUB, __zmq__.PULL, __zmq__.PAIR])
_multi_writer_types = set([__zmq__.PUB, __zmq__.PUSH, __zmq__.PAIR])
try:
_multi_reader_types.update([__zmq__.XREP, __zmq__.XREQ])
_multi_writer_types.update([__zmq__.XREP, __zmq__.XREQ])
except AttributeError:
# XREP and XREQ are being renamed ROUTER and DEALER
_multi_reader_types.update([__zmq__.ROUTER, __zmq__.DEALER])
_multi_writer_types.update([__zmq__.ROUTER, __zmq__.DEALER])
_disable_send_types = set([__zmq__.SUB, __zmq__.PULL])
_disable_recv_types = set([__zmq__.PUB, __zmq__.PUSH])
@@ -82,6 +89,12 @@ class Socket(__zmq__.Socket):
self._writers = None
self._readers = None
self._blocked_thread = None
self._wakeup_timer = None
self._super_getsockopt = super(Socket, self).getsockopt
self._fd = self._super_getsockopt(__zmq__.FD)
# customize send and recv functions based on socket type
if socket_type in _multi_writer_types:
# support multiple greenthreads writing at the same time
@@ -106,7 +119,7 @@ class Socket(__zmq__.Socket):
are none it will trampoline and when coming back check
for the events.
"""
events = self.getsockopt(__zmq__.EVENTS)
events = self._super_getsockopt(__zmq__.EVENTS)
if read and (events & __zmq__.POLLIN):
return events
@@ -114,8 +127,8 @@ class Socket(__zmq__.Socket):
return events
else:
# ONLY trampoline on read events for the zmq FD
trampoline(self.getsockopt(__zmq__.FD), read=True)
return self.getsockopt(__zmq__.EVENTS)
trampoline(self._fd, read=True)
return self._super_getsockopt(__zmq__.EVENTS)
def send(self, msg, flags=0, copy=True, track=False):
"""
@@ -156,32 +169,27 @@ class Socket(__zmq__.Socket):
if e.errno != EAGAIN:
raise
def getsockopt(self, option):
result = self._super_getsockopt(option)
if option == __zmq__.EVENTS:
# Getting the events causes the zmq socket to process
# events which may mean a msg can be sent or received. If
# there is a greenthread blocked and waiting for events,
# it will miss the edge-triggered read event, so wake it
# up.
if self._blocked_thread is not None:
if (self._readers and (result & __zmq__.POLLIN)) or \
(self._writers and (result & __zmq__.POLLOUT)):
self._wake_listener()
return result
def _send_not_supported(self, msg, flags=0, copy=True, track=False):
raise __zmq__.ZMQError(__zmq__.ENOTSUP)
def _recv_not_supported(self, flags=0, copy=True, track=False):
raise __zmq__.ZMQError(__zmq__.ENOTSUP)
def _xsafe_sock_wait(self, read=False, write=False):
"""
First checks if there are events in the socket, to avoid
edge trigger problems with race conditions. Then if there
are none it will trampoline and when coming back check
for the events.
"""
events = self.getsockopt(__zmq__.EVENTS)
if read and (events & __zmq__.POLLIN):
return events
elif write and (events & __zmq__.POLLOUT):
return events
else:
# ONLY trampoline on read events for the zmq FD
trampoline(self.getsockopt(__zmq__.FD), read=True)
return self.getsockopt(__zmq__.EVENTS)
def _xsafe_send(self, msg, flags=0, copy=True, track=False):
"""
A send method that's safe to use when multiple greenthreads
@@ -189,9 +197,12 @@ class Socket(__zmq__.Socket):
the same socket.
"""
if flags & __zmq__.NOBLOCK:
return super(Socket, self).send(msg, flags=flags, copy=copy, track=track)
raise __zmq__.ZMQError(__zmq__.ENOTSUP)
result = super(Socket, self).send(msg, flags=flags, copy=copy, track=track)
self._wake_listener()
return result
self._xsafe_inner_send(msg, flags, copy, track)
return self._xsafe_inner_send(msg, flags, copy, track)
def _xsafe_send_multipart(self, msg_parts, flags=0, copy=True, track=False):
"""
@@ -203,35 +214,25 @@ class Socket(__zmq__.Socket):
"""
if flags & __zmq__.NOBLOCK:
return super(Socket, self).send_multipart(msg_parts, flags=flags, copy=copy, track=track)
raise __zmq__.ZMQError(__zmq__.ENOTSUP)
result = super(Socket, self).send_multipart(msg_parts, flags=flags, copy=copy, track=track)
self._wake_listener()
return result
self._xsafe_inner_send(list(msg_parts), flags, copy, track)
return self._xsafe_inner_send(list(msg_parts), flags, copy, track)
def _xsafe_inner_send(self, msg, flags, copy, track):
is_listening = bool(self._writers or self._readers)
self._writers.append((greenlet.getcurrent(), msg, flags | __zmq__.NOBLOCK, copy, track))
if is_listening:
# Other readers or writers are blocked. If this is the
# first writer, it may be possible to send immediately
if len(self._writers) == 1:
# TODO: Check EVENTS first?
result = self._send_queued()
if not self._writers:
# success!
return result
# other readers or writers are blocked so this greenthread must wait its turn
result = hubs.get_hub().switch()
if result is False:
# msg was not yet sent, but this thread was woken up
# so that it could process the queues
return self._process_queues()
else:
# msg was sent by another greenthread
if len(self._writers) == 1:
# no other waiting writers, may be able to send immediately
result = self._send_queued()
if not self._writers:
# received message
self._wake_listener()
return result
else:
return self._process_queues()
return self._inner_send_recv()
def _xsafe_recv(self, flags=0, copy=True, track=False):
"""
@@ -241,7 +242,10 @@ class Socket(__zmq__.Socket):
"""
if flags & __zmq__.NOBLOCK:
return super(Socket, self).recv(flags=flags, copy=copy, track=track)
raise __zmq__.ZMQError(__zmq__.ENOTSUP)
msg = super(Socket, self).recv(flags=flags, copy=copy, track=track)
self._wake_listener()
return msg
return self._xsafe_inner_recv(False, flags, copy, track)
@@ -252,37 +256,46 @@ class Socket(__zmq__.Socket):
the same socket.
"""
if flags & __zmq__.NOBLOCK:
return super(Socket, self).recv_multipart(flags=flags, copy=copy, track=track)
raise __zmq__.ZMQError(__zmq__.ENOTSUP)
msg = super(Socket, self).recv_multipart(flags=flags, copy=copy, track=track)
self._wake_listener()
return msg
return self._xsafe_inner_recv(True, flags, copy, track)
def _xsafe_inner_recv(self, multi, flags, copy, track):
is_listening = bool(self._writers or self._readers)
self._readers.append((greenlet.getcurrent(), multi, flags | __zmq__.NOBLOCK, copy, track))
if is_listening:
# Other readers or writers are blocked. If this is the
# first reader, it may be possible to recv immediately
if len(self._readers) == 1:
# TODO: Check EVENTS first?
result = self._recv_queued()
if not self._readers:
# success!
return result
# other readers or writers are blocked so this greenthread must wait its turn
result = hubs.get_hub().switch()
if result is False:
# msg was not yet received, but this thread was woken up
# so that it could process the queues
return self._process_queues()
else:
# msg was received
if len(self._readers) == 1:
# no other waiting readers, may be able to recv immediately
result = self._recv_queued()
if result is not None:
# received message
self._wake_listener()
return result
else:
return self._process_queues()
return self._inner_send_recv()
def _inner_send_recv(self):
if self._wake_listener():
# Another greenthread is listening on the FD. Block this one.
result = hubs.get_hub().switch()
if result is not False:
# msg was sent or received
return result
# Send or recv has not been done, but this thread was
# woken up so that it could process the queues
return self._process_queues()
def _wake_listener(self):
is_listener = self._blocked_thread is not None
if is_listener and self._wakeup_timer is None:
self._wakeup_timer = hubs.get_hub().schedule_call_global(0, self._blocked_thread.switch)
return True
return is_listener
def _process_queues(self):
""" If there are readers or writers queued, this method tries
@@ -290,17 +303,15 @@ class Socket(__zmq__.Socket):
either in this greenthread or in another one. """
readers = self._readers
writers = self._writers
current = greenlet.getcurrent()
send_result = None
recv_result = None
result = None
while True:
events = self.getsockopt(__zmq__.EVENTS)
try:
if readers and (events & __zmq__.POLLIN):
recv_result = self._recv_queued()
if writers and (events & __zmq__.POLLOUT):
send_result = self._send_queued()
if readers:
result = self._recv_queued() or result
if writers:
result = self._send_queued() or result
except (SystemExit, KeyboardInterrupt):
raise
except:
@@ -312,20 +323,34 @@ class Socket(__zmq__.Socket):
hubs.get_hub().schedule_call_global(0, writers[0][0].switch, False)
raise
events = self._super_getsockopt(__zmq__.EVENTS)
if (readers and (events & __zmq__.POLLIN)) or \
(writers and (events & __zmq__.POLLOUT)):
# more work to do
continue
# send and recv cannot continue right now. If there are
# more readers or writers queued, either trampoline or
# wake another greenthread.
current = greenlet.getcurrent()
if (readers and readers[0][0] is current) or (writers and writers[0][0] is current):
if (readers and readers[0][0] is current) or \
(writers and writers[0][0] is current):
# Only trampoline if this thread is the next reader or writer,
# and ONLY trampoline on read events for zmq FDs.
trampoline(self.getsockopt(__zmq__.FD), read=True)
try:
self._blocked_thread = current
trampoline(self._fd, read=True)
finally:
if self._wakeup_timer is not None:
self._wakeup_timer.cancel()
self._wakeup_timer = None
self._blocked_thread = None
else:
if readers:
hubs.get_hub().schedule_call_global(0, readers[0][0].switch, False)
elif writers:
hubs.get_hub().schedule_call_global(0, writers[0][0].switch, False)
return send_result or recv_result
return result
def _send_queued(self):

View File

@@ -242,10 +242,12 @@ got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
def test_send_during_recv(self):
sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ)
sleep()
done = event.Event()
def slow_rx():
self.assertEqual(sender.recv(), "done")
num_recvs = 30
done_evts = [event.Event() for _ in range(num_recvs)]
def slow_rx(done, msg):
self.assertEqual(sender.recv(), msg)
done.send(0)
def tx():
@@ -253,49 +255,51 @@ got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
while tx_i <= 1000:
sender.send(str(tx_i))
tx_i += 1
def rx():
while True:
rx_i = receiver.recv()
if rx_i == "1000":
receiver.send('done')
for i in range(num_recvs):
receiver.send('done%d' % i)
sleep()
return
spawn(slow_rx)
for i in range(num_recvs):
spawn(slow_rx, done_evts[i], "done%d" % i)
spawn(tx)
spawn(rx)
for i in range(num_recvs):
final_i = done_evts[i].wait()
self.assertEqual(final_i, 0)
# Need someway to ensure a thread is blocked on send... This isn't working
@skip_unless(zmq_supported)
def test_recv_during_send(self):
sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ)
sleep()
num_recvs = 30
done = event.Event()
sender.setsockopt(zmq.HWM, 10)
sender.setsockopt(zmq.SNDBUF, 10)
receiver.setsockopt(zmq.RCVBUF, 10)
def tx():
tx_i = 0
while tx_i <= 1000:
sender.send(str(tx_i))
tx_i += 1
done.send(0)
spawn(tx)
final_i = done.wait()
self.assertEqual(final_i, 0)
# Need someway to ensure a thread is blocked on send. This method
# below uses too much memory. Try adjust watermarks or other
# socket opts?
# @skip_unless(zmq_supported)
# def test_recv_during_send(self):
# sender, receiver, port = self.create_bound_pair(zmq.XREQ, zmq.XREQ)
# sleep()
# done = event.Event()
# def tx():
# msg = "0" * 1024
# while True:
# sender.send(msg)
# def rx():
# self.assertEqual(sender.recv(), "done")
# sender_thread.kill()
# done.send(0)
# def single_tx():
# receiver.send("done")
# sender_thread = spawn(tx)
# sleep()
# spawn(rx)
# spawn(single_tx)
# final_i = done.wait()
# self.assertEqual(final_i, 0)
class TestThreadedContextAccess(TestCase):
"""zmq's Context must be unique within a hub