Correctly handle multi-message reads

Data at the end of the first message was accidentally
being discarded (by overwriting self._iobuf prior to
reading the remainder).  There's also some other minor
cleanup and reorg.
This commit is contained in:
Tyler Hobbs
2013-09-17 17:28:53 -05:00
parent de23a89b43
commit ca039729d9

View File

@@ -1,6 +1,7 @@
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import partial, wraps from functools import partial, wraps
import logging import logging
import os
import socket import socket
from threading import Event, Lock, Thread from threading import Event, Lock, Thread
import traceback import traceback
@@ -12,7 +13,10 @@ from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack from cassandra.marshal import int32_unpack
import cassandra.io.libevwrapper as libev import cassandra.io.libevwrapper as libev
import cStringIO try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -93,7 +97,7 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs) Connection.__init__(self, *args, **kwargs)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = cStringIO.StringIO() self._iobuf = StringIO()
self._callbacks = {} self._callbacks = {}
self._push_watchers = defaultdict(set) self._push_watchers = defaultdict(set)
@@ -201,11 +205,10 @@ class LibevConnection(Connection):
self._iobuf.write(buf) self._iobuf.write(buf)
while True: while True:
pos = self._iobuf.tell() pos = self._iobuf.tell()
if pos < 8: if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet # we don't have a complete header yet or we
break # already saw a header, but we don't have a
elif self._total_reqd_bytes and self._iobuf.tell() < self._total_reqd_bytes: # complete message yet
# we already saw a header, but we don't have a complete message yet
break break
else: else:
# have enough for header, read body len from header # have enough for header, read body len from header
@@ -214,24 +217,24 @@ class LibevConnection(Connection):
body_len = int32_unpack(body_len_bytes) body_len = int32_unpack(body_len_bytes)
# seek to end to get length of current buffer # seek to end to get length of current buffer
self._iobuf.seek(0, 2) self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell() pos = self._iobuf.tell()
if pos - 8 >= body_len: if pos - 8 >= body_len:
# read message header/body # read message header and body
self._iobuf.seek(0) self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len) msg = self._iobuf.read(8 + body_len)
# leave leftover in current buffer # leave leftover in current buffer
self._iobuf = cStringIO.StringIO() leftover = self._iobuf.read()
self._iobuf.write(self._iobuf.read()) self._iobuf = StringIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0 self._total_reqd_bytes = 0
self.process_msg(msg, body_len) self.process_msg(msg, body_len)
else: else:
self._total_reqd_bytes = body_len + 8 self._total_reqd_bytes = body_len + 8
break
self._iobuf.seek(0, 2) # seek to end
else: else:
log.debug("connection closed by server") log.debug("connection closed by server")
self.close() self.close()