fix race adding connection pool to session while handling keyspace change

PYTHON-628
This commit is contained in:
Adam Holmberg
2016-08-31 09:50:45 -05:00
parent 0ba2dc9353
commit 8bbe7ce3e3
2 changed files with 37 additions and 11 deletions

View File

@@ -2320,7 +2320,27 @@ class Session(object):
return False return False
previous = self._pools.get(host) previous = self._pools.get(host)
self._pools[host] = new_pool with self._lock:
while new_pool._keyspace != self.keyspace:
self._lock.release()
set_keyspace_event = Event()
errors_returned = []
def callback(pool, errors):
errors_returned.extend(errors)
set_keyspace_event.set()
new_pool._set_keyspace_for_all_conns(self.keyspace, callback)
set_keyspace_event.wait(self.cluster.connect_timeout)
if not set_keyspace_event.is_set() or errors_returned:
log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned)
self.cluster.on_down(host, is_host_addition)
new_pool.shutdown()
self._lock.acquire()
return False
self._lock.acquire()
self._pools[host] = new_pool
log.debug("Added pool for host %s to session", host) log.debug("Added pool for host %s to session", host)
if previous: if previous:
previous.shutdown() previous.shutdown()
@@ -2397,9 +2417,9 @@ class Session(object):
called with a dictionary of all errors that occurred, keyed called with a dictionary of all errors that occurred, keyed
by the `Host` that they occurred against. by the `Host` that they occurred against.
""" """
self.keyspace = keyspace with self._lock:
self.keyspace = keyspace
remaining_callbacks = set(self._pools.values()) remaining_callbacks = set(self._pools.values())
errors = {} errors = {}
if not remaining_callbacks: if not remaining_callbacks:

View File

@@ -307,6 +307,7 @@ class HostConnection(object):
_session = None _session = None
_connection = None _connection = None
_lock = None _lock = None
_keyspace = None
def __init__(self, host, host_distance, session): def __init__(self, host, host_distance, session):
self.host = host self.host = host
@@ -326,8 +327,9 @@ class HostConnection(object):
log.debug("Initializing connection for host %s", self.host) log.debug("Initializing connection for host %s", self.host)
self._connection = session.cluster.connection_factory(host.address) self._connection = session.cluster.connection_factory(host.address)
if session.keyspace: self._keyspace = session.keyspace
self._connection.set_keyspace_blocking(session.keyspace) if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
log.debug("Finished initializing connection for host %s", self.host) log.debug("Finished initializing connection for host %s", self.host)
def borrow_connection(self, timeout): def borrow_connection(self, timeout):
@@ -381,8 +383,8 @@ class HostConnection(object):
log.debug("Replacing connection (%s) to %s", id(connection), self.host) log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try: try:
conn = self._session.cluster.connection_factory(self.host.address) conn = self._session.cluster.connection_factory(self.host.address)
if self._session.keyspace: if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace) conn.set_keyspace_blocking(self._keyspace)
self._connection = conn self._connection = conn
except Exception: except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.address,)) log.warning("Failed reconnecting %s. Retrying." % (self.host.address,))
@@ -412,6 +414,7 @@ class HostConnection(object):
errors = [] if not error else [error] errors = [] if not error else [error]
callback(self, errors) callback(self, errors)
self._keyspace = keyspace
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
def get_connections(self): def get_connections(self):
@@ -445,6 +448,7 @@ class HostConnectionPool(object):
open_count = 0 open_count = 0
_scheduled_for_creation = 0 _scheduled_for_creation = 0
_next_trash_allowed_at = 0 _next_trash_allowed_at = 0
_keyspace = None
def __init__(self, host, host_distance, session): def __init__(self, host, host_distance, session):
self.host = host self.host = host
@@ -459,9 +463,10 @@ class HostConnectionPool(object):
self._connections = [session.cluster.connection_factory(host.address) self._connections = [session.cluster.connection_factory(host.address)
for i in range(core_conns)] for i in range(core_conns)]
if session.keyspace: self._keyspace = session.keyspace
if self._keyspace:
for conn in self._connections: for conn in self._connections:
conn.set_keyspace_blocking(session.keyspace) conn.set_keyspace_blocking(self._keyspace)
self._trash = set() self._trash = set()
self._next_trash_allowed_at = time.time() self._next_trash_allowed_at = time.time()
@@ -560,7 +565,7 @@ class HostConnectionPool(object):
log.debug("Going to open new connection to host %s", self.host) log.debug("Going to open new connection to host %s", self.host)
try: try:
conn = self._session.cluster.connection_factory(self.host.address) conn = self._session.cluster.connection_factory(self.host.address)
if self._session.keyspace: if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace) conn.set_keyspace_blocking(self._session.keyspace)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
with self._lock: with self._lock:
@@ -761,6 +766,7 @@ class HostConnectionPool(object):
if not remaining_callbacks: if not remaining_callbacks:
callback(self, errors) callback(self, errors)
self._keyspace = keyspace
for conn in self._connections: for conn in self._connections:
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)