From 089f05b24f3bb4b8b5e9c0142818f16dc9e8dcac Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Thu, 4 Apr 2013 18:06:57 -0500 Subject: [PATCH] Basic working Cluster, Session, Connection, ResponseFuture --- cassandra/cluster.py | 229 +++++++++++++++++++++++++++------------- cassandra/connection.py | 153 +++++++++++++++++---------- cassandra/decoder.py | 47 +++------ cassandra/metadata.py | 15 +-- cassandra/policies.py | 4 +- cassandra/pool.py | 16 +-- cassandra/query.py | 10 +- 7 files changed, 294 insertions(+), 180 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index baacaf32..57be5135 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1,15 +1,18 @@ -import asyncore import time -from threading import Lock, RLock, Thread +from threading import Lock, RLock, Thread, Event import Queue +import logging +import traceback from futures import ThreadPoolExecutor from connection import Connection -from decoder import ConsistencyLevel, QueryMessage +from decoder import (ConsistencyLevel, QueryMessage, ResultMessage, ErrorMessage, + ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableExceptionErrorMessage, + OverloadedErrorMessage, IsBootstrappingErrorMessage) from metadata import Metadata from policies import (RoundRobinPolicy, SimpleConvictionPolicy, - ExponentialReconnectionPolicy, HostDistance) + ExponentialReconnectionPolicy, HostDistance, RetryPolicy) from query import SimpleStatement from pool import (ConnectionException, BusyConnectionException, AuthenticationException, _ReconnectionHandler, @@ -18,19 +21,27 @@ from pool import (ConnectionException, BusyConnectionException, # TODO: we might want to make this configurable MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000 +log = logging.getLogger(__name__) -class _RetryingCallback(object): - def __init__(self, session, query): +class ResponseFuture(object): + + def __init__(self, session, message, query): self.session = session - self.query = query - self.query_plan = session._load_balancer.new_query_plan(query) + self.message = message + + self.query_plan = session._load_balancer.make_query_plan(query) self._req_id = None self._final_result = None self._final_exception = None - self._wait_fn = None + self._current_host = None + self._current_pool = None + self._connection = None + self._event = Event() self._errors = {} + self._query_retries = 0 + self._callback = self._errback = None def send_request(self): for host in self.query_plan: @@ -42,32 +53,122 @@ class _RetryingCallback(object): self._final_exception = NoHostAvailable(self._errors) def _query(self, host): - pool = self._session._pools.get(host) + pool = self.session._pools.get(host) if not pool or pool.is_shutdown: return None + connection = None try: # TODO get connectTimeout from cluster settings connection = pool.borrow_connection(timeout=2.0) - request_id = connection.request_and_callback(self.query, self._set_results) - self._wait_fn = connection.wait_for_result + request_id = connection.send_msg(self.message, cb=self._set_result) + self._current_host = host + self._current_pool = pool + self._connection = connection return request_id - except Exception: - return None - finally: + except Exception, exc: + self._errors[host] = exc if connection: pool.return_connection(connection) + return None - def _set_results(self, response): + def _set_result(self, response): + self._current_pool.return_connection(self._connection) + + if isinstance(response, ResultMessage): + self._set_final_result(response) + elif isinstance(response, ErrorMessage): + retry_policy = self.query.retry_policy # TODO also check manager.configuration + # for retry policy if None + if isinstance(response, ReadTimeoutErrorMessage): + details = response.recv_error_info() + retry = retry_policy.on_read_timeout( + self.query, attempt_num=self._query_retries, **details) + elif isinstance(response, WriteTimeoutErrorMessage): + details = response.recv_error_info() + retry = retry_policy.on_write_timeout( + self.query, attempt_num=self._query_retries, **details) + elif isinstance(response, UnavailableExceptionErrorMessage): + details = response.recv_error_info() + retry = retry_policy.on_write_timeout( + self.query, attempt_num=self._query_retries, **details) + elif isinstance(response, OverloadedErrorMessage): + # need to retry against a different host here + self._retry(False, None) + elif isinstance(response, IsBootstrappingErrorMessage): + # need to retry against a different host here + self._retry(False, None) + # TODO need to define the PreparedQueryNotFound class + # elif isinstance(response, PreparedQueryNotFound): + # pass + else: + pass + + retry_type, consistency = retry + if retry_type == RetryPolicy.RETRY: + self._query_retries += 1 + self._retry(True, consistency) + elif retry_type == RetryPolicy.RETHROW: + self._set_final_result(response) + else: # IGNORE + self._set_final_result(None) + else: + # we got some other kind of response message + self._set_final_result(response) + + def _set_final_result(self, response): self._final_result = response + self._event.set() + if self._callback: + fn, args, kwargs = self._callback + fn(response, *args, **kwargs) + + def _set_final_exception(self, response): + self._final_exception = response + self._event.set() + if self._errback: + fn, args, kwargs = self._errback + fn(response, *args, **kwargs) + + def _retry(self, reuse_connection, consistency_level): + self.message.consistency_level = consistency_level + # don't retry on the event loop thread + self.session.submit(self._retry_helper, reuse_connection) + + def _retry_task(self, reuse_connection): + if reuse_connection and self._query(self._current_host): + return + + # otherwise, move onto another host + self.send_request() def deliver(self): - if self._final_result: - return self._final_result - elif self._final_exception: + if self._final_exception: raise self._final_exception + elif self._final_result: + return self._final_result else: - return self._wait_fn(self._req_id) + self._event.wait() + if self._final_exception: + raise self._final_exception + else: + return self._final_result + + def addCallback(self, fn, *args, **kwargs): + self._callback = (fn, args, kwargs) + return self + + def addErrback(self, fn, *args, **kwargs): + self._errback = (fn, args, kwargs) + return self + + def addCallbacks(self, + callback, errback, + callback_args=(), callback_kwargs=None, + errback_args=(), errback_kwargs=None): + self._callback = (callback, callback_args, callback_kwargs | {}) + self._errback = (errback, errback_args, errback_kwargs | {}) + class Session(object): @@ -79,17 +180,29 @@ class Session(object): self._is_shutdown = False self._pools = {} self._load_balancer = RoundRobinPolicy() + self._load_balancer.populate(cluster, hosts) + + for host in hosts: + self.add_host(host) def execute(self, query): - if isinstance(query, basestring): - query = SimpleStatement(query) + future = self.execute_async(query) + return future.deliver() def execute_async(self, query): if isinstance(query, basestring): query = SimpleStatement(query) - qmsg = QueryMessage(query=query.query, consistency_level=query.consistency_level) - return self._execute_query(qmsg, query) + # TODO bound statements need to be handled differently + message = QueryMessage(query=query.query_string, consistency_level=query.consistency_level) + + if query.tracing_enabled: + # TODO enable tracing on the message + pass + + future = ResponseFuture(self, message, query) + future.send_request() + return future def prepare(self, query): pass @@ -97,25 +210,6 @@ class Session(object): def shutdown(self): self.cluster.shutdown() - def _execute_query(self, message, query): - if query.tracing_enabled: - # TODO enable tracing on the message - pass - - errors = {} - for host in self._load_balancer.make_query_plan(query): - try: - result = self._query(host) - if result: - return - except Exception, exc: - errors[host] = exc - - def _query(self, host, query): - pool = self._pools.get(host) - if not pool or pool.is_shutdown: - return False - def add_host(self, host): distance = self._load_balancer.distance(host) if distance == HostDistance.IGNORED: @@ -174,8 +268,8 @@ class Session(object): def set_keyspace(self, keyspace): pass - def submit(self, task): - return self.cluster.executor.submit(task) + def submit(self, fn, *args, **kwargs): + return self.cluster.executor.submit(fn, *args, **kwargs) DEFAULT_MIN_REQUESTS = 25 DEFAULT_MAX_REQUESTS = 100 @@ -219,18 +313,6 @@ class _Scheduler(object): time.sleep(0.1) -_loop_started = False -_loop_lock = Lock() -def _start_loop(): - with _loop_lock: - global _loop_started - if not _loop_started: - _loop_started = True - t = Thread(target=asyncore.loop, name="Async event loop") - t.daemon = True - t.start() - - class Cluster(object): port = 9042 @@ -282,9 +364,6 @@ class Cluster(object): self._is_shutdown = False self._lock = Lock() - # start the global event loop - _start_loop() - for address in contact_points: self.add_host(address, signal=False) @@ -472,13 +551,16 @@ class ControlConnection(object): except ConnectionException, exc: errors[host.address] = exc host.monitor.signal_connection_failure(exc) + log.error("Error reconnecting control connection: %s", traceback.format_exc()) except Exception, exc: errors[host.address] = exc + log.error("Error reconnecting control connection: %s", traceback.format_exc()) raise NoHostAvailable("Unable to connect to any servers", errors) def _try_connect(self, host): connection = self._cluster.connection_factory(host.address) + connection.register_watchers({ "TOPOLOGY_CHANGE": self._handle_topology_change, "STATUS_CHANGE": self._handle_status_change, @@ -486,7 +568,7 @@ class ControlConnection(object): }) self.refresh_node_list_and_token_map() - self.refresh_schema() + self._refresh_schema(connection) return connection def reconnect(self): @@ -521,6 +603,9 @@ class ControlConnection(object): self._connection.close() def refresh_schema(self, keyspace=None, table=None): + return self._refresh_schema(self._connection, keyspace, table) + + def _refresh_schema(self, connection, keyspace=None, table=None): where_clause = "" if keyspace: where_clause = " WHERE keyspace_name = '%s'" % (keyspace,) @@ -531,15 +616,15 @@ class ControlConnection(object): if table: ks_query = None else: - ks_query = QueryMessage(query=self.SELECT_KEYSPACES + where_clause, consistency_level=cl) - cf_query = QueryMessage(query=self.SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) - col_query = QueryMessage(query=self.SELECT_COLUMNS + where_clause, consistency_level=cl) + ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl) + cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) + col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) if ks_query: - ks_result, cf_result, col_result = self._connection.wait_for_requests(ks_query, cf_query, col_query) + ks_result, cf_result, col_result = connection.wait_for_responses(ks_query, cf_query, col_query) else: ks_result = None - cf_result, col_result = self._connection.wait_for_requests(cf_query, col_query) + cf_result, col_result = connection.wait_for_responses(cf_query, col_query) self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result) @@ -549,10 +634,10 @@ class ControlConnection(object): return cl = ConsistencyLevel.ONE - peers_query = QueryMessage(query=self.SELECT_PEERS, consistency_level=cl) - local_query = QueryMessage(query=self.SELECT_LOCAL, consistency_level=cl) + peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl) + local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl) try: - peers_result, local_result = conn.wait_for_requests(peers_query, local_query) + peers_result, local_result = conn.wait_for_responses(peers_query, local_query) except (ConnectionException, BusyConnectionException): self.reconnect() @@ -649,9 +734,9 @@ class ControlConnection(object): elapsed = 0 cl = ConsistencyLevel.ONE while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS: - peers_query = QueryMessage(query=self.SELECT_SCHEMA_PEERS, consistency_level=cl) - local_query = QueryMessage(query=self.SELECT_SCHEMA_LOCAL, consistency_level=cl) - peers_result, local_result = self._connection.wait_for_requests(peers_query, local_query) + peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) + local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) + peers_result, local_result = self._connection.wait_for_responses(peers_query, local_query) versions = set() if local_result and local_result.rows: diff --git a/cassandra/connection.py b/cassandra/connection.py index 2ce9fd88..2d027b71 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1,14 +1,21 @@ +import asyncore import asynchat +from collections import defaultdict +from functools import partial import itertools +import logging import socket +from threading import RLock, Event, Lock, Thread +import traceback -from threading import RLock, Event from cassandra.marshal import (int8_unpack, int32_unpack) from cassandra.decoder import (OptionsMessage, ReadyMessage, AuthenticateMessage, StartupMessage, ErrorMessage, CredentialsMessage, decode_response) +log = logging.getLogger(__name__) + locally_supported_compressions = {} try: import snappy @@ -45,6 +52,22 @@ class InternalError(Exception): pass +_loop_started = False +_loop_lock = Lock() +def _start_loop(): + global _loop_started + should_start = False + with _loop_lock: + if not _loop_started: + _loop_started = True + should_start = True + + if should_start: + t = Thread(target=asyncore.loop, name="async_event_loop", kwargs={"timeout": 0.01}) + t.daemon = True + t.start() + + class Connection(asynchat.async_chat): @classmethod @@ -67,16 +90,19 @@ class Connection(asynchat.async_chat): self.decompressor = None self.connected_event = Event() - - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - self.connect((host, port)) + self.in_flight = 0 + self.is_defunct = False self.make_request_id = itertools.cycle(xrange(127)).next - self._waiting_callbacks = {} - self._waiting_events = {} - self._responses = {} + self._callbacks = {} + self._push_watchers = defaultdict(set) self._lock = RLock() + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + # start the global event loop if needed + _start_loop() + self.connect((host, port)) + def handle_read(self): # simpler to do here than collect_incoming_data() header = self.recv(8) @@ -89,30 +115,62 @@ class Connection(asynchat.async_chat): assert body_len >= 0, "Invalid CQL protocol body_len %r" % body_len body = self.recv(body_len) + response = decode_response(stream_id, flags, opcode, body, self.decompressor) if stream_id < 0: - self.handle_pushed(stream_id, flags, opcode, body) + self.handle_pushed(response) else: - try: - cb = self._waiting_callbacks[stream_id] - except KeyError: - # store the response in a location accessible by other threads - self._responses[stream_id] = (flags, opcode, body) - - # signal to the waiting thread that the response is ready - self._waiting_events[stream_id].set() - else: - cb(stream_id, flags, opcode, body) + self._callbacks.pop(stream_id)(response) def readable(self): - # TODO only return True if we have pending responses (except for ControlConnections?) - return True + # TODO this isn't accurate for control connections + return bool(self._callbacks) + + def handle_error(self): + log.error(traceback.format_exc()) + self.is_defunct = True + + def handle_pushed(self, response): + pass + # details = response.recv_body + # for cb in self._push_watchers[response] + + def push(self, data): + # overridden to avoid calling initiate_send() at the end of this + # and hold the lock + sabs = self.ac_out_buffer_size + with self._lock: + if len(data) > sabs: + for i in xrange(0, len(data), sabs): + self.producer_fifo.append(data[i:i + sabs]) + else: + self.producer_fifo.append(data) + + def send_msg(self, msg, cb): + request_id = self.make_request_id() + self._callbacks[request_id] = cb + self.push(msg.to_string(request_id, compression=self.compressor)) + return request_id + + def wait_for_responses(self, *msgs): + waiter = ResponseWaiter(len(msgs)) + for i, msg in enumerate(msgs): + self.send_msg(msg, partial(waiter.got_response, index=i)) + waiter.event.wait() + return waiter.responses + + def register_watcher(self, event_type, callback): + self._push_watchers[event_type].add(callback) + + def register_watchers(self, type_callback_dict): + for event_type, callback in type_callback_dict.items(): + self.register_watcher(event_type, callback) def handle_connect(self): + log.debug("Sending initial message for new Connection to %s" % self.host) self.send_msg(OptionsMessage(), self._handle_options_response) - def _handle_options_response(self, stream_id, flags, opcode, body): - options_response = decode_response(stream_id, flags, opcode, body) - + def _handle_options_response(self, options_response): + log.debug("Received options response on new Connection from %s" % self.host) self.supported_cql_versions = options_response.cql_versions self.remote_supported_compressions = options_response.options['COMPRESSION'] @@ -147,58 +205,39 @@ class Connection(asynchat.async_chat): sm = StartupMessage(cqlversion=self.cql_version, options=opts) self.send_msg(sm, cb=self._handle_startup_response) - def _handle_startup_response(self, stream_id, flags, opcode, body): - startup_response = decode_response( - stream_id, flags, opcode, body, self.decompressor) - + def _handle_startup_response(self, startup_response): if isinstance(startup_response, ReadyMessage): + log.debug("Got ReadyMessage on new Connection from %s" % self.host) if self._compresstype: self.compressor = self._compressor self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): + log.debug("Got AuthenticateMessage on new Connection from %s" % self.host) self.authenticator = startup_response.authenticator if self.credentials is None: raise ProgrammingError('Remote end requires authentication.') cm = CredentialsMessage(creds=self.credentials) self.send_msg(cm, cb=self._handle_startup_response) elif isinstance(startup_response, ErrorMessage): + log.debug("Received ErrorMessage on new Connection from %s" % self.host) raise ProgrammingError("Server did not accept credentials. %s" % startup_response.summary_msg()) else: raise InternalError("Unexpected response %r during connection setup" % startup_response) - # def handle_error(self): - # print "got connection error" # TODO - # self.close() - - def handle_pushed(self, stream_id, flags, opcode, body): + def set_keyspace(self, keyspace): pass - def push(self, data): - # overridden to avoid calling initiate_send() at the end of this - # and hold the lock - sabs = self.ac_out_buffer_size - with self._lock: - if len(data) > sabs: - for i in xrange(0, len(data), sabs): - self.producer_fifo.append(data[i:i + sabs]) - else: - self.producer_fifo.append(data) +class ResponseWaiter(object): - def send_msg(self, msg, cb=None): - request_id = self.make_request_id() - if cb: - self._waiting_callbacks[request_id] = cb - else: - self._waiting_events[request_id] = Event() + def __init__(self, num_responses): + self.pending = num_responses + self.responses = [None] * num_responses + self.event = Event() - self.push(msg.to_string(request_id, compression=self.compressor)) - return request_id - - def get_response(self, stream_id): - """ Blocking wait for a response """ - # TODO waiting on the event in the loop thread will deadlock - self._waiting_events.pop(stream_id).wait() - (flags, opcode, body) = self._responses.pop(stream_id) - return decode_response(stream_id, flags, opcode, body, self.decompressor) + def got_response(self, response, index): + self.responses[index] = response + self.pending -= 1 + if not self.pending: + self.event.set() diff --git a/cassandra/decoder.py b/cassandra/decoder.py index c75eea04..5c53f616 100644 --- a/cassandra/decoder.py +++ b/cassandra/decoder.py @@ -79,24 +79,6 @@ ConsistencyLevel.name_to_value = { } -class CqlResult: - def __init__(self, column_metadata, rows): - self.column_metadata = column_metadata - self.rows = rows - - def __iter__(self): - return iter(self.rows) - - # TODO this definitely needs to be abstracted at some other level - def as_dicts(self): - colnames = [c[2] for c in self.column_metadata] - return [dict(zip(colnames, row)) for row in self.rows] - - def __str__(self): - return '' \ - % (self.column_metadata, self.rows) - __repr__ = __str__ - class PreparedResult: def __init__(self, queryid, param_metadata): self.queryid = queryid @@ -253,9 +235,9 @@ class UnavailableExceptionErrorMessage(RequestExecutionException): @staticmethod def recv_error_info(f): return { - 'consistency_level': read_consistency_level(f), - 'required': read_int(f), - 'alive': read_int(f), + 'consistency': read_consistency_level(f), + 'required_replicas': read_int(f), + 'alive_replicas': read_int(f), } class OverloadedErrorMessage(RequestExecutionException): @@ -280,10 +262,10 @@ class WriteTimeoutErrorMessage(RequestTimeoutException): @staticmethod def recv_error_info(f): return { - 'consistency_level': read_consistency_level(f), - 'received': read_int(f), - 'blockfor': read_int(f), - 'writetype': read_string(f), + 'consistency': read_consistency_level(f), + 'received_responses': read_int(f), + 'required_responses': read_int(f), + 'write_type': read_string(f), } class ReadTimeoutErrorMessage(RequestTimeoutException): @@ -293,10 +275,10 @@ class ReadTimeoutErrorMessage(RequestTimeoutException): @staticmethod def recv_error_info(f): return { - 'consistency_level': read_consistency_level(f), - 'received': read_int(f), - 'blockfor': read_int(f), - 'data_present': bool(read_byte(f)), + 'consistency': read_consistency_level(f), + 'received_responses': read_int(f), + 'required_responses': read_int(f), + 'data_retrieved': bool(read_byte(f)), } class SyntaxException(RequestValidationException): @@ -455,7 +437,10 @@ class ResultMessage(_MessageType): colspecs = cls.recv_results_metadata(f) rowcount = read_int(f) rows = [cls.recv_row(f, len(colspecs)) for x in xrange(rowcount)] - return CqlResult(column_metadata=colspecs, rows=rows) + colnames = [c[2] for c in colspecs] + coltypes = [c[3] for c in colspecs] + return [dict(zip(colnames, [ctype.from_binary(val) for ctype, val in zip(coltypes, row)])) + for row in rows] @classmethod def recv_results_prepared(cls, f): @@ -596,7 +581,7 @@ def read_consistency_level(f): return ConsistencyLevel.value_to_name[read_short(f)] def write_consistency_level(f, cl): - write_short(f, ConsistencyLevel.name_to_value[cl]) + write_short(f, cl) def read_string(f): size = read_short(f) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 6b5a72df..f1e3bfa0 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1,3 +1,4 @@ +import json from collections import defaultdict from threading import RLock @@ -20,10 +21,10 @@ class Metadata(object): cf_def_rows = defaultdict(list) col_def_rows = defaultdict(lambda: defaultdict(list)) - for row in cf_results: + for row in cf_results.results: cf_def_rows[row["keyspace_name"]].append(row) - for row in col_results: + for row in col_results.results: ksname = row["keyspace_name"] cfname = row["columnfamily_name"] col_def_rows[ksname][cfname].append(row) @@ -32,13 +33,13 @@ class Metadata(object): if not table: # ks_results is not None added_keyspaces = set() - for row in ks_results: + for row in ks_results.results: keyspace_meta = self._build_keyspace_metadata(row) ksname = keyspace_meta.name if ksname in cf_def_rows: for table_row in cf_def_rows[keyspace_meta.name]: table_meta = self._build_table_metadata( - keyspace_meta, table_row, col_def_rows[keyspace_meta.name]) + keyspace_meta, table_row, col_def_rows[keyspace_meta.name]) keyspace_meta.tables[table_meta.name] = table_meta added_keyspaces.add(keyspace_meta.name) @@ -67,7 +68,7 @@ class Metadata(object): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_class = row["strategy_class"] - strategy_options = row["strategy_options"] + strategy_options = json.loads(row["strategy_options"]) return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) def _build_table_metadata(self, keyspace_metadata, row, col_rows): @@ -143,7 +144,7 @@ class Metadata(object): # value alias (if present) if has_value: validator = cqltypes.lookup_casstype(row["default_validator"]) - if not row.get["key_aliases"]: + if not row.get("key_aliases"): value_alias = "value" else: value_alias = row["value_alias"] @@ -156,7 +157,7 @@ class Metadata(object): column_meta = self._build_column_metadata(table_meta, row) table_meta.columns[column_meta.name] = column_meta - table_meta.options = self._build_table_options(is_compact) + table_meta.options = self._build_table_options(row, is_compact) return table_meta def _build_table_options(self, row, is_compact_storage): diff --git a/cassandra/policies.py b/cassandra/policies.py index 3f8505b9..d0a5f848 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -21,13 +21,15 @@ class LoadBalancingPolicy(object): class RoundRobinPolicy(LoadBalancingPolicy): + def __init__(self): + self._lock = RLock() + def populate(self, cluster, hosts): self._live_hosts = set(hosts) if len(hosts) == 1: self._position = 0 else: self._position = randint(0, len(hosts) - 1) - self._lock = RLock() def distance(self, host): return HostDistance.LOCAL diff --git a/cassandra/pool.py b/cassandra/pool.py index b70cce22..1991fdab 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -204,7 +204,7 @@ class HostConnectionPool(object): # TODO potentially use threading.Queue for this core_conns = session.cluster.get_core_connections_per_host(host_distance) - self._connections = [session.connection_factory(host) + self._connections = [session.cluster.connection_factory(host.address) for i in range(core_conns)] self._trash = set() self._open_count = len(self._connections) @@ -249,7 +249,7 @@ class HostConnectionPool(object): least_busy = self._wait_for_conn(timeout) break - least_busy.set_keyspace() # TODO get keyspace from pool state + least_busy.set_keyspace("keyspace") # TODO get keyspace from pool state return least_busy def _create_new_connection(self): @@ -283,16 +283,16 @@ class HostConnectionPool(object): return False def _await_available_conn(self, timeout): - with self._available_conn_condition: - self._available_conn_condition.wait(timeout) + with self._conn_available_condition: + self._conn_available_condition.wait(timeout) def _signal_available_conn(self): - with self._available_conn_condition: - self._available_conn_condition.notify() + with self._conn_available_condition: + self._conn_available_condition.notify() def _signal_all_available_conn(self): - with self._available_conn_condition: - self._available_conn_condition.notify_all() + with self._conn_available_condition: + self._conn_available_condition.notify_all() def _wait_for_conn(self, timeout): start = time.time() diff --git a/cassandra/query.py b/cassandra/query.py index 42b26bb0..0da16759 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -12,10 +12,8 @@ class Query(object): class SimpleStatement(Query): - query = None - - def __init__(self, query): - self.query = query + def __init__(self, query_string): + self._query_string = query_string self._routing_key = None @property @@ -26,3 +24,7 @@ class SimpleStatement(Query): def set_routing_key(self, value): self._routing_key = "".join(struct.pack("HsB", len(component), component, 0) for component in value) + + @property + def query_string(self): + return self._query_string