Fix v3 proto header reading w/ gevent, twisted reactors
Fixes PYTHON-87
This commit is contained in:
@@ -12,10 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import # to enable import io from stdlib
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
import errno
|
import errno
|
||||||
from functools import wraps, partial
|
from functools import wraps, partial
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from threading import Event, RLock
|
from threading import Event, RLock
|
||||||
import time
|
import time
|
||||||
@@ -29,7 +32,7 @@ import six
|
|||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
||||||
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack
|
from cassandra.marshal import int32_pack, header_unpack, v3_header_unpack, int32_unpack
|
||||||
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||||
QueryMessage, ResultMessage, decode_response,
|
QueryMessage, ResultMessage, decode_response,
|
||||||
@@ -170,6 +173,7 @@ class Connection(object):
|
|||||||
user_type_map = None
|
user_type_map = None
|
||||||
|
|
||||||
is_control_connection = False
|
is_control_connection = False
|
||||||
|
_iobuf = None
|
||||||
|
|
||||||
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
||||||
ssl_options=None, sockopts=None, compression=True,
|
ssl_options=None, sockopts=None, compression=True,
|
||||||
@@ -186,6 +190,7 @@ class Connection(object):
|
|||||||
self.is_control_connection = is_control_connection
|
self.is_control_connection = is_control_connection
|
||||||
self.user_type_map = user_type_map
|
self.user_type_map = user_type_map
|
||||||
self._push_watchers = defaultdict(set)
|
self._push_watchers = defaultdict(set)
|
||||||
|
self._iobuf = io.BytesIO()
|
||||||
if protocol_version >= 3:
|
if protocol_version >= 3:
|
||||||
self._header_unpack = v3_header_unpack
|
self._header_unpack = v3_header_unpack
|
||||||
self._header_length = 5
|
self._header_length = 5
|
||||||
@@ -344,6 +349,39 @@ class Connection(object):
|
|||||||
self.is_control_connection = False
|
self.is_control_connection = False
|
||||||
self._push_watchers = {}
|
self._push_watchers = {}
|
||||||
|
|
||||||
|
def process_io_buffer(self):
|
||||||
|
while True:
|
||||||
|
pos = self._iobuf.tell()
|
||||||
|
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
||||||
|
# we don't have a complete header yet or we
|
||||||
|
# already saw a header, but we don't have a
|
||||||
|
# complete message yet
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# have enough for header, read body len from header
|
||||||
|
self._iobuf.seek(self._header_length)
|
||||||
|
body_len = int32_unpack(self._iobuf.read(4))
|
||||||
|
|
||||||
|
# seek to end to get length of current buffer
|
||||||
|
self._iobuf.seek(0, os.SEEK_END)
|
||||||
|
pos = self._iobuf.tell()
|
||||||
|
|
||||||
|
if pos >= body_len + self._full_header_length:
|
||||||
|
# read message header and body
|
||||||
|
self._iobuf.seek(0)
|
||||||
|
msg = self._iobuf.read(self._full_header_length + body_len)
|
||||||
|
|
||||||
|
# leave leftover in current buffer
|
||||||
|
leftover = self._iobuf.read()
|
||||||
|
self._iobuf = io.BytesIO()
|
||||||
|
self._iobuf.write(leftover)
|
||||||
|
|
||||||
|
self._total_reqd_bytes = 0
|
||||||
|
self.process_msg(msg, body_len)
|
||||||
|
else:
|
||||||
|
self._total_reqd_bytes = body_len + self._full_header_length
|
||||||
|
return
|
||||||
|
|
||||||
@defunct_on_error
|
@defunct_on_error
|
||||||
def process_msg(self, msg, body_len):
|
def process_msg(self, msg, body_len):
|
||||||
version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length])
|
version, flags, stream_id, opcode = self._header_unpack(msg[:self._header_length])
|
||||||
|
|||||||
@@ -11,12 +11,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import # to enable import io from stdlib
|
|
||||||
import atexit
|
import atexit
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
@@ -43,7 +40,6 @@ from cassandra import OperationTimedOut
|
|||||||
from cassandra.connection import (Connection, ConnectionShutdown,
|
from cassandra.connection import (Connection, ConnectionShutdown,
|
||||||
ConnectionException, NONBLOCKING)
|
ConnectionException, NONBLOCKING)
|
||||||
from cassandra.protocol import RegisterMessage
|
from cassandra.protocol import RegisterMessage
|
||||||
from cassandra.marshal import int32_unpack
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -178,7 +174,6 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
asyncore.dispatcher.__init__(self)
|
asyncore.dispatcher.__init__(self)
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
self._iobuf = io.BytesIO()
|
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self.deque = deque()
|
self.deque = deque()
|
||||||
@@ -226,7 +221,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
self.socket.settimeout(1.0)
|
self.socket.settimeout(1.0)
|
||||||
err = self.socket.connect_ex(address)
|
err = self.socket.connect_ex(address)
|
||||||
if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \
|
if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \
|
||||||
or err == EINVAL and os.name in ('nt', 'ce'):
|
or err == EINVAL and os.name in ('nt', 'ce'):
|
||||||
raise ConnectionException("Timed out connecting to %s" % (address[0]))
|
raise ConnectionException("Timed out connecting to %s" % (address[0]))
|
||||||
if err in (0, EISCONN):
|
if err in (0, EISCONN):
|
||||||
self.addr = address
|
self.addr = address
|
||||||
@@ -308,38 +303,7 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self._iobuf.tell():
|
if self._iobuf.tell():
|
||||||
while True:
|
self.process_io_buffer()
|
||||||
pos = self._iobuf.tell()
|
|
||||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
|
||||||
# we don't have a complete header yet or we
|
|
||||||
# already saw a header, but we don't have a
|
|
||||||
# complete message yet
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# have enough for header, read body len from header
|
|
||||||
self._iobuf.seek(self._header_length)
|
|
||||||
body_len = int32_unpack(self._iobuf.read(4))
|
|
||||||
|
|
||||||
# seek to end to get length of current buffer
|
|
||||||
self._iobuf.seek(0, os.SEEK_END)
|
|
||||||
pos = self._iobuf.tell()
|
|
||||||
|
|
||||||
if pos >= body_len + self._full_header_length:
|
|
||||||
# read message header and body
|
|
||||||
self._iobuf.seek(0)
|
|
||||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
|
||||||
|
|
||||||
# leave leftover in current buffer
|
|
||||||
leftover = self._iobuf.read()
|
|
||||||
self._iobuf = io.BytesIO()
|
|
||||||
self._iobuf.write(leftover)
|
|
||||||
|
|
||||||
self._total_reqd_bytes = 0
|
|
||||||
self.process_msg(msg, body_len)
|
|
||||||
else:
|
|
||||||
self._total_reqd_bytes = body_len + self._full_header_length
|
|
||||||
break
|
|
||||||
|
|
||||||
if not self._callbacks and not self.is_control_connection:
|
if not self._callbacks and not self.is_control_connection:
|
||||||
self._readable = False
|
self._readable = False
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import # to enable import io from stdlib
|
|
||||||
import gevent
|
import gevent
|
||||||
from gevent import select, socket
|
from gevent import select, socket
|
||||||
from gevent.event import Event
|
from gevent.event import Event
|
||||||
@@ -20,7 +18,6 @@ from gevent.queue import Queue
|
|||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -31,7 +28,6 @@ from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL
|
|||||||
from cassandra import OperationTimedOut
|
from cassandra import OperationTimedOut
|
||||||
from cassandra.connection import Connection, ConnectionShutdown
|
from cassandra.connection import Connection, ConnectionShutdown
|
||||||
from cassandra.protocol import RegisterMessage
|
from cassandra.protocol import RegisterMessage
|
||||||
from cassandra.marshal import int32_unpack
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -71,7 +67,6 @@ class GeventConnection(Connection):
|
|||||||
Connection.__init__(self, *args, **kwargs)
|
Connection.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
self._iobuf = io.BytesIO()
|
|
||||||
self._write_queue = Queue()
|
self._write_queue = Queue()
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
@@ -154,39 +149,9 @@ class GeventConnection(Connection):
|
|||||||
return # leave the read loop
|
return # leave the read loop
|
||||||
|
|
||||||
if self._iobuf.tell():
|
if self._iobuf.tell():
|
||||||
while True:
|
self.process_io_buffer()
|
||||||
pos = self._iobuf.tell()
|
|
||||||
if pos < 8 or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
|
||||||
# we don't have a complete header yet or we
|
|
||||||
# already saw a header, but we don't have a
|
|
||||||
# complete message yet
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# have enough for header, read body len from header
|
|
||||||
self._iobuf.seek(4)
|
|
||||||
body_len = int32_unpack(self._iobuf.read(4))
|
|
||||||
|
|
||||||
# seek to end to get length of current buffer
|
|
||||||
self._iobuf.seek(0, os.SEEK_END)
|
|
||||||
pos = self._iobuf.tell()
|
|
||||||
|
|
||||||
if pos >= body_len + 8:
|
|
||||||
# read message header and body
|
|
||||||
self._iobuf.seek(0)
|
|
||||||
msg = self._iobuf.read(8 + body_len)
|
|
||||||
|
|
||||||
# leave leftover in current buffer
|
|
||||||
leftover = self._iobuf.read()
|
|
||||||
self._iobuf = io.BytesIO()
|
|
||||||
self._iobuf.write(leftover)
|
|
||||||
|
|
||||||
self._total_reqd_bytes = 0
|
|
||||||
self.process_msg(msg, body_len)
|
|
||||||
else:
|
|
||||||
self._total_reqd_bytes = body_len + 8
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
log.debug("connection closed by server")
|
log.debug("Connection %s closed by server", self)
|
||||||
self.close()
|
self.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -11,12 +11,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import # to enable import io from stdlib
|
|
||||||
import atexit
|
import atexit
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
@@ -28,7 +25,6 @@ from six.moves import xrange
|
|||||||
from cassandra import OperationTimedOut
|
from cassandra import OperationTimedOut
|
||||||
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING
|
||||||
from cassandra.protocol import RegisterMessage
|
from cassandra.protocol import RegisterMessage
|
||||||
from cassandra.marshal import int32_unpack
|
|
||||||
try:
|
try:
|
||||||
import cassandra.io.libevwrapper as libev
|
import cassandra.io.libevwrapper as libev
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -257,7 +253,6 @@ class LibevConnection(Connection):
|
|||||||
Connection.__init__(self, *args, **kwargs)
|
Connection.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
self._iobuf = io.BytesIO()
|
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self.deque = deque()
|
self.deque = deque()
|
||||||
@@ -359,37 +354,7 @@ class LibevConnection(Connection):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self._iobuf.tell():
|
if self._iobuf.tell():
|
||||||
while True:
|
self.process_io_buffer()
|
||||||
pos = self._iobuf.tell()
|
|
||||||
if pos < self._full_header_length or (self._total_reqd_bytes > 0 and pos < self._total_reqd_bytes):
|
|
||||||
# we don't have a complete header yet or we
|
|
||||||
# already saw a header, but we don't have a
|
|
||||||
# complete message yet
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# have enough for header, read body len from header
|
|
||||||
self._iobuf.seek(self._header_length)
|
|
||||||
body_len = int32_unpack(self._iobuf.read(4))
|
|
||||||
|
|
||||||
# seek to end to get length of current buffer
|
|
||||||
self._iobuf.seek(0, os.SEEK_END)
|
|
||||||
pos = self._iobuf.tell()
|
|
||||||
|
|
||||||
if pos >= body_len + self._full_header_length:
|
|
||||||
# read message header and body
|
|
||||||
self._iobuf.seek(0)
|
|
||||||
msg = self._iobuf.read(self._full_header_length + body_len)
|
|
||||||
|
|
||||||
# leave leftover in current buffer
|
|
||||||
leftover = self._iobuf.read()
|
|
||||||
self._iobuf = io.BytesIO()
|
|
||||||
self._iobuf.write(leftover)
|
|
||||||
|
|
||||||
self._total_reqd_bytes = 0
|
|
||||||
self.process_msg(msg, body_len)
|
|
||||||
else:
|
|
||||||
self._total_reqd_bytes = body_len + self._full_header_length
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
log.debug("Connection %s closed by server", self)
|
log.debug("Connection %s closed by server", self)
|
||||||
self.close()
|
self.close()
|
||||||
|
|||||||
@@ -21,14 +21,10 @@ from functools import partial
|
|||||||
import logging
|
import logging
|
||||||
import weakref
|
import weakref
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
|
||||||
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from cassandra import OperationTimedOut
|
from cassandra import OperationTimedOut
|
||||||
from cassandra.connection import Connection, ConnectionShutdown
|
from cassandra.connection import Connection, ConnectionShutdown
|
||||||
from cassandra.protocol import RegisterMessage
|
from cassandra.protocol import RegisterMessage
|
||||||
from cassandra.marshal import int32_unpack
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -182,7 +178,6 @@ class TwistedConnection(Connection):
|
|||||||
Connection.__init__(self, *args, **kwargs)
|
Connection.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
self._iobuf = BytesIO()
|
|
||||||
self.is_closed = True
|
self.is_closed = True
|
||||||
self.connector = None
|
self.connector = None
|
||||||
|
|
||||||
@@ -231,38 +226,7 @@ class TwistedConnection(Connection):
|
|||||||
"""
|
"""
|
||||||
Process the incoming data buffer.
|
Process the incoming data buffer.
|
||||||
"""
|
"""
|
||||||
while True:
|
self.process_io_buffer()
|
||||||
pos = self._iobuf.tell()
|
|
||||||
if pos < 8 or (self._total_reqd_bytes > 0 and
|
|
||||||
pos < self._total_reqd_bytes):
|
|
||||||
# we don't have a complete header yet or we
|
|
||||||
# already saw a header, but we don't have a
|
|
||||||
# complete message yet
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# have enough for header, read body len from header
|
|
||||||
self._iobuf.seek(4)
|
|
||||||
body_len = int32_unpack(self._iobuf.read(4))
|
|
||||||
|
|
||||||
# seek to end to get length of current buffer
|
|
||||||
self._iobuf.seek(0, os.SEEK_END)
|
|
||||||
pos = self._iobuf.tell()
|
|
||||||
|
|
||||||
if pos >= body_len + 8:
|
|
||||||
# read message header and body
|
|
||||||
self._iobuf.seek(0)
|
|
||||||
msg = self._iobuf.read(8 + body_len)
|
|
||||||
|
|
||||||
# leave leftover in current buffer
|
|
||||||
leftover = self._iobuf.read()
|
|
||||||
self._iobuf = BytesIO()
|
|
||||||
self._iobuf.write(leftover)
|
|
||||||
|
|
||||||
self._total_reqd_bytes = 0
|
|
||||||
self.process_msg(msg, body_len)
|
|
||||||
else:
|
|
||||||
self._total_reqd_bytes = body_len + 8
|
|
||||||
return
|
|
||||||
|
|
||||||
def push(self, data):
|
def push(self, data):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user