Make PyevConnection a proper subclass of Connection
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
import errno
|
import errno
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
import logging
|
import logging
|
||||||
from threading import Event
|
from threading import Event, Lock, RLock
|
||||||
|
from Queue import Queue
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel
|
from cassandra import ConsistencyLevel
|
||||||
from cassandra.marshal import int8_unpack
|
from cassandra.marshal import int8_unpack
|
||||||
@@ -87,10 +88,6 @@ class Connection(object):
|
|||||||
is_closed = False
|
is_closed = False
|
||||||
lock = None
|
lock = None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def factory(cls, *args, **kwargs):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None):
|
def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
@@ -99,6 +96,13 @@ class Connection(object):
|
|||||||
self.compression = compression
|
self.compression = compression
|
||||||
self.cql_version = cql_version
|
self.cql_version = cql_version
|
||||||
|
|
||||||
|
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
||||||
|
for i in range(MAX_STREAM_PER_CONNECTION):
|
||||||
|
self._id_queue.put_nowait(i)
|
||||||
|
|
||||||
|
self.lock = RLock()
|
||||||
|
self.id_lock = Lock()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@@ -3,14 +3,14 @@ from functools import partial
|
|||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
from threading import RLock, Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
import traceback
|
import traceback
|
||||||
from Queue import Queue
|
from Queue import Queue
|
||||||
|
|
||||||
import asyncore
|
import asyncore
|
||||||
|
|
||||||
from cassandra.connection import (Connection, ResponseWaiter, ConnectionException,
|
from cassandra.connection import (Connection, ResponseWaiter, ConnectionException,
|
||||||
ConnectionBusy, MAX_STREAM_PER_CONNECTION, NONBLOCKING)
|
ConnectionBusy, NONBLOCKING)
|
||||||
from cassandra.marshal import int32_unpack
|
from cassandra.marshal import int32_unpack
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -64,20 +64,12 @@ class AsyncoreConnection(Connection, asyncore.dispatcher):
|
|||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
|
|
||||||
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
|
||||||
for i in range(MAX_STREAM_PER_CONNECTION):
|
|
||||||
self._id_queue.put_nowait(i)
|
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self._push_watchers = defaultdict(set)
|
self._push_watchers = defaultdict(set)
|
||||||
self.lock = RLock()
|
|
||||||
self.id_lock = Lock()
|
|
||||||
|
|
||||||
self.deque = deque()
|
self.deque = deque()
|
||||||
|
|
||||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.connect((self.host, self.port))
|
self.connect((self.host, self.port))
|
||||||
# self.setblocking(0)
|
|
||||||
|
|
||||||
if self.sockopts:
|
if self.sockopts:
|
||||||
for args in self.sockopts:
|
for args in self.sockopts:
|
||||||
|
@@ -1,66 +1,19 @@
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
import errno
|
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from threading import RLock, Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
import traceback
|
import traceback
|
||||||
from Queue import Queue
|
from Queue import Queue
|
||||||
|
|
||||||
import pyev
|
import pyev
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel
|
from cassandra.connection import (Connection, ResponseWaiter, ConnectionException,
|
||||||
from cassandra.marshal import (int8_unpack, int32_unpack)
|
ConnectionBusy, NONBLOCKING)
|
||||||
from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage,
|
from cassandra.marshal import int32_unpack
|
||||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
|
||||||
QueryMessage, ResultMessage, decode_response)
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
locally_supported_compressions = {}
|
|
||||||
try:
|
|
||||||
import snappy
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# work around apparently buggy snappy decompress
|
|
||||||
def decompress(byts):
|
|
||||||
if byts == '\x00':
|
|
||||||
return ''
|
|
||||||
return snappy.decompress(byts)
|
|
||||||
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
|
|
||||||
|
|
||||||
|
|
||||||
MAX_STREAM_PER_CONNECTION = 128
|
|
||||||
|
|
||||||
PROTOCOL_VERSION = 0x01
|
|
||||||
PROTOCOL_VERSION_MASK = 0x7f
|
|
||||||
|
|
||||||
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
|
||||||
HEADER_DIRECTION_TO_CLIENT = 0x80
|
|
||||||
HEADER_DIRECTION_MASK = 0x80
|
|
||||||
|
|
||||||
NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK)
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionException(Exception):
|
|
||||||
|
|
||||||
def __init__(self, message, host=None):
|
|
||||||
Exception.__init__(self, message)
|
|
||||||
self.host = host
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionBusy(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProgrammingError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
_loop = pyev.default_loop(pyev.EVBACKEND_SELECT)
|
_loop = pyev.default_loop(pyev.EVBACKEND_SELECT)
|
||||||
|
|
||||||
@@ -117,23 +70,7 @@ def defunct_on_error(f):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class Connection(object):
|
class PyevConnection(Connection):
|
||||||
|
|
||||||
in_buffer_size = 4096
|
|
||||||
out_buffer_size = 4096
|
|
||||||
|
|
||||||
cql_version = None
|
|
||||||
|
|
||||||
keyspace = None
|
|
||||||
compression = True
|
|
||||||
compressor = None
|
|
||||||
decompressor = None
|
|
||||||
|
|
||||||
last_error = None
|
|
||||||
in_flight = 0
|
|
||||||
is_defunct = False
|
|
||||||
is_closed = False
|
|
||||||
lock = None
|
|
||||||
|
|
||||||
_buf = ""
|
_buf = ""
|
||||||
_total_reqd_bytes = 0
|
_total_reqd_bytes = 0
|
||||||
@@ -150,32 +87,21 @@ class Connection(object):
|
|||||||
else:
|
else:
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None):
|
def __init__(self, *args, **kwargs):
|
||||||
self.host = host
|
Connection.__init__(self, *args, **kwargs)
|
||||||
self.port = port
|
|
||||||
self.credentials = credentials
|
|
||||||
self.compression = compression
|
|
||||||
self.cql_version = cql_version
|
|
||||||
|
|
||||||
self.connected_event = Event()
|
self.connected_event = Event()
|
||||||
|
|
||||||
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
|
||||||
for i in range(MAX_STREAM_PER_CONNECTION):
|
|
||||||
self._id_queue.put_nowait(i)
|
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self._push_watchers = defaultdict(set)
|
self._push_watchers = defaultdict(set)
|
||||||
self.lock = RLock()
|
|
||||||
self.id_lock = Lock()
|
|
||||||
|
|
||||||
self.deque = deque()
|
self.deque = deque()
|
||||||
|
|
||||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self._socket.connect((host, port))
|
self._socket.connect((self.host, self.port))
|
||||||
self._socket.setblocking(0)
|
self._socket.setblocking(0)
|
||||||
|
|
||||||
if sockopts:
|
if self.sockopts:
|
||||||
for args in sockopts:
|
for args in self.sockopts:
|
||||||
self._socket.setsockopt(*args)
|
self._socket.setsockopt(*args)
|
||||||
|
|
||||||
self._read_watcher = pyev.Io(self._socket._sock, pyev.EV_READ, _loop, self.handle_read)
|
self._read_watcher = pyev.Io(self._socket._sock, pyev.EV_READ, _loop, self.handle_read)
|
||||||
@@ -184,8 +110,7 @@ class Connection(object):
|
|||||||
self._read_watcher.start()
|
self._read_watcher.start()
|
||||||
self._write_watcher.start()
|
self._write_watcher.start()
|
||||||
|
|
||||||
log.debug("Sending initial options message for new Connection to %s" % (host,))
|
self._send_options_message()
|
||||||
self.send_msg(OptionsMessage(), self._handle_options_response)
|
|
||||||
|
|
||||||
# start the global event loop if needed
|
# start the global event loop if needed
|
||||||
if not _start_loop():
|
if not _start_loop():
|
||||||
@@ -281,51 +206,6 @@ class Connection(object):
|
|||||||
log.debug("connection closed by server")
|
log.debug("connection closed by server")
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
@defunct_on_error
|
|
||||||
def process_msg(self, msg, body_len):
|
|
||||||
version, flags, stream_id, opcode = map(int8_unpack, msg[:4])
|
|
||||||
if stream_id < 0:
|
|
||||||
callback = None
|
|
||||||
else:
|
|
||||||
callback = self._callbacks.pop(stream_id)
|
|
||||||
self._id_queue.put_nowait(stream_id)
|
|
||||||
|
|
||||||
body = None
|
|
||||||
try:
|
|
||||||
# check that the protocol version is supported
|
|
||||||
given_version = version & PROTOCOL_VERSION_MASK
|
|
||||||
if given_version != PROTOCOL_VERSION:
|
|
||||||
raise ProtocolError("Unsupported CQL protocol version: %d" % given_version)
|
|
||||||
|
|
||||||
# check that the header direction is correct
|
|
||||||
if version & HEADER_DIRECTION_MASK != HEADER_DIRECTION_TO_CLIENT:
|
|
||||||
raise ProtocolError(
|
|
||||||
"Header direction in response is incorrect; opcode %04x, stream id %r"
|
|
||||||
% (opcode, stream_id))
|
|
||||||
|
|
||||||
if body_len > 0:
|
|
||||||
body = msg[8:]
|
|
||||||
elif body_len == 0:
|
|
||||||
body = ""
|
|
||||||
else:
|
|
||||||
raise ProtocolError("Got negative body length: %r" % body_len)
|
|
||||||
|
|
||||||
response = decode_response(stream_id, flags, opcode, body, self.decompressor)
|
|
||||||
except Exception, exc:
|
|
||||||
log.exception("Error decoding response from Cassandra. "
|
|
||||||
"opcode: %04x; message contents: %r" % (opcode, body))
|
|
||||||
callback(exc)
|
|
||||||
self.defunct(exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if stream_id < 0:
|
|
||||||
self.handle_pushed(response)
|
|
||||||
elif callback is not None:
|
|
||||||
callback(response)
|
|
||||||
except:
|
|
||||||
log.exception("Callback handler errored, ignoring:")
|
|
||||||
|
|
||||||
def handle_pushed(self, response):
|
def handle_pushed(self, response):
|
||||||
for cb in self._push_watchers[response.type]:
|
for cb in self._push_watchers[response.type]:
|
||||||
try:
|
try:
|
||||||
@@ -382,114 +262,3 @@ class Connection(object):
|
|||||||
def register_watchers(self, type_callback_dict):
|
def register_watchers(self, type_callback_dict):
|
||||||
for event_type, callback in type_callback_dict.items():
|
for event_type, callback in type_callback_dict.items():
|
||||||
self.register_watcher(event_type, callback)
|
self.register_watcher(event_type, callback)
|
||||||
|
|
||||||
@defunct_on_error
|
|
||||||
def _handle_options_response(self, options_response):
|
|
||||||
if self.is_defunct:
|
|
||||||
return
|
|
||||||
log.debug("Received options response on new Connection from %s" % self.host)
|
|
||||||
self.supported_cql_versions = options_response.cql_versions
|
|
||||||
self.remote_supported_compressions = options_response.options['COMPRESSION']
|
|
||||||
|
|
||||||
if self.cql_version:
|
|
||||||
if self.cql_version not in self.supported_cql_versions:
|
|
||||||
raise ProtocolError(
|
|
||||||
"cql_version %r is not supported by remote (w/ native "
|
|
||||||
"protocol). Supported versions: %r"
|
|
||||||
% (self.cql_version, self.supported_cql_versions))
|
|
||||||
else:
|
|
||||||
self.cql_version = self.supported_cql_versions[0]
|
|
||||||
|
|
||||||
opts = {}
|
|
||||||
self._compressor = None
|
|
||||||
if self.compression:
|
|
||||||
overlap = (set(locally_supported_compressions.keys()) &
|
|
||||||
set(self.remote_supported_compressions))
|
|
||||||
if len(overlap) == 0:
|
|
||||||
log.debug("No available compression types supported on both ends."
|
|
||||||
" locally supported: %r. remotely supported: %r"
|
|
||||||
% (locally_supported_compressions.keys(),
|
|
||||||
self.remote_supported_compressions))
|
|
||||||
else:
|
|
||||||
compression_type = iter(overlap).next() # choose any
|
|
||||||
opts['COMPRESSION'] = compression_type
|
|
||||||
# set the decompressor here, but set the compressor only after
|
|
||||||
# a successful Ready message
|
|
||||||
self._compressor, self.decompressor = \
|
|
||||||
locally_supported_compressions[compression_type]
|
|
||||||
|
|
||||||
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
|
||||||
self.send_msg(sm, cb=self._handle_startup_response)
|
|
||||||
|
|
||||||
@defunct_on_error
|
|
||||||
def _handle_startup_response(self, startup_response):
|
|
||||||
if self.is_defunct:
|
|
||||||
return
|
|
||||||
if isinstance(startup_response, ReadyMessage):
|
|
||||||
log.debug("Got ReadyMessage on new Connection from %s" % self.host)
|
|
||||||
if self._compressor:
|
|
||||||
self.compressor = self._compressor
|
|
||||||
self.connected_event.set()
|
|
||||||
elif isinstance(startup_response, AuthenticateMessage):
|
|
||||||
log.debug("Got AuthenticateMessage on new Connection from %s" % self.host)
|
|
||||||
|
|
||||||
if self.credentials is None:
|
|
||||||
raise ProgrammingError('Remote end requires authentication.')
|
|
||||||
|
|
||||||
self.authenticator = startup_response.authenticator
|
|
||||||
cm = CredentialsMessage(creds=self.credentials)
|
|
||||||
self.send_msg(cm, cb=self._handle_startup_response)
|
|
||||||
elif isinstance(startup_response, ErrorMessage):
|
|
||||||
log.debug("Received ErrorMessage on new Connection from %s: %s"
|
|
||||||
% (self.host, startup_response.summary_msg()))
|
|
||||||
raise ConnectionException(
|
|
||||||
"Failed to initialize new connection to %s: %s"
|
|
||||||
% (self.host, startup_response.summary_msg()))
|
|
||||||
else:
|
|
||||||
msg = "Unexpected response during Connection setup: %r" % (startup_response,)
|
|
||||||
log.error(msg)
|
|
||||||
raise ProtocolError(msg)
|
|
||||||
|
|
||||||
def set_keyspace(self, keyspace):
|
|
||||||
if not keyspace or keyspace == self.keyspace:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
query = 'USE "%s"' % (keyspace,)
|
|
||||||
try:
|
|
||||||
result = self.wait_for_response(
|
|
||||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE))
|
|
||||||
if isinstance(result, ResultMessage):
|
|
||||||
self.keyspace = keyspace
|
|
||||||
else:
|
|
||||||
raise self.defunct(ConnectionException(
|
|
||||||
"Problem while setting keyspace: %r" % (result,), self.host))
|
|
||||||
except Exception, exc:
|
|
||||||
raise self.defunct(ConnectionException(
|
|
||||||
"Problem while setting keyspace: %r" % (exc,), self.host))
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseWaiter(object):
|
|
||||||
|
|
||||||
def __init__(self, num_responses):
|
|
||||||
self.pending = num_responses
|
|
||||||
self.error = None
|
|
||||||
self.responses = [None] * num_responses
|
|
||||||
self.event = Event()
|
|
||||||
|
|
||||||
def got_response(self, response, index):
|
|
||||||
if isinstance(response, Exception):
|
|
||||||
self.error = response
|
|
||||||
self.event.set()
|
|
||||||
else:
|
|
||||||
self.responses[index] = response
|
|
||||||
self.pending -= 1
|
|
||||||
if not self.pending:
|
|
||||||
self.event.set()
|
|
||||||
|
|
||||||
def deliver(self):
|
|
||||||
self.event.wait()
|
|
||||||
if self.error:
|
|
||||||
raise self.error
|
|
||||||
else:
|
|
||||||
return self.responses
|
|
||||||
|
Reference in New Issue
Block a user