1055 lines
39 KiB
Python
1055 lines
39 KiB
Python
# Copyright 2013-2015 DataStax, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import # to enable import io from stdlib
|
|
from collections import defaultdict, deque
|
|
import errno
|
|
from functools import wraps, partial
|
|
from heapq import heappush, heappop
|
|
import io
|
|
import logging
|
|
import six
|
|
from six.moves import range
|
|
import socket
|
|
import struct
|
|
import sys
|
|
from threading import Thread, Event, RLock
|
|
import time
|
|
|
|
try:
|
|
import ssl
|
|
except ImportError:
|
|
ssl = None # NOQA
|
|
|
|
if 'gevent.monkey' in sys.modules:
|
|
from gevent.queue import Queue, Empty
|
|
else:
|
|
from six.moves.queue import Queue, Empty # noqa
|
|
|
|
from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut
|
|
from cassandra.marshal import int32_pack
|
|
from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage,
|
|
StartupMessage, ErrorMessage, CredentialsMessage,
|
|
QueryMessage, ResultMessage, ProtocolHandler,
|
|
InvalidRequestException, SupportedMessage,
|
|
AuthResponseMessage, AuthChallengeMessage,
|
|
AuthSuccessMessage, ProtocolException,
|
|
MAX_SUPPORTED_VERSION, RegisterMessage)
|
|
from cassandra.util import OrderedDict
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# We use an ordered dictionary and specifically add lz4 before
|
|
# snappy so that lz4 will be preferred. Changing the order of this
|
|
# will change the compression preferences for the driver.
|
|
locally_supported_compressions = OrderedDict()
|
|
|
|
try:
|
|
import lz4
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
|
|
# Cassandra writes the uncompressed message length in big endian order,
|
|
# but the lz4 lib requires little endian order, so we wrap these
|
|
# functions to handle that
|
|
|
|
def lz4_compress(byts):
|
|
# write length in big-endian instead of little-endian
|
|
return int32_pack(len(byts)) + lz4.compress(byts)[4:]
|
|
|
|
def lz4_decompress(byts):
|
|
# flip from big-endian to little-endian
|
|
return lz4.decompress(byts[3::-1] + byts[4:])
|
|
|
|
locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress)
|
|
|
|
try:
|
|
import snappy
|
|
except ImportError:
|
|
pass
|
|
else:
|
|
# work around apparently buggy snappy decompress
|
|
def decompress(byts):
|
|
if byts == '\x00':
|
|
return ''
|
|
return snappy.decompress(byts)
|
|
locally_supported_compressions['snappy'] = (snappy.compress, decompress)
|
|
|
|
|
|
PROTOCOL_VERSION_MASK = 0x7f
|
|
|
|
HEADER_DIRECTION_FROM_CLIENT = 0x00
|
|
HEADER_DIRECTION_TO_CLIENT = 0x80
|
|
HEADER_DIRECTION_MASK = 0x80
|
|
|
|
frame_header_v1_v2 = struct.Struct('>BbBi')
|
|
frame_header_v3 = struct.Struct('>BhBi')
|
|
|
|
|
|
class _Frame(object):
|
|
def __init__(self, version, flags, stream, opcode, body_offset, end_pos):
|
|
self.version = version
|
|
self.flags = flags
|
|
self.stream = stream
|
|
self.opcode = opcode
|
|
self.body_offset = body_offset
|
|
self.end_pos = end_pos
|
|
|
|
def __eq__(self, other): # facilitates testing
|
|
if isinstance(other, _Frame):
|
|
return (self.version == other.version and
|
|
self.flags == other.flags and
|
|
self.stream == other.stream and
|
|
self.opcode == other.opcode and
|
|
self.body_offset == other.body_offset and
|
|
self.end_pos == other.end_pos)
|
|
return NotImplemented
|
|
|
|
def __str__(self):
|
|
return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset)
|
|
|
|
|
|
|
|
NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK)
|
|
|
|
class ConnectionException(Exception):
|
|
"""
|
|
An unrecoverable error was hit when attempting to use a connection,
|
|
or the connection was already closed or defunct.
|
|
"""
|
|
|
|
def __init__(self, message, host=None):
|
|
Exception.__init__(self, message)
|
|
self.host = host
|
|
|
|
|
|
class ConnectionShutdown(ConnectionException):
|
|
"""
|
|
Raised when a connection has been marked as defunct or has been closed.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ProtocolVersionUnsupported(ConnectionException):
|
|
"""
|
|
Server rejected startup message due to unsupported protocol version
|
|
"""
|
|
def __init__(self, host, startup_version):
|
|
super(ProtocolVersionUnsupported, self).__init__("Unsupported protocol version on %s: %d",
|
|
(host, startup_version))
|
|
self.startup_version = startup_version
|
|
|
|
|
|
class ConnectionBusy(Exception):
|
|
"""
|
|
An attempt was made to send a message through a :class:`.Connection` that
|
|
was already at the max number of in-flight operations.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ProtocolError(Exception):
|
|
"""
|
|
Communication did not match the protocol that this driver expects.
|
|
"""
|
|
pass
|
|
|
|
|
|
def defunct_on_error(f):
|
|
|
|
@wraps(f)
|
|
def wrapper(self, *args, **kwargs):
|
|
try:
|
|
return f(self, *args, **kwargs)
|
|
except Exception as exc:
|
|
self.defunct(exc)
|
|
return wrapper
|
|
|
|
|
|
DEFAULT_CQL_VERSION = '3.0.0'
|
|
|
|
if six.PY3:
|
|
def int_from_buf_item(i):
|
|
return i
|
|
else:
|
|
int_from_buf_item = ord
|
|
|
|
|
|
class Connection(object):
|
|
|
|
CALLBACK_ERR_THREAD_THRESHOLD = 100
|
|
|
|
in_buffer_size = 4096
|
|
out_buffer_size = 4096
|
|
|
|
cql_version = None
|
|
protocol_version = MAX_SUPPORTED_VERSION
|
|
|
|
keyspace = None
|
|
compression = True
|
|
compressor = None
|
|
decompressor = None
|
|
|
|
ssl_options = None
|
|
last_error = None
|
|
|
|
# The current number of operations that are in flight. More precisely,
|
|
# the number of request IDs that are currently in use.
|
|
in_flight = 0
|
|
|
|
# A set of available request IDs. When using the v3 protocol or higher,
|
|
# this will not initially include all request IDs in order to save memory,
|
|
# but the set will grow if it is exhausted.
|
|
request_ids = None
|
|
|
|
# Tracks the highest used request ID in order to help with growing the
|
|
# request_ids set
|
|
highest_request_id = 0
|
|
|
|
is_defunct = False
|
|
is_closed = False
|
|
lock = None
|
|
user_type_map = None
|
|
|
|
msg_received = False
|
|
|
|
is_unsupported_proto_version = False
|
|
|
|
is_control_connection = False
|
|
signaled_error = False # used for flagging at the pool level
|
|
|
|
_iobuf = None
|
|
_current_frame = None
|
|
|
|
_socket = None
|
|
|
|
_socket_impl = socket
|
|
_ssl_impl = ssl
|
|
|
|
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
|
|
ssl_options=None, sockopts=None, compression=True,
|
|
cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False,
|
|
user_type_map=None):
|
|
self.host = host
|
|
self.port = port
|
|
self.authenticator = authenticator
|
|
self.ssl_options = ssl_options
|
|
self.sockopts = sockopts
|
|
self.compression = compression
|
|
self.cql_version = cql_version
|
|
self.protocol_version = protocol_version
|
|
self.is_control_connection = is_control_connection
|
|
self.user_type_map = user_type_map
|
|
self._push_watchers = defaultdict(set)
|
|
self._requests = {}
|
|
self._iobuf = io.BytesIO()
|
|
|
|
if protocol_version >= 3:
|
|
self.max_request_id = (2 ** 15) - 1
|
|
# Don't fill the deque with 2**15 items right away. Start with 300 and add
|
|
# more if needed.
|
|
self.request_ids = deque(range(300))
|
|
self.highest_request_id = 299
|
|
else:
|
|
self.max_request_id = (2 ** 7) - 1
|
|
self.request_ids = deque(range(self.max_request_id + 1))
|
|
self.highest_request_id = self.max_request_id
|
|
|
|
self.lock = RLock()
|
|
self.connected_event = Event()
|
|
|
|
@classmethod
|
|
def initialize_reactor(cls):
|
|
"""
|
|
Called once by Cluster.connect(). This should be used by implementations
|
|
to set up any resources that will be shared across connections.
|
|
"""
|
|
pass
|
|
|
|
@classmethod
|
|
def handle_fork(cls):
|
|
"""
|
|
Called after a forking. This should cleanup any remaining reactor state
|
|
from the parent process.
|
|
"""
|
|
pass
|
|
|
|
@classmethod
|
|
def create_timer(cls, timeout, callback):
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def factory(cls, host, timeout, *args, **kwargs):
|
|
"""
|
|
A factory function which returns connections which have
|
|
succeeded in connecting and are ready for service (or
|
|
raises an exception otherwise).
|
|
"""
|
|
conn = cls(host, *args, **kwargs)
|
|
conn.connected_event.wait(timeout)
|
|
if conn.last_error:
|
|
if conn.is_unsupported_proto_version:
|
|
raise ProtocolVersionUnsupported(host, conn.protocol_version)
|
|
raise conn.last_error
|
|
elif not conn.connected_event.is_set():
|
|
conn.close()
|
|
raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout)
|
|
else:
|
|
return conn
|
|
|
|
def _connect_socket(self):
|
|
sockerr = None
|
|
addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
|
for (af, socktype, proto, canonname, sockaddr) in addresses:
|
|
try:
|
|
self._socket = self._socket_impl.socket(af, socktype, proto)
|
|
if self.ssl_options:
|
|
if not self._ssl_impl:
|
|
raise Exception("This version of Python was not compiled with SSL support")
|
|
self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options)
|
|
self._socket.settimeout(1.0)
|
|
self._socket.connect(sockaddr)
|
|
sockerr = None
|
|
break
|
|
except socket.error as err:
|
|
if self._socket:
|
|
self._socket.close()
|
|
self._socket = None
|
|
sockerr = err
|
|
|
|
if sockerr:
|
|
raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror))
|
|
|
|
if self.sockopts:
|
|
for args in self.sockopts:
|
|
self._socket.setsockopt(*args)
|
|
|
|
def close(self):
|
|
raise NotImplementedError()
|
|
|
|
def defunct(self, exc):
|
|
with self.lock:
|
|
if self.is_defunct or self.is_closed:
|
|
return
|
|
self.is_defunct = True
|
|
|
|
log.debug("Defuncting connection (%s) to %s:",
|
|
id(self), self.host, exc_info=exc)
|
|
|
|
self.last_error = exc
|
|
self.close()
|
|
self.error_all_requests(exc)
|
|
self.connected_event.set()
|
|
return exc
|
|
|
|
def error_all_requests(self, exc):
|
|
with self.lock:
|
|
requests = self._requests
|
|
self._requests = {}
|
|
|
|
if not requests:
|
|
return
|
|
|
|
new_exc = ConnectionShutdown(str(exc))
|
|
def try_callback(cb):
|
|
try:
|
|
cb(new_exc)
|
|
except Exception:
|
|
log.warning("Ignoring unhandled exception while erroring requests for a "
|
|
"failed connection (%s) to host %s:",
|
|
id(self), self.host, exc_info=True)
|
|
|
|
# run first callback from this thread to ensure pool state before leaving
|
|
cb, _ = requests.popitem()[1]
|
|
try_callback(cb)
|
|
|
|
if not requests:
|
|
return
|
|
|
|
# additional requests are optionally errored from a separate thread
|
|
# The default callback and retry logic is fairly expensive -- we don't
|
|
# want to tie up the event thread when there are many requests
|
|
def err_all_callbacks():
|
|
for cb, _ in requests.values():
|
|
try_callback(cb)
|
|
if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD:
|
|
err_all_callbacks()
|
|
else:
|
|
# daemon thread here because we want to stay decoupled from the cluster TPE
|
|
# TODO: would it make sense to just have a driver-global TPE?
|
|
t = Thread(target=err_all_callbacks)
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
def get_request_id(self):
|
|
"""
|
|
This must be called while self.lock is held.
|
|
"""
|
|
try:
|
|
return self.request_ids.popleft()
|
|
except IndexError:
|
|
self.highest_request_id += 1
|
|
# in_flight checks should guarantee this
|
|
assert self.highest_request_id <= self.max_request_id
|
|
return self.highest_request_id
|
|
|
|
def handle_pushed(self, response):
|
|
log.debug("Message pushed from server: %r", response)
|
|
for cb in self._push_watchers.get(response.event_type, []):
|
|
try:
|
|
cb(response.event_args)
|
|
except Exception:
|
|
log.exception("Pushed event handler errored, ignoring:")
|
|
|
|
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message):
|
|
if self.is_defunct:
|
|
raise ConnectionShutdown("Connection to %s is defunct" % self.host)
|
|
elif self.is_closed:
|
|
raise ConnectionShutdown("Connection to %s is closed" % self.host)
|
|
|
|
# queue the decoder function with the request
|
|
# this allows us to inject custom functions per request to encode, decode messages
|
|
self._requests[request_id] = (cb, decoder)
|
|
self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor))
|
|
return request_id
|
|
|
|
def wait_for_response(self, msg, timeout=None):
|
|
return self.wait_for_responses(msg, timeout=timeout)[0]
|
|
|
|
def wait_for_responses(self, *msgs, **kwargs):
|
|
"""
|
|
Returns a list of (success, response) tuples. If success
|
|
is False, response will be an Exception. Otherwise, response
|
|
will be the normal query response.
|
|
|
|
If fail_on_error was left as True and one of the requests
|
|
failed, the corresponding Exception will be raised.
|
|
"""
|
|
if self.is_closed or self.is_defunct:
|
|
raise ConnectionShutdown("Connection %s is already closed" % (self, ))
|
|
timeout = kwargs.get('timeout')
|
|
fail_on_error = kwargs.get('fail_on_error', True)
|
|
waiter = ResponseWaiter(self, len(msgs), fail_on_error)
|
|
|
|
# busy wait for sufficient space on the connection
|
|
messages_sent = 0
|
|
while True:
|
|
needed = len(msgs) - messages_sent
|
|
with self.lock:
|
|
available = min(needed, self.max_request_id - self.in_flight)
|
|
request_ids = [self.get_request_id() for _ in range(available)]
|
|
self.in_flight += available
|
|
|
|
for i, request_id in enumerate(request_ids):
|
|
self.send_msg(msgs[messages_sent + i],
|
|
request_id,
|
|
partial(waiter.got_response, index=messages_sent + i))
|
|
messages_sent += available
|
|
|
|
if messages_sent == len(msgs):
|
|
break
|
|
else:
|
|
if timeout is not None:
|
|
timeout -= 0.01
|
|
if timeout <= 0.0:
|
|
raise OperationTimedOut()
|
|
time.sleep(0.01)
|
|
|
|
try:
|
|
return waiter.deliver(timeout)
|
|
except OperationTimedOut:
|
|
raise
|
|
except Exception as exc:
|
|
self.defunct(exc)
|
|
raise
|
|
|
|
def register_watcher(self, event_type, callback, register_timeout=None):
|
|
"""
|
|
Register a callback for a given event type.
|
|
"""
|
|
self._push_watchers[event_type].add(callback)
|
|
self.wait_for_response(
|
|
RegisterMessage(event_list=[event_type]),
|
|
timeout=register_timeout)
|
|
|
|
def register_watchers(self, type_callback_dict, register_timeout=None):
|
|
"""
|
|
Register multiple callback/event type pairs, expressed as a dict.
|
|
"""
|
|
for event_type, callback in type_callback_dict.items():
|
|
self._push_watchers[event_type].add(callback)
|
|
self.wait_for_response(
|
|
RegisterMessage(event_list=type_callback_dict.keys()),
|
|
timeout=register_timeout)
|
|
|
|
def control_conn_disposed(self):
|
|
self.is_control_connection = False
|
|
self._push_watchers = {}
|
|
|
|
@defunct_on_error
|
|
def _read_frame_header(self):
|
|
buf = self._iobuf.getvalue()
|
|
pos = len(buf)
|
|
if pos:
|
|
version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK
|
|
if version > MAX_SUPPORTED_VERSION:
|
|
raise ProtocolError("This version of the driver does not support protocol version %d" % version)
|
|
frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2
|
|
# this frame header struct is everything after the version byte
|
|
header_size = frame_header.size + 1
|
|
if pos >= header_size:
|
|
flags, stream, op, body_len = frame_header.unpack_from(buf, 1)
|
|
if body_len < 0:
|
|
raise ProtocolError("Received negative body length: %r" % body_len)
|
|
self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size)
|
|
return pos
|
|
|
|
def _reset_frame(self):
|
|
self._iobuf = io.BytesIO(self._iobuf.read())
|
|
self._current_frame = None
|
|
|
|
def process_io_buffer(self):
|
|
while True:
|
|
if not self._current_frame:
|
|
pos = self._read_frame_header()
|
|
else:
|
|
pos = self._iobuf.tell()
|
|
|
|
if not self._current_frame or pos < self._current_frame.end_pos:
|
|
# we don't have a complete header yet or we
|
|
# already saw a header, but we don't have a
|
|
# complete message yet
|
|
return
|
|
else:
|
|
frame = self._current_frame
|
|
self._iobuf.seek(frame.body_offset)
|
|
msg = self._iobuf.read(frame.end_pos - frame.body_offset)
|
|
self.process_msg(frame, msg)
|
|
self._reset_frame()
|
|
|
|
@defunct_on_error
|
|
def process_msg(self, header, body):
|
|
stream_id = header.stream
|
|
if stream_id < 0:
|
|
callback = None
|
|
decoder = ProtocolHandler.decode_message
|
|
else:
|
|
callback, decoder = self._requests.pop(stream_id, None)
|
|
with self.lock:
|
|
self.request_ids.append(stream_id)
|
|
|
|
self.msg_received = True
|
|
|
|
try:
|
|
response = decoder(header.version, self.user_type_map, stream_id,
|
|
header.flags, header.opcode, body, self.decompressor)
|
|
except Exception as exc:
|
|
log.exception("Error decoding response from Cassandra. "
|
|
"%s; buffer: %r", header, self._iobuf.getvalue())
|
|
if callback is not None:
|
|
callback(exc)
|
|
self.defunct(exc)
|
|
return
|
|
|
|
try:
|
|
if stream_id >= 0:
|
|
if isinstance(response, ProtocolException):
|
|
if 'unsupported protocol version' in response.message:
|
|
self.is_unsupported_proto_version = True
|
|
|
|
log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg())
|
|
self.defunct(response)
|
|
if callback is not None:
|
|
callback(response)
|
|
else:
|
|
self.handle_pushed(response)
|
|
except Exception:
|
|
log.exception("Callback handler errored, ignoring:")
|
|
|
|
@defunct_on_error
|
|
def _send_options_message(self):
|
|
if self.cql_version is None and (not self.compression or not locally_supported_compressions):
|
|
log.debug("Not sending options message for new connection(%s) to %s "
|
|
"because compression is disabled and a cql version was not "
|
|
"specified", id(self), self.host)
|
|
self._compressor = None
|
|
self.cql_version = DEFAULT_CQL_VERSION
|
|
self._send_startup_message()
|
|
else:
|
|
log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host)
|
|
self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response)
|
|
|
|
@defunct_on_error
|
|
def _handle_options_response(self, options_response):
|
|
if self.is_defunct:
|
|
return
|
|
|
|
if not isinstance(options_response, SupportedMessage):
|
|
if isinstance(options_response, ConnectionException):
|
|
raise options_response
|
|
else:
|
|
log.error("Did not get expected SupportedMessage response; "
|
|
"instead, got: %s", options_response)
|
|
raise ConnectionException("Did not get expected SupportedMessage "
|
|
"response; instead, got: %s"
|
|
% (options_response,))
|
|
|
|
log.debug("Received options response on new connection (%s) from %s",
|
|
id(self), self.host)
|
|
supported_cql_versions = options_response.cql_versions
|
|
remote_supported_compressions = options_response.options['COMPRESSION']
|
|
|
|
if self.cql_version:
|
|
if self.cql_version not in supported_cql_versions:
|
|
raise ProtocolError(
|
|
"cql_version %r is not supported by remote (w/ native "
|
|
"protocol). Supported versions: %r"
|
|
% (self.cql_version, supported_cql_versions))
|
|
else:
|
|
self.cql_version = supported_cql_versions[0]
|
|
|
|
self._compressor = None
|
|
compression_type = None
|
|
if self.compression:
|
|
overlap = (set(locally_supported_compressions.keys()) &
|
|
set(remote_supported_compressions))
|
|
if len(overlap) == 0:
|
|
log.debug("No available compression types supported on both ends."
|
|
" locally supported: %r. remotely supported: %r",
|
|
locally_supported_compressions.keys(),
|
|
remote_supported_compressions)
|
|
else:
|
|
compression_type = None
|
|
if isinstance(self.compression, six.string_types):
|
|
# the user picked a specific compression type ('snappy' or 'lz4')
|
|
if self.compression not in remote_supported_compressions:
|
|
raise ProtocolError(
|
|
"The requested compression type (%s) is not supported by the Cassandra server at %s"
|
|
% (self.compression, self.host))
|
|
compression_type = self.compression
|
|
else:
|
|
# our locally supported compressions are ordered to prefer
|
|
# lz4, if available
|
|
for k in locally_supported_compressions.keys():
|
|
if k in overlap:
|
|
compression_type = k
|
|
break
|
|
|
|
# set the decompressor here, but set the compressor only after
|
|
# a successful Ready message
|
|
self._compressor, self.decompressor = \
|
|
locally_supported_compressions[compression_type]
|
|
|
|
self._send_startup_message(compression_type)
|
|
|
|
@defunct_on_error
|
|
def _send_startup_message(self, compression=None):
|
|
log.debug("Sending StartupMessage on %s", self)
|
|
opts = {}
|
|
if compression:
|
|
opts['COMPRESSION'] = compression
|
|
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
|
self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response)
|
|
log.debug("Sent StartupMessage on %s", self)
|
|
|
|
@defunct_on_error
|
|
def _handle_startup_response(self, startup_response, did_authenticate=False):
|
|
if self.is_defunct:
|
|
return
|
|
if isinstance(startup_response, ReadyMessage):
|
|
log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.host)
|
|
if self._compressor:
|
|
self.compressor = self._compressor
|
|
self.connected_event.set()
|
|
elif isinstance(startup_response, AuthenticateMessage):
|
|
log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s",
|
|
id(self), self.host, startup_response.authenticator)
|
|
|
|
if self.authenticator is None:
|
|
raise AuthenticationFailed('Remote end requires authentication.')
|
|
|
|
self.authenticator_class = startup_response.authenticator
|
|
|
|
if isinstance(self.authenticator, dict):
|
|
log.debug("Sending credentials-based auth response on %s", self)
|
|
cm = CredentialsMessage(creds=self.authenticator)
|
|
callback = partial(self._handle_startup_response, did_authenticate=True)
|
|
self.send_msg(cm, self.get_request_id(), cb=callback)
|
|
else:
|
|
log.debug("Sending SASL-based auth response on %s", self)
|
|
initial_response = self.authenticator.initial_response()
|
|
initial_response = "" if initial_response is None else initial_response
|
|
self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response)
|
|
elif isinstance(startup_response, ErrorMessage):
|
|
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
|
id(self), self.host, startup_response.summary_msg())
|
|
if did_authenticate:
|
|
raise AuthenticationFailed(
|
|
"Failed to authenticate to %s: %s" %
|
|
(self.host, startup_response.summary_msg()))
|
|
else:
|
|
raise ConnectionException(
|
|
"Failed to initialize new connection to %s: %s"
|
|
% (self.host, startup_response.summary_msg()))
|
|
elif isinstance(startup_response, ConnectionShutdown):
|
|
log.debug("Connection to %s was closed during the startup handshake", (self.host))
|
|
raise startup_response
|
|
else:
|
|
msg = "Unexpected response during Connection setup: %r"
|
|
log.error(msg, startup_response)
|
|
raise ProtocolError(msg % (startup_response,))
|
|
|
|
@defunct_on_error
|
|
def _handle_auth_response(self, auth_response):
|
|
if self.is_defunct:
|
|
return
|
|
|
|
if isinstance(auth_response, AuthSuccessMessage):
|
|
log.debug("Connection %s successfully authenticated", self)
|
|
self.authenticator.on_authentication_success(auth_response.token)
|
|
if self._compressor:
|
|
self.compressor = self._compressor
|
|
self.connected_event.set()
|
|
elif isinstance(auth_response, AuthChallengeMessage):
|
|
response = self.authenticator.evaluate_challenge(auth_response.challenge)
|
|
msg = AuthResponseMessage("" if response is None else response)
|
|
log.debug("Responding to auth challenge on %s", self)
|
|
self.send_msg(msg, self.get_request_id(), self._handle_auth_response)
|
|
elif isinstance(auth_response, ErrorMessage):
|
|
log.debug("Received ErrorMessage on new connection (%s) from %s: %s",
|
|
id(self), self.host, auth_response.summary_msg())
|
|
raise AuthenticationFailed(
|
|
"Failed to authenticate to %s: %s" %
|
|
(self.host, auth_response.summary_msg()))
|
|
elif isinstance(auth_response, ConnectionShutdown):
|
|
log.debug("Connection to %s was closed during the authentication process", self.host)
|
|
raise auth_response
|
|
else:
|
|
msg = "Unexpected response during Connection authentication to %s: %r"
|
|
log.error(msg, self.host, auth_response)
|
|
raise ProtocolError(msg % (self.host, auth_response))
|
|
|
|
def set_keyspace_blocking(self, keyspace):
|
|
if not keyspace or keyspace == self.keyspace:
|
|
return
|
|
|
|
query = QueryMessage(query='USE "%s"' % (keyspace,),
|
|
consistency_level=ConsistencyLevel.ONE)
|
|
try:
|
|
result = self.wait_for_response(query)
|
|
except InvalidRequestException as ire:
|
|
# the keyspace probably doesn't exist
|
|
raise ire.to_exception()
|
|
except Exception as exc:
|
|
conn_exc = ConnectionException(
|
|
"Problem while setting keyspace: %r" % (exc,), self.host)
|
|
self.defunct(conn_exc)
|
|
raise conn_exc
|
|
|
|
if isinstance(result, ResultMessage):
|
|
self.keyspace = keyspace
|
|
else:
|
|
conn_exc = ConnectionException(
|
|
"Problem while setting keyspace: %r" % (result,), self.host)
|
|
self.defunct(conn_exc)
|
|
raise conn_exc
|
|
|
|
def set_keyspace_async(self, keyspace, callback):
|
|
"""
|
|
Use this in order to avoid deadlocking the event loop thread.
|
|
When the operation completes, `callback` will be called with
|
|
two arguments: this connection and an Exception if an error
|
|
occurred, otherwise :const:`None`.
|
|
"""
|
|
if not keyspace or keyspace == self.keyspace:
|
|
callback(self, None)
|
|
return
|
|
|
|
query = QueryMessage(query='USE "%s"' % (keyspace,),
|
|
consistency_level=ConsistencyLevel.ONE)
|
|
|
|
def process_result(result):
|
|
if isinstance(result, ResultMessage):
|
|
self.keyspace = keyspace
|
|
callback(self, None)
|
|
elif isinstance(result, InvalidRequestException):
|
|
callback(self, result.to_exception())
|
|
else:
|
|
callback(self, self.defunct(ConnectionException(
|
|
"Problem while setting keyspace: %r" % (result,), self.host)))
|
|
|
|
request_id = None
|
|
# we use a busy wait on the lock here because:
|
|
# - we'll only spin if the connection is at max capacity, which is very
|
|
# unlikely for a set_keyspace call
|
|
# - it allows us to avoid signaling a condition every time a request completes
|
|
while True:
|
|
with self.lock:
|
|
if self.in_flight < self.max_request_id:
|
|
request_id = self.get_request_id()
|
|
self.in_flight += 1
|
|
break
|
|
|
|
time.sleep(0.001)
|
|
|
|
self.send_msg(query, request_id, process_result)
|
|
|
|
@property
|
|
def is_idle(self):
|
|
return not self.msg_received
|
|
|
|
def reset_idle(self):
|
|
self.msg_received = False
|
|
|
|
def __str__(self):
|
|
status = ""
|
|
if self.is_defunct:
|
|
status = " (defunct)"
|
|
elif self.is_closed:
|
|
status = " (closed)"
|
|
|
|
return "<%s(%r) %s:%d%s>" % (self.__class__.__name__, id(self), self.host, self.port, status)
|
|
__repr__ = __str__
|
|
|
|
|
|
class ResponseWaiter(object):
|
|
|
|
def __init__(self, connection, num_responses, fail_on_error):
|
|
self.connection = connection
|
|
self.pending = num_responses
|
|
self.fail_on_error = fail_on_error
|
|
self.error = None
|
|
self.responses = [None] * num_responses
|
|
self.event = Event()
|
|
|
|
def got_response(self, response, index):
|
|
with self.connection.lock:
|
|
self.connection.in_flight -= 1
|
|
if isinstance(response, Exception):
|
|
if hasattr(response, 'to_exception'):
|
|
response = response.to_exception()
|
|
if self.fail_on_error:
|
|
self.error = response
|
|
self.event.set()
|
|
else:
|
|
self.responses[index] = (False, response)
|
|
else:
|
|
if not self.fail_on_error:
|
|
self.responses[index] = (True, response)
|
|
else:
|
|
self.responses[index] = response
|
|
|
|
self.pending -= 1
|
|
if not self.pending:
|
|
self.event.set()
|
|
|
|
def deliver(self, timeout=None):
|
|
"""
|
|
If fail_on_error was set to False, a list of (success, response)
|
|
tuples will be returned. If success is False, response will be
|
|
an Exception. Otherwise, response will be the normal query response.
|
|
|
|
If fail_on_error was left as True and one of the requests
|
|
failed, the corresponding Exception will be raised. Otherwise,
|
|
the normal response will be returned.
|
|
"""
|
|
self.event.wait(timeout)
|
|
if self.error:
|
|
raise self.error
|
|
elif not self.event.is_set():
|
|
raise OperationTimedOut()
|
|
else:
|
|
return self.responses
|
|
|
|
|
|
class HeartbeatFuture(object):
|
|
def __init__(self, connection, owner):
|
|
self._exception = None
|
|
self._event = Event()
|
|
self.connection = connection
|
|
self.owner = owner
|
|
log.debug("Sending options message heartbeat on idle connection (%s) %s",
|
|
id(connection), connection.host)
|
|
with connection.lock:
|
|
if connection.in_flight < connection.max_request_id:
|
|
connection.in_flight += 1
|
|
connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback)
|
|
else:
|
|
self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold")
|
|
self._event.set()
|
|
|
|
def wait(self, timeout):
|
|
self._event.wait(timeout)
|
|
if self._event.is_set():
|
|
if self._exception:
|
|
raise self._exception
|
|
else:
|
|
raise OperationTimedOut()
|
|
|
|
def _options_callback(self, response):
|
|
if not isinstance(response, SupportedMessage):
|
|
if isinstance(response, ConnectionException):
|
|
self._exception = response
|
|
else:
|
|
self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s"
|
|
% (response,))
|
|
|
|
log.debug("Received options response on connection (%s) from %s",
|
|
id(self.connection), self.connection.host)
|
|
self._event.set()
|
|
|
|
|
|
class ConnectionHeartbeat(Thread):
|
|
|
|
def __init__(self, interval_sec, get_connection_holders):
|
|
Thread.__init__(self, name="Connection heartbeat")
|
|
self._interval = interval_sec
|
|
self._get_connection_holders = get_connection_holders
|
|
self._shutdown_event = Event()
|
|
self.daemon = True
|
|
self.start()
|
|
|
|
class ShutdownException(Exception):
|
|
pass
|
|
|
|
def run(self):
|
|
self._shutdown_event.wait(self._interval)
|
|
while not self._shutdown_event.is_set():
|
|
start_time = time.time()
|
|
|
|
futures = []
|
|
failed_connections = []
|
|
try:
|
|
for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]:
|
|
for connection in connections:
|
|
self._raise_if_stopped()
|
|
if not (connection.is_defunct or connection.is_closed):
|
|
if connection.is_idle:
|
|
try:
|
|
futures.append(HeartbeatFuture(connection, owner))
|
|
except Exception:
|
|
log.warning("Failed sending heartbeat message on connection (%s) to %s",
|
|
id(connection), connection.host, exc_info=True)
|
|
failed_connections.append((connection, owner))
|
|
else:
|
|
connection.reset_idle()
|
|
else:
|
|
# make sure the owner sees this defunt/closed connection
|
|
owner.return_connection(connection)
|
|
self._raise_if_stopped()
|
|
|
|
for f in futures:
|
|
self._raise_if_stopped()
|
|
connection = f.connection
|
|
try:
|
|
f.wait(self._interval)
|
|
# TODO: move this, along with connection locks in pool, down into Connection
|
|
with connection.lock:
|
|
connection.in_flight -= 1
|
|
connection.reset_idle()
|
|
except Exception:
|
|
log.warning("Heartbeat failed for connection (%s) to %s",
|
|
id(connection), connection.host, exc_info=True)
|
|
failed_connections.append((f.connection, f.owner))
|
|
|
|
for connection, owner in failed_connections:
|
|
self._raise_if_stopped()
|
|
connection.defunct(Exception('Connection heartbeat failure'))
|
|
owner.return_connection(connection)
|
|
except self.ShutdownException:
|
|
pass
|
|
except Exception:
|
|
log.error("Failed connection heartbeat", exc_info=True)
|
|
|
|
elapsed = time.time() - start_time
|
|
self._shutdown_event.wait(max(self._interval - elapsed, 0.01))
|
|
|
|
def stop(self):
|
|
self._shutdown_event.set()
|
|
self.join()
|
|
|
|
def _raise_if_stopped(self):
|
|
if self._shutdown_event.is_set():
|
|
raise self.ShutdownException()
|
|
|
|
|
|
class Timer(object):
|
|
|
|
canceled = False
|
|
|
|
def __init__(self, timeout, callback):
|
|
self.end = time.time() + timeout
|
|
self.callback = callback
|
|
if timeout < 0:
|
|
self.callback()
|
|
|
|
def cancel(self):
|
|
self.canceled = True
|
|
|
|
def finish(self, time_now):
|
|
if self.canceled:
|
|
return True
|
|
|
|
if time_now >= self.end:
|
|
self.callback()
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
class TimerManager(object):
|
|
|
|
def __init__(self):
|
|
self._queue = []
|
|
self._new_timers = []
|
|
|
|
def add_timer(self, timer):
|
|
"""
|
|
called from client thread with a Timer object
|
|
"""
|
|
self._new_timers.append((timer.end, timer))
|
|
|
|
def service_timeouts(self):
|
|
"""
|
|
run callbacks on all expired timers
|
|
Called from the event thread
|
|
:return: next end time, or None
|
|
"""
|
|
queue = self._queue
|
|
if self._new_timers:
|
|
new_timers = self._new_timers
|
|
while new_timers:
|
|
heappush(queue, new_timers.pop())
|
|
|
|
if queue:
|
|
now = time.time()
|
|
while queue:
|
|
try:
|
|
timer = queue[0][1]
|
|
if timer.finish(now):
|
|
heappop(queue)
|
|
else:
|
|
return timer.end
|
|
except Exception:
|
|
log.exception("Exception while servicing timeout callback: ")
|
|
|
|
@property
|
|
def next_timeout(self):
|
|
try:
|
|
return self._queue[0][0]
|
|
except IndexError:
|
|
pass
|