Merge branch 'pr_20_merge'

This commit is contained in:
Tyler Hobbs
2013-09-17 17:30:28 -05:00
2 changed files with 36 additions and 15 deletions

View File

@@ -1,6 +1,7 @@
from collections import defaultdict, deque
from functools import partial, wraps
import logging
import os
import socket
from threading import Event, Lock, Thread
import traceback
@@ -12,6 +13,10 @@ from cassandra.decoder import RegisterMessage
from cassandra.marshal import int32_unpack
import cassandra.io.libevwrapper as libev
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO # ignore flake8 warning: # NOQA
log = logging.getLogger(__name__)
@@ -74,7 +79,6 @@ class LibevConnection(Connection):
An implementation of :class:`.Connection` that utilizes libev.
"""
_buf = ""
_total_reqd_bytes = 0
_read_watcher = None
_write_watcher = None
@@ -93,6 +97,7 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs)
self.connected_event = Event()
self._iobuf = StringIO()
self._callbacks = {}
self._push_watchers = defaultdict(set)
@@ -197,23 +202,39 @@ class LibevConnection(Connection):
return
if buf:
self._buf += buf
self._iobuf.write(buf)
while True:
if len(self._buf) < 8:
# we don't have a complete header yet
break
elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes:
# we already saw a header, but we don't have a complete message yet
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
body_len = int32_unpack(self._buf[4:8])
if len(self._buf) - 8 >= body_len:
msg = self._buf[:8 + body_len]
self._buf = self._buf[8 + body_len:]
# have enough for header, read body len from header
self._iobuf.seek(4)
body_len_bytes = self._iobuf.read(4)
body_len = int32_unpack(body_len_bytes)
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos - 8 >= body_len:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = StringIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
break
else:
log.debug("connection closed by server")
self.close()

View File

@@ -191,11 +191,11 @@ class LibevConnectionTest(unittest.TestCase):
# read in the first byte
c._socket.recv.return_value = message[0]
c.handle_read(None, None)
self.assertEquals(c._buf, message[0])
self.assertEquals(c._iobuf.getvalue(), message[0])
c._socket.recv.return_value = message[1:]
c.handle_read(None, None)
self.assertEquals("", c._buf)
self.assertEquals("", c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write(None, None)
@@ -217,12 +217,12 @@ class LibevConnectionTest(unittest.TestCase):
# read in the first nine bytes
c._socket.recv.return_value = message[:9]
c.handle_read(None, None)
self.assertEquals(c._buf, message[:9])
self.assertEquals(c._iobuf.getvalue(), message[:9])
# ... then read in the rest
c._socket.recv.return_value = message[9:]
c.handle_read(None, None)
self.assertEquals("", c._buf)
self.assertEquals("", c._iobuf.getvalue())
# let it write out a StartupMessage
c.handle_write(None, None)