789 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			789 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import time
 | |
| from threading import Lock, RLock, Thread, Event
 | |
| import Queue
 | |
| import logging
 | |
| import traceback
 | |
| 
 | |
| from futures import ThreadPoolExecutor
 | |
| 
 | |
| from connection import Connection
 | |
| from decoder import (ConsistencyLevel, QueryMessage, ResultMessage, ErrorMessage,
 | |
|                      ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableExceptionErrorMessage,
 | |
|                      OverloadedErrorMessage, IsBootstrappingErrorMessage)
 | |
| from metadata import Metadata
 | |
| from policies import (RoundRobinPolicy, SimpleConvictionPolicy,
 | |
|                       ExponentialReconnectionPolicy, HostDistance, RetryPolicy)
 | |
| from query import SimpleStatement
 | |
| from pool import (ConnectionException, BusyConnectionException,
 | |
|                   AuthenticationException, _ReconnectionHandler,
 | |
|                   _HostReconnectionHandler, HostConnectionPool)
 | |
| 
 | |
| # TODO: we might want to make this configurable
 | |
| MAX_SCHEMA_AGREEMENT_WAIT_MS = 10000
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class ResponseFuture(object):
 | |
| 
 | |
|     def __init__(self, session, message, query):
 | |
|         self.session = session
 | |
|         self.message = message
 | |
| 
 | |
|         self.query_plan = session._load_balancer.make_query_plan(query)
 | |
| 
 | |
|         self._req_id = None
 | |
|         self._final_result = None
 | |
|         self._final_exception = None
 | |
|         self._current_host = None
 | |
|         self._current_pool = None
 | |
|         self._connection = None
 | |
|         self._event = Event()
 | |
|         self._errors = {}
 | |
|         self._query_retries = 0
 | |
|         self._callback = self._errback = None
 | |
| 
 | |
|     def send_request(self):
 | |
|         for host in self.query_plan:
 | |
|             req_id = self._query(host)
 | |
|             if req_id:
 | |
|                 self._req_id = req_id
 | |
|                 return
 | |
| 
 | |
|         self._final_exception = NoHostAvailable(self._errors)
 | |
| 
 | |
|     def _query(self, host):
 | |
|         pool = self.session._pools.get(host)
 | |
|         if not pool or pool.is_shutdown:
 | |
|             return None
 | |
| 
 | |
|         connection = None
 | |
|         try:
 | |
|             # TODO get connectTimeout from cluster settings
 | |
|             connection = pool.borrow_connection(timeout=2.0)
 | |
|             request_id = connection.send_msg(self.message, cb=self._set_result)
 | |
|             self._current_host = host
 | |
|             self._current_pool = pool
 | |
|             self._connection = connection
 | |
|             return request_id
 | |
|         except Exception, exc:
 | |
|             self._errors[host] = exc
 | |
|             if connection:
 | |
|                 pool.return_connection(connection)
 | |
|             return None
 | |
| 
 | |
|     def _set_result(self, response):
 | |
|         self._current_pool.return_connection(self._connection)
 | |
| 
 | |
|         if isinstance(response, ResultMessage):
 | |
|             self._set_final_result(response)
 | |
|         elif isinstance(response, ErrorMessage):
 | |
|             retry_policy = self.query.retry_policy # TODO also check manager.configuration
 | |
|                                                    # for retry policy if None
 | |
|             if isinstance(response, ReadTimeoutErrorMessage):
 | |
|                 details = response.recv_error_info()
 | |
|                 retry = retry_policy.on_read_timeout(
 | |
|                     self.query, attempt_num=self._query_retries, **details)
 | |
|             elif isinstance(response, WriteTimeoutErrorMessage):
 | |
|                 details = response.recv_error_info()
 | |
|                 retry = retry_policy.on_write_timeout(
 | |
|                     self.query, attempt_num=self._query_retries, **details)
 | |
|             elif isinstance(response, UnavailableExceptionErrorMessage):
 | |
|                 details = response.recv_error_info()
 | |
|                 retry = retry_policy.on_write_timeout(
 | |
|                     self.query, attempt_num=self._query_retries, **details)
 | |
|             elif isinstance(response, OverloadedErrorMessage):
 | |
|                 # need to retry against a different host here
 | |
|                 self._retry(False, None)
 | |
|             elif isinstance(response, IsBootstrappingErrorMessage):
 | |
|                 # need to retry against a different host here
 | |
|                 self._retry(False, None)
 | |
|             # TODO need to define the PreparedQueryNotFound class
 | |
|             # elif isinstance(response, PreparedQueryNotFound):
 | |
|             #     pass
 | |
|             else:
 | |
|                 pass
 | |
| 
 | |
|             retry_type, consistency = retry
 | |
|             if retry_type == RetryPolicy.RETRY:
 | |
|                 self._query_retries += 1
 | |
|                 self._retry(True, consistency)
 | |
|             elif retry_type == RetryPolicy.RETHROW:
 | |
|                 self._set_final_result(response)
 | |
|             else:  # IGNORE
 | |
|                 self._set_final_result(None)
 | |
|         else:
 | |
|             # we got some other kind of response message
 | |
|             self._set_final_result(response)
 | |
| 
 | |
|     def _set_final_result(self, response):
 | |
|         self._final_result = response
 | |
|         self._event.set()
 | |
|         if self._callback:
 | |
|             fn, args, kwargs = self._callback
 | |
|             fn(response, *args, **kwargs)
 | |
| 
 | |
|     def _set_final_exception(self, response):
 | |
|         self._final_exception = response
 | |
|         self._event.set()
 | |
|         if self._errback:
 | |
|             fn, args, kwargs = self._errback
 | |
|             fn(response, *args, **kwargs)
 | |
| 
 | |
|     def _retry(self, reuse_connection, consistency_level):
 | |
|         self.message.consistency_level = consistency_level
 | |
|         # don't retry on the event loop thread
 | |
|         self.session.submit(self._retry_helper, reuse_connection)
 | |
| 
 | |
|     def _retry_task(self, reuse_connection):
 | |
|         if reuse_connection and self._query(self._current_host):
 | |
|             return
 | |
| 
 | |
|         # otherwise, move onto another host
 | |
|         self.send_request()
 | |
| 
 | |
|     def deliver(self):
 | |
|         if self._final_exception:
 | |
|             raise self._final_exception
 | |
|         elif self._final_result:
 | |
|             return self._final_result
 | |
|         else:
 | |
|             self._event.wait()
 | |
|             if self._final_exception:
 | |
|                 raise self._final_exception
 | |
|             else:
 | |
|                 return self._final_result
 | |
| 
 | |
|     def addCallback(self, fn, *args, **kwargs):
 | |
|         self._callback = (fn, args, kwargs)
 | |
|         return self
 | |
| 
 | |
|     def addErrback(self, fn, *args, **kwargs):
 | |
|         self._errback = (fn, args, kwargs)
 | |
|         return self
 | |
| 
 | |
|     def addCallbacks(self,
 | |
|             callback, errback,
 | |
|             callback_args=(), callback_kwargs=None,
 | |
|             errback_args=(), errback_kwargs=None):
 | |
|         self._callback = (callback, callback_args, callback_kwargs | {})
 | |
|         self._errback = (errback, errback_args, errback_kwargs | {})
 | |
| 
 | |
| 
 | |
| class Session(object):
 | |
| 
 | |
|     def __init__(self, cluster, hosts):
 | |
|         self.cluster = cluster
 | |
|         self.hosts = hosts
 | |
| 
 | |
|         self._lock = RLock()
 | |
|         self._is_shutdown = False
 | |
|         self._pools = {}
 | |
|         self._load_balancer = RoundRobinPolicy()
 | |
|         self._load_balancer.populate(cluster, hosts)
 | |
| 
 | |
|         for host in hosts:
 | |
|             self.add_host(host)
 | |
| 
 | |
|     def execute(self, query):
 | |
|         future = self.execute_async(query)
 | |
|         return future.deliver()
 | |
| 
 | |
|     def execute_async(self, query):
 | |
|         if isinstance(query, basestring):
 | |
|             query = SimpleStatement(query)
 | |
| 
 | |
|         # TODO bound statements need to be handled differently
 | |
|         message = QueryMessage(query=query.query_string, consistency_level=query.consistency_level)
 | |
| 
 | |
|         if query.tracing_enabled:
 | |
|             # TODO enable tracing on the message
 | |
|             pass
 | |
| 
 | |
|         future = ResponseFuture(self, message, query)
 | |
|         future.send_request()
 | |
|         return future
 | |
| 
 | |
|     def prepare(self, query):
 | |
|         pass
 | |
| 
 | |
|     def shutdown(self):
 | |
|         self.cluster.shutdown()
 | |
| 
 | |
|     def add_host(self, host):
 | |
|         distance = self._load_balancer.distance(host)
 | |
|         if distance == HostDistance.IGNORED:
 | |
|             return self._pools.get(host)
 | |
|         else:
 | |
|             try:
 | |
|                 new_pool = HostConnectionPool(host, distance, self)
 | |
|             except AuthenticationException, auth_exc:
 | |
|                 conn_exc = ConnectionException(str(auth_exc), host=host)
 | |
|                 host.monitor.signal_connection_failure(conn_exc)
 | |
|                 return self._pools.get(host)
 | |
|             except ConnectionException, conn_exc:
 | |
|                 host.monitor.signal_connection_failure(conn_exc)
 | |
|                 return self._pools.get(host)
 | |
| 
 | |
|             previous = self._pools.get(host)
 | |
|             self._pools[host] = new_pool
 | |
|             return previous
 | |
| 
 | |
|     def on_up(self, host):
 | |
|         previous_pool = self.add_host(host)
 | |
|         self._load_balancer.on_up(host)
 | |
|         if previous_pool:
 | |
|             previous_pool.shutdown()
 | |
| 
 | |
|     def on_down(self, host):
 | |
|         self._load_balancer.on_down(host)
 | |
|         pool = self._pools.pop(host, None)
 | |
|         if pool:
 | |
|             pool.shutdown()
 | |
| 
 | |
|         for host in self._cluster.metadata.all_hosts():
 | |
|             if not host.monitor.is_up:
 | |
|                 continue
 | |
| 
 | |
|             distance = self._load_balancer.distance(host)
 | |
|             if distance != HostDistance.IGNORED:
 | |
|                 pool = self._pools.get(host)
 | |
|                 if not pool:
 | |
|                     self.add_host(host)
 | |
|                 else:
 | |
|                     pool.host_distance = distance
 | |
| 
 | |
|     def on_add(self, host):
 | |
|         previous_pool = self.add_host(host)
 | |
|         self._load_balancer.on_add(host)
 | |
|         if previous_pool:
 | |
|             previous_pool.shutdown()
 | |
| 
 | |
|     def on_remove(self, host):
 | |
|         self._load_balancer.on_remove(host)
 | |
|         pool = self._pools.pop(host)
 | |
|         if pool:
 | |
|             pool.shutdown()
 | |
| 
 | |
|     def set_keyspace(self, keyspace):
 | |
|         pass
 | |
| 
 | |
|     def submit(self, fn, *args, **kwargs):
 | |
|         return self.cluster.executor.submit(fn, *args, **kwargs)
 | |
| 
 | |
| DEFAULT_MIN_REQUESTS = 25
 | |
| DEFAULT_MAX_REQUESTS = 100
 | |
| 
 | |
| DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2
 | |
| DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8
 | |
| 
 | |
| DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1
 | |
| DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2
 | |
| 
 | |
| class _Scheduler(object):
 | |
| 
 | |
|     def __init__(self, executor):
 | |
|         self._scheduled = Queue.PriorityQueue()
 | |
|         self._executor = executor
 | |
| 
 | |
|         t = Thread(target=self.run, name="Task Scheduler")
 | |
|         t.daemon = True
 | |
|         t.start()
 | |
| 
 | |
|     # TODO add a shutdown method to stop processing the queue?
 | |
| 
 | |
|     def schedule(self, delay, fn, *args, **kwargs):
 | |
|         run_at = time.time() + delay
 | |
|         self._scheduled.put_nowait((run_at, (fn, args, kwargs)))
 | |
| 
 | |
|     def run(self):
 | |
|         while True:
 | |
|             try:
 | |
|                 while True:
 | |
|                     run_at, task = self._scheduled.get(block=True, timeout=None)
 | |
|                     if run_at <= time.time():
 | |
|                         fn, args, kwargs = task
 | |
|                         self._executor.submit(fn, *args, **kwargs)
 | |
|                     else:
 | |
|                         self._scheduled.put_nowait((run_at, task))
 | |
|                         break
 | |
|             except Queue.empty:
 | |
|                 pass
 | |
| 
 | |
|             time.sleep(0.1)
 | |
| 
 | |
| 
 | |
| class Cluster(object):
 | |
| 
 | |
|     port = 9042
 | |
| 
 | |
|     auth_provider = None
 | |
| 
 | |
|     load_balancing_policy = None
 | |
|     reconnection_policy = ExponentialReconnectionPolicy(2 * 1000, 5 * 60 * 1000)
 | |
|     retry_policy = None
 | |
| 
 | |
|     compression = None
 | |
|     metrics_enabled = False
 | |
|     socket_options = None
 | |
| 
 | |
|     conviction_policy_factory = SimpleConvictionPolicy
 | |
| 
 | |
|     def __init__(self, contact_points):
 | |
|         self.contact_points = contact_points
 | |
|         self.sessions = set()
 | |
|         self.metadata = Metadata(self)
 | |
| 
 | |
|         self._min_requests_per_connection = {
 | |
|             HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
 | |
|             HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
 | |
|         }
 | |
| 
 | |
|         self._max_requests_per_connection = {
 | |
|             HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
 | |
|             HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
 | |
|         }
 | |
| 
 | |
|         self._core_connections_per_host = {
 | |
|             HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
 | |
|             HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
 | |
|         }
 | |
| 
 | |
|         self._max_connections_per_host = {
 | |
|             HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
 | |
|             HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
 | |
|         }
 | |
| 
 | |
|         # TODO real factory based on config
 | |
|         self.connection_factory = Connection.factory
 | |
| 
 | |
|         # TODO make the pool size configurable somewhere
 | |
|         self.executor = ThreadPoolExecutor(max_workers=3)
 | |
|         self.scheduler = _Scheduler(self.executor)
 | |
| 
 | |
|         self._is_shutdown = False
 | |
|         self._lock = Lock()
 | |
| 
 | |
|         for address in contact_points:
 | |
|             self.add_host(address, signal=False)
 | |
| 
 | |
|         self._control_connection = ControlConnection(self, self.metadata)
 | |
|         try:
 | |
|             self._control_connection.connect()
 | |
|         except:
 | |
|             self.shutdown()
 | |
|             raise
 | |
| 
 | |
|     def get_min_requests_per_connection(self, host_distance):
 | |
|         return self._min_requests_per_connection[host_distance]
 | |
| 
 | |
|     def set_min_requests_per_connection(self, host_distance, min_requests):
 | |
|         self._min_requests_per_connection[host_distance] = min_requests
 | |
| 
 | |
|     def get_max_requests_per_connection(self, host_distance):
 | |
|         return self._max_requests_per_connection[host_distance]
 | |
| 
 | |
|     def set_max_requests_per_connection(self, host_distance, max_requests):
 | |
|         self._max_requests_per_connection[host_distance] = max_requests
 | |
| 
 | |
|     def get_core_connections_per_host(self, host_distance):
 | |
|         return self._core_connections_per_host[host_distance]
 | |
| 
 | |
|     def set_core_connections_per_host(self, host_distance, core_connections):
 | |
|         old = self._core_connections_per_host[host_distance]
 | |
|         self._core_connections_per_host[host_distance] = core_connections
 | |
|         if old < core_connections:
 | |
|             self.ensure_pool_sizing()
 | |
| 
 | |
|     def get_max_connections_per_host(self, host_distance):
 | |
|         return self._max_connections_per_host[host_distance]
 | |
| 
 | |
|     def set_max_connections_per_host(self, host_distance, max_connections):
 | |
|         self._max_connections_per_host[host_distance] = max_connections
 | |
| 
 | |
|     def connect(self, keyspace=None):
 | |
|         # TODO set keyspace if not None
 | |
|         return self._new_session()
 | |
| 
 | |
|     def shutdown(self):
 | |
|         with self._lock:
 | |
|             if self._is_shutdown:
 | |
|                 return
 | |
|             else:
 | |
|                 self._is_shutdown = True
 | |
| 
 | |
|         self._control_connection.shutdown()
 | |
| 
 | |
|         for session in self.sessions:
 | |
|             session.shutdown()
 | |
| 
 | |
|         self.executor.shutdown()
 | |
| 
 | |
|     def _new_session(self):
 | |
|         session = Session(self, self.metadata.all_hosts())
 | |
|         self.sessions.add(session)
 | |
|         return session
 | |
| 
 | |
|     def on_up(self, host):
 | |
|         reconnector = host.get_and_set_reconnection_handler(None)
 | |
|         if reconnector:
 | |
|             reconnector.cancel()
 | |
| 
 | |
|         # TODO prepareAllQueries(host)
 | |
| 
 | |
|         self._control_connection.on_up(host)
 | |
|         for session in self.sessions:
 | |
|             session.on_up(host)
 | |
| 
 | |
|     def on_down(self, host):
 | |
|         self._control_connection.on_down(host)
 | |
|         for session in self.sessions:
 | |
|             session.on_down(host)
 | |
| 
 | |
|         schedule = self.reconnection_policy.new_schedule()
 | |
|         reconnector = _HostReconnectionHandler(
 | |
|             host, self.connection_factory, self.scheduler, schedule,
 | |
|             callback=host.get_and_set_reconnection_handler,
 | |
|             callback_kwargs=dict(new_handler=None))
 | |
| 
 | |
|         old_reconnector = host.get_and_set_reconnection_handler(reconnector)
 | |
|         if old_reconnector:
 | |
|             old_reconnector.cancel()
 | |
| 
 | |
|         reconnector.start()
 | |
| 
 | |
|     def on_add(self, host):
 | |
|         self.prepare_all_queries(host)
 | |
|         self._control_connection.on_add(host)
 | |
|         for session in self.sessions:  # TODO need to copy/lock?
 | |
|             session.on_add(host)
 | |
| 
 | |
|     def on_remove(self, host):
 | |
|         self._control_connection.on_remove(host)
 | |
|         for session in self.sessions:
 | |
|             session.on_remove(host)
 | |
| 
 | |
|     def add_host(self, address, signal):
 | |
|         new_host = self.metadata.add_host(address)
 | |
|         if new_host and signal:
 | |
|             self.on_add(new_host)
 | |
|         return new_host
 | |
| 
 | |
|     def remove_host(self, host):
 | |
|         if host and self.metdata.remove_host(host):
 | |
|             self.on_remove(host)
 | |
| 
 | |
|     def ensure_core_connections(self):
 | |
|         for session in self.session:
 | |
|             for pool in session._pools.values():
 | |
|                 pool.ensure_core_connections()
 | |
| 
 | |
| 
 | |
| class NoHostAvailable(Exception):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class _ControlReconnectionHandler(_ReconnectionHandler):
 | |
| 
 | |
|     def __init__(self, control_connection, *args, **kwargs):
 | |
|         _ReconnectionHandler.__init__(self, *args, **kwargs)
 | |
|         self.control_connection = control_connection
 | |
| 
 | |
|     def try_reconnect(self):
 | |
|         return self.control_connection._reconnect_internal()
 | |
| 
 | |
|     def on_reconnection(self, connection):
 | |
|         self.control_connection._set_new_connection(connection)
 | |
| 
 | |
|     def on_exception(self, exc, next_delay):
 | |
|         # TODO only overridden to add logging, so add logging
 | |
|         if isinstance(exc, AuthenticationException):
 | |
|             return False
 | |
|         else:
 | |
|             return True
 | |
| 
 | |
| 
 | |
| class ControlConnection(object):
 | |
| 
 | |
|     _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces"
 | |
|     _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies"
 | |
|     _SELECT_COLUMNS = "SELECT * FROM system.schema_columns"
 | |
| 
 | |
|     _SELECT_PEERS = "SELECT peer, data_center, rack, tokens, rpc_address FROM system.peers"
 | |
|     _SELECT_LOCAL = "SELECT cluster_name, data_center, rack, tokens, partitioner FROM system.local WHERE key='local'"
 | |
| 
 | |
|     _SELECT_SCHEMA_PEERS = "SELECT rpc_address, schema_version FROM system.peers"
 | |
|     _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'"
 | |
| 
 | |
|     def __init__(self, cluster, metadata):
 | |
|         self._cluster = cluster
 | |
|         self._balancing_policy = RoundRobinPolicy()
 | |
|         self._balancing_policy.populate(cluster, metadata.all_hosts())
 | |
|         self._reconnection_policy = ExponentialReconnectionPolicy(2 * 1000, 5 * 60 * 1000)
 | |
|         self._connection = None
 | |
| 
 | |
|         self._lock = RLock()
 | |
| 
 | |
|         self._reconnection_handler = None
 | |
|         self._reconnection_lock = RLock()
 | |
| 
 | |
|         self._is_shutdown = False
 | |
| 
 | |
|     def connect(self):
 | |
|         if self._is_shutdown:
 | |
|             return
 | |
| 
 | |
|         self._set_new_connection(self._reconnect_internal())
 | |
| 
 | |
|     def _set_new_connection(self, conn):
 | |
|         with self._lock:
 | |
|             old = self._connection
 | |
|             self._connection = conn
 | |
| 
 | |
|         if old and not old.is_closed():  # TODO is_closed() may not exist
 | |
|             old.close()
 | |
| 
 | |
|     def _reconnect_internal(self):
 | |
|         errors = {}
 | |
|         for host in self._balancing_policy.make_query_plan():
 | |
|             try:
 | |
|                 return self._try_connect(host)
 | |
|             except ConnectionException, exc:
 | |
|                 errors[host.address] = exc
 | |
|                 host.monitor.signal_connection_failure(exc)
 | |
|                 log.error("Error reconnecting control connection: %s", traceback.format_exc())
 | |
|             except Exception, exc:
 | |
|                 errors[host.address] = exc
 | |
|                 log.error("Error reconnecting control connection: %s", traceback.format_exc())
 | |
| 
 | |
|         raise NoHostAvailable("Unable to connect to any servers", errors)
 | |
| 
 | |
|     def _try_connect(self, host):
 | |
|         connection = self._cluster.connection_factory(host.address)
 | |
| 
 | |
|         connection.register_watchers({
 | |
|             "TOPOLOGY_CHANGE": self._handle_topology_change,
 | |
|             "STATUS_CHANGE": self._handle_status_change,
 | |
|             "SCHEMA_CHANGE": self._handle_schema_change
 | |
|         })
 | |
| 
 | |
|         self.refresh_node_list_and_token_map()
 | |
|         self._refresh_schema(connection)
 | |
|         return connection
 | |
| 
 | |
|     def reconnect(self):
 | |
|         if self._is_shutdown:
 | |
|             return
 | |
| 
 | |
|         try:
 | |
|             self._set_new_connection(self._reconnect_internal())
 | |
|         except NoHostAvailable:
 | |
|             schedule = self._reconnection_policy.new_schedule()
 | |
|             with self._reconnection_lock:
 | |
|                 if self._reconnection_handler:
 | |
|                     self._reconnection_handler.cancel()
 | |
|                 self._reconnection_handler = _ControlReconnectionHandler(
 | |
|                     self, self._cluster.scheduler, schedule,
 | |
|                     callback=self._get_and_set_reconnection_handler,
 | |
|                     callback_kwargs=dict(new_handler=None))
 | |
|                 self._reconnection_handler.start()
 | |
| 
 | |
|     def _get_and_set_reconnection_handler(self, new_handler):
 | |
|         with self._reconnection_lock:
 | |
|             if self._reconnection_handler:
 | |
|                 return self._reconnection_handler
 | |
|             else:
 | |
|                 self._reconnection_handler = new_handler
 | |
|                 return None
 | |
| 
 | |
|     def shutdown(self):
 | |
|         self._is_shutdown = True
 | |
|         with self._lock:
 | |
|             if self._connection:
 | |
|                 self._connection.close()
 | |
| 
 | |
|     def refresh_schema(self, keyspace=None, table=None):
 | |
|         return self._refresh_schema(self._connection, keyspace, table)
 | |
| 
 | |
|     def _refresh_schema(self, connection, keyspace=None, table=None):
 | |
|         where_clause = ""
 | |
|         if keyspace:
 | |
|             where_clause = " WHERE keyspace_name = '%s'" % (keyspace,)
 | |
|             if table:
 | |
|                 where_clause += " AND columnfamily_name = '%s'" % (table,)
 | |
| 
 | |
|         cl = ConsistencyLevel.ONE
 | |
|         if table:
 | |
|             ks_query = None
 | |
|         else:
 | |
|             ks_query = QueryMessage(query=self._SELECT_KEYSPACES + where_clause, consistency_level=cl)
 | |
|         cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl)
 | |
|         col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl)
 | |
| 
 | |
|         if ks_query:
 | |
|             ks_result, cf_result, col_result = connection.wait_for_responses(ks_query, cf_query, col_query)
 | |
|         else:
 | |
|             ks_result = None
 | |
|             cf_result, col_result = connection.wait_for_responses(cf_query, col_query)
 | |
| 
 | |
|         self._cluster.metadata.rebuild_schema(keyspace, table, ks_result, cf_result, col_result)
 | |
| 
 | |
|     def refresh_node_list_and_token_map(self):
 | |
|         conn = self._connection
 | |
|         if not conn:
 | |
|             return
 | |
| 
 | |
|         cl = ConsistencyLevel.ONE
 | |
|         peers_query = QueryMessage(query=self._SELECT_PEERS, consistency_level=cl)
 | |
|         local_query = QueryMessage(query=self._SELECT_LOCAL, consistency_level=cl)
 | |
|         try:
 | |
|             peers_result, local_result = conn.wait_for_responses(peers_query, local_query)
 | |
|         except (ConnectionException, BusyConnectionException):
 | |
|             self.reconnect()
 | |
| 
 | |
|         partitioner = None
 | |
|         token_map = {}
 | |
| 
 | |
|         if local_result and local_result.rows:  # TODO: probably just check local_result.rows
 | |
|             local_row = local_result.as_dicts()[0]
 | |
|             cluster_name = local_row["cluster_name"]
 | |
|             self._cluster.metadata.cluster_name = cluster_name
 | |
| 
 | |
|             host = self._cluster.metadata.get_host(conn.host)
 | |
|             if host:
 | |
|                 host.set_location_info(local_row["data_center"], local_row["rack"])
 | |
| 
 | |
|             partitioner = local_row.get("partitioner")
 | |
|             tokens = local_row.get("tokens")
 | |
|             if partitioner and tokens:
 | |
|                 token_map[host] = tokens
 | |
| 
 | |
|         found_hosts = set()
 | |
| 
 | |
|         for row in peers_result.as_dicts():
 | |
|             addr = row.get("rpc_address")
 | |
|             if not addr:
 | |
|                 addr = row.get("peer")
 | |
|             elif addr == "0.0.0.0":  # TODO handle ipv6 equivalent
 | |
|                 addr = row.get("peer")
 | |
| 
 | |
|             found_hosts.add(addr)
 | |
| 
 | |
|             host = self._cluster.metadata.getHost(addr)
 | |
|             if host is None:
 | |
|                 host = self._cluster.addHost(addr, True)
 | |
|             host.set_location_info(row.get("data_center"), row.get("rack"))
 | |
| 
 | |
|             tokens = row.get("tokens")
 | |
|             if partitioner and tokens:
 | |
|                 token_map[host] = tokens
 | |
| 
 | |
|         for old_host in self._cluster.metadata.all_hosts():
 | |
|             if old_host.address != conn.address and \
 | |
|                     old_host.address not in found_hosts:
 | |
|                 self._cluster.remove_host(old_host)
 | |
| 
 | |
|         if partitioner:
 | |
|             self._cluster.metadata.rebuild_token_map(partitioner, token_map)
 | |
| 
 | |
|     def _handle_topology_change(self, event):
 | |
|         # TODO schedule on executor
 | |
|         change_type = event["change_type"]
 | |
|         addr, port = event["address"]
 | |
|         if change_type == "NEW_NODE":
 | |
|             # TODO check Host constructor
 | |
|             self._cluster.scheduler.schedule(1, self.add_host, addr, signal=True)
 | |
|         elif change_type == "REMOVED_NODE":
 | |
|             host = self._cluster.metadata.get_host(addr)
 | |
|             self._cluster.scheduler.schedule(1, self.remove_host, host)
 | |
|         elif change_type == "MOVED_NODE":
 | |
|             self._cluster.scheduler.schedule(1, self.refresh_node_list_and_token_map)
 | |
| 
 | |
|     def _handle_status_change(self, event):
 | |
|         change_type = event["change_type"]
 | |
|         addr, port = event["address"]
 | |
|         if change_type == "UP":
 | |
|             host = self._cluster.metadata.get_host(addr)
 | |
|             if not host:
 | |
|                 self._cluster.scheduler.schedule(1, self.add_host, addr, signal=True)
 | |
|             else:
 | |
|                 self._cluster.scheduler.schedule(1, self.on_up, host)
 | |
|         elif change_type == "DOWN":
 | |
|             # Ignore down event. Connection will realize a node is dead quicly
 | |
|             # enough when it writes to it, and there is no point in taking the
 | |
|             # risk of marking the node down mistakenly because we didn't
 | |
|             # receive the event in a timely fashion
 | |
|             pass
 | |
| 
 | |
|     def _handle_schema_change(self, event):
 | |
|         change_type, ks, cf = event
 | |
|         if change_type in ("CREATED", "DROPPED"):
 | |
|             if not cf:
 | |
|                 self._cluster.executor.submit(self.refresh_schema)
 | |
|             else:
 | |
|                 self._cluster.executor.submit(self.refresh_schema, ks)
 | |
|         elif change_type == "UPDATED":
 | |
|             if not cf:
 | |
|                 self._cluster.executor.submit(self.refresh_schema, ks)
 | |
|             else:
 | |
|                 self._cluster.executor.submit(self.refresh_schema, ks, cf)
 | |
| 
 | |
|     def wait_for_schema_agreement(self):
 | |
|         # TODO is returning True/False the best option for this? Potentially raise Exception?
 | |
|         start = time.time()
 | |
|         elapsed = 0
 | |
|         cl = ConsistencyLevel.ONE
 | |
|         while elapsed < MAX_SCHEMA_AGREEMENT_WAIT_MS:
 | |
|             peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl)
 | |
|             local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl)
 | |
|             peers_result, local_result = self._connection.wait_for_responses(peers_query, local_query)
 | |
| 
 | |
|             versions = set()
 | |
|             if local_result and local_result.rows:
 | |
|                 local_row = local_result.as_dicts()[0]
 | |
|                 if local_row.get("schema_version"):
 | |
|                     versions.add(local_row.get("schema_version"))
 | |
| 
 | |
|             for row in peers_result.as_dicts():
 | |
|                 if not row.get("rpc_address") or not row.get("schema_version"):
 | |
|                     continue
 | |
| 
 | |
|                 rpc = row.get("rpc_address")
 | |
|                 if rpc == "0.0.0.0":  # TODO ipv6 check
 | |
|                     rpc = row.get("peer")
 | |
| 
 | |
|                 peer = self._cluster.metadata.get_host(rpc)
 | |
|                 if peer and peer.monitor.is_up:
 | |
|                     versions.add(row.get("schema_version"))
 | |
| 
 | |
|             if len(versions) == 1:
 | |
|                 return True
 | |
| 
 | |
|             time.sleep(0.2)
 | |
|             elapsed = time.time() - start
 | |
| 
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def is_open(self):
 | |
|         conn = self._connection
 | |
|         return bool(conn and conn.is_open)
 | |
| 
 | |
|     def on_up(self, host):
 | |
|         self._balancing_policy.on_up(host)
 | |
| 
 | |
|     def on_down(self, host):
 | |
|         self._balancing_policy.on_down(host)
 | |
| 
 | |
|         conn = self._connection
 | |
|         if conn and conn.address == host.address:  # TODO and reconnection_attempt is None
 | |
|             self.reconnect()
 | |
| 
 | |
|     def on_add(self, host):
 | |
|         self._balancing_policy.on_add(host)
 | |
|         self.refresh_node_list_and_token_map()
 | |
| 
 | |
|     def on_remove(self, host):
 | |
|         self._balancing_policy.on_remove(host)
 | |
|         self.refresh_node_list_and_token_map()
 | 
