From 60ded354a9d0dbf39eeb8a549f1f05adcf97be7a Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Tue, 8 Jul 2014 14:03:11 -0500 Subject: [PATCH] Fix v3 proto header reading w/ gevent, twisted reactors Fixes PYTHON-87 --- cassandra/connection.py | 40 ++++++++++++++++++++++++++++++++- cassandra/io/asyncorereactor.py | 40 ++------------------------------- cassandra/io/geventreactor.py | 39 ++------------------------------ cassandra/io/libevreactor.py | 37 +----------------------------- cassandra/io/twistedreactor.py | 38 +------------------------------ 5 files changed, 45 insertions(+), 149 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index ead99430..6aa47481 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque import errno from functools import wraps, partial +import io import logging +import os import sys from threading import Event, RLock import time @@ -29,7 +32,7 @@ import six from six.moves import range from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut -from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack +from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack, int32_unpack from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, StartupMessage, ErrorMessage, CredentialsMessage, QueryMessage, ResultMessage, decode_response, @@ -170,6 +173,7 @@ class Connection(object): user_type_map = None is_control_connection = False + _iobuf = None def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, @@ -186,6 +190,7 @@ class Connection(object): self.is_control_connection = is_control_connection self.user_type_map = user_type_map self._push_watchers = defaultdict(set) + self._iobuf = io.BytesIO() if protocol_version >= 3: self._header_unpack = v3_header_unpack self._header_length = 5 @@ -344,6 +349,39 @@ class Connection(object): self.is_control_connection = False self._push_watchers = {} + def process_io_buffer(self): + while True: + pos = self._iobuf.tell() + if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): + # we don't have a complete header yet or we + # already saw a header, but we don't have a + # complete message yet + return + else: + # have enough for header, read body len from header + self._iobuf.seek(self._header_length) + body_len = int32_unpack(self._iobuf.read(4)) + + # seek to end to get length of current buffer + self._iobuf.seek(0, os.SEEK_END) + pos = self._iobuf.tell() + + if pos >= body_len + self._full_header_length: + # read message header and body + self._iobuf.seek(0) + msg = self._iobuf.read(self._full_header_length + body_len) + + # leave leftover in current buffer + leftover = self._iobuf.read() + self._iobuf = io.BytesIO() + self._iobuf.write(leftover) + + self._total_reqd_bytes = 0 + self.process_msg(msg, body_len) + else: + self._total_reqd_bytes = body_len + self._full_header_length + return + @defunct_on_error def process_msg(self, msg, body_len): version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length]) diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 7f7e5226..c5dd44fe 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import # to enable import io from stdlib import atexit from collections import deque from functools import partial -import io import logging import os import socket @@ -43,7 +40,6 @@ from cassandra import OperationTimedOut from cassandra.connection import (Connection, ConnectionShutdown, ConnectionException, NONBLOCKING) from cassandra.protocol import RegisterMessage -from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) @@ -178,7 +174,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): asyncore.dispatcher.__init__(self) self.connected_event = Event() - self._iobuf = io.BytesIO() self._callbacks = {} self.deque = deque() @@ -226,7 +221,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): self.socket.settimeout(1.0) err = self.socket.connect_ex(address) if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \ - or err == EINVAL and os.name in ('nt', 'ce'): + or err == EINVAL and os.name in ('nt', 'ce'): raise ConnectionException("Timed out connecting to %s" % (address[0])) if err in (0, EISCONN): self.addr = address @@ -308,38 +303,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): return if self._iobuf.tell(): - while True: - pos = self._iobuf.tell() - if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): - # we don't have a complete header yet or we - # already saw a header, but we don't have a - # complete message yet - break - else: - # have enough for header, read body len from header - self._iobuf.seek(self._header_length) - body_len = int32_unpack(self._iobuf.read(4)) - - # seek to end to get length of current buffer - self._iobuf.seek(0, os.SEEK_END) - pos = self._iobuf.tell() - - if pos >= body_len + self._full_header_length: - # read message header and body - self._iobuf.seek(0) - msg = self._iobuf.read(self._full_header_length + body_len) - - # leave leftover in current buffer - leftover = self._iobuf.read() - self._iobuf = io.BytesIO() - self._iobuf.write(leftover) - - self._total_reqd_bytes = 0 - self.process_msg(msg, body_len) - else: - self._total_reqd_bytes = body_len + self._full_header_length - break - + self.process_io_buffer() if not self._callbacks and not self.is_control_connection: self._readable = False diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 93ebab82..0caaf014 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import # to enable import io from stdlib import gevent from gevent import select, socket from gevent.event import Event @@ -20,7 +18,6 @@ from gevent.queue import Queue from collections import defaultdict from functools import partial -import io import logging import os @@ -31,7 +28,6 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL from cassandra import OperationTimedOut from cassandra.connection import Connection, ConnectionShutdown from cassandra.protocol import RegisterMessage -from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) @@ -71,7 +67,6 @@ class GeventConnection(Connection): Connection.__init__(self, *args, **kwargs) self.connected_event = Event() - self._iobuf = io.BytesIO() self._write_queue = Queue() self._callbacks = {} @@ -154,39 +149,9 @@ class GeventConnection(Connection): return # leave the read loop if self._iobuf.tell(): - while True: - pos = self._iobuf.tell() - if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): - # we don't have a complete header yet or we - # already saw a header, but we don't have a - # complete message yet - break - else: - # have enough for header, read body len from header - self._iobuf.seek(4) - body_len = int32_unpack(self._iobuf.read(4)) - - # seek to end to get length of current buffer - self._iobuf.seek(0, os.SEEK_END) - pos = self._iobuf.tell() - - if pos >= body_len + 8: - # read message header and body - self._iobuf.seek(0) - msg = self._iobuf.read(8 + body_len) - - # leave leftover in current buffer - leftover = self._iobuf.read() - self._iobuf = io.BytesIO() - self._iobuf.write(leftover) - - self._total_reqd_bytes = 0 - self.process_msg(msg, body_len) - else: - self._total_reqd_bytes = body_len + 8 - break + self.process_io_buffer() else: - log.debug("connection closed by server") + log.debug("Connection %s closed by server", self) self.close() return diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index eb95f5bc..5adf008f 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import # to enable import io from stdlib import atexit from collections import deque from functools import partial -import io import logging import os import socket @@ -28,7 +25,6 @@ from six.moves import xrange from cassandra import OperationTimedOut from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING from cassandra.protocol import RegisterMessage -from cassandra.marshal import int32_unpack try: import cassandra.io.libevwrapper as libev except ImportError: @@ -257,7 +253,6 @@ class LibevConnection(Connection): Connection.__init__(self, *args, **kwargs) self.connected_event = Event() - self._iobuf = io.BytesIO() self._callbacks = {} self.deque = deque() @@ -359,37 +354,7 @@ class LibevConnection(Connection): return if self._iobuf.tell(): - while True: - pos = self._iobuf.tell() - if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes): - # we don't have a complete header yet or we - # already saw a header, but we don't have a - # complete message yet - break - else: - # have enough for header, read body len from header - self._iobuf.seek(self._header_length) - body_len = int32_unpack(self._iobuf.read(4)) - - # seek to end to get length of current buffer - self._iobuf.seek(0, os.SEEK_END) - pos = self._iobuf.tell() - - if pos >= body_len + self._full_header_length: - # read message header and body - self._iobuf.seek(0) - msg = self._iobuf.read(self._full_header_length + body_len) - - # leave leftover in current buffer - leftover = self._iobuf.read() - self._iobuf = io.BytesIO() - self._iobuf.write(leftover) - - self._total_reqd_bytes = 0 - self.process_msg(msg, body_len) - else: - self._total_reqd_bytes = body_len + self._full_header_length - break + self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 9ca401dd..ac7f97e4 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -21,14 +21,10 @@ from functools import partial import logging import weakref import atexit -import os - -from io import BytesIO from cassandra import OperationTimedOut from cassandra.connection import Connection, ConnectionShutdown from cassandra.protocol import RegisterMessage -from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) @@ -182,7 +178,6 @@ class TwistedConnection(Connection): Connection.__init__(self, *args, **kwargs) self.connected_event = Event() - self._iobuf = BytesIO() self.is_closed = True self.connector = None @@ -231,38 +226,7 @@ class TwistedConnection(Connection): """ Process the incoming data buffer. """ - while True: - pos = self._iobuf.tell() - if pos < 8 or (self._total_reqd_bytes > 0 and - pos < self._total_reqd_bytes): - # we don't have a complete header yet or we - # already saw a header, but we don't have a - # complete message yet - return - else: - # have enough for header, read body len from header - self._iobuf.seek(4) - body_len = int32_unpack(self._iobuf.read(4)) - - # seek to end to get length of current buffer - self._iobuf.seek(0, os.SEEK_END) - pos = self._iobuf.tell() - - if pos >= body_len + 8: - # read message header and body - self._iobuf.seek(0) - msg = self._iobuf.read(8 + body_len) - - # leave leftover in current buffer - leftover = self._iobuf.read() - self._iobuf = BytesIO() - self._iobuf.write(leftover) - - self._total_reqd_bytes = 0 - self.process_msg(msg, body_len) - else: - self._total_reqd_bytes = body_len + 8 - return + self.process_io_buffer() def push(self, data): """