Fix v3 proto header reading w/ gevent, twisted reactors
Fixes PYTHON-87
This commit is contained in:
@@ -12,10 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import # to enable import io from stdlib
|
||||
from collections import defaultdict, deque
|
||||
import errno
|
||||
from functools import wraps, partial
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from threading import Event, RLock
|
||||
import time
|
||||
@@ -29,7 +32,7 @@ import six
|
||||
from six.moves import range
|
||||
|
||||
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack
|
||||
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack, int32_unpack
|
||||
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response,
|
||||
@@ -170,6 +173,7 @@ class Connection(object):
|
||||
user_type_map = None
|
||||
|
||||
is_control_connection = False
|
||||
_iobuf = None
|
||||
|
||||
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
||||
ssl_options=None, sockopts=None, compression=True,
|
||||
@@ -186,6 +190,7 @@ class Connection(object):
|
||||
self.is_control_connection = is_control_connection
|
||||
self.user_type_map = user_type_map
|
||||
self._push_watchers = defaultdict(set)
|
||||
self._iobuf = io.BytesIO()
|
||||
if protocol_version >= 3:
|
||||
self._header_unpack = v3_header_unpack
|
||||
self._header_length = 5
|
||||
@@ -344,6 +349,39 @@ class Connection(object):
|
||||
self.is_control_connection = False
|
||||
self._push_watchers = {}
|
||||
|
||||
def process_io_buffer(self):
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
return
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(self._header_length)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + self._full_header_length:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = io.BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + self._full_header_length
|
||||
return
|
||||
|
||||
@defunct_on_error
|
||||
def process_msg(self, msg, body_len):
|
||||
version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length])
|
||||
|
||||
@@ -11,12 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import # to enable import io from stdlib
|
||||
import atexit
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
@@ -43,7 +40,6 @@ from cassandra import OperationTimedOut
|
||||
from cassandra.connection import (Connection, ConnectionShutdown,
|
||||
ConnectionException, NONBLOCKING)
|
||||
from cassandra.protocol import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -178,7 +174,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
asyncore.dispatcher.__init__(self)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = io.BytesIO()
|
||||
|
||||
self._callbacks = {}
|
||||
self.deque = deque()
|
||||
@@ -226,7 +221,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
self.socket.settimeout(1.0)
|
||||
err = self.socket.connect_ex(address)
|
||||
if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \
|
||||
or err == EINVAL and os.name in ('nt', 'ce'):
|
||||
or err == EINVAL and os.name in ('nt', 'ce'):
|
||||
raise ConnectionException("Timed out connecting to %s" % (address[0]))
|
||||
if err in (0, EISCONN):
|
||||
self.addr = address
|
||||
@@ -308,38 +303,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
return
|
||||
|
||||
if self._iobuf.tell():
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
break
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(self._header_length)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + self._full_header_length:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = io.BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + self._full_header_length
|
||||
break
|
||||
|
||||
self.process_io_buffer()
|
||||
if not self._callbacks and not self.is_control_connection:
|
||||
self._readable = False
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import # to enable import io from stdlib
|
||||
import gevent
|
||||
from gevent import select, socket
|
||||
from gevent.event import Event
|
||||
@@ -20,7 +18,6 @@ from gevent.queue import Queue
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -31,7 +28,6 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import Connection, ConnectionShutdown
|
||||
from cassandra.protocol import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -71,7 +67,6 @@ class GeventConnection(Connection):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = io.BytesIO()
|
||||
self._write_queue = Queue()
|
||||
|
||||
self._callbacks = {}
|
||||
@@ -154,39 +149,9 @@ class GeventConnection(Connection):
|
||||
return # leave the read loop
|
||||
|
||||
if self._iobuf.tell():
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
break
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(4)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + 8:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(8 + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = io.BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
break
|
||||
self.process_io_buffer()
|
||||
else:
|
||||
log.debug("connection closed by server")
|
||||
log.debug("Connection %s closed by server", self)
|
||||
self.close()
|
||||
return
|
||||
|
||||
|
||||
@@ -11,12 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import # to enable import io from stdlib
|
||||
import atexit
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
@@ -28,7 +25,6 @@ from six.moves import xrange
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
||||
from cassandra.protocol import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
try:
|
||||
import cassandra.io.libevwrapper as libev
|
||||
except ImportError:
|
||||
@@ -257,7 +253,6 @@ class LibevConnection(Connection):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = io.BytesIO()
|
||||
|
||||
self._callbacks = {}
|
||||
self.deque = deque()
|
||||
@@ -359,37 +354,7 @@ class LibevConnection(Connection):
|
||||
return
|
||||
|
||||
if self._iobuf.tell():
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
break
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(self._header_length)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + self._full_header_length:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = io.BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + self._full_header_length
|
||||
break
|
||||
self.process_io_buffer()
|
||||
else:
|
||||
log.debug("Connection %s closed by server", self)
|
||||
self.close()
|
||||
|
||||
@@ -21,14 +21,10 @@ from functools import partial
|
||||
import logging
|
||||
import weakref
|
||||
import atexit
|
||||
import os
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.connection import Connection, ConnectionShutdown
|
||||
from cassandra.protocol import RegisterMessage
|
||||
from cassandra.marshal import int32_unpack
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -182,7 +178,6 @@ class TwistedConnection(Connection):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
|
||||
self.connected_event = Event()
|
||||
self._iobuf = BytesIO()
|
||||
self.is_closed = True
|
||||
self.connector = None
|
||||
|
||||
@@ -231,38 +226,7 @@ class TwistedConnection(Connection):
|
||||
"""
|
||||
Process the incoming data buffer.
|
||||
"""
|
||||
while True:
|
||||
pos = self._iobuf.tell()
|
||||
if pos < 8 or (self._total_reqd_bytes > 0 and
|
||||
pos < self._total_reqd_bytes):
|
||||
# we don't have a complete header yet or we
|
||||
# already saw a header, but we don't have a
|
||||
# complete message yet
|
||||
return
|
||||
else:
|
||||
# have enough for header, read body len from header
|
||||
self._iobuf.seek(4)
|
||||
body_len = int32_unpack(self._iobuf.read(4))
|
||||
|
||||
# seek to end to get length of current buffer
|
||||
self._iobuf.seek(0, os.SEEK_END)
|
||||
pos = self._iobuf.tell()
|
||||
|
||||
if pos >= body_len + 8:
|
||||
# read message header and body
|
||||
self._iobuf.seek(0)
|
||||
msg = self._iobuf.read(8 + body_len)
|
||||
|
||||
# leave leftover in current buffer
|
||||
leftover = self._iobuf.read()
|
||||
self._iobuf = BytesIO()
|
||||
self._iobuf.write(leftover)
|
||||
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
return
|
||||
self.process_io_buffer()
|
||||
|
||||
def push(self, data):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user