Basic working Cluster, Session, Connection, ResponseFuture
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
import asyncore
|
||||
import time
|
||||
from threading import Lock, RLock, Thread
|
||||
from threading import Lock, RLock, Thread, Event
|
||||
import Queue
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from futures import ThreadPoolExecutor
|
||||
|
||||
from connection import Connection
|
||||
from decoder import ConsistencyLevel, QueryMessage
|
||||
from decoder import (ConsistencyLevel, QueryMessage, ResultMessage, ErrorMessage,
|
||||
ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableExceptionErrorMessage,
|
||||
OverloadedErrorMessage, IsBootstrappingErrorMessage)
|
||||
from metadata import Metadata
|
||||
from policies import (RoundRobinPolicy, SimpleConvictionPolicy,
|
||||
ExponentialReconnectionPolicy, HostDistance)
|
||||
ExponentialReconnectionPolicy, HostDistance, RetryPolicy)
|
||||
from query import SimpleStatement
|
||||
from pool import (ConnectionException, BusyConnectionException,
|
||||
AuthenticationException, _ReconnectionHandler,
|
||||
@@ -18,19 +21,27 @@ from pool import (ConnectionException, BusyConnectionException,
|
||||
# TODO: we might want to make this configurable
|
||||
MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
class _RetryingCallback(object):
|
||||
|
||||
def __init__(self, session, query):
|
||||
class ResponseFuture(object):
|
||||
|
||||
def __init__(self, session, message, query):
|
||||
self.session = session
|
||||
self.query = query
|
||||
self.query_plan = session._load_balancer.new_query_plan(query)
|
||||
self.message = message
|
||||
|
||||
self.query_plan = session._load_balancer.make_query_plan(query)
|
||||
|
||||
self._req_id = None
|
||||
self._final_result = None
|
||||
self._final_exception = None
|
||||
self._wait_fn = None
|
||||
self._current_host = None
|
||||
self._current_pool = None
|
||||
self._connection = None
|
||||
self._event = Event()
|
||||
self._errors = {}
|
||||
self._query_retries = 0
|
||||
self._callback = self._errback = None
|
||||
|
||||
def send_request(self):
|
||||
for host in self.query_plan:
|
||||
@@ -42,32 +53,122 @@ class _RetryingCallback(object):
|
||||
self._final_exception = NoHostAvailable(self._errors)
|
||||
|
||||
def _query(self, host):
|
||||
pool = self._session._pools.get(host)
|
||||
pool = self.session._pools.get(host)
|
||||
if not pool or pool.is_shutdown:
|
||||
return None
|
||||
|
||||
connection = None
|
||||
try:
|
||||
# TODO get connectTimeout from cluster settings
|
||||
connection = pool.borrow_connection(timeout=2.0)
|
||||
request_id = connection.request_and_callback(self.query, self._set_results)
|
||||
self._wait_fn = connection.wait_for_result
|
||||
request_id = connection.send_msg(self.message, cb=self._set_result)
|
||||
self._current_host = host
|
||||
self._current_pool = pool
|
||||
self._connection = connection
|
||||
return request_id
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
except Exception, exc:
|
||||
self._errors[host] = exc
|
||||
if connection:
|
||||
pool.return_connection(connection)
|
||||
return None
|
||||
|
||||
def _set_results(self, response):
|
||||
def _set_result(self, response):
|
||||
self._current_pool.return_connection(self._connection)
|
||||
|
||||
if isinstance(response, ResultMessage):
|
||||
self._set_final_result(response)
|
||||
elif isinstance(response, ErrorMessage):
|
||||
retry_policy = self.query.retry_policy # TODO also check manager.configuration
|
||||
# for retry policy if None
|
||||
if isinstance(response, ReadTimeoutErrorMessage):
|
||||
details = response.recv_error_info()
|
||||
retry = retry_policy.on_read_timeout(
|
||||
self.query, attempt_num=self._query_retries, **details)
|
||||
elif isinstance(response, WriteTimeoutErrorMessage):
|
||||
details = response.recv_error_info()
|
||||
retry = retry_policy.on_write_timeout(
|
||||
self.query, attempt_num=self._query_retries, **details)
|
||||
elif isinstance(response, UnavailableExceptionErrorMessage):
|
||||
details = response.recv_error_info()
|
||||
retry = retry_policy.on_write_timeout(
|
||||
self.query, attempt_num=self._query_retries, **details)
|
||||
elif isinstance(response, OverloadedErrorMessage):
|
||||
# need to retry against a different host here
|
||||
self._retry(False, None)
|
||||
elif isinstance(response, IsBootstrappingErrorMessage):
|
||||
# need to retry against a different host here
|
||||
self._retry(False, None)
|
||||
# TODO need to define the PreparedQueryNotFound class
|
||||
# elif isinstance(response, PreparedQueryNotFound):
|
||||
# pass
|
||||
else:
|
||||
pass
|
||||
|
||||
retry_type, consistency = retry
|
||||
if retry_type == RetryPolicy.RETRY:
|
||||
self._query_retries += 1
|
||||
self._retry(True, consistency)
|
||||
elif retry_type == RetryPolicy.RETHROW:
|
||||
self._set_final_result(response)
|
||||
else: # IGNORE
|
||||
self._set_final_result(None)
|
||||
else:
|
||||
# we got some other kind of response message
|
||||
self._set_final_result(response)
|
||||
|
||||
def _set_final_result(self, response):
|
||||
self._final_result = response
|
||||
self._event.set()
|
||||
if self._callback:
|
||||
fn, args, kwargs = self._callback
|
||||
fn(response, *args, **kwargs)
|
||||
|
||||
def _set_final_exception(self, response):
|
||||
self._final_exception = response
|
||||
self._event.set()
|
||||
if self._errback:
|
||||
fn, args, kwargs = self._errback
|
||||
fn(response, *args, **kwargs)
|
||||
|
||||
def _retry(self, reuse_connection, consistency_level):
|
||||
self.message.consistency_level = consistency_level
|
||||
# don't retry on the event loop thread
|
||||
self.session.submit(self._retry_helper, reuse_connection)
|
||||
|
||||
def _retry_task(self, reuse_connection):
|
||||
if reuse_connection and self._query(self._current_host):
|
||||
return
|
||||
|
||||
# otherwise, move onto another host
|
||||
self.send_request()
|
||||
|
||||
def deliver(self):
|
||||
if self._final_result:
|
||||
return self._final_result
|
||||
elif self._final_exception:
|
||||
if self._final_exception:
|
||||
raise self._final_exception
|
||||
elif self._final_result:
|
||||
return self._final_result
|
||||
else:
|
||||
return self._wait_fn(self._req_id)
|
||||
self._event.wait()
|
||||
if self._final_exception:
|
||||
raise self._final_exception
|
||||
else:
|
||||
return self._final_result
|
||||
|
||||
def addCallback(self, fn, *args, **kwargs):
|
||||
self._callback = (fn, args, kwargs)
|
||||
return self
|
||||
|
||||
def addErrback(self, fn, *args, **kwargs):
|
||||
self._errback = (fn, args, kwargs)
|
||||
return self
|
||||
|
||||
def addCallbacks(self,
|
||||
callback, errback,
|
||||
callback_args=(), callback_kwargs=None,
|
||||
errback_args=(), errback_kwargs=None):
|
||||
self._callback = (callback, callback_args, callback_kwargs | {})
|
||||
self._errback = (errback, errback_args, errback_kwargs | {})
|
||||
|
||||
|
||||
class Session(object):
|
||||
|
||||
@@ -79,17 +180,29 @@ class Session(object):
|
||||
self._is_shutdown = False
|
||||
self._pools = {}
|
||||
self._load_balancer = RoundRobinPolicy()
|
||||
self._load_balancer.populate(cluster, hosts)
|
||||
|
||||
for host in hosts:
|
||||
self.add_host(host)
|
||||
|
||||
def execute(self, query):
|
||||
if isinstance(query, basestring):
|
||||
query = SimpleStatement(query)
|
||||
future = self.execute_async(query)
|
||||
return future.deliver()
|
||||
|
||||
def execute_async(self, query):
|
||||
if isinstance(query, basestring):
|
||||
query = SimpleStatement(query)
|
||||
|
||||
qmsg = QueryMessage(query=query.query, consistency_level=query.consistency_level)
|
||||
return self._execute_query(qmsg, query)
|
||||
# TODO bound statements need to be handled differently
|
||||
message = QueryMessage(query=query.query_string, consistency_level=query.consistency_level)
|
||||
|
||||
if query.tracing_enabled:
|
||||
# TODO enable tracing on the message
|
||||
pass
|
||||
|
||||
future = ResponseFuture(self, message, query)
|
||||
future.send_request()
|
||||
return future
|
||||
|
||||
def prepare(self, query):
|
||||
pass
|
||||
@@ -97,25 +210,6 @@ class Session(object):
|
||||
def shutdown(self):
|
||||
self.cluster.shutdown()
|
||||
|
||||
def _execute_query(self, message, query):
|
||||
if query.tracing_enabled:
|
||||
# TODO enable tracing on the message
|
||||
pass
|
||||
|
||||
errors = {}
|
||||
for host in self._load_balancer.make_query_plan(query):
|
||||
try:
|
||||
result = self._query(host)
|
||||
if result:
|
||||
return
|
||||
except Exception, exc:
|
||||
errors[host] = exc
|
||||
|
||||
def _query(self, host, query):
|
||||
pool = self._pools.get(host)
|
||||
if not pool or pool.is_shutdown:
|
||||
return False
|
||||
|
||||
def add_host(self, host):
|
||||
distance = self._load_balancer.distance(host)
|
||||
if distance == HostDistance.IGNORED:
|
||||
@@ -174,8 +268,8 @@ class Session(object):
|
||||
def set_keyspace(self, keyspace):
|
||||
pass
|
||||
|
||||
def submit(self, task):
|
||||
return self.cluster.executor.submit(task)
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
return self.cluster.executor.submit(fn, *args, **kwargs)
|
||||
|
||||
DEFAULT_MIN_REQUESTS = 25
|
||||
DEFAULT_MAX_REQUESTS = 100
|
||||
@@ -219,18 +313,6 @@ class _Scheduler(object):
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
_loop_started = False
|
||||
_loop_lock = Lock()
|
||||
def _start_loop():
|
||||
with _loop_lock:
|
||||
global _loop_started
|
||||
if not _loop_started:
|
||||
_loop_started = True
|
||||
t = Thread(target=asyncore.loop, name="Async event loop")
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
|
||||
class Cluster(object):
|
||||
|
||||
port = 9042
|
||||
@@ -282,9 +364,6 @@ class Cluster(object):
|
||||
self._is_shutdown = False
|
||||
self._lock = Lock()
|
||||
|
||||
# start the global event loop
|
||||
_start_loop()
|
||||
|
||||
for address in contact_points:
|
||||
self.add_host(address, signal=False)
|
||||
|
||||
@@ -472,13 +551,16 @@ class ControlConnection(object):
|
||||
except ConnectionException, exc:
|
||||
errors[host.address] = exc
|
||||
host.monitor.signal_connection_failure(exc)
|
||||
log.error("Error reconnecting control connection: %s", traceback.format_exc())
|
||||
except Exception, exc:
|
||||
errors[host.address] = exc
|
||||
log.error("Error reconnecting control connection: %s", traceback.format_exc())
|
||||
|
||||
raise NoHostAvailable("Unable to connect to any servers", errors)
|
||||
|
||||
def _try_connect(self, host):
|
||||
connection = self._cluster.connection_factory(host.address)
|
||||
|
||||
connection.register_watchers({
|
||||
"TOPOLOGY_CHANGE": self._handle_topology_change,
|
||||
"STATUS_CHANGE": self._handle_status_change,
|
||||
@@ -486,7 +568,7 @@ class ControlConnection(object):
|
||||
})
|
||||
|
||||
self.refresh_node_list_and_token_map()
|
||||
self.refresh_schema()
|
||||
self._refresh_schema(connection)
|
||||
return connection
|
||||
|
||||
def reconnect(self):
|
||||
@@ -521,6 +603,9 @@ class ControlConnection(object):
|
||||
self._connection.close()
|
||||
|
||||
def refresh_schema(self, keyspace=None, table=None):
|
||||
return self._refresh_schema(self._connection, keyspace, table)
|
||||
|
||||
def _refresh_schema(self, connection, keyspace=None, table=None):
|
||||
where_clause = ""
|
||||
if keyspace:
|
||||
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
|
||||
@@ -531,15 +616,15 @@ class ControlConnection(object):
|
||||
if table:
|
||||
ks_query = None
|
||||
else:
|
||||
ks_query = QueryMessage(query=self.SELECT_KEYSPACES + where_clause, consistency_level=cl)
|
||||
cf_query = QueryMessage(query=self.SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
|
||||
col_query = QueryMessage(query=self.SELECT_COLUMNS + where_clause, consistency_level=cl)
|
||||
ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
|
||||
cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
|
||||
col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl)
|
||||
|
||||
if ks_query:
|
||||
ks_result, cf_result, col_result = self._connection.wait_for_requests(ks_query, cf_query, col_query)
|
||||
ks_result, cf_result, col_result = connection.wait_for_responses(ks_query, cf_query, col_query)
|
||||
else:
|
||||
ks_result = None
|
||||
cf_result, col_result = self._connection.wait_for_requests(cf_query, col_query)
|
||||
cf_result, col_result = connection.wait_for_responses(cf_query, col_query)
|
||||
|
||||
self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result)
|
||||
|
||||
@@ -549,10 +634,10 @@ class ControlConnection(object):
|
||||
return
|
||||
|
||||
cl = ConsistencyLevel.ONE
|
||||
peers_query = QueryMessage(query=self.SELECT_PEERS, consistency_level=cl)
|
||||
local_query = QueryMessage(query=self.SELECT_LOCAL, consistency_level=cl)
|
||||
peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl)
|
||||
local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl)
|
||||
try:
|
||||
peers_result, local_result = conn.wait_for_requests(peers_query, local_query)
|
||||
peers_result, local_result = conn.wait_for_responses(peers_query, local_query)
|
||||
except (ConnectionException, BusyConnectionException):
|
||||
self.reconnect()
|
||||
|
||||
@@ -649,9 +734,9 @@ class ControlConnection(object):
|
||||
elapsed = 0
|
||||
cl = ConsistencyLevel.ONE
|
||||
while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS:
|
||||
peers_query = QueryMessage(query=self.SELECT_SCHEMA_PEERS, consistency_level=cl)
|
||||
local_query = QueryMessage(query=self.SELECT_SCHEMA_LOCAL, consistency_level=cl)
|
||||
peers_result, local_result = self._connection.wait_for_requests(peers_query, local_query)
|
||||
peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl)
|
||||
local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl)
|
||||
peers_result, local_result = self._connection.wait_for_responses(peers_query, local_query)
|
||||
|
||||
versions = set()
|
||||
if local_result and local_result.rows:
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
import asyncore
|
||||
import asynchat
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
import socket
|
||||
from threading import RLock, Event, Lock, Thread
|
||||
import traceback
|
||||
|
||||
from threading import RLock, Event
|
||||
|
||||
from cassandra.marshal import (int8_unpack, int32_unpack)
|
||||
from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage,
|
||||
StartupMessage, ErrorMessage, CredentialsMessage,
|
||||
decode_response)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
locally_supported_compressions = {}
|
||||
try:
|
||||
import snappy
|
||||
@@ -45,6 +52,22 @@ class InternalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_loop_started = False
|
||||
_loop_lock = Lock()
|
||||
def _start_loop():
|
||||
global _loop_started
|
||||
should_start = False
|
||||
with _loop_lock:
|
||||
if not _loop_started:
|
||||
_loop_started = True
|
||||
should_start = True
|
||||
|
||||
if should_start:
|
||||
t = Thread(target=asyncore.loop, name="async_event_loop", kwargs={"timeout": 0.01})
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
|
||||
class Connection(asynchat.async_chat):
|
||||
|
||||
@classmethod
|
||||
@@ -67,16 +90,19 @@ class Connection(asynchat.async_chat):
|
||||
self.decompressor = None
|
||||
|
||||
self.connected_event = Event()
|
||||
|
||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.connect((host, port))
|
||||
self.in_flight = 0
|
||||
self.is_defunct = False
|
||||
|
||||
self.make_request_id = itertools.cycle(xrange(127)).next
|
||||
self._waiting_callbacks = {}
|
||||
self._waiting_events = {}
|
||||
self._responses = {}
|
||||
self._callbacks = {}
|
||||
self._push_watchers = defaultdict(set)
|
||||
self._lock = RLock()
|
||||
|
||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# start the global event loop if needed
|
||||
_start_loop()
|
||||
self.connect((host, port))
|
||||
|
||||
def handle_read(self):
|
||||
# simpler to do here than collect_incoming_data()
|
||||
header = self.recv(8)
|
||||
@@ -89,30 +115,62 @@ class Connection(asynchat.async_chat):
|
||||
assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len
|
||||
body = self.recv(body_len)
|
||||
|
||||
response = decode_response(stream_id, flags, opcode, body, self.decompressor)
|
||||
if stream_id < 0:
|
||||
self.handle_pushed(stream_id, flags, opcode, body)
|
||||
self.handle_pushed(response)
|
||||
else:
|
||||
try:
|
||||
cb = self._waiting_callbacks[stream_id]
|
||||
except KeyError:
|
||||
# store the response in a location accessible by other threads
|
||||
self._responses[stream_id] = (flags, opcode, body)
|
||||
|
||||
# signal to the waiting thread that the response is ready
|
||||
self._waiting_events[stream_id].set()
|
||||
else:
|
||||
cb(stream_id, flags, opcode, body)
|
||||
self._callbacks.pop(stream_id)(response)
|
||||
|
||||
def readable(self):
|
||||
# TODO only return True if we have pending responses (except for ControlConnections?)
|
||||
return True
|
||||
# TODO this isn't accurate for control connections
|
||||
return bool(self._callbacks)
|
||||
|
||||
def handle_error(self):
|
||||
log.error(traceback.format_exc())
|
||||
self.is_defunct = True
|
||||
|
||||
def handle_pushed(self, response):
|
||||
pass
|
||||
# details = response.recv_body
|
||||
# for cb in self._push_watchers[response]
|
||||
|
||||
def push(self, data):
|
||||
# overridden to avoid calling initiate_send() at the end of this
|
||||
# and hold the lock
|
||||
sabs = self.ac_out_buffer_size
|
||||
with self._lock:
|
||||
if len(data) > sabs:
|
||||
for i in xrange(0, len(data), sabs):
|
||||
self.producer_fifo.append(data[i:i + sabs])
|
||||
else:
|
||||
self.producer_fifo.append(data)
|
||||
|
||||
def send_msg(self, msg, cb):
|
||||
request_id = self.make_request_id()
|
||||
self._callbacks[request_id] = cb
|
||||
self.push(msg.to_string(request_id, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def wait_for_responses(self, *msgs):
|
||||
waiter = ResponseWaiter(len(msgs))
|
||||
for i, msg in enumerate(msgs):
|
||||
self.send_msg(msg, partial(waiter.got_response, index=i))
|
||||
waiter.event.wait()
|
||||
return waiter.responses
|
||||
|
||||
def register_watcher(self, event_type, callback):
|
||||
self._push_watchers[event_type].add(callback)
|
||||
|
||||
def register_watchers(self, type_callback_dict):
|
||||
for event_type, callback in type_callback_dict.items():
|
||||
self.register_watcher(event_type, callback)
|
||||
|
||||
def handle_connect(self):
|
||||
log.debug("Sending initial message for new Connection to %s" % self.host)
|
||||
self.send_msg(OptionsMessage(), self._handle_options_response)
|
||||
|
||||
def _handle_options_response(self, stream_id, flags, opcode, body):
|
||||
options_response = decode_response(stream_id, flags, opcode, body)
|
||||
|
||||
def _handle_options_response(self, options_response):
|
||||
log.debug("Received options response on new Connection from %s" % self.host)
|
||||
self.supported_cql_versions = options_response.cql_versions
|
||||
self.remote_supported_compressions = options_response.options['COMPRESSION']
|
||||
|
||||
@@ -147,58 +205,39 @@ class Connection(asynchat.async_chat):
|
||||
sm = StartupMessage(cqlversion=self.cql_version, options=opts)
|
||||
self.send_msg(sm, cb=self._handle_startup_response)
|
||||
|
||||
def _handle_startup_response(self, stream_id, flags, opcode, body):
|
||||
startup_response = decode_response(
|
||||
stream_id, flags, opcode, body, self.decompressor)
|
||||
|
||||
def _handle_startup_response(self, startup_response):
|
||||
if isinstance(startup_response, ReadyMessage):
|
||||
log.debug("Got ReadyMessage on new Connection from %s" % self.host)
|
||||
if self._compresstype:
|
||||
self.compressor = self._compressor
|
||||
self.connected_event.set()
|
||||
elif isinstance(startup_response, AuthenticateMessage):
|
||||
log.debug("Got AuthenticateMessage on new Connection from %s" % self.host)
|
||||
self.authenticator = startup_response.authenticator
|
||||
if self.credentials is None:
|
||||
raise ProgrammingError('Remote end requires authentication.')
|
||||
cm = CredentialsMessage(creds=self.credentials)
|
||||
self.send_msg(cm, cb=self._handle_startup_response)
|
||||
elif isinstance(startup_response, ErrorMessage):
|
||||
log.debug("Received ErrorMessage on new Connection from %s" % self.host)
|
||||
raise ProgrammingError("Server did not accept credentials. %s"
|
||||
% startup_response.summary_msg())
|
||||
else:
|
||||
raise InternalError("Unexpected response %r during connection setup"
|
||||
% startup_response)
|
||||
|
||||
# def handle_error(self):
|
||||
# print "got connection error" # TODO
|
||||
# self.close()
|
||||
|
||||
def handle_pushed(self, stream_id, flags, opcode, body):
|
||||
def set_keyspace(self, keyspace):
|
||||
pass
|
||||
|
||||
def push(self, data):
|
||||
# overridden to avoid calling initiate_send() at the end of this
|
||||
# and hold the lock
|
||||
sabs = self.ac_out_buffer_size
|
||||
with self._lock:
|
||||
if len(data) > sabs:
|
||||
for i in xrange(0, len(data), sabs):
|
||||
self.producer_fifo.append(data[i:i + sabs])
|
||||
else:
|
||||
self.producer_fifo.append(data)
|
||||
class ResponseWaiter(object):
|
||||
|
||||
def send_msg(self, msg, cb=None):
|
||||
request_id = self.make_request_id()
|
||||
if cb:
|
||||
self._waiting_callbacks[request_id] = cb
|
||||
else:
|
||||
self._waiting_events[request_id] = Event()
|
||||
def __init__(self, num_responses):
|
||||
self.pending = num_responses
|
||||
self.responses = [None] * num_responses
|
||||
self.event = Event()
|
||||
|
||||
self.push(msg.to_string(request_id, compression=self.compressor))
|
||||
return request_id
|
||||
|
||||
def get_response(self, stream_id):
|
||||
""" Blocking wait for a response """
|
||||
# TODO waiting on the event in the loop thread will deadlock
|
||||
self._waiting_events.pop(stream_id).wait()
|
||||
(flags, opcode, body) = self._responses.pop(stream_id)
|
||||
return decode_response(stream_id, flags, opcode, body, self.decompressor)
|
||||
def got_response(self, response, index):
|
||||
self.responses[index] = response
|
||||
self.pending -= 1
|
||||
if not self.pending:
|
||||
self.event.set()
|
||||
|
||||
@@ -79,24 +79,6 @@ ConsistencyLevel.name_to_value = {
|
||||
}
|
||||
|
||||
|
||||
class CqlResult:
|
||||
def __init__(self, column_metadata, rows):
|
||||
self.column_metadata = column_metadata
|
||||
self.rows = rows
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.rows)
|
||||
|
||||
# TODO this definitely needs to be abstracted at some other level
|
||||
def as_dicts(self):
|
||||
colnames = [c[2] for c in self.column_metadata]
|
||||
return [dict(zip(colnames, row)) for row in self.rows]
|
||||
|
||||
def __str__(self):
|
||||
return '<CqlResult: column_metadata=%r, rows=%r>' \
|
||||
% (self.column_metadata, self.rows)
|
||||
__repr__ = __str__
|
||||
|
||||
class PreparedResult:
|
||||
def __init__(self, queryid, param_metadata):
|
||||
self.queryid = queryid
|
||||
@@ -253,9 +235,9 @@ class UnavailableExceptionErrorMessage(RequestExecutionException):
|
||||
@staticmethod
|
||||
def recv_error_info(f):
|
||||
return {
|
||||
'consistency_level': read_consistency_level(f),
|
||||
'required': read_int(f),
|
||||
'alive': read_int(f),
|
||||
'consistency': read_consistency_level(f),
|
||||
'required_replicas': read_int(f),
|
||||
'alive_replicas': read_int(f),
|
||||
}
|
||||
|
||||
class OverloadedErrorMessage(RequestExecutionException):
|
||||
@@ -280,10 +262,10 @@ class WriteTimeoutErrorMessage(RequestTimeoutException):
|
||||
@staticmethod
|
||||
def recv_error_info(f):
|
||||
return {
|
||||
'consistency_level': read_consistency_level(f),
|
||||
'received': read_int(f),
|
||||
'blockfor': read_int(f),
|
||||
'writetype': read_string(f),
|
||||
'consistency': read_consistency_level(f),
|
||||
'received_responses': read_int(f),
|
||||
'required_responses': read_int(f),
|
||||
'write_type': read_string(f),
|
||||
}
|
||||
|
||||
class ReadTimeoutErrorMessage(RequestTimeoutException):
|
||||
@@ -293,10 +275,10 @@ class ReadTimeoutErrorMessage(RequestTimeoutException):
|
||||
@staticmethod
|
||||
def recv_error_info(f):
|
||||
return {
|
||||
'consistency_level': read_consistency_level(f),
|
||||
'received': read_int(f),
|
||||
'blockfor': read_int(f),
|
||||
'data_present': bool(read_byte(f)),
|
||||
'consistency': read_consistency_level(f),
|
||||
'received_responses': read_int(f),
|
||||
'required_responses': read_int(f),
|
||||
'data_retrieved': bool(read_byte(f)),
|
||||
}
|
||||
|
||||
class SyntaxException(RequestValidationException):
|
||||
@@ -455,7 +437,10 @@ class ResultMessage(_MessageType):
|
||||
colspecs = cls.recv_results_metadata(f)
|
||||
rowcount = read_int(f)
|
||||
rows = [cls.recv_row(f, len(colspecs)) for x in xrange(rowcount)]
|
||||
return CqlResult(column_metadata=colspecs, rows=rows)
|
||||
colnames = [c[2] for c in colspecs]
|
||||
coltypes = [c[3] for c in colspecs]
|
||||
return [dict(zip(colnames, [ctype.from_binary(val) for ctype, val in zip(coltypes, row)]))
|
||||
for row in rows]
|
||||
|
||||
@classmethod
|
||||
def recv_results_prepared(cls, f):
|
||||
@@ -596,7 +581,7 @@ def read_consistency_level(f):
|
||||
return ConsistencyLevel.value_to_name[read_short(f)]
|
||||
|
||||
def write_consistency_level(f, cl):
|
||||
write_short(f, ConsistencyLevel.name_to_value[cl])
|
||||
write_short(f, cl)
|
||||
|
||||
def read_string(f):
|
||||
size = read_short(f)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from threading import RLock
|
||||
|
||||
@@ -20,10 +21,10 @@ class Metadata(object):
|
||||
cf_def_rows = defaultdict(list)
|
||||
col_def_rows = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for row in cf_results:
|
||||
for row in cf_results.results:
|
||||
cf_def_rows[row["keyspace_name"]].append(row)
|
||||
|
||||
for row in col_results:
|
||||
for row in col_results.results:
|
||||
ksname = row["keyspace_name"]
|
||||
cfname = row["columnfamily_name"]
|
||||
col_def_rows[ksname][cfname].append(row)
|
||||
@@ -32,13 +33,13 @@ class Metadata(object):
|
||||
if not table:
|
||||
# ks_results is not None
|
||||
added_keyspaces = set()
|
||||
for row in ks_results:
|
||||
for row in ks_results.results:
|
||||
keyspace_meta = self._build_keyspace_metadata(row)
|
||||
ksname = keyspace_meta.name
|
||||
if ksname in cf_def_rows:
|
||||
for table_row in cf_def_rows[keyspace_meta.name]:
|
||||
table_meta = self._build_table_metadata(
|
||||
keyspace_meta, table_row, col_def_rows[keyspace_meta.name])
|
||||
keyspace_meta, table_row, col_def_rows[keyspace_meta.name])
|
||||
keyspace_meta.tables[table_meta.name] = table_meta
|
||||
|
||||
added_keyspaces.add(keyspace_meta.name)
|
||||
@@ -67,7 +68,7 @@ class Metadata(object):
|
||||
name = row["keyspace_name"]
|
||||
durable_writes = row["durable_writes"]
|
||||
strategy_class = row["strategy_class"]
|
||||
strategy_options = row["strategy_options"]
|
||||
strategy_options = json.loads(row["strategy_options"])
|
||||
return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
|
||||
|
||||
def _build_table_metadata(self, keyspace_metadata, row, col_rows):
|
||||
@@ -143,7 +144,7 @@ class Metadata(object):
|
||||
# value alias (if present)
|
||||
if has_value:
|
||||
validator = cqltypes.lookup_casstype(row["default_validator"])
|
||||
if not row.get["key_aliases"]:
|
||||
if not row.get("key_aliases"):
|
||||
value_alias = "value"
|
||||
else:
|
||||
value_alias = row["value_alias"]
|
||||
@@ -156,7 +157,7 @@ class Metadata(object):
|
||||
column_meta = self._build_column_metadata(table_meta, row)
|
||||
table_meta.columns[column_meta.name] = column_meta
|
||||
|
||||
table_meta.options = self._build_table_options(is_compact)
|
||||
table_meta.options = self._build_table_options(row, is_compact)
|
||||
return table_meta
|
||||
|
||||
def _build_table_options(self, row, is_compact_storage):
|
||||
|
||||
@@ -21,13 +21,15 @@ class LoadBalancingPolicy(object):
|
||||
|
||||
class RoundRobinPolicy(LoadBalancingPolicy):
|
||||
|
||||
def __init__(self):
|
||||
self._lock = RLock()
|
||||
|
||||
def populate(self, cluster, hosts):
|
||||
self._live_hosts = set(hosts)
|
||||
if len(hosts) == 1:
|
||||
self._position = 0
|
||||
else:
|
||||
self._position = randint(0, len(hosts) - 1)
|
||||
self._lock = RLock()
|
||||
|
||||
def distance(self, host):
|
||||
return HostDistance.LOCAL
|
||||
|
||||
@@ -204,7 +204,7 @@ class HostConnectionPool(object):
|
||||
|
||||
# TODO potentially use threading.Queue for this
|
||||
core_conns = session.cluster.get_core_connections_per_host(host_distance)
|
||||
self._connections = [session.connection_factory(host)
|
||||
self._connections = [session.cluster.connection_factory(host.address)
|
||||
for i in range(core_conns)]
|
||||
self._trash = set()
|
||||
self._open_count = len(self._connections)
|
||||
@@ -249,7 +249,7 @@ class HostConnectionPool(object):
|
||||
least_busy = self._wait_for_conn(timeout)
|
||||
break
|
||||
|
||||
least_busy.set_keyspace() # TODO get keyspace from pool state
|
||||
least_busy.set_keyspace("keyspace") # TODO get keyspace from pool state
|
||||
return least_busy
|
||||
|
||||
def _create_new_connection(self):
|
||||
@@ -283,16 +283,16 @@ class HostConnectionPool(object):
|
||||
return False
|
||||
|
||||
def _await_available_conn(self, timeout):
|
||||
with self._available_conn_condition:
|
||||
self._available_conn_condition.wait(timeout)
|
||||
with self._conn_available_condition:
|
||||
self._conn_available_condition.wait(timeout)
|
||||
|
||||
def _signal_available_conn(self):
|
||||
with self._available_conn_condition:
|
||||
self._available_conn_condition.notify()
|
||||
with self._conn_available_condition:
|
||||
self._conn_available_condition.notify()
|
||||
|
||||
def _signal_all_available_conn(self):
|
||||
with self._available_conn_condition:
|
||||
self._available_conn_condition.notify_all()
|
||||
with self._conn_available_condition:
|
||||
self._conn_available_condition.notify_all()
|
||||
|
||||
def _wait_for_conn(self, timeout):
|
||||
start = time.time()
|
||||
|
||||
@@ -12,10 +12,8 @@ class Query(object):
|
||||
|
||||
class SimpleStatement(Query):
|
||||
|
||||
query = None
|
||||
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
def __init__(self, query_string):
|
||||
self._query_string = query_string
|
||||
self._routing_key = None
|
||||
|
||||
@property
|
||||
@@ -26,3 +24,7 @@ class SimpleStatement(Query):
|
||||
def set_routing_key(self, value):
|
||||
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
|
||||
for component in value)
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
return self._query_string
|
||||
|
||||
Reference in New Issue
Block a user