Basic working implementation of asyncore reactor
This commit is contained in:
@@ -9,7 +9,7 @@ from functools import partial
|
||||
from itertools import groupby
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.connection import Connection, ConnectionException
|
||||
from cassandra.connection import ConnectionException
|
||||
from cassandra.decoder import (QueryMessage, ResultMessage,
|
||||
ErrorMessage, ReadTimeoutErrorMessage,
|
||||
WriteTimeoutErrorMessage,
|
||||
@@ -19,6 +19,7 @@ from cassandra.decoder import (QueryMessage, ResultMessage,
|
||||
PreparedQueryNotFound,
|
||||
IsBootstrappingErrorMessage, named_tuple_factory,
|
||||
dict_factory)
|
||||
from cassandra.io.asyncorereactor import AsyncoreConnection
|
||||
from cassandra.metadata import Metadata
|
||||
from cassandra.policies import (RoundRobinPolicy, SimpleConvictionPolicy,
|
||||
ExponentialReconnectionPolicy, HostDistance,
|
||||
@@ -141,6 +142,8 @@ class Cluster(object):
|
||||
An instance of :cls:`cassandra.metadata.Metadata`.
|
||||
"""
|
||||
|
||||
connection_class = AsyncoreConnection
|
||||
|
||||
sessions = None
|
||||
control_connection = None
|
||||
scheduler = None
|
||||
@@ -264,7 +267,7 @@ class Cluster(object):
|
||||
kwargs['compression'] = self.compression
|
||||
kwargs['sockopts'] = self.sockopts
|
||||
|
||||
return Connection.factory(address, *args, **kwargs)
|
||||
return self.connection_class.factory(address, *args, **kwargs)
|
||||
|
||||
def _make_connection_factory(self, host, *args, **kwargs):
|
||||
if self.auth_provider:
|
||||
@@ -273,7 +276,7 @@ class Cluster(object):
|
||||
kwargs['compression'] = self.compression
|
||||
kwargs['sockopts'] = self.sockopts
|
||||
|
||||
return partial(Connection.factory, host.address, *args, **kwargs)
|
||||
return partial(self.connection_class.factory, host.address, *args, **kwargs)
|
||||
|
||||
def connect(self, keyspace=None):
|
||||
"""
|
||||
@@ -288,6 +291,7 @@ class Cluster(object):
|
||||
if self.control_connection:
|
||||
try:
|
||||
self.control_connection.connect()
|
||||
log.debug("Control connection created")
|
||||
except:
|
||||
log.exception("Control connection failed to connect, "
|
||||
"shutting down Cluster:")
|
||||
@@ -1036,6 +1040,7 @@ class ControlConnection(object):
|
||||
self._cluster.executor.submit(self.refresh_schema, keyspace, table)
|
||||
|
||||
def wait_for_schema_agreement(self, connection=None):
|
||||
log.debug("[control connection] Waiting for schema agreement")
|
||||
if not connection:
|
||||
connection = self._connection
|
||||
|
||||
@@ -1069,6 +1074,7 @@ class ControlConnection(object):
|
||||
if len(versions) == 1:
|
||||
return True
|
||||
|
||||
log.debug("[control connection] Schemas mismatched, trying again")
|
||||
self._time.sleep(0.2)
|
||||
elapsed = self._time.time() - start
|
||||
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
from collections import defaultdict, deque
|
||||
import errno
|
||||
from functools import partial, wraps
|
||||
from functools import wraps
|
||||
import logging
|
||||
import socket
|
||||
from threading import RLock, Event, Lock, Thread
|
||||
import traceback
|
||||
from Queue import Queue
|
||||
|
||||
import pyev
|
||||
from threading import Event
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.marshal import (int8_unpack, int32_unpack)
|
||||
from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage,
|
||||
from cassandra.marshal import int8_unpack
|
||||
from cassandra.decoder import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
locally_supported_compressions = {}
|
||||
@@ -62,49 +57,6 @@ class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_loop = pyev.default_loop(pyev.EVBACKEND_SELECT)
|
||||
|
||||
_loop_notifier = _loop.async(lambda *a, **kw: None)
|
||||
_loop_notifier.start()
|
||||
|
||||
# prevent _loop_notifier from keeping the loop from returning
|
||||
_loop.unref()
|
||||
|
||||
_loop_started = None
|
||||
_loop_lock = Lock()
|
||||
|
||||
def _run_loop():
|
||||
while True:
|
||||
end_condition = _loop.start()
|
||||
# there are still active watchers, no deadlock
|
||||
with _loop_lock:
|
||||
if end_condition:
|
||||
log.debug("Restarting event loop")
|
||||
continue
|
||||
else:
|
||||
# all Connections have been closed, no active watchers
|
||||
log.debug("All Connections currently closed, event loop ended")
|
||||
global _loop_started
|
||||
_loop_started = False
|
||||
break
|
||||
|
||||
def _start_loop():
|
||||
global _loop_started
|
||||
should_start = False
|
||||
with _loop_lock:
|
||||
if not _loop_started:
|
||||
log.debug("Starting pyev event loop")
|
||||
_loop_started = True
|
||||
should_start = True
|
||||
|
||||
if should_start:
|
||||
t = Thread(target=_run_loop, name="async_event_loop")
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
return should_start
|
||||
|
||||
|
||||
def defunct_on_error(f):
|
||||
|
||||
@wraps(f)
|
||||
@@ -135,151 +87,38 @@ class Connection(object):
|
||||
is_closed = False
|
||||
lock = None
|
||||
|
||||
_buf = ""
|
||||
_total_reqd_bytes = 0
|
||||
_read_watcher = None
|
||||
_write_watcher = None
|
||||
_socket = None
|
||||
|
||||
@classmethod
|
||||
def factory(cls, *args, **kwargs):
|
||||
conn = cls(*args, **kwargs)
|
||||
conn.connected_event.wait()
|
||||
if conn.last_error:
|
||||
raise conn.last_error
|
||||
else:
|
||||
return conn
|
||||
raise NotImplementedError()
|
||||
|
||||
def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
self.sockopts = sockopts
|
||||
self.compression = compression
|
||||
self.cql_version = cql_version
|
||||
|
||||
self.connected_event = Event()
|
||||
|
||||
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
||||
for i in range(MAX_STREAM_PER_CONNECTION):
|
||||
self._id_queue.put_nowait(i)
|
||||
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
self.lock = RLock()
|
||||
self.id_lock = Lock()
|
||||
|
||||
self.deque = deque()
|
||||
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.connect((host, port))
|
||||
self._socket.setblocking(0)
|
||||
|
||||
if sockopts:
|
||||
for args in sockopts:
|
||||
self._socket.setsockopt(*args)
|
||||
|
||||
self._read_watcher = pyev.Io(self._socket._sock, pyev.EV_READ, _loop, self.handle_read)
|
||||
self._write_watcher = pyev.Io(self._socket._sock, pyev.EV_WRITE, _loop, self.handle_write)
|
||||
with _loop_lock:
|
||||
self._read_watcher.start()
|
||||
self._write_watcher.start()
|
||||
|
||||
log.debug("Sending initial options message for new Connection to %s" % (host,))
|
||||
self.send_msg(OptionsMessage(), self._handle_options_response)
|
||||
|
||||
# start the global event loop if needed
|
||||
if not _start_loop():
|
||||
# if the loop was already started, notify it
|
||||
with _loop_lock:
|
||||
_loop_notifier.send()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
if self.is_closed:
|
||||
return
|
||||
self.is_closed = True
|
||||
raise NotImplementedError()
|
||||
|
||||
log.debug("Closing connection to %s" % (self.host,))
|
||||
if self._read_watcher:
|
||||
self._read_watcher.stop()
|
||||
if self._write_watcher:
|
||||
self._write_watcher.stop()
|
||||
self._socket.close()
|
||||
with _loop_lock:
|
||||
_loop_notifier.send()
|
||||
def defunct(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
# don't leave in-progress operations hanging
|
||||
if not self.is_defunct:
|
||||
self._error_all_callbacks(
|
||||
ConnectionException("Connection to %s was closed" % self.host))
|
||||
def send_msg(self, msg, cb):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
def wait_for_response(self, msg):
|
||||
raise NotImplementedError()
|
||||
|
||||
def defunct(self, exc):
|
||||
log.debug("Defuncting connection to %s: %s\n%s" %
|
||||
(self.host, exc, traceback.format_exc(exc)))
|
||||
self.last_error = exc
|
||||
self.is_defunct = True
|
||||
self._error_all_callbacks(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
def wait_for_responses(self, *msgs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _error_all_callbacks(self, exc):
|
||||
for cb in self._callbacks.values():
|
||||
cb(exc)
|
||||
def register_watcher(self, event_type, callback):
|
||||
raise NotImplementedError()
|
||||
|
||||
def handle_write(self, watcher, revents):
|
||||
try:
|
||||
next_msg = self.deque.popleft()
|
||||
except IndexError:
|
||||
self._write_watcher.stop()
|
||||
return
|
||||
|
||||
try:
|
||||
sent = self._socket.send(next_msg)
|
||||
except socket.error, err:
|
||||
if (err.args[0] in NONBLOCKING):
|
||||
self.deque.appendleft(next_msg)
|
||||
else:
|
||||
self.defunct(err)
|
||||
return
|
||||
else:
|
||||
if sent < len(next_msg):
|
||||
self.deque.appendleft(next_msg[sent:])
|
||||
|
||||
if not self.deque:
|
||||
self._write_watcher.stop()
|
||||
|
||||
def handle_read(self, watcher, revents):
|
||||
try:
|
||||
buf = self._socket.recv(self.in_buffer_size)
|
||||
except socket.error, err:
|
||||
if err.args[0] not in NONBLOCKING:
|
||||
self.defunct(err)
|
||||
return
|
||||
|
||||
if buf:
|
||||
self._buf += buf
|
||||
while True:
|
||||
if len(self._buf) < 8:
|
||||
# we don't have a complete header yet
|
||||
break
|
||||
elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes:
|
||||
# we already saw a header, but we don't have a complete message yet
|
||||
break
|
||||
else:
|
||||
body_len = int32_unpack(self._buf[4:8])
|
||||
if len(self._buf) - 8 >= body_len:
|
||||
msg = self._buf[:8 + body_len]
|
||||
self._buf = self._buf[8 + body_len:]
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
else:
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
def register_watchers(self, type_callback_dict):
|
||||
raise NotImplementedError()
|
||||
|
||||
@defunct_on_error
|
||||
def process_msg(self, msg, body_len):
|
||||
@@ -326,62 +165,10 @@ class Connection(object):
|
||||
except:
|
||||
log.exception("Callback handler errored, ignoring:")
|
||||
|
||||
def handle_pushed(self, response):
|
||||
for cb in self._push_watchers[response.type]:
|
||||
try:
|
||||
cb(response)
|
||||
except:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def push(self, data):
|
||||
sabs = self.out_buffer_size
|
||||
if len(data) > sabs:
|
||||
chunks = []
|
||||
for i in xrange(0, len(data), sabs):
|
||||
chunks.append(data[i:i + sabs])
|
||||
else:
|
||||
chunks = [data]
|
||||
|
||||
with self.lock:
|
||||
self.deque.extend(chunks)
|
||||
|
||||
if not self._write_watcher.active:
|
||||
with _loop_lock:
|
||||
self._write_watcher.start()
|
||||
_loop_notifier.send()
|
||||
|
||||
def send_msg(self, msg, cb):
|
||||
if self.is_defunct:
|
||||
raise ConnectionException("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionException("Connection to %s is closed" % self.host)
|
||||
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Queue.EMPTY:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_string(request_id, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg):
|
||||
return self.wait_for_responses(msg)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs):
|
||||
waiter = ResponseWaiter(len(msgs))
|
||||
for i, msg in enumerate(msgs):
|
||||
self.send_msg(msg, partial(waiter.got_response, index=i))
|
||||
|
||||
return waiter.deliver()
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
|
||||
def register_watchers(self, type_callback_dict):
|
||||
for event_type, callback in type_callback_dict.items():
|
||||
self.register_watcher(event_type, callback)
|
||||
@defunct_on_error
|
||||
def _send_options_message(self):
|
||||
log.debug("Sending initial options message for new Connection to %s" % (self.host,))
|
||||
self.send_msg(OptionsMessage(), self._handle_options_response)
|
||||
|
||||
@defunct_on_error
|
||||
def _handle_options_response(self, options_response):
|
||||
|
||||
0
cassandra/io/__init__.py
Normal file
0
cassandra/io/__init__.py
Normal file
254
cassandra/io/asyncorereactor.py
Normal file
254
cassandra/io/asyncorereactor.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from collections import defaultdict, deque
|
||||
from functools import partial
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
from threading import RLock, Event, Lock, Thread
|
||||
import traceback
|
||||
from Queue import Queue
|
||||
|
||||
import asyncore
|
||||
|
||||
from cassandra.connection import (Connection, ResponseWaiter, ConnectionException,
|
||||
ConnectionBusy, MAX_STREAM_PER_CONNECTION, NONBLOCKING)
|
||||
from cassandra.marshal import int32_unpack
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_loop_started = None
|
||||
_loop_lock = Lock()
|
||||
|
||||
def _run_loop():
|
||||
log.debug("Starting asyncore event loop")
|
||||
asyncore.loop(timeout=0.001, use_poll=True, count=None)
|
||||
log.debug("Asyncore event loop ended")
|
||||
with _loop_lock:
|
||||
global _loop_started
|
||||
_loop_started = False
|
||||
|
||||
def _start_loop():
|
||||
global _loop_started
|
||||
should_start = False
|
||||
with _loop_lock:
|
||||
if not _loop_started:
|
||||
_loop_started = True
|
||||
should_start = True
|
||||
|
||||
if should_start:
|
||||
t = Thread(target=_run_loop, name="async_event_loop")
|
||||
t.daemon = False
|
||||
t.start()
|
||||
|
||||
return should_start
|
||||
|
||||
|
||||
class AsyncoreConnection(Connection, asyncore.dispatcher):
|
||||
|
||||
_buf = ""
|
||||
_total_reqd_bytes = 0
|
||||
_writable = False
|
||||
_readable = False
|
||||
|
||||
@classmethod
|
||||
def factory(cls, *args, **kwargs):
|
||||
conn = cls(*args, **kwargs)
|
||||
conn.connected_event.wait()
|
||||
if conn.last_error:
|
||||
raise conn.last_error
|
||||
else:
|
||||
return conn
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
asyncore.dispatcher.__init__(self)
|
||||
|
||||
self.connected_event = Event()
|
||||
|
||||
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
||||
for i in range(MAX_STREAM_PER_CONNECTION):
|
||||
self._id_queue.put_nowait(i)
|
||||
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
self.lock = RLock()
|
||||
self.id_lock = Lock()
|
||||
|
||||
self.deque = deque()
|
||||
|
||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.connect((self.host, self.port))
|
||||
# self.setblocking(0)
|
||||
|
||||
if self.sockopts:
|
||||
for args in self.sockopts:
|
||||
self.socket.setsockopt(*args)
|
||||
|
||||
self._writable = True
|
||||
|
||||
# start the global event loop if needed
|
||||
_start_loop()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
if self.is_closed:
|
||||
return
|
||||
self.is_closed = True
|
||||
|
||||
log.debug("Closing connection to %s" % (self.host,))
|
||||
self._writable = False
|
||||
self._readable = False
|
||||
asyncore.dispatcher.close(self)
|
||||
|
||||
# don't leave in-progress operations hanging
|
||||
self.connected_event.set()
|
||||
if not self.is_defunct:
|
||||
log.debug("Erroring all callbacks")
|
||||
self._error_all_callbacks(
|
||||
ConnectionException("Connection to %s was closed" % self.host))
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def defunct(self, exc):
|
||||
if self.is_defunct:
|
||||
return
|
||||
|
||||
log.debug("Defuncting connection to %s: %s\n%s" %
|
||||
(self.host, exc, traceback.format_exc(exc)))
|
||||
self.last_error = exc
|
||||
self.is_defunct = True
|
||||
self._error_all_callbacks(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
|
||||
def _error_all_callbacks(self, exc):
|
||||
for cb in self._callbacks.values():
|
||||
cb(exc)
|
||||
|
||||
def handle_connect(self):
|
||||
self._send_options_message()
|
||||
|
||||
def handle_error(self):
|
||||
t, v, tbinfo = sys.exc_info()
|
||||
self.defunct(v)
|
||||
|
||||
def handle_close(self):
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
|
||||
def handle_write(self):
|
||||
try:
|
||||
next_msg = self.deque.popleft()
|
||||
except IndexError:
|
||||
self._writable = False
|
||||
return
|
||||
|
||||
try:
|
||||
sent = self.send(next_msg)
|
||||
except socket.error, err:
|
||||
if (err.args[0] in NONBLOCKING):
|
||||
self.deque.appendleft(next_msg)
|
||||
else:
|
||||
self.defunct(err)
|
||||
return
|
||||
else:
|
||||
if sent < len(next_msg):
|
||||
self.deque.appendleft(next_msg[sent:])
|
||||
|
||||
if not self.deque:
|
||||
self._writable = False
|
||||
|
||||
self._readable = True
|
||||
|
||||
def handle_read(self):
|
||||
try:
|
||||
buf = self.recv(self.in_buffer_size)
|
||||
except socket.error, err:
|
||||
if err.args[0] not in NONBLOCKING:
|
||||
self.defunct(err)
|
||||
return
|
||||
|
||||
if buf:
|
||||
self._buf += buf
|
||||
while True:
|
||||
if len(self._buf) < 8:
|
||||
# we don't have a complete header yet
|
||||
break
|
||||
elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes:
|
||||
# we already saw a header, but we don't have a complete message yet
|
||||
break
|
||||
else:
|
||||
body_len = int32_unpack(self._buf[4:8])
|
||||
if len(self._buf) - 8 >= body_len:
|
||||
msg = self._buf[:8 + body_len]
|
||||
self._buf = self._buf[8 + body_len:]
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
|
||||
if not self._callbacks:
|
||||
self._readable = False
|
||||
else:
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
|
||||
def handle_pushed(self, response):
|
||||
for cb in self._push_watchers[response.type]:
|
||||
try:
|
||||
cb(response)
|
||||
except:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def push(self, data):
|
||||
sabs = self.out_buffer_size
|
||||
if len(data) > sabs:
|
||||
chunks = []
|
||||
for i in xrange(0, len(data), sabs):
|
||||
chunks.append(data[i:i + sabs])
|
||||
else:
|
||||
chunks = [data]
|
||||
|
||||
with self.lock:
|
||||
self.deque.extend(chunks)
|
||||
|
||||
self._writable = True
|
||||
|
||||
def writable(self):
|
||||
return self._writable
|
||||
|
||||
def readable(self):
|
||||
return self._readable
|
||||
|
||||
def send_msg(self, msg, cb):
|
||||
if self.is_defunct:
|
||||
raise ConnectionException("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionException("Connection to %s is closed" % self.host)
|
||||
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Queue.EMPTY:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_string(request_id, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg):
|
||||
return self.wait_for_responses(msg)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs):
|
||||
waiter = ResponseWaiter(len(msgs))
|
||||
for i, msg in enumerate(msgs):
|
||||
self.send_msg(msg, partial(waiter.got_response, index=i))
|
||||
|
||||
return waiter.deliver()
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
|
||||
def register_watchers(self, type_callback_dict):
|
||||
for event_type, callback in type_callback_dict.items():
|
||||
self.register_watcher(event_type, callback)
|
||||
495
cassandra/io/pyevreactor.py
Normal file
495
cassandra/io/pyevreactor.py
Normal file
@@ -0,0 +1,495 @@
|
||||
from collections import defaultdict, deque
|
||||
import errno
|
||||
from functools import partial, wraps
|
||||
import logging
|
||||
import socket
|
||||
from threading import RLock, Event, Lock, Thread
|
||||
import traceback
|
||||
from Queue import Queue
|
||||
|
||||
import pyev
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.marshal import (int8_unpack, int32_unpack)
|
||||
from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
QueryMessage, ResultMessage, decode_response)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
locally_supported_compressions = {}
|
||||
try:
|
||||
import snappy
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# work around apparently buggy snappy decompress
|
||||
def decompress(byts):
|
||||
if byts == '\x00':
|
||||
return ''
|
||||
return snappy.decompress(byts)
|
||||
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
|
||||
|
||||
|
||||
MAX_STREAM_PER_CONNECTION = 128
|
||||
|
||||
PROTOCOL_VERSION = 0x01
|
||||
PROTOCOL_VERSION_MASK = 0x7f
|
||||
|
||||
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
||||
HEADER_DIRECTION_TO_CLIENT = 0x80
|
||||
HEADER_DIRECTION_MASK = 0x80
|
||||
|
||||
NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK)
|
||||
|
||||
|
||||
class ConnectionException(Exception):
|
||||
|
||||
def __init__(self, message, host=None):
|
||||
Exception.__init__(self, message)
|
||||
self.host = host
|
||||
|
||||
|
||||
class ConnectionBusy(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProgrammingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_loop = pyev.default_loop(pyev.EVBACKEND_SELECT)
|
||||
|
||||
_loop_notifier = _loop.async(lambda *a, **kw: None)
|
||||
_loop_notifier.start()
|
||||
|
||||
# prevent _loop_notifier from keeping the loop from returning
|
||||
_loop.unref()
|
||||
|
||||
_loop_started = None
|
||||
_loop_lock = Lock()
|
||||
|
||||
def _run_loop():
|
||||
while True:
|
||||
end_condition = _loop.start()
|
||||
# there are still active watchers, no deadlock
|
||||
with _loop_lock:
|
||||
if end_condition:
|
||||
log.debug("Restarting event loop")
|
||||
continue
|
||||
else:
|
||||
# all Connections have been closed, no active watchers
|
||||
log.debug("All Connections currently closed, event loop ended")
|
||||
global _loop_started
|
||||
_loop_started = False
|
||||
break
|
||||
|
||||
def _start_loop():
|
||||
global _loop_started
|
||||
should_start = False
|
||||
with _loop_lock:
|
||||
if not _loop_started:
|
||||
log.debug("Starting pyev event loop")
|
||||
_loop_started = True
|
||||
should_start = True
|
||||
|
||||
if should_start:
|
||||
t = Thread(target=_run_loop, name="async_event_loop")
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
return should_start
|
||||
|
||||
|
||||
def defunct_on_error(f):
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return f(self, *args, **kwargs)
|
||||
except Exception, exc:
|
||||
self.defunct(exc)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Connection(object):
|
||||
|
||||
in_buffer_size = 4096
|
||||
out_buffer_size = 4096
|
||||
|
||||
cql_version = None
|
||||
|
||||
keyspace = None
|
||||
compression = True
|
||||
compressor = None
|
||||
decompressor = None
|
||||
|
||||
last_error = None
|
||||
in_flight = 0
|
||||
is_defunct = False
|
||||
is_closed = False
|
||||
lock = None
|
||||
|
||||
_buf = ""
|
||||
_total_reqd_bytes = 0
|
||||
_read_watcher = None
|
||||
_write_watcher = None
|
||||
_socket = None
|
||||
|
||||
@classmethod
|
||||
def factory(cls, *args, **kwargs):
|
||||
conn = cls(*args, **kwargs)
|
||||
conn.connected_event.wait()
|
||||
if conn.last_error:
|
||||
raise conn.last_error
|
||||
else:
|
||||
return conn
|
||||
|
||||
def __init__(self, host='127.0.0.1', port=9042, credentials=None, sockopts=None, compression=True, cql_version=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.credentials = credentials
|
||||
self.compression = compression
|
||||
self.cql_version = cql_version
|
||||
|
||||
self.connected_event = Event()
|
||||
|
||||
self._id_queue = Queue(MAX_STREAM_PER_CONNECTION)
|
||||
for i in range(MAX_STREAM_PER_CONNECTION):
|
||||
self._id_queue.put_nowait(i)
|
||||
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
self.lock = RLock()
|
||||
self.id_lock = Lock()
|
||||
|
||||
self.deque = deque()
|
||||
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket.connect((host, port))
|
||||
self._socket.setblocking(0)
|
||||
|
||||
if sockopts:
|
||||
for args in sockopts:
|
||||
self._socket.setsockopt(*args)
|
||||
|
||||
self._read_watcher = pyev.Io(self._socket._sock, pyev.EV_READ, _loop, self.handle_read)
|
||||
self._write_watcher = pyev.Io(self._socket._sock, pyev.EV_WRITE, _loop, self.handle_write)
|
||||
with _loop_lock:
|
||||
self._read_watcher.start()
|
||||
self._write_watcher.start()
|
||||
|
||||
log.debug("Sending initial options message for new Connection to %s" % (host,))
|
||||
self.send_msg(OptionsMessage(), self._handle_options_response)
|
||||
|
||||
# start the global event loop if needed
|
||||
if not _start_loop():
|
||||
# if the loop was already started, notify it
|
||||
with _loop_lock:
|
||||
_loop_notifier.send()
|
||||
|
||||
def close(self):
|
||||
with self.lock:
|
||||
if self.is_closed:
|
||||
return
|
||||
self.is_closed = True
|
||||
|
||||
log.debug("Closing connection to %s" % (self.host,))
|
||||
if self._read_watcher:
|
||||
self._read_watcher.stop()
|
||||
if self._write_watcher:
|
||||
self._write_watcher.stop()
|
||||
self._socket.close()
|
||||
with _loop_lock:
|
||||
_loop_notifier.send()
|
||||
|
||||
# don't leave in-progress operations hanging
|
||||
if not self.is_defunct:
|
||||
self._error_all_callbacks(
|
||||
ConnectionException("Connection to %s was closed" % self.host))
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def defunct(self, exc):
|
||||
log.debug("Defuncting connection to %s: %s\n%s" %
|
||||
(self.host, exc, traceback.format_exc(exc)))
|
||||
self.last_error = exc
|
||||
self.is_defunct = True
|
||||
self._error_all_callbacks(exc)
|
||||
self.connected_event.set()
|
||||
return exc
|
||||
|
||||
def _error_all_callbacks(self, exc):
|
||||
for cb in self._callbacks.values():
|
||||
cb(exc)
|
||||
|
||||
def handle_write(self, watcher, revents):
|
||||
try:
|
||||
next_msg = self.deque.popleft()
|
||||
except IndexError:
|
||||
self._write_watcher.stop()
|
||||
return
|
||||
|
||||
try:
|
||||
sent = self._socket.send(next_msg)
|
||||
except socket.error, err:
|
||||
if (err.args[0] in NONBLOCKING):
|
||||
self.deque.appendleft(next_msg)
|
||||
else:
|
||||
self.defunct(err)
|
||||
return
|
||||
else:
|
||||
if sent < len(next_msg):
|
||||
self.deque.appendleft(next_msg[sent:])
|
||||
|
||||
if not self.deque:
|
||||
self._write_watcher.stop()
|
||||
|
||||
def handle_read(self, watcher, revents):
|
||||
try:
|
||||
buf = self._socket.recv(self.in_buffer_size)
|
||||
except socket.error, err:
|
||||
if err.args[0] not in NONBLOCKING:
|
||||
self.defunct(err)
|
||||
return
|
||||
|
||||
if buf:
|
||||
self._buf += buf
|
||||
while True:
|
||||
if len(self._buf) < 8:
|
||||
# we don't have a complete header yet
|
||||
break
|
||||
elif self._total_reqd_bytes and len(self._buf) < self._total_reqd_bytes:
|
||||
# we already saw a header, but we don't have a complete message yet
|
||||
break
|
||||
else:
|
||||
body_len = int32_unpack(self._buf[4:8])
|
||||
if len(self._buf) - 8 >= body_len:
|
||||
msg = self._buf[:8 + body_len]
|
||||
self._buf = self._buf[8 + body_len:]
|
||||
self._total_reqd_bytes = 0
|
||||
self.process_msg(msg, body_len)
|
||||
else:
|
||||
self._total_reqd_bytes = body_len + 8
|
||||
else:
|
||||
log.debug("connection closed by server")
|
||||
self.close()
|
||||
|
||||
@defunct_on_error
|
||||
def process_msg(self, msg, body_len):
|
||||
version, flags, stream_id, opcode = map(int8_unpack, msg[:4])
|
||||
if stream_id < 0:
|
||||
callback = None
|
||||
else:
|
||||
callback = self._callbacks.pop(stream_id)
|
||||
self._id_queue.put_nowait(stream_id)
|
||||
|
||||
body = None
|
||||
try:
|
||||
# check that the protocol version is supported
|
||||
given_version = version & PROTOCOL_VERSION_MASK
|
||||
if given_version != PROTOCOL_VERSION:
|
||||
raise ProtocolError("Unsupported CQL protocol version: %d" % given_version)
|
||||
|
||||
# check that the header direction is correct
|
||||
if version & HEADER_DIRECTION_MASK != HEADER_DIRECTION_TO_CLIENT:
|
||||
raise ProtocolError(
|
||||
"Header direction in response is incorrect; opcode %04x, stream id %r"
|
||||
% (opcode, stream_id))
|
||||
|
||||
if body_len > 0:
|
||||
body = msg[8:]
|
||||
elif body_len == 0:
|
||||
body = ""
|
||||
else:
|
||||
raise ProtocolError("Got negative body length: %r" % body_len)
|
||||
|
||||
response = decode_response(stream_id, flags, opcode, body, self.decompressor)
|
||||
except Exception, exc:
|
||||
log.exception("Error decoding response from Cassandra. "
|
||||
"opcode: %04x; message contents: %r" % (opcode, body))
|
||||
callback(exc)
|
||||
self.defunct(exc)
|
||||
return
|
||||
|
||||
try:
|
||||
if stream_id < 0:
|
||||
self.handle_pushed(response)
|
||||
elif callback is not None:
|
||||
callback(response)
|
||||
except:
|
||||
log.exception("Callback handler errored, ignoring:")
|
||||
|
||||
def handle_pushed(self, response):
|
||||
for cb in self._push_watchers[response.type]:
|
||||
try:
|
||||
cb(response)
|
||||
except:
|
||||
log.exception("Pushed event handler errored, ignoring:")
|
||||
|
||||
def push(self, data):
|
||||
sabs = self.out_buffer_size
|
||||
if len(data) > sabs:
|
||||
chunks = []
|
||||
for i in xrange(0, len(data), sabs):
|
||||
chunks.append(data[i:i + sabs])
|
||||
else:
|
||||
chunks = [data]
|
||||
|
||||
with self.lock:
|
||||
self.deque.extend(chunks)
|
||||
|
||||
if not self._write_watcher.active:
|
||||
with _loop_lock:
|
||||
self._write_watcher.start()
|
||||
_loop_notifier.send()
|
||||
|
||||
def send_msg(self, msg, cb):
|
||||
if self.is_defunct:
|
||||
raise ConnectionException("Connection to %s is defunct" % self.host)
|
||||
elif self.is_closed:
|
||||
raise ConnectionException("Connection to %s is closed" % self.host)
|
||||
|
||||
try:
|
||||
request_id = self._id_queue.get_nowait()
|
||||
except Queue.EMPTY:
|
||||
raise ConnectionBusy(
|
||||
"Connection to %s is at the max number of requests" % self.host)
|
||||
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_string(request_id, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_response(self, msg):
|
||||
return self.wait_for_responses(msg)[0]
|
||||
|
||||
def wait_for_responses(self, *msgs):
|
||||
waiter = ResponseWaiter(len(msgs))
|
||||
for i, msg in enumerate(msgs):
|
||||
self.send_msg(msg, partial(waiter.got_response, index=i))
|
||||
|
||||
return waiter.deliver()
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
|
||||
def register_watchers(self, type_callback_dict):
|
||||
for event_type, callback in type_callback_dict.items():
|
||||
self.register_watcher(event_type, callback)
|
||||
|
||||
@defunct_on_error
|
||||
def _handle_options_response(self, options_response):
|
||||
if self.is_defunct:
|
||||
return
|
||||
log.debug("Received options response on new Connection from %s" % self.host)
|
||||
self.supported_cql_versions = options_response.cql_versions
|
||||
self.remote_supported_compressions = options_response.options['COMPRESSION']
|
||||
|
||||
if self.cql_version:
|
||||
if self.cql_version not in self.supported_cql_versions:
|
||||
raise ProtocolError(
|
||||
"cql_version %r is not supported by remote (w/ native "
|
||||
"protocol). Supported versions: %r"
|
||||
% (self.cql_version, self.supported_cql_versions))
|
||||
else:
|
||||
self.cql_version = self.supported_cql_versions[0]
|
||||
|
||||
opts = {}
|
||||
self._compressor = None
|
||||
if self.compression:
|
||||
overlap = (set(locally_supported_compressions.keys()) &
|
||||
set(self.remote_supported_compressions))
|
||||
if len(overlap) == 0:
|
||||
log.debug("No available compression types supported on both ends."
|
||||
" locally supported: %r. remotely supported: %r"
|
||||
% (locally_supported_compressions.keys(),
|
||||
self.remote_supported_compressions))
|
||||
else:
|
||||
compression_type = iter(overlap).next() # choose any
|
||||
opts['COMPRESSION'] = compression_type
|
||||
# set the decompressor here, but set the compressor only after
|
||||
# a successful Ready message
|
||||
self._compressor, self.decompressor = \
|
||||
locally_supported_compressions[compression_type]
|
||||
|
||||
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
||||
self.send_msg(sm, cb=self._handle_startup_response)
|
||||
|
||||
@defunct_on_error
|
||||
def _handle_startup_response(self, startup_response):
|
||||
if self.is_defunct:
|
||||
return
|
||||
if isinstance(startup_response, ReadyMessage):
|
||||
log.debug("Got ReadyMessage on new Connection from %s" % self.host)
|
||||
if self._compressor:
|
||||
self.compressor = self._compressor
|
||||
self.connected_event.set()
|
||||
elif isinstance(startup_response, AuthenticateMessage):
|
||||
log.debug("Got AuthenticateMessage on new Connection from %s" % self.host)
|
||||
|
||||
if self.credentials is None:
|
||||
raise ProgrammingError('Remote end requires authentication.')
|
||||
|
||||
self.authenticator = startup_response.authenticator
|
||||
cm = CredentialsMessage(creds=self.credentials)
|
||||
self.send_msg(cm, cb=self._handle_startup_response)
|
||||
elif isinstance(startup_response, ErrorMessage):
|
||||
log.debug("Received ErrorMessage on new Connection from %s: %s"
|
||||
% (self.host, startup_response.summary_msg()))
|
||||
raise ConnectionException(
|
||||
"Failed to initialize new connection to %s: %s"
|
||||
% (self.host, startup_response.summary_msg()))
|
||||
else:
|
||||
msg = "Unexpected response during Connection setup: %r" % (startup_response,)
|
||||
log.error(msg)
|
||||
raise ProtocolError(msg)
|
||||
|
||||
def set_keyspace(self, keyspace):
|
||||
if not keyspace or keyspace == self.keyspace:
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
query = 'USE "%s"' % (keyspace,)
|
||||
try:
|
||||
result = self.wait_for_response(
|
||||
QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE))
|
||||
if isinstance(result, ResultMessage):
|
||||
self.keyspace = keyspace
|
||||
else:
|
||||
raise self.defunct(ConnectionException(
|
||||
"Problem while setting keyspace: %r" % (result,), self.host))
|
||||
except Exception, exc:
|
||||
raise self.defunct(ConnectionException(
|
||||
"Problem while setting keyspace: %r" % (exc,), self.host))
|
||||
|
||||
|
||||
class ResponseWaiter(object):
|
||||
|
||||
def __init__(self, num_responses):
|
||||
self.pending = num_responses
|
||||
self.error = None
|
||||
self.responses = [None] * num_responses
|
||||
self.event = Event()
|
||||
|
||||
def got_response(self, response, index):
|
||||
if isinstance(response, Exception):
|
||||
self.error = response
|
||||
self.event.set()
|
||||
else:
|
||||
self.responses[index] = response
|
||||
self.pending -= 1
|
||||
if not self.pending:
|
||||
self.event.set()
|
||||
|
||||
def deliver(self):
|
||||
self.event.wait()
|
||||
if self.error:
|
||||
raise self.error
|
||||
else:
|
||||
return self.responses
|
||||
Reference in New Issue
Block a user