Fix v3 proto header reading w/ gevent, twisted reactors

Fixes PYTHON-87
This commit is contained in:
Tyler Hobbs
2014-07-08 14:03:11 -05:00
parent 1fca5e0a02
commit 60ded354a9
5 changed files with 45 additions and 149 deletions

View File

@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
from collections import defaultdict, deque from collections import defaultdict, deque
import errno import errno
from functools import wraps, partial from functools import wraps, partial
import io
import logging import logging
import os
import sys import sys
from threading import Event, RLock from threading import Event, RLock
import time import time
@@ -29,7 +32,7 @@ import six
from six.moves import range from six.moves import range
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack, int32_unpack
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
StartupMessage, ErrorMessage, CredentialsMessage, StartupMessage, ErrorMessage, CredentialsMessage,
QueryMessage, ResultMessage, decode_response, QueryMessage, ResultMessage, decode_response,
@@ -170,6 +173,7 @@ class Connection(object):
user_type_map = None user_type_map = None
is_control_connection = False is_control_connection = False
_iobuf = None
def __init__(self, host='127.0.0.1', port=9042, authenticator=None, def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True, ssl_options=None, sockopts=None, compression=True,
@@ -186,6 +190,7 @@ class Connection(object):
self.is_control_connection = is_control_connection self.is_control_connection = is_control_connection
self.user_type_map = user_type_map self.user_type_map = user_type_map
self._push_watchers = defaultdict(set) self._push_watchers = defaultdict(set)
self._iobuf = io.BytesIO()
if protocol_version >= 3: if protocol_version >= 3:
self._header_unpack = v3_header_unpack self._header_unpack = v3_header_unpack
self._header_length = 5 self._header_length = 5
@@ -344,6 +349,39 @@ class Connection(object):
self.is_control_connection = False self.is_control_connection = False
self._push_watchers = {} self._push_watchers = {}
def process_io_buffer(self):
while True:
pos = self._iobuf.tell()
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
return
else:
# have enough for header, read body len from header
self._iobuf.seek(self._header_length)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + self._full_header_length:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(self._full_header_length + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = io.BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + self._full_header_length
return
@defunct_on_error @defunct_on_error
def process_msg(self, msg, body_len): def process_msg(self, msg, body_len):
version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length]) version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length])

View File

@@ -11,12 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
import atexit import atexit
from collections import deque from collections import deque
from functools import partial from functools import partial
import io
import logging import logging
import os import os
import socket import socket
@@ -43,7 +40,6 @@ from cassandra import OperationTimedOut
from cassandra.connection import (Connection, ConnectionShutdown, from cassandra.connection import (Connection, ConnectionShutdown,
ConnectionException, NONBLOCKING) ConnectionException, NONBLOCKING)
from cassandra.protocol import RegisterMessage from cassandra.protocol import RegisterMessage
from cassandra.marshal import int32_unpack
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -178,7 +174,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
asyncore.dispatcher.__init__(self) asyncore.dispatcher.__init__(self)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = io.BytesIO()
self._callbacks = {} self._callbacks = {}
self.deque = deque() self.deque = deque()
@@ -226,7 +221,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
self.socket.settimeout(1.0) self.socket.settimeout(1.0)
err = self.socket.connect_ex(address) err = self.socket.connect_ex(address)
if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \ if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \
or err == EINVAL and os.name in ('nt', 'ce'): or err == EINVAL and os.name in ('nt', 'ce'):
raise ConnectionException("Timed out connecting to %s" % (address[0])) raise ConnectionException("Timed out connecting to %s" % (address[0]))
if err in (0, EISCONN): if err in (0, EISCONN):
self.addr = address self.addr = address
@@ -308,38 +303,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
return return
if self._iobuf.tell(): if self._iobuf.tell():
while True: self.process_io_buffer()
pos = self._iobuf.tell()
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(self._header_length)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + self._full_header_length:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(self._full_header_length + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = io.BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + self._full_header_length
break
if not self._callbacks and not self.is_control_connection: if not self._callbacks and not self.is_control_connection:
self._readable = False self._readable = False

View File

@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
import gevent import gevent
from gevent import select, socket from gevent import select, socket
from gevent.event import Event from gevent.event import Event
@@ -20,7 +18,6 @@ from gevent.queue import Queue
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
import io
import logging import logging
import os import os
@@ -31,7 +28,6 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage from cassandra.protocol import RegisterMessage
from cassandra.marshal import int32_unpack
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -71,7 +67,6 @@ class GeventConnection(Connection):
Connection.__init__(self, *args, **kwargs) Connection.__init__(self, *args, **kwargs)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = io.BytesIO()
self._write_queue = Queue() self._write_queue = Queue()
self._callbacks = {} self._callbacks = {}
@@ -154,39 +149,9 @@ class GeventConnection(Connection):
return # leave the read loop return # leave the read loop
if self._iobuf.tell(): if self._iobuf.tell():
while True: self.process_io_buffer()
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(4)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + 8:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = io.BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
break
else: else:
log.debug("connection closed by server") log.debug("Connection %s closed by server", self)
self.close() self.close()
return return

View File

@@ -11,12 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import # to enable import io from stdlib
import atexit import atexit
from collections import deque from collections import deque
from functools import partial from functools import partial
import io
import logging import logging
import os import os
import socket import socket
@@ -28,7 +25,6 @@ from six.moves import xrange
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
from cassandra.protocol import RegisterMessage from cassandra.protocol import RegisterMessage
from cassandra.marshal import int32_unpack
try: try:
import cassandra.io.libevwrapper as libev import cassandra.io.libevwrapper as libev
except ImportError: except ImportError:
@@ -257,7 +253,6 @@ class LibevConnection(Connection):
Connection.__init__(self, *args, **kwargs) Connection.__init__(self, *args, **kwargs)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = io.BytesIO()
self._callbacks = {} self._callbacks = {}
self.deque = deque() self.deque = deque()
@@ -359,37 +354,7 @@ class LibevConnection(Connection):
return return
if self._iobuf.tell(): if self._iobuf.tell():
while True: self.process_io_buffer()
pos = self._iobuf.tell()
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
break
else:
# have enough for header, read body len from header
self._iobuf.seek(self._header_length)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + self._full_header_length:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(self._full_header_length + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = io.BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + self._full_header_length
break
else: else:
log.debug("Connection %s closed by server", self) log.debug("Connection %s closed by server", self)
self.close() self.close()

View File

@@ -21,14 +21,10 @@ from functools import partial
import logging import logging
import weakref import weakref
import atexit import atexit
import os
from io import BytesIO
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.connection import Connection, ConnectionShutdown from cassandra.connection import Connection, ConnectionShutdown
from cassandra.protocol import RegisterMessage from cassandra.protocol import RegisterMessage
from cassandra.marshal import int32_unpack
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -182,7 +178,6 @@ class TwistedConnection(Connection):
Connection.__init__(self, *args, **kwargs) Connection.__init__(self, *args, **kwargs)
self.connected_event = Event() self.connected_event = Event()
self._iobuf = BytesIO()
self.is_closed = True self.is_closed = True
self.connector = None self.connector = None
@@ -231,38 +226,7 @@ class TwistedConnection(Connection):
""" """
Process the incoming data buffer. Process the incoming data buffer.
""" """
while True: self.process_io_buffer()
pos = self._iobuf.tell()
if pos < 8 or (self._total_reqd_bytes > 0 and
pos < self._total_reqd_bytes):
# we don't have a complete header yet or we
# already saw a header, but we don't have a
# complete message yet
return
else:
# have enough for header, read body len from header
self._iobuf.seek(4)
body_len = int32_unpack(self._iobuf.read(4))
# seek to end to get length of current buffer
self._iobuf.seek(0, os.SEEK_END)
pos = self._iobuf.tell()
if pos >= body_len + 8:
# read message header and body
self._iobuf.seek(0)
msg = self._iobuf.read(8 + body_len)
# leave leftover in current buffer
leftover = self._iobuf.read()
self._iobuf = BytesIO()
self._iobuf.write(leftover)
self._total_reqd_bytes = 0
self.process_msg(msg, body_len)
else:
self._total_reqd_bytes = body_len + 8
return
def push(self, data): def push(self, data):
""" """