diff --git a/cassandra/connection.py b/cassandra/connection.py index ee783184..c9f34b95 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1,7 +1,8 @@ import errno from functools import wraps import logging -from threading import Event +from threading import Event, Lock, RLock +from Queue import Queue from cassandra import ConsistencyLevel from cassandra.marshal import int8_unpack @@ -87,10 +88,6 @@ class Connection(object): is_closed = False lock = None - @classmethod - def factory(cls, *args, **kwargs): - raise NotImplementedError() - def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None): self.host = host self.port = port @@ -99,6 +96,13 @@ class Connection(object): self.compression = compression self.cql_version = cql_version + self._id_queue = Queue(MAX_STREAM_PER_CONNECTION) + for i in range(MAX_STREAM_PER_CONNECTION): + self._id_queue.put_nowait(i) + + self.lock = RLock() + self.id_lock = Lock() + def close(self): raise NotImplementedError() diff --git a/cassandra/io/asyncorereactor.py b/cassandra/io/asyncorereactor.py index 3293b072..b8c3b6cf 100644 --- a/cassandra/io/asyncorereactor.py +++ b/cassandra/io/asyncorereactor.py @@ -3,14 +3,14 @@ from functools import partial import logging import socket import sys -from threading import RLock, Event, Lock, Thread +from threading import Event, Lock, Thread import traceback from Queue import Queue import asyncore from cassandra.connection import (Connection, ResponseWaiter, ConnectionException, - ConnectionBusy, MAX_STREAM_PER_CONNECTION, NONBLOCKING) + ConnectionBusy, NONBLOCKING) from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) @@ -64,20 +64,12 @@ class AsyncoreConnection(Connection, asyncore.dispatcher): self.connected_event = Event() - self._id_queue = Queue(MAX_STREAM_PER_CONNECTION) - for i in range(MAX_STREAM_PER_CONNECTION): - self._id_queue.put_nowait(i) - self._callbacks = {} self._push_watchers = defaultdict(set) - self.lock = RLock() - self.id_lock = Lock() - self.deque = deque() self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.connect((self.host, self.port)) - # self.setblocking(0) if self.sockopts: for args in self.sockopts: diff --git a/cassandra/io/pyevreactor.py b/cassandra/io/pyevreactor.py index a5faba06..0cabbcb3 100644 --- a/cassandra/io/pyevreactor.py +++ b/cassandra/io/pyevreactor.py @@ -1,66 +1,19 @@ from collections import defaultdict, deque -import errno from functools import partial, wraps import logging import socket -from threading import RLock, Event, Lock, Thread +from threading import Event, Lock, Thread import traceback from Queue import Queue import pyev -from cassandra import ConsistencyLevel -from cassandra.marshal import (int8_unpack, int32_unpack) -from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage, - StartupMessage, ErrorMessage, CredentialsMessage, - QueryMessage, ResultMessage, decode_response) +from cassandra.connection import (Connection, ResponseWaiter, ConnectionException, + ConnectionBusy, NONBLOCKING) +from cassandra.marshal import int32_unpack log = logging.getLogger(__name__) -locally_supported_compressions = {} -try: - import snappy -except ImportError: - pass -else: - # work around apparently buggy snappy decompress - def decompress(byts): - if byts == '\x00': - return '' - return snappy.decompress(byts) - locally_supported_compressions['snappy'] = (snappy.compress, decompress) - - -MAX_STREAM_PER_CONNECTION = 128 - -PROTOCOL_VERSION = 0x01 -PROTOCOL_VERSION_MASK = 0x7f - -HEADER_DIRECTION_FROM_CLIENT = 0x00 -HEADER_DIRECTION_TO_CLIENT = 0x80 -HEADER_DIRECTION_MASK = 0x80 - -NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) - - -class ConnectionException(Exception): - - def __init__(self, message, host=None): - Exception.__init__(self, message) - self.host = host - - -class ConnectionBusy(Exception): - pass - - -class ProgrammingError(Exception): - pass - - -class ProtocolError(Exception): - pass - _loop = pyev.default_loop(pyev.EVBACKEND_SELECT) @@ -117,23 +70,7 @@ def defunct_on_error(f): return wrapper -class Connection(object): - - in_buffer_size = 4096 - out_buffer_size = 4096 - - cql_version = None - - keyspace = None - compression = True - compressor = None - decompressor = None - - last_error = None - in_flight = 0 - is_defunct = False - is_closed = False - lock = None +class PyevConnection(Connection): _buf = "" _total_reqd_bytes = 0 @@ -150,32 +87,21 @@ class Connection(object): else: return conn - def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None): - self.host = host - self.port = port - self.credentials = credentials - self.compression = compression - self.cql_version = cql_version + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) self.connected_event = Event() - self._id_queue = Queue(MAX_STREAM_PER_CONNECTION) - for i in range(MAX_STREAM_PER_CONNECTION): - self._id_queue.put_nowait(i) - self._callbacks = {} self._push_watchers = defaultdict(set) - self.lock = RLock() - self.id_lock = Lock() - self.deque = deque() self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.connect((host, port)) + self._socket.connect((self.host, self.port)) self._socket.setblocking(0) - if sockopts: - for args in sockopts: + if self.sockopts: + for args in self.sockopts: self._socket.setsockopt(*args) self._read_watcher = pyev.Io(self._socket._sock, pyev.EV_READ, _loop, self.handle_read) @@ -184,8 +110,7 @@ class Connection(object): self._read_watcher.start() self._write_watcher.start() - log.debug("Sending initial options message for new Connection to %s" % (host,)) - self.send_msg(OptionsMessage(), self._handle_options_response) + self._send_options_message() # start the global event loop if needed if not _start_loop(): @@ -281,51 +206,6 @@ class Connection(object): log.debug("connection closed by server") self.close() - @defunct_on_error - def process_msg(self, msg, body_len): - version, flags, stream_id, opcode = map(int8_unpack, msg[:4]) - if stream_id < 0: - callback = None - else: - callback = self._callbacks.pop(stream_id) - self._id_queue.put_nowait(stream_id) - - body = None - try: - # check that the protocol version is supported - given_version = version & PROTOCOL_VERSION_MASK - if given_version != PROTOCOL_VERSION: - raise ProtocolError("Unsupported CQL protocol version: %d" % given_version) - - # check that the header direction is correct - if version & HEADER_DIRECTION_MASK != HEADER_DIRECTION_TO_CLIENT: - raise ProtocolError( - "Header direction in response is incorrect; opcode %04x, stream id %r" - % (opcode, stream_id)) - - if body_len > 0: - body = msg[8:] - elif body_len == 0: - body = "" - else: - raise ProtocolError("Got negative body length: %r" % body_len) - - response = decode_response(stream_id, flags, opcode, body, self.decompressor) - except Exception, exc: - log.exception("Error decoding response from Cassandra. " - "opcode: %04x; message contents: %r" % (opcode, body)) - callback(exc) - self.defunct(exc) - return - - try: - if stream_id < 0: - self.handle_pushed(response) - elif callback is not None: - callback(response) - except: - log.exception("Callback handler errored, ignoring:") - def handle_pushed(self, response): for cb in self._push_watchers[response.type]: try: @@ -382,114 +262,3 @@ class Connection(object): def register_watchers(self, type_callback_dict): for event_type, callback in type_callback_dict.items(): self.register_watcher(event_type, callback) - - @defunct_on_error - def _handle_options_response(self, options_response): - if self.is_defunct: - return - log.debug("Received options response on new Connection from %s" % self.host) - self.supported_cql_versions = options_response.cql_versions - self.remote_supported_compressions = options_response.options['COMPRESSION'] - - if self.cql_version: - if self.cql_version not in self.supported_cql_versions: - raise ProtocolError( - "cql_version %r is not supported by remote (w/ native " - "protocol). Supported versions: %r" - % (self.cql_version, self.supported_cql_versions)) - else: - self.cql_version = self.supported_cql_versions[0] - - opts = {} - self._compressor = None - if self.compression: - overlap = (set(locally_supported_compressions.keys()) & - set(self.remote_supported_compressions)) - if len(overlap) == 0: - log.debug("No available compression types supported on both ends." - " locally supported: %r. remotely supported: %r" - % (locally_supported_compressions.keys(), - self.remote_supported_compressions)) - else: - compression_type = iter(overlap).next() # choose any - opts['COMPRESSION'] = compression_type - # set the decompressor here, but set the compressor only after - # a successful Ready message - self._compressor, self.decompressor = \ - locally_supported_compressions[compression_type] - - sm = StartupMessage(cqlversion=self.cql_version, options=opts) - self.send_msg(sm, cb=self._handle_startup_response) - - @defunct_on_error - def _handle_startup_response(self, startup_response): - if self.is_defunct: - return - if isinstance(startup_response, ReadyMessage): - log.debug("Got ReadyMessage on new Connection from %s" % self.host) - if self._compressor: - self.compressor = self._compressor - self.connected_event.set() - elif isinstance(startup_response, AuthenticateMessage): - log.debug("Got AuthenticateMessage on new Connection from %s" % self.host) - - if self.credentials is None: - raise ProgrammingError('Remote end requires authentication.') - - self.authenticator = startup_response.authenticator - cm = CredentialsMessage(creds=self.credentials) - self.send_msg(cm, cb=self._handle_startup_response) - elif isinstance(startup_response, ErrorMessage): - log.debug("Received ErrorMessage on new Connection from %s: %s" - % (self.host, startup_response.summary_msg())) - raise ConnectionException( - "Failed to initialize new connection to %s: %s" - % (self.host, startup_response.summary_msg())) - else: - msg = "Unexpected response during Connection setup: %r" % (startup_response,) - log.error(msg) - raise ProtocolError(msg) - - def set_keyspace(self, keyspace): - if not keyspace or keyspace == self.keyspace: - return - - with self.lock: - query = 'USE "%s"' % (keyspace,) - try: - result = self.wait_for_response( - QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE)) - if isinstance(result, ResultMessage): - self.keyspace = keyspace - else: - raise self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (result,), self.host)) - except Exception, exc: - raise self.defunct(ConnectionException( - "Problem while setting keyspace: %r" % (exc,), self.host)) - - -class ResponseWaiter(object): - - def __init__(self, num_responses): - self.pending = num_responses - self.error = None - self.responses = [None] * num_responses - self.event = Event() - - def got_response(self, response, index): - if isinstance(response, Exception): - self.error = response - self.event.set() - else: - self.responses[index] = response - self.pending -= 1 - if not self.pending: - self.event.set() - - def deliver(self): - self.event.wait() - if self.error: - raise self.error - else: - return self.responses