574 lines
20 KiB
Python
574 lines
20 KiB
Python
import time
|
|
from threading import Lock, RLock, Thread
|
|
import Queue
|
|
|
|
from futures import ThreadPoolExecutor
|
|
|
|
from connection import Connection
|
|
from decoder import ConsistencyLevel, QueryMessage
|
|
from metadata import Metadata
|
|
from policies import (RoundRobinPolicy, SimpleConvictionPolicy,
|
|
ExponentialReconnectionPolicy, HostDistance)
|
|
from query import SimpleStatement
|
|
from pool import (ConnectionException, BusyConnectionException,
|
|
AuthenticationException, _ReconnectionHandler,
|
|
_HostReconnectionHandler)
|
|
|
|
# TODO: we might want to make this configurable
|
|
MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000
|
|
|
|
class Session(object):
|
|
|
|
def __init__(self, cluster, hosts):
|
|
self.cluster = cluster
|
|
self.hosts = hosts
|
|
|
|
self._lock = RLock()
|
|
self._is_shutdown = False
|
|
self._pools = {}
|
|
self._load_balancer = RoundRobinPolicy()
|
|
|
|
def execute(self, query):
|
|
if isinstance(query, basestring):
|
|
query = SimpleStatement(query)
|
|
|
|
def execute_async(self, query):
|
|
if isinstance(query, basestring):
|
|
query = SimpleStatement(query)
|
|
|
|
qmsg = QueryMessage(query=query.query, consistency_level=query.consistency_level)
|
|
return self._execute_query(qmsg, query)
|
|
|
|
def prepare(self, query):
|
|
pass
|
|
|
|
def shutdown(self):
|
|
self.cluster.shutdown()
|
|
|
|
def _execute_query(self, message, query):
|
|
if query.tracing_enabled:
|
|
# TODO enable tracing on the message
|
|
pass
|
|
|
|
errors = {}
|
|
query_plan = self._load_balancer.make_query_plan(query)
|
|
for host in query_plan:
|
|
try:
|
|
result = self._query(host)
|
|
if result:
|
|
return
|
|
except Exception, exc:
|
|
errors[host] = exc
|
|
|
|
def _query(self, host, query):
|
|
pool = self._pools.get(host)
|
|
if not pool or pool.is_shutdown:
|
|
return False
|
|
|
|
DEFAULT_MIN_REQUESTS = 25
|
|
DEFAULT_MAX_REQUESTS = 100
|
|
|
|
DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2
|
|
DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8
|
|
|
|
DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1
|
|
DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2
|
|
|
|
class _Scheduler(object):
|
|
|
|
def __init__(self, executor):
|
|
self._scheduled = Queue.PriorityQueue()
|
|
self._executor = executor
|
|
|
|
t = Thread(target=self.run, name="Task Scheduler")
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
# TODO add a shutdown method to stop processing the queue?
|
|
|
|
def schedule(self, delay, fn, *args, **kwargs):
|
|
run_at = time.time() + delay
|
|
self._scheduled.put_nowait((run_at, (fn, args, kwargs)))
|
|
|
|
def run(self):
|
|
while True:
|
|
try:
|
|
while True:
|
|
run_at, task = self._scheduled.get(block=True, timeout=None)
|
|
if run_at <= time.time():
|
|
fn, args, kwargs = task
|
|
self._executor.submit(fn, *args, **kwargs)
|
|
else:
|
|
self._scheduled.put_nowait((run_at, task))
|
|
break
|
|
except Queue.empty:
|
|
pass
|
|
|
|
time.sleep(0.1)
|
|
|
|
|
|
class Cluster(object):
|
|
|
|
port = 9042
|
|
|
|
auth_provider = None
|
|
|
|
load_balancing_policy = None
|
|
reconnection_policy = ExponentialReconnectionPolicy(2 * 1000, 5 * 60 * 1000)
|
|
retry_policy = None
|
|
|
|
compression = None
|
|
metrics_enabled = False
|
|
socket_options = None
|
|
|
|
conviction_policy_factory = SimpleConvictionPolicy
|
|
|
|
def __init__(self, contact_points):
|
|
self.contact_points = contact_points
|
|
self.sessions = set()
|
|
self.metadata = Metadata(self)
|
|
|
|
self._min_requests_per_connection = {
|
|
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
|
|
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
|
|
}
|
|
|
|
self._max_requests_per_connection = {
|
|
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
|
|
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
|
|
}
|
|
|
|
self._core_connections_per_host = {
|
|
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
|
|
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
|
|
}
|
|
|
|
self._max_connections_per_host = {
|
|
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
|
|
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
|
|
}
|
|
|
|
# TODO real factory based on config
|
|
self._connection_factory = Connection
|
|
|
|
# TODO make the pool size configurable somewhere
|
|
self.executor = ThreadPoolExecutor(max_workers=3)
|
|
self.scheduler = _Scheduler(self.executor)
|
|
|
|
self._is_shutdown = False
|
|
self._lock = Lock()
|
|
|
|
self._control_connection = ControlConnection(self, self.metadata)
|
|
try:
|
|
self._control_connection.connect()
|
|
except:
|
|
self.shutdown()
|
|
raise
|
|
|
|
def get_min_requests_per_connection(self, host_distance):
|
|
return self._min_requests_per_connection[host_distance]
|
|
|
|
def set_min_requests_per_connection(self, host_distance, min_requests):
|
|
self._min_requests_per_connection[host_distance] = min_requests
|
|
|
|
def get_max_requests_per_connection(self, host_distance):
|
|
return self._max_requests_per_connection[host_distance]
|
|
|
|
def set_max_requests_per_connection(self, host_distance, max_requests):
|
|
self._max_requests_per_connection[host_distance] = max_requests
|
|
|
|
def get_core_connections_per_host(self, host_distance):
|
|
return self._core_connections_per_host[host_distance]
|
|
|
|
def set_core_connections_per_host(self, host_distance, core_connections):
|
|
old = self._core_connections_per_host[host_distance]
|
|
self._core_connections_per_host[host_distance] = core_connections
|
|
if old < core_connections:
|
|
self.ensure_pool_sizing()
|
|
|
|
def get_max_connections_per_host(self, host_distance):
|
|
return self._max_connections_per_host[host_distance]
|
|
|
|
def set_max_connections_per_host(self, host_distance, max_connections):
|
|
self._max_connections_per_host[host_distance] = max_connections
|
|
|
|
def connect(self, keyspace=None):
|
|
# TODO set keyspace if not None
|
|
return self._new_session()
|
|
|
|
def shutdown(self):
|
|
with self._lock:
|
|
if self._is_shutdown:
|
|
return
|
|
else:
|
|
self._is_shutdown = True
|
|
|
|
self._control_connection.shutdown()
|
|
|
|
for session in self.sessions:
|
|
session.shutdown()
|
|
|
|
self.executor.shutdown()
|
|
|
|
def _new_session(self):
|
|
session = Session(self, self.metadata.all_hosts())
|
|
self.sessions.add(session)
|
|
return session
|
|
|
|
def on_up(self, host):
|
|
reconnector = host.get_and_set_reconnection_handler(None)
|
|
if reconnector:
|
|
reconnector.cancel()
|
|
|
|
# TODO prepareAllQueries(host)
|
|
|
|
self._control_connection.on_up(host)
|
|
for session in self.sessions:
|
|
session.on_up(host)
|
|
|
|
def on_down(self, host):
|
|
self._control_connection.on_down(host)
|
|
for session in self.sessions:
|
|
session.on_down(host)
|
|
|
|
schedule = self.reconnection_policy.new_schedule()
|
|
reconnector = _HostReconnectionHandler(
|
|
host, self._connection_factory, self.scheduler, schedule,
|
|
callback=host.get_and_set_reconnection_handler,
|
|
callback_kwargs=dict(new_handler=None))
|
|
|
|
old_reconnector = host.get_and_set_reconnection_handler(reconnector)
|
|
if old_reconnector:
|
|
old_reconnector.cancel()
|
|
|
|
reconnector.start()
|
|
|
|
def on_add(self, host):
|
|
self.prepare_all_queries(host)
|
|
self._control_connection.on_add(host)
|
|
for session in self.sessions: # TODO need to copy/lock?
|
|
session.on_add(host)
|
|
|
|
def on_remove(self, host):
|
|
self._control_connection.on_remove(host)
|
|
for session in self.sessions:
|
|
session.on_remove(host)
|
|
|
|
def add_host(self, address, signal):
|
|
new_host = self.metadata.add_host(address)
|
|
if new_host and signal:
|
|
self.on_add(new_host)
|
|
return new_host
|
|
|
|
def remove_host(self, host):
|
|
if host and self.metdata.remove_host(host):
|
|
self.on_remove(host)
|
|
|
|
def ensure_core_connections(self):
|
|
for session in self.session:
|
|
for pool in session._pools.values():
|
|
pool.ensure_core_connections()
|
|
|
|
|
|
class NoHostAvailable(Exception):
|
|
pass
|
|
|
|
|
|
class _ControlReconnectionHandler(_ReconnectionHandler):
|
|
|
|
def __init__(self, control_connection, *args, **kwargs):
|
|
_ReconnectionHandler.__init__(self, *args, **kwargs)
|
|
self.control_connection = control_connection
|
|
|
|
def try_reconnect(self):
|
|
return self.control_connection._reconnect_internal()
|
|
|
|
def on_reconnection(self, connection):
|
|
self.control_connection._set_new_connection(connection)
|
|
|
|
def on_exception(self, exc, next_delay):
|
|
# TODO only overridden to add logging, so add logging
|
|
if isinstance(exc, AuthenticationException):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
class ControlConnection(object):
|
|
|
|
_SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces"
|
|
_SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
|
|
_SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
|
|
|
|
_SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address FROM system.peers"
|
|
_SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner FROM system.local WHERE key='local'"
|
|
|
|
_SELECT_SCHEMA_PEERS = "SELECT rpc_address, schema_version FROM system.peers"
|
|
_SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'"
|
|
|
|
def __init__(self, cluster, metadata):
|
|
self._cluster = cluster
|
|
self._balancing_policy = RoundRobinPolicy()
|
|
self._balancing_policy.populate(cluster, metadata.all_hosts())
|
|
self._reconnection_policy = ExponentialReconnectionPolicy(2 * 1000, 5 * 60 * 1000)
|
|
self._connection = None
|
|
|
|
self._lock = RLock()
|
|
|
|
self._reconnection_handler = None
|
|
self._reconnection_lock = RLock()
|
|
|
|
self._is_shutdown = False
|
|
|
|
def connect(self):
|
|
if self._is_shutdown:
|
|
return
|
|
|
|
self._set_new_connection(self._reconnect_internal())
|
|
|
|
def _set_new_connection(self, conn):
|
|
with self._lock:
|
|
old = self._connection
|
|
self._connection = conn
|
|
|
|
if old and not old.is_closed(): # TODO is_closed() may not exist
|
|
old.close()
|
|
|
|
def _reconnect_internal(self):
|
|
errors = {}
|
|
for host in self._balancing_policy:
|
|
try:
|
|
return self._try_connect(host)
|
|
except ConnectionException, exc:
|
|
errors[host.address] = exc
|
|
host.monitor.signal_connection_failure(exc)
|
|
except Exception, exc:
|
|
errors[host.address] = exc
|
|
|
|
raise NoHostAvailable("Unable to connect to any servers", errors)
|
|
|
|
def _try_connect(self, host):
|
|
connection = self._cluster.connection_factory.open(host)
|
|
connection.register_watchers({
|
|
"TOPOLOGY_CHANGE": self._handle_topology_change,
|
|
"STATUS_CHANGE": self._handle_status_change,
|
|
"SCHEMA_CHANGE": self._handle_schema_change
|
|
})
|
|
|
|
self.refresh_node_list_and_token_map()
|
|
self.refresh_schema()
|
|
return connection
|
|
|
|
def reconnect(self):
|
|
if self._is_shutdown:
|
|
return
|
|
|
|
try:
|
|
self._set_new_connection(self._reconnect_internal())
|
|
except NoHostAvailable:
|
|
schedule = self._reconnection_policy.new_schedule()
|
|
with self._reconnection_lock:
|
|
if self._reconnection_handler:
|
|
self._reconnection_handler.cancel()
|
|
self._reconnection_handler = _ControlReconnectionHandler(
|
|
self, self._cluster.scheduler, schedule,
|
|
callback=self._get_and_set_reconnection_handler,
|
|
callback_kwargs=dict(new_handler=None))
|
|
self._reconnection_handler.start()
|
|
|
|
def _get_and_set_reconnection_handler(self, new_handler):
|
|
with self._reconnection_lock:
|
|
if self._reconnection_handler:
|
|
return self._reconnection_handler
|
|
else:
|
|
self._reconnection_handler = new_handler
|
|
return None
|
|
|
|
def shutdown(self):
|
|
self._is_shutdown = True
|
|
with self._lock:
|
|
if self._connection:
|
|
self._connection.close()
|
|
|
|
def refresh_schema(self, keyspace=None, table=None):
|
|
where_clause = ""
|
|
if keyspace:
|
|
where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
|
|
if table:
|
|
where_clause += " AND columnfamily_name = '%s'" % (table,)
|
|
|
|
cl = ConsistencyLevel.ONE
|
|
if table:
|
|
ks_query = None
|
|
else:
|
|
ks_query = QueryMessage(query=self.SELECT_KEYSPACES + where_clause, consistency_level=cl)
|
|
cf_query = QueryMessage(query=self.SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
|
|
col_query = QueryMessage(query=self.SELECT_COLUMNS + where_clause, consistency_level=cl)
|
|
|
|
if ks_query:
|
|
ks_result, cf_result, col_result = self._connection.wait_for_requests(ks_query, cf_query, col_query)
|
|
else:
|
|
ks_result = None
|
|
cf_result, col_result = self._connection.wait_for_requests(cf_query, col_query)
|
|
|
|
self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result)
|
|
|
|
def refresh_node_list_and_token_map(self):
|
|
conn = self._connection
|
|
if not conn:
|
|
return
|
|
|
|
cl = ConsistencyLevel.ONE
|
|
peers_query = QueryMessage(query=self.SELECT_PEERS, consistency_level=cl)
|
|
local_query = QueryMessage(query=self.SELECT_LOCAL, consistency_level=cl)
|
|
try:
|
|
peers_result, local_result = conn.wait_for_requests(peers_query, local_query)
|
|
except (ConnectionException, BusyConnectionException):
|
|
self.reconnect()
|
|
|
|
partitioner = None
|
|
token_map = {}
|
|
|
|
if local_result and local_result.rows: # TODO: probably just check local_result.rows
|
|
local_row = local_result.as_dicts()[0]
|
|
cluster_name = local_row["cluster_name"]
|
|
self._cluster.metadata.cluster_name = cluster_name
|
|
|
|
host = self._cluster.metadata.get_host(conn.host)
|
|
if host:
|
|
host.set_location_info(local_row["data_center"], local_row["rack"])
|
|
|
|
partitioner = local_row.get("partitioner")
|
|
tokens = local_row.get("tokens")
|
|
if partitioner and tokens:
|
|
token_map[host] = tokens
|
|
|
|
found_hosts = set()
|
|
|
|
for row in peers_result.as_dicts():
|
|
addr = row.get("rpc_address")
|
|
if not addr:
|
|
addr = row.get("peer")
|
|
elif addr == "0.0.0.0": # TODO handle ipv6 equivalent
|
|
addr = row.get("peer")
|
|
|
|
found_hosts.add(addr)
|
|
|
|
host = self._cluster.metadata.getHost(addr)
|
|
if host is None:
|
|
host = self._cluster.addHost(addr, True)
|
|
host.set_location_info(row.get("data_center"), row.get("rack"))
|
|
|
|
tokens = row.get("tokens")
|
|
if partitioner and tokens:
|
|
token_map[host] = tokens
|
|
|
|
for old_host in self._cluster.metadata.all_hosts():
|
|
if old_host.address != conn.address and \
|
|
old_host.address not in found_hosts:
|
|
self._cluster.remove_host(old_host)
|
|
|
|
if partitioner:
|
|
self._cluster.metadata.rebuild_token_map(partitioner, token_map)
|
|
|
|
def _handle_topology_change(self, event):
|
|
# TODO schedule on executor
|
|
change_type = event["change_type"]
|
|
addr, port = event["address"]
|
|
if change_type == "NEW_NODE":
|
|
# TODO check Host constructor
|
|
self._cluster.scheduler.schedule(1, self.add_host, addr, signal=True)
|
|
elif change_type == "REMOVED_NODE":
|
|
host = self._cluster.metadata.get_host(addr)
|
|
self._cluster.scheduler.schedule(1, self.remove_host, host)
|
|
elif change_type == "MOVED_NODE":
|
|
self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
|
|
|
|
def _handle_status_change(self, event):
|
|
change_type = event["change_type"]
|
|
addr, port = event["address"]
|
|
if change_type == "UP":
|
|
host = self._cluster.metadata.get_host(addr)
|
|
if not host:
|
|
self._cluster.scheduler.schedule(1, self.add_host, addr, signal=True)
|
|
else:
|
|
self._cluster.scheduler.schedule(1, self.on_up, host)
|
|
elif change_type == "DOWN":
|
|
# Ignore down event. Connection will realize a node is dead quicly
|
|
# enough when it writes to it, and there is no point in taking the
|
|
# risk of marking the node down mistakenly because we didn't
|
|
# receive the event in a timely fashion
|
|
pass
|
|
|
|
def _handle_schema_change(self, event):
|
|
change_type, ks, cf = event
|
|
if change_type in ("CREATED", "DROPPED"):
|
|
if not cf:
|
|
self._cluster.executor.submit(self.refresh_schema)
|
|
else:
|
|
self._cluster.executor.submit(self.refresh_schema, ks)
|
|
elif change_type == "UPDATED":
|
|
if not cf:
|
|
self._cluster.executor.submit(self.refresh_schema, ks)
|
|
else:
|
|
self._cluster.executor.submit(self.refresh_schema, ks, cf)
|
|
|
|
def wait_for_schema_agreement(self):
|
|
# TODO is returning True/False the best option for this? Potentially raise Exception?
|
|
start = time.time()
|
|
elapsed = 0
|
|
cl = ConsistencyLevel.ONE
|
|
while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS:
|
|
peers_query = QueryMessage(query=self.SELECT_SCHEMA_PEERS, consistency_level=cl)
|
|
local_query = QueryMessage(query=self.SELECT_SCHEMA_LOCAL, consistency_level=cl)
|
|
peers_result, local_result = self._connection.wait_for_requests(peers_query, local_query)
|
|
|
|
versions = set()
|
|
if local_result and local_result.rows:
|
|
local_row = local_result.as_dicts()[0]
|
|
if local_row.get("schema_version"):
|
|
versions.add(local_row.get("schema_version"))
|
|
|
|
for row in peers_result.as_dicts():
|
|
if not row.get("rpc_address") or not row.get("schema_version"):
|
|
continue
|
|
|
|
rpc = row.get("rpc_address")
|
|
if rpc == "0.0.0.0": # TODO ipv6 check
|
|
rpc = row.get("peer")
|
|
|
|
peer = self._cluster.metadata.get_host(rpc)
|
|
if peer and peer.monitor.is_up:
|
|
versions.add(row.get("schema_version"))
|
|
|
|
if len(versions) == 1:
|
|
return True
|
|
|
|
time.sleep(0.2)
|
|
elapsed = time.time() - start
|
|
|
|
return False
|
|
|
|
@property
|
|
def is_open(self):
|
|
conn = self._connection
|
|
return bool(conn and conn.is_open)
|
|
|
|
def on_up(self, host):
|
|
self._balancing_policy.on_up(host)
|
|
|
|
def on_down(self, host):
|
|
self._balancing_policy.on_down(host)
|
|
|
|
conn = self._connection
|
|
if conn and conn.address == host.address: # TODO and reconnection_attempt is None
|
|
self.reconnect()
|
|
|
|
def on_add(self, host):
|
|
self._balancing_policy.on_add(host)
|
|
self.refresh_node_list_and_token_map()
|
|
|
|
def on_remove(self, host):
|
|
self._balancing_policy.on_remove(host)
|
|
self.refresh_node_list_and_token_map()
|