Files
deb-python-cassandra-driver/cassandra/cluster.py

574 lines
20 KiB
Python

import time
from threading import Lock, RLock, Thread
import Queue
from futures import ThreadPoolExecutor
from connection import Connection
from decoder import ConsistencyLevel, QueryMessage
from metadata import Metadata
from policies import (RoundRobinPolicy, SimpleConvictionPolicy,
ExponentialReconnectionPolicy, HostDistance)
from query import SimpleStatement
from pool import (ConnectionException, BusyConnectionException,
AuthenticationException, _ReconnectionHandler,
_HostReconnectionHandler)
# TODO: we might want to make this configurable
MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000
class Session(object):
def __init__(self, cluster, hosts):
self.cluster = cluster
self.hosts = hosts
self._lock = RLock()
self._is_shutdown = False
self._pools = {}
self._load_balancer = RoundRobinPolicy()
def execute(self, query):
if isinstance(query, basestring):
query = SimpleStatement(query)
def execute_async(self, query):
if isinstance(query, basestring):
query = SimpleStatement(query)
qmsg = QueryMessage(query=query.query, consistency_level=query.consistency_level)
return self._execute_query(qmsg, query)
def prepare(self, query):
pass
def shutdown(self):
self.cluster.shutdown()
def _execute_query(self, message, query):
if query.tracing_enabled:
# TODO enable tracing on the message
pass
errors = {}
query_plan = self._load_balancer.make_query_plan(query)
for host in query_plan:
try:
result = self._query(host)
if result:
return
except Exception, exc:
errors[host] = exc
def _query(self, host, query):
pool = self._pools.get(host)
if not pool or pool.is_shutdown:
return False
DEFAULT_MIN_REQUESTS = 25
DEFAULT_MAX_REQUESTS = 100
DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2
DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8
DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1
DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2
class _Scheduler(object):
def __init__(self, executor):
self._scheduled = Queue.PriorityQueue()
self._executor = executor
t = Thread(target=self.run, name="Task Scheduler")
t.daemon = True
t.start()
# TODO add a shutdown method to stop processing the queue?
def schedule(self, delay, fn, *args, **kwargs):
run_at = time.time() + delay
self._scheduled.put_nowait((run_at, (fn, args, kwargs)))
def run(self):
while True:
try:
while True:
run_at, task = self._scheduled.get(block=True, timeout=None)
if run_at <= time.time():
fn, args, kwargs = task
self._executor.submit(fn, *args, **kwargs)
else:
self._scheduled.put_nowait((run_at, task))
break
except Queue.empty:
pass
time.sleep(0.1)
class Cluster(object):
port = 9042
auth_provider = None
load_balancing_policy = None
reconnection_policy = ExponentialReconnectionPolicy(2 * 1000, 5 * 60 * 1000)
retry_policy = None
compression = None
metrics_enabled = False
socket_options = None
conviction_policy_factory = SimpleConvictionPolicy
def __init__(self, contact_points):
self.contact_points = contact_points
self.sessions = set()
self.metadata = Metadata(self)
self._min_requests_per_connection = {
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
}
self._max_requests_per_connection = {
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
}
self._core_connections_per_host = {
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
}
self._max_connections_per_host = {
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}
# TODO real factory based on config
self._connection_factory = Connection
# TODO make the pool size configurable somewhere
self.executor = ThreadPoolExecutor(max_workers=3)
self.scheduler = _Scheduler(self.executor)
self._is_shutdown = False
self._lock = Lock()
self._control_connection = ControlConnection(self, self.metadata)
try:
self._control_connection.connect()
except:
self.shutdown()
raise
def get_min_requests_per_connection(self, host_distance):
return self._min_requests_per_connection[host_distance]
def set_min_requests_per_connection(self, host_distance, min_requests):
self._min_requests_per_connection[host_distance] = min_requests
def get_max_requests_per_connection(self, host_distance):
return self._max_requests_per_connection[host_distance]
def set_max_requests_per_connection(self, host_distance, max_requests):
self._max_requests_per_connection[host_distance] = max_requests
def get_core_connections_per_host(self, host_distance):
return self._core_connections_per_host[host_distance]
def set_core_connections_per_host(self, host_distance, core_connections):
old = self._core_connections_per_host[host_distance]
self._core_connections_per_host[host_distance] = core_connections
if old < core_connections:
self.ensure_pool_sizing()
def get_max_connections_per_host(self, host_distance):
return self._max_connections_per_host[host_distance]
def set_max_connections_per_host(self, host_distance, max_connections):
self._max_connections_per_host[host_distance] = max_connections
def connect(self, keyspace=None):
# TODO set keyspace if not None
return self._new_session()
def shutdown(self):
with self._lock:
if self._is_shutdown:
return
else:
self._is_shutdown = True
self._control_connection.shutdown()
for session in self.sessions:
session.shutdown()
self.executor.shutdown()
def _new_session(self):
session = Session(self, self.metadata.all_hosts())
self.sessions.add(session)
return session
def on_up(self, host):
reconnector = host.get_and_set_reconnection_handler(None)
if reconnector:
reconnector.cancel()
# TODO prepareAllQueries(host)
self._control_connection.on_up(host)
for session in self.sessions:
session.on_up(host)
def on_down(self, host):
self._control_connection.on_down(host)
for session in self.sessions:
session.on_down(host)
schedule = self.reconnection_policy.new_schedule()
reconnector = _HostReconnectionHandler(
host, self._connection_factory, self.scheduler, schedule,
callback=host.get_and_set_reconnection_handler,
callback_kwargs=dict(new_handler=None))
old_reconnector = host.get_and_set_reconnection_handler(reconnector)
if old_reconnector:
old_reconnector.cancel()
reconnector.start()
def on_add(self, host):
self.prepare_all_queries(host)
self._control_connection.on_add(host)
for session in self.sessions: # TODO need to copy/lock?
session.on_add(host)
def on_remove(self, host):
self._control_connection.on_remove(host)
for session in self.sessions:
session.on_remove(host)
def add_host(self, address, signal):
new_host = self.metadata.add_host(address)
if new_host and signal:
self.on_add(new_host)
return new_host
def remove_host(self, host):
if host and self.metdata.remove_host(host):
self.on_remove(host)
def ensure_core_connections(self):
for session in self.session:
for pool in session._pools.values():
pool.ensure_core_connections()
class NoHostAvailable(Exception):
pass
class _ControlReconnectionHandler(_ReconnectionHandler):
def __init__(self, control_connection, *args, **kwargs):
_ReconnectionHandler.__init__(self, *args, **kwargs)
self.control_connection = control_connection
def try_reconnect(self):
return self.control_connection._reconnect_internal()
def on_reconnection(self, connection):
self.control_connection._set_new_connection(connection)
def on_exception(self, exc, next_delay):
# TODO only overridden to add logging, so add logging
if isinstance(exc, AuthenticationException):
return False
else:
return True
class ControlConnection(object):
_SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces"
_SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
_SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
_SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address FROM system.peers"
_SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner FROM system.local WHERE key='local'"
_SELECT_SCHEMA_PEERS = "SELECT rpc_address, schema_version FROM system.peers"
_SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'"
def __init__(self, cluster, metadata):
self._cluster = cluster
self._balancing_policy = RoundRobinPolicy()
self._balancing_policy.populate(cluster, metadata.all_hosts())
self._reconnection_policy = ExponentialReconnectionPolicy(2 * 1000, 5 * 60 * 1000)
self._connection = None
self._lock = RLock()
self._reconnection_handler = None
self._reconnection_lock = RLock()
self._is_shutdown = False
def connect(self):
if self._is_shutdown:
return
self._set_new_connection(self._reconnect_internal())
def _set_new_connection(self, conn):
with self._lock:
old = self._connection
self._connection = conn
if old and not old.is_closed(): # TODO is_closed() may not exist
old.close()
def _reconnect_internal(self):
errors = {}
for host in self._balancing_policy:
try:
return self._try_connect(host)
except ConnectionException, exc:
errors[host.address] = exc
host.monitor.signal_connection_failure(exc)
except Exception, exc:
errors[host.address] = exc
raise NoHostAvailable("Unable to connect to any servers", errors)
def _try_connect(self, host):
connection = self._cluster.connection_factory.open(host)
connection.register_watchers({
"TOPOLOGY_CHANGE": self._handle_topology_change,
"STATUS_CHANGE": self._handle_status_change,
"SCHEMA_CHANGE": self._handle_schema_change
})
self.refresh_node_list_and_token_map()
self.refresh_schema()
return connection
def reconnect(self):
if self._is_shutdown:
return
try:
self._set_new_connection(self._reconnect_internal())
except NoHostAvailable:
schedule = self._reconnection_policy.new_schedule()
with self._reconnection_lock:
if self._reconnection_handler:
self._reconnection_handler.cancel()
self._reconnection_handler = _ControlReconnectionHandler(
self, self._cluster.scheduler, schedule,
callback=self._get_and_set_reconnection_handler,
callback_kwargs=dict(new_handler=None))
self._reconnection_handler.start()
def _get_and_set_reconnection_handler(self, new_handler):
with self._reconnection_lock:
if self._reconnection_handler:
return self._reconnection_handler
else:
self._reconnection_handler = new_handler
return None
def shutdown(self):
self._is_shutdown = True
with self._lock:
if self._connection:
self._connection.close()
def refresh_schema(self, keyspace=None, table=None):
where_clause = ""
if keyspace:
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
if table:
where_clause += " AND columnfamily_name = '%s'" % (table,)
cl = ConsistencyLevel.ONE
if table:
ks_query = None
else:
ks_query = QueryMessage(query=self.SELECT_KEYSPACES + where_clause, consistency_level=cl)
cf_query = QueryMessage(query=self.SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
col_query = QueryMessage(query=self.SELECT_COLUMNS + where_clause, consistency_level=cl)
if ks_query:
ks_result, cf_result, col_result = self._connection.wait_for_requests(ks_query, cf_query, col_query)
else:
ks_result = None
cf_result, col_result = self._connection.wait_for_requests(cf_query, col_query)
self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result)
def refresh_node_list_and_token_map(self):
conn = self._connection
if not conn:
return
cl = ConsistencyLevel.ONE
peers_query = QueryMessage(query=self.SELECT_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self.SELECT_LOCAL, consistency_level=cl)
try:
peers_result, local_result = conn.wait_for_requests(peers_query, local_query)
except (ConnectionException, BusyConnectionException):
self.reconnect()
partitioner = None
token_map = {}
if local_result and local_result.rows: # TODO: probably just check local_result.rows
local_row = local_result.as_dicts()[0]
cluster_name = local_row["cluster_name"]
self._cluster.metadata.cluster_name = cluster_name
host = self._cluster.metadata.get_host(conn.host)
if host:
host.set_location_info(local_row["data_center"], local_row["rack"])
partitioner = local_row.get("partitioner")
tokens = local_row.get("tokens")
if partitioner and tokens:
token_map[host] = tokens
found_hosts = set()
for row in peers_result.as_dicts():
addr = row.get("rpc_address")
if not addr:
addr = row.get("peer")
elif addr == "0.0.0.0": # TODO handle ipv6 equivalent
addr = row.get("peer")
found_hosts.add(addr)
host = self._cluster.metadata.getHost(addr)
if host is None:
host = self._cluster.addHost(addr, True)
host.set_location_info(row.get("data_center"), row.get("rack"))
tokens = row.get("tokens")
if partitioner and tokens:
token_map[host] = tokens
for old_host in self._cluster.metadata.all_hosts():
if old_host.address != conn.address and \
old_host.address not in found_hosts:
self._cluster.remove_host(old_host)
if partitioner:
self._cluster.metadata.rebuild_token_map(partitioner, token_map)
def _handle_topology_change(self, event):
# TODO schedule on executor
change_type = event["change_type"]
addr, port = event["address"]
if change_type == "NEW_NODE":
# TODO check Host constructor
self._cluster.scheduler.schedule(1, self.add_host, addr, signal=True)
elif change_type == "REMOVED_NODE":
host = self._cluster.metadata.get_host(addr)
self._cluster.scheduler.schedule(1, self.remove_host, host)
elif change_type == "MOVED_NODE":
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
def _handle_status_change(self, event):
change_type = event["change_type"]
addr, port = event["address"]
if change_type == "UP":
host = self._cluster.metadata.get_host(addr)
if not host:
self._cluster.scheduler.schedule(1, self.add_host, addr, signal=True)
else:
self._cluster.scheduler.schedule(1, self.on_up, host)
elif change_type == "DOWN":
# Ignore down event. Connection will realize a node is dead quicly
# enough when it writes to it, and there is no point in taking the
# risk of marking the node down mistakenly because we didn't
# receive the event in a timely fashion
pass
def _handle_schema_change(self, event):
change_type, ks, cf = event
if change_type in ("CREATED", "DROPPED"):
if not cf:
self._cluster.executor.submit(self.refresh_schema)
else:
self._cluster.executor.submit(self.refresh_schema, ks)
elif change_type == "UPDATED":
if not cf:
self._cluster.executor.submit(self.refresh_schema, ks)
else:
self._cluster.executor.submit(self.refresh_schema, ks, cf)
def wait_for_schema_agreement(self):
# TODO is returning True/False the best option for this? Potentially raise Exception?
start = time.time()
elapsed = 0
cl = ConsistencyLevel.ONE
while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS:
peers_query = QueryMessage(query=self.SELECT_SCHEMA_PEERS, consistency_level=cl)
local_query = QueryMessage(query=self.SELECT_SCHEMA_LOCAL, consistency_level=cl)
peers_result, local_result = self._connection.wait_for_requests(peers_query, local_query)
versions = set()
if local_result and local_result.rows:
local_row = local_result.as_dicts()[0]
if local_row.get("schema_version"):
versions.add(local_row.get("schema_version"))
for row in peers_result.as_dicts():
if not row.get("rpc_address") or not row.get("schema_version"):
continue
rpc = row.get("rpc_address")
if rpc == "0.0.0.0": # TODO ipv6 check
rpc = row.get("peer")
peer = self._cluster.metadata.get_host(rpc)
if peer and peer.monitor.is_up:
versions.add(row.get("schema_version"))
if len(versions) == 1:
return True
time.sleep(0.2)
elapsed = time.time() - start
return False
@property
def is_open(self):
conn = self._connection
return bool(conn and conn.is_open)
def on_up(self, host):
self._balancing_policy.on_up(host)
def on_down(self, host):
self._balancing_policy.on_down(host)
conn = self._connection
if conn and conn.address == host.address: # TODO and reconnection_attempt is None
self.reconnect()
def on_add(self, host):
self._balancing_policy.on_add(host)
self.refresh_node_list_and_token_map()
def on_remove(self, host):
self._balancing_policy.on_remove(host)
self.refresh_node_list_and_token_map()