Basic working Cluster, Session, Connection, ResponseFuture

This commit is contained in:
Tyler Hobbs
2013-04-04 18:06:57 -05:00
parent 38231f4727
commit 089f05b24f
7 changed files with 294 additions and 180 deletions

View File

@@ -1,15 +1,18 @@
import asyncore
import time import time
from threading import Lock, RLock, Thread from threading import Lock, RLock, Thread, Event
import Queue import Queue
import logging
import traceback
from futures import ThreadPoolExecutor from futures import ThreadPoolExecutor
from connection import Connection from connection import Connection
from decoder import ConsistencyLevel, QueryMessage from decoder import (ConsistencyLevel, QueryMessage, ResultMessage, ErrorMessage,
ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableExceptionErrorMessage,
OverloadedErrorMessage, IsBootstrappingErrorMessage)
from metadata import Metadata from metadata import Metadata
from policies import (RoundRobinPolicy, SimpleConvictionPolicy, from policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance) ExponentialReconnectionPolicy, HostDistance, RetryPolicy)
from query import SimpleStatement from query import SimpleStatement
from pool import (ConnectionException, BusyConnectionException, from pool import (ConnectionException, BusyConnectionException,
AuthenticationException, _ReconnectionHandler, AuthenticationException, _ReconnectionHandler,
@@ -18,19 +21,27 @@ from pool import (ConnectionException, BusyConnectionException,
# TODO: we might want to make this configurable # TODO: we might want to make this configurable
MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000 MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000
log = logging.getLogger(__name__)
class _RetryingCallback(object):
def __init__(self, session, query): class ResponseFuture(object):
def __init__(self, session, message, query):
self.session = session self.session = session
self.query = query self.message = message
self.query_plan = session._load_balancer.new_query_plan(query)
self.query_plan = session._load_balancer.make_query_plan(query)
self._req_id = None self._req_id = None
self._final_result = None self._final_result = None
self._final_exception = None self._final_exception = None
self._wait_fn = None self._current_host = None
self._current_pool = None
self._connection = None
self._event = Event()
self._errors = {} self._errors = {}
self._query_retries = 0
self._callback = self._errback = None
def send_request(self): def send_request(self):
for host in self.query_plan: for host in self.query_plan:
@@ -42,32 +53,122 @@ class _RetryingCallback(object):
self._final_exception = NoHostAvailable(self._errors) self._final_exception = NoHostAvailable(self._errors)
def _query(self, host): def _query(self, host):
pool = self._session._pools.get(host) pool = self.session._pools.get(host)
if not pool or pool.is_shutdown: if not pool or pool.is_shutdown:
return None return None
connection = None
try: try:
# TODO get connectTimeout from cluster settings # TODO get connectTimeout from cluster settings
connection = pool.borrow_connection(timeout=2.0) connection = pool.borrow_connection(timeout=2.0)
request_id = connection.request_and_callback(self.query, self._set_results) request_id = connection.send_msg(self.message, cb=self._set_result)
self._wait_fn = connection.wait_for_result self._current_host = host
self._current_pool = pool
self._connection = connection
return request_id return request_id
except Exception: except Exception, exc:
return None self._errors[host] = exc
finally:
if connection: if connection:
pool.return_connection(connection) pool.return_connection(connection)
return None
def _set_results(self, response): def _set_result(self, response):
self._current_pool.return_connection(self._connection)
if isinstance(response, ResultMessage):
self._set_final_result(response)
elif isinstance(response, ErrorMessage):
retry_policy = self.query.retry_policy # TODO also check manager.configuration
# for retry policy if None
if isinstance(response, ReadTimeoutErrorMessage):
details = response.recv_error_info()
retry = retry_policy.on_read_timeout(
self.query, attempt_num=self._query_retries, **details)
elif isinstance(response, WriteTimeoutErrorMessage):
details = response.recv_error_info()
retry = retry_policy.on_write_timeout(
self.query, attempt_num=self._query_retries, **details)
elif isinstance(response, UnavailableExceptionErrorMessage):
details = response.recv_error_info()
retry = retry_policy.on_write_timeout(
self.query, attempt_num=self._query_retries, **details)
elif isinstance(response, OverloadedErrorMessage):
# need to retry against a different host here
self._retry(False, None)
elif isinstance(response, IsBootstrappingErrorMessage):
# need to retry against a different host here
self._retry(False, None)
# TODO need to define the PreparedQueryNotFound class
# elif isinstance(response, PreparedQueryNotFound):
# pass
else:
pass
retry_type, consistency = retry
if retry_type == RetryPolicy.RETRY:
self._query_retries += 1
self._retry(True, consistency)
elif retry_type == RetryPolicy.RETHROW:
self._set_final_result(response)
else: # IGNORE
self._set_final_result(None)
else:
# we got some other kind of response message
self._set_final_result(response)
def _set_final_result(self, response):
self._final_result = response self._final_result = response
self._event.set()
if self._callback:
fn, args, kwargs = self._callback
fn(response, *args, **kwargs)
def _set_final_exception(self, response):
self._final_exception = response
self._event.set()
if self._errback:
fn, args, kwargs = self._errback
fn(response, *args, **kwargs)
def _retry(self, reuse_connection, consistency_level):
self.message.consistency_level = consistency_level
# don't retry on the event loop thread
self.session.submit(self._retry_helper, reuse_connection)
def _retry_task(self, reuse_connection):
if reuse_connection and self._query(self._current_host):
return
# otherwise, move onto another host
self.send_request()
def deliver(self): def deliver(self):
if self._final_result: if self._final_exception:
return self._final_result
elif self._final_exception:
raise self._final_exception raise self._final_exception
elif self._final_result:
return self._final_result
else: else:
return self._wait_fn(self._req_id) self._event.wait()
if self._final_exception:
raise self._final_exception
else:
return self._final_result
def addCallback(self, fn, *args, **kwargs):
self._callback = (fn, args, kwargs)
return self
def addErrback(self, fn, *args, **kwargs):
self._errback = (fn, args, kwargs)
return self
def addCallbacks(self,
callback, errback,
callback_args=(), callback_kwargs=None,
errback_args=(), errback_kwargs=None):
self._callback = (callback, callback_args, callback_kwargs | {})
self._errback = (errback, errback_args, errback_kwargs | {})
class Session(object): class Session(object):
@@ -79,17 +180,29 @@ class Session(object):
self._is_shutdown = False self._is_shutdown = False
self._pools = {} self._pools = {}
self._load_balancer = RoundRobinPolicy() self._load_balancer = RoundRobinPolicy()
self._load_balancer.populate(cluster, hosts)
for host in hosts:
self.add_host(host)
def execute(self, query): def execute(self, query):
if isinstance(query, basestring): future = self.execute_async(query)
query = SimpleStatement(query) return future.deliver()
def execute_async(self, query): def execute_async(self, query):
if isinstance(query, basestring): if isinstance(query, basestring):
query = SimpleStatement(query) query = SimpleStatement(query)
qmsg = QueryMessage(query=query.query, consistency_level=query.consistency_level) # TODO bound statements need to be handled differently
return self._execute_query(qmsg, query) message = QueryMessage(query=query.query_string, consistency_level=query.consistency_level)
if query.tracing_enabled:
# TODO enable tracing on the message
pass
future = ResponseFuture(self, message, query)
future.send_request()
return future
def prepare(self, query): def prepare(self, query):
pass pass
@@ -97,25 +210,6 @@ class Session(object):
def shutdown(self): def shutdown(self):
self.cluster.shutdown() self.cluster.shutdown()
def _execute_query(self, message, query):
if query.tracing_enabled:
# TODO enable tracing on the message
pass
errors = {}
for host in self._load_balancer.make_query_plan(query):
try:
result = self._query(host)
if result:
return
except Exception, exc:
errors[host] = exc
def _query(self, host, query):
pool = self._pools.get(host)
if not pool or pool.is_shutdown:
return False
def add_host(self, host): def add_host(self, host):
distance = self._load_balancer.distance(host) distance = self._load_balancer.distance(host)
if distance == HostDistance.IGNORED: if distance == HostDistance.IGNORED:
@@ -174,8 +268,8 @@ class Session(object):
def set_keyspace(self, keyspace): def set_keyspace(self, keyspace):
pass pass
def submit(self, task): def submit(self, fn, *args, **kwargs):
return self.cluster.executor.submit(task) return self.cluster.executor.submit(fn, *args, **kwargs)
DEFAULT_MIN_REQUESTS = 25 DEFAULT_MIN_REQUESTS = 25
DEFAULT_MAX_REQUESTS = 100 DEFAULT_MAX_REQUESTS = 100
@@ -219,18 +313,6 @@ class _Scheduler(object):
time.sleep(0.1) time.sleep(0.1)
_loop_started = False
_loop_lock = Lock()
def _start_loop():
with _loop_lock:
global _loop_started
if not _loop_started:
_loop_started = True
t = Thread(target=asyncore.loop, name="Async event loop")
t.daemon = True
t.start()
class Cluster(object): class Cluster(object):
port = 9042 port = 9042
@@ -282,9 +364,6 @@ class Cluster(object):
self._is_shutdown = False self._is_shutdown = False
self._lock = Lock() self._lock = Lock()
# start the global event loop
_start_loop()
for address in contact_points: for address in contact_points:
self.add_host(address, signal=False) self.add_host(address, signal=False)
@@ -472,13 +551,16 @@ class ControlConnection(object):
except ConnectionException, exc: except ConnectionException, exc:
errors[host.address] = exc errors[host.address] = exc
host.monitor.signal_connection_failure(exc) host.monitor.signal_connection_failure(exc)
log.error("Error reconnecting control connection: %s", traceback.format_exc())
except Exception, exc: except Exception, exc:
errors[host.address] = exc errors[host.address] = exc
log.error("Error reconnecting control connection: %s", traceback.format_exc())
raise NoHostAvailable("Unable to connect to any servers", errors) raise NoHostAvailable("Unable to connect to any servers", errors)
def _try_connect(self, host): def _try_connect(self, host):
connection = self._cluster.connection_factory(host.address) connection = self._cluster.connection_factory(host.address)
connection.register_watchers({ connection.register_watchers({
"TOPOLOGY_CHANGE": self._handle_topology_change, "TOPOLOGY_CHANGE": self._handle_topology_change,
"STATUS_CHANGE": self._handle_status_change, "STATUS_CHANGE": self._handle_status_change,
@@ -486,7 +568,7 @@ class ControlConnection(object):
}) })
self.refresh_node_list_and_token_map() self.refresh_node_list_and_token_map()
self.refresh_schema() self._refresh_schema(connection)
return connection return connection
def reconnect(self): def reconnect(self):
@@ -521,6 +603,9 @@ class ControlConnection(object):
self._connection.close() self._connection.close()
def refresh_schema(self, keyspace=None, table=None): def refresh_schema(self, keyspace=None, table=None):
return self._refresh_schema(self._connection, keyspace, table)
def _refresh_schema(self, connection, keyspace=None, table=None):
where_clause = "" where_clause = ""
if keyspace: if keyspace:
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,) where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
@@ -531,15 +616,15 @@ class ControlConnection(object):
if table: if table:
ks_query = None ks_query = None
else: else:
ks_query = QueryMessage(query=self.SELECT_KEYSPACES + where_clause, consistency_level=cl) ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
cf_query = QueryMessage(query=self.SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
col_query = QueryMessage(query=self.SELECT_COLUMNS + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl)
if ks_query: if ks_query:
ks_result, cf_result, col_result = self._connection.wait_for_requests(ks_query, cf_query, col_query) ks_result, cf_result, col_result = connection.wait_for_responses(ks_query, cf_query, col_query)
else: else:
ks_result = None ks_result = None
cf_result, col_result = self._connection.wait_for_requests(cf_query, col_query) cf_result, col_result = connection.wait_for_responses(cf_query, col_query)
self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result) self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result)
@@ -549,10 +634,10 @@ class ControlConnection(object):
return return
cl = ConsistencyLevel.ONE cl = ConsistencyLevel.ONE
peers_query = QueryMessage(query=self.SELECT_PEERS, consistency_level=cl) peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self.SELECT_LOCAL, consistency_level=cl) local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl)
try: try:
peers_result, local_result = conn.wait_for_requests(peers_query, local_query) peers_result, local_result = conn.wait_for_responses(peers_query, local_query)
except (ConnectionException, BusyConnectionException): except (ConnectionException, BusyConnectionException):
self.reconnect() self.reconnect()
@@ -649,9 +734,9 @@ class ControlConnection(object):
elapsed = 0 elapsed = 0
cl = ConsistencyLevel.ONE cl = ConsistencyLevel.ONE
while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS: while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS:
peers_query = QueryMessage(query=self.SELECT_SCHEMA_PEERS, consistency_level=cl) peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self.SELECT_SCHEMA_LOCAL, consistency_level=cl) local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl)
peers_result, local_result = self._connection.wait_for_requests(peers_query, local_query) peers_result, local_result = self._connection.wait_for_responses(peers_query, local_query)
versions = set() versions = set()
if local_result and local_result.rows: if local_result and local_result.rows:

View File

@@ -1,14 +1,21 @@
import asyncore
import asynchat import asynchat
from collections import defaultdict
from functools import partial
import itertools import itertools
import logging
import socket import socket
from threading import RLock, Event, Lock, Thread
import traceback
from threading import RLock, Event
from cassandra.marshal import (int8_unpack, int32_unpack) from cassandra.marshal import (int8_unpack, int32_unpack)
from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage, from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage,
StartupMessage, ErrorMessage, CredentialsMessage, StartupMessage, ErrorMessage, CredentialsMessage,
decode_response) decode_response)
log = logging.getLogger(__name__)
locally_supported_compressions = {} locally_supported_compressions = {}
try: try:
import snappy import snappy
@@ -45,6 +52,22 @@ class InternalError(Exception):
pass pass
_loop_started = False
_loop_lock = Lock()
def _start_loop():
global _loop_started
should_start = False
with _loop_lock:
if not _loop_started:
_loop_started = True
should_start = True
if should_start:
t = Thread(target=asyncore.loop, name="async_event_loop", kwargs={"timeout": 0.01})
t.daemon = True
t.start()
class Connection(asynchat.async_chat): class Connection(asynchat.async_chat):
@classmethod @classmethod
@@ -67,16 +90,19 @@ class Connection(asynchat.async_chat):
self.decompressor = None self.decompressor = None
self.connected_event = Event() self.connected_event = Event()
self.in_flight = 0
self.create_socket(socket.AF_INET, socket.SOCK_STREAM) self.is_defunct = False
self.connect((host, port))
self.make_request_id = itertools.cycle(xrange(127)).next self.make_request_id = itertools.cycle(xrange(127)).next
self._waiting_callbacks = {} self._callbacks = {}
self._waiting_events = {} self._push_watchers = defaultdict(set)
self._responses = {}
self._lock = RLock() self._lock = RLock()
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
# start the global event loop if needed
_start_loop()
self.connect((host, port))
def handle_read(self): def handle_read(self):
# simpler to do here than collect_incoming_data() # simpler to do here than collect_incoming_data()
header = self.recv(8) header = self.recv(8)
@@ -89,30 +115,62 @@ class Connection(asynchat.async_chat):
assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len
body = self.recv(body_len) body = self.recv(body_len)
response = decode_response(stream_id, flags, opcode, body, self.decompressor)
if stream_id < 0: if stream_id < 0:
self.handle_pushed(stream_id, flags, opcode, body) self.handle_pushed(response)
else: else:
try: self._callbacks.pop(stream_id)(response)
cb = self._waiting_callbacks[stream_id]
except KeyError:
# store the response in a location accessible by other threads
self._responses[stream_id] = (flags, opcode, body)
# signal to the waiting thread that the response is ready
self._waiting_events[stream_id].set()
else:
cb(stream_id, flags, opcode, body)
def readable(self): def readable(self):
# TODO only return True if we have pending responses (except for ControlConnections?) # TODO this isn't accurate for control connections
return True return bool(self._callbacks)
def handle_error(self):
log.error(traceback.format_exc())
self.is_defunct = True
def handle_pushed(self, response):
pass
# details = response.recv_body
# for cb in self._push_watchers[response]
def push(self, data):
# overridden to avoid calling initiate_send() at the end of this
# and hold the lock
sabs = self.ac_out_buffer_size
with self._lock:
if len(data) > sabs:
for i in xrange(0, len(data), sabs):
self.producer_fifo.append(data[i:i + sabs])
else:
self.producer_fifo.append(data)
def send_msg(self, msg, cb):
request_id = self.make_request_id()
self._callbacks[request_id] = cb
self.push(msg.to_string(request_id, compression=self.compressor))
return request_id
def wait_for_responses(self, *msgs):
waiter = ResponseWaiter(len(msgs))
for i, msg in enumerate(msgs):
self.send_msg(msg, partial(waiter.got_response, index=i))
waiter.event.wait()
return waiter.responses
def register_watcher(self, event_type, callback):
self._push_watchers[event_type].add(callback)
def register_watchers(self, type_callback_dict):
for event_type, callback in type_callback_dict.items():
self.register_watcher(event_type, callback)
def handle_connect(self): def handle_connect(self):
log.debug("Sending initial message for new Connection to %s" % self.host)
self.send_msg(OptionsMessage(), self._handle_options_response) self.send_msg(OptionsMessage(), self._handle_options_response)
def _handle_options_response(self, stream_id, flags, opcode, body): def _handle_options_response(self, options_response):
options_response = decode_response(stream_id, flags, opcode, body) log.debug("Received options response on new Connection from %s" % self.host)
self.supported_cql_versions = options_response.cql_versions self.supported_cql_versions = options_response.cql_versions
self.remote_supported_compressions = options_response.options['COMPRESSION'] self.remote_supported_compressions = options_response.options['COMPRESSION']
@@ -147,58 +205,39 @@ class Connection(asynchat.async_chat):
sm = StartupMessage(cqlversion=self.cql_version, options=opts) sm = StartupMessage(cqlversion=self.cql_version, options=opts)
self.send_msg(sm, cb=self._handle_startup_response) self.send_msg(sm, cb=self._handle_startup_response)
def _handle_startup_response(self, stream_id, flags, opcode, body): def _handle_startup_response(self, startup_response):
startup_response = decode_response(
stream_id, flags, opcode, body, self.decompressor)
if isinstance(startup_response, ReadyMessage): if isinstance(startup_response, ReadyMessage):
log.debug("Got ReadyMessage on new Connection from %s" % self.host)
if self._compresstype: if self._compresstype:
self.compressor = self._compressor self.compressor = self._compressor
self.connected_event.set() self.connected_event.set()
elif isinstance(startup_response, AuthenticateMessage): elif isinstance(startup_response, AuthenticateMessage):
log.debug("Got AuthenticateMessage on new Connection from %s" % self.host)
self.authenticator = startup_response.authenticator self.authenticator = startup_response.authenticator
if self.credentials is None: if self.credentials is None:
raise ProgrammingError('Remote end requires authentication.') raise ProgrammingError('Remote end requires authentication.')
cm = CredentialsMessage(creds=self.credentials) cm = CredentialsMessage(creds=self.credentials)
self.send_msg(cm, cb=self._handle_startup_response) self.send_msg(cm, cb=self._handle_startup_response)
elif isinstance(startup_response, ErrorMessage): elif isinstance(startup_response, ErrorMessage):
log.debug("Received ErrorMessage on new Connection from %s" % self.host)
raise ProgrammingError("Server did not accept credentials. %s" raise ProgrammingError("Server did not accept credentials. %s"
% startup_response.summary_msg()) % startup_response.summary_msg())
else: else:
raise InternalError("Unexpected response %r during connection setup" raise InternalError("Unexpected response %r during connection setup"
% startup_response) % startup_response)
# def handle_error(self): def set_keyspace(self, keyspace):
# print "got connection error" # TODO
# self.close()
def handle_pushed(self, stream_id, flags, opcode, body):
pass pass
def push(self, data): class ResponseWaiter(object):
# overridden to avoid calling initiate_send() at the end of this
# and hold the lock
sabs = self.ac_out_buffer_size
with self._lock:
if len(data) > sabs:
for i in xrange(0, len(data), sabs):
self.producer_fifo.append(data[i:i + sabs])
else:
self.producer_fifo.append(data)
def send_msg(self, msg, cb=None): def __init__(self, num_responses):
request_id = self.make_request_id() self.pending = num_responses
if cb: self.responses = [None] * num_responses
self._waiting_callbacks[request_id] = cb self.event = Event()
else:
self._waiting_events[request_id] = Event()
self.push(msg.to_string(request_id, compression=self.compressor)) def got_response(self, response, index):
return request_id self.responses[index] = response
self.pending -= 1
def get_response(self, stream_id): if not self.pending:
""" Blocking wait for a response """ self.event.set()
# TODO waiting on the event in the loop thread will deadlock
self._waiting_events.pop(stream_id).wait()
(flags, opcode, body) = self._responses.pop(stream_id)
return decode_response(stream_id, flags, opcode, body, self.decompressor)

View File

@@ -79,24 +79,6 @@ ConsistencyLevel.name_to_value = {
} }
class CqlResult:
def __init__(self, column_metadata, rows):
self.column_metadata = column_metadata
self.rows = rows
def __iter__(self):
return iter(self.rows)
# TODO this definitely needs to be abstracted at some other level
def as_dicts(self):
colnames = [c[2] for c in self.column_metadata]
return [dict(zip(colnames, row)) for row in self.rows]
def __str__(self):
return '<CqlResult: column_metadata=%r, rows=%r>' \
% (self.column_metadata, self.rows)
__repr__ = __str__
class PreparedResult: class PreparedResult:
def __init__(self, queryid, param_metadata): def __init__(self, queryid, param_metadata):
self.queryid = queryid self.queryid = queryid
@@ -253,9 +235,9 @@ class UnavailableExceptionErrorMessage(RequestExecutionException):
@staticmethod @staticmethod
def recv_error_info(f): def recv_error_info(f):
return { return {
'consistency_level': read_consistency_level(f), 'consistency': read_consistency_level(f),
'required': read_int(f), 'required_replicas': read_int(f),
'alive': read_int(f), 'alive_replicas': read_int(f),
} }
class OverloadedErrorMessage(RequestExecutionException): class OverloadedErrorMessage(RequestExecutionException):
@@ -280,10 +262,10 @@ class WriteTimeoutErrorMessage(RequestTimeoutException):
@staticmethod @staticmethod
def recv_error_info(f): def recv_error_info(f):
return { return {
'consistency_level': read_consistency_level(f), 'consistency': read_consistency_level(f),
'received': read_int(f), 'received_responses': read_int(f),
'blockfor': read_int(f), 'required_responses': read_int(f),
'writetype': read_string(f), 'write_type': read_string(f),
} }
class ReadTimeoutErrorMessage(RequestTimeoutException): class ReadTimeoutErrorMessage(RequestTimeoutException):
@@ -293,10 +275,10 @@ class ReadTimeoutErrorMessage(RequestTimeoutException):
@staticmethod @staticmethod
def recv_error_info(f): def recv_error_info(f):
return { return {
'consistency_level': read_consistency_level(f), 'consistency': read_consistency_level(f),
'received': read_int(f), 'received_responses': read_int(f),
'blockfor': read_int(f), 'required_responses': read_int(f),
'data_present': bool(read_byte(f)), 'data_retrieved': bool(read_byte(f)),
} }
class SyntaxException(RequestValidationException): class SyntaxException(RequestValidationException):
@@ -455,7 +437,10 @@ class ResultMessage(_MessageType):
colspecs = cls.recv_results_metadata(f) colspecs = cls.recv_results_metadata(f)
rowcount = read_int(f) rowcount = read_int(f)
rows = [cls.recv_row(f, len(colspecs)) for x in xrange(rowcount)] rows = [cls.recv_row(f, len(colspecs)) for x in xrange(rowcount)]
return CqlResult(column_metadata=colspecs, rows=rows) colnames = [c[2] for c in colspecs]
coltypes = [c[3] for c in colspecs]
return [dict(zip(colnames, [ctype.from_binary(val) for ctype, val in zip(coltypes, row)]))
for row in rows]
@classmethod @classmethod
def recv_results_prepared(cls, f): def recv_results_prepared(cls, f):
@@ -596,7 +581,7 @@ def read_consistency_level(f):
return ConsistencyLevel.value_to_name[read_short(f)] return ConsistencyLevel.value_to_name[read_short(f)]
def write_consistency_level(f, cl): def write_consistency_level(f, cl):
write_short(f, ConsistencyLevel.name_to_value[cl]) write_short(f, cl)
def read_string(f): def read_string(f):
size = read_short(f) size = read_short(f)

View File

@@ -1,3 +1,4 @@
import json
from collections import defaultdict from collections import defaultdict
from threading import RLock from threading import RLock
@@ -20,10 +21,10 @@ class Metadata(object):
cf_def_rows = defaultdict(list) cf_def_rows = defaultdict(list)
col_def_rows = defaultdict(lambda: defaultdict(list)) col_def_rows = defaultdict(lambda: defaultdict(list))
for row in cf_results: for row in cf_results.results:
cf_def_rows[row["keyspace_name"]].append(row) cf_def_rows[row["keyspace_name"]].append(row)
for row in col_results: for row in col_results.results:
ksname = row["keyspace_name"] ksname = row["keyspace_name"]
cfname = row["columnfamily_name"] cfname = row["columnfamily_name"]
col_def_rows[ksname][cfname].append(row) col_def_rows[ksname][cfname].append(row)
@@ -32,13 +33,13 @@ class Metadata(object):
if not table: if not table:
# ks_results is not None # ks_results is not None
added_keyspaces = set() added_keyspaces = set()
for row in ks_results: for row in ks_results.results:
keyspace_meta = self._build_keyspace_metadata(row) keyspace_meta = self._build_keyspace_metadata(row)
ksname = keyspace_meta.name ksname = keyspace_meta.name
if ksname in cf_def_rows: if ksname in cf_def_rows:
for table_row in cf_def_rows[keyspace_meta.name]: for table_row in cf_def_rows[keyspace_meta.name]:
table_meta = self._build_table_metadata( table_meta = self._build_table_metadata(
keyspace_meta, table_row, col_def_rows[keyspace_meta.name]) keyspace_meta, table_row, col_def_rows[keyspace_meta.name])
keyspace_meta.tables[table_meta.name] = table_meta keyspace_meta.tables[table_meta.name] = table_meta
added_keyspaces.add(keyspace_meta.name) added_keyspaces.add(keyspace_meta.name)
@@ -67,7 +68,7 @@ class Metadata(object):
name = row["keyspace_name"] name = row["keyspace_name"]
durable_writes = row["durable_writes"] durable_writes = row["durable_writes"]
strategy_class = row["strategy_class"] strategy_class = row["strategy_class"]
strategy_options = row["strategy_options"] strategy_options = json.loads(row["strategy_options"])
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
def _build_table_metadata(self, keyspace_metadata, row, col_rows): def _build_table_metadata(self, keyspace_metadata, row, col_rows):
@@ -143,7 +144,7 @@ class Metadata(object):
# value alias (if present) # value alias (if present)
if has_value: if has_value:
validator = cqltypes.lookup_casstype(row["default_validator"]) validator = cqltypes.lookup_casstype(row["default_validator"])
if not row.get["key_aliases"]: if not row.get("key_aliases"):
value_alias = "value" value_alias = "value"
else: else:
value_alias = row["value_alias"] value_alias = row["value_alias"]
@@ -156,7 +157,7 @@ class Metadata(object):
column_meta = self._build_column_metadata(table_meta, row) column_meta = self._build_column_metadata(table_meta, row)
table_meta.columns[column_meta.name] = column_meta table_meta.columns[column_meta.name] = column_meta
table_meta.options = self._build_table_options(is_compact) table_meta.options = self._build_table_options(row, is_compact)
return table_meta return table_meta
def _build_table_options(self, row, is_compact_storage): def _build_table_options(self, row, is_compact_storage):

View File

@@ -21,13 +21,15 @@ class LoadBalancingPolicy(object):
class RoundRobinPolicy(LoadBalancingPolicy): class RoundRobinPolicy(LoadBalancingPolicy):
def __init__(self):
self._lock = RLock()
def populate(self, cluster, hosts): def populate(self, cluster, hosts):
self._live_hosts = set(hosts) self._live_hosts = set(hosts)
if len(hosts) == 1: if len(hosts) == 1:
self._position = 0 self._position = 0
else: else:
self._position = randint(0, len(hosts) - 1) self._position = randint(0, len(hosts) - 1)
self._lock = RLock()
def distance(self, host): def distance(self, host):
return HostDistance.LOCAL return HostDistance.LOCAL

View File

@@ -204,7 +204,7 @@ class HostConnectionPool(object):
# TODO potentially use threading.Queue for this # TODO potentially use threading.Queue for this
core_conns = session.cluster.get_core_connections_per_host(host_distance) core_conns = session.cluster.get_core_connections_per_host(host_distance)
self._connections = [session.connection_factory(host) self._connections = [session.cluster.connection_factory(host.address)
for i in range(core_conns)] for i in range(core_conns)]
self._trash = set() self._trash = set()
self._open_count = len(self._connections) self._open_count = len(self._connections)
@@ -249,7 +249,7 @@ class HostConnectionPool(object):
least_busy = self._wait_for_conn(timeout) least_busy = self._wait_for_conn(timeout)
break break
least_busy.set_keyspace() # TODO get keyspace from pool state least_busy.set_keyspace("keyspace") # TODO get keyspace from pool state
return least_busy return least_busy
def _create_new_connection(self): def _create_new_connection(self):
@@ -283,16 +283,16 @@ class HostConnectionPool(object):
return False return False
def _await_available_conn(self, timeout): def _await_available_conn(self, timeout):
with self._available_conn_condition: with self._conn_available_condition:
self._available_conn_condition.wait(timeout) self._conn_available_condition.wait(timeout)
def _signal_available_conn(self): def _signal_available_conn(self):
with self._available_conn_condition: with self._conn_available_condition:
self._available_conn_condition.notify() self._conn_available_condition.notify()
def _signal_all_available_conn(self): def _signal_all_available_conn(self):
with self._available_conn_condition: with self._conn_available_condition:
self._available_conn_condition.notify_all() self._conn_available_condition.notify_all()
def _wait_for_conn(self, timeout): def _wait_for_conn(self, timeout):
start = time.time() start = time.time()

View File

@@ -12,10 +12,8 @@ class Query(object):
class SimpleStatement(Query): class SimpleStatement(Query):
query = None def __init__(self, query_string):
self._query_string = query_string
def __init__(self, query):
self.query = query
self._routing_key = None self._routing_key = None
@property @property
@@ -26,3 +24,7 @@ class SimpleStatement(Query):
def set_routing_key(self, value): def set_routing_key(self, value):
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0) self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
for component in value) for component in value)
@property
def query_string(self):
return self._query_string