Avoid deadlock when nodes go up and down

This commit is contained in:
Tyler Hobbs
2013-10-17 12:31:09 -05:00
parent 4834683a43
commit 8b78f37d80
10 changed files with 334 additions and 238 deletions

View File

@@ -333,8 +333,6 @@ class Cluster(object):
self.metrics = Metrics(weakref.proxy(self))
self.control_connection = ControlConnection(self)
for address in contact_points:
self.add_host(address, signal=True)
def get_min_requests_per_connection(self, host_distance):
return self._min_requests_per_connection[host_distance]
@@ -398,6 +396,13 @@ class Cluster(object):
raise Exception("Cluster is already shut down")
if not self._is_setup:
for address in self.contact_points:
host = self.add_host(address, signal=False)
if host:
host.set_up()
for listener in self.listeners:
listener.on_add(host)
self.load_balancing_policy.populate(
weakref.proxy(self), self.metadata.all_hosts())
self._is_setup = True
@@ -459,30 +464,105 @@ class Cluster(object):
self.sessions.add(session)
return session
def _on_up_future_completed(self, host, futures, results, lock, finished_future):
with lock:
futures.discard(finished_future)
try:
results.append(finished_future.result())
except Exception as exc:
results.append(exc)
if futures:
return
try:
# all futures have completed at this point
for exc in [f for f in results if isinstance(f, Exception)]:
log.error("Unexpected failure while marking node %s up:", host, exc_info=exc)
return
if not all(results):
log.debug("Connection pool could not be created, not marking node %s up:", host)
return
# mark the host as up and notify all listeners
host.set_up()
for listener in self.listeners:
listener.on_up(host)
finally:
host.lock.release()
# see if there are any pools to add or remove now that the host is marked up
for session in self.sessions:
session.update_created_pools()
def on_up(self, host):
"""
Called when a host is marked up by its :class:`~.HealthMonitor`.
Intended for internal use only.
"""
reconnector = host.get_and_set_reconnection_handler(None)
if reconnector:
reconnector.cancel()
if self._is_shutdown:
return
self._prepare_all_queries(host)
host.lock.acquire()
try:
if host.is_up:
host.lock.release()
return
self.control_connection.on_up(host)
for session in self.sessions:
session.on_up(host)
log.debug("Host %s has been marked up", host)
def on_down(self, host):
reconnector = host.get_and_set_reconnection_handler(None)
if reconnector:
log.debug("Now that host %s is up, cancelling the reconnection handler", host)
reconnector.cancel()
self._prepare_all_queries(host)
for session in self.sessions:
session.remove_pool(host)
self.load_balancing_policy.on_up(host)
self.control_connection.on_up(host)
futures_lock = Lock()
futures_results = []
futures = set()
callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock)
for session in self.sessions:
future = session.add_or_renew_pool(host, is_host_addition=False)
future.add_done_callback(callback)
futures.add(future)
except Exception:
host.lock.release()
# for testing purposes
return futures
@run_in_executor
def on_down(self, host, is_host_addition):
"""
Called when a host is marked down by its :class:`~.HealthMonitor`.
Intended for internal use only.
"""
if self._is_shutdown:
return
with host.lock:
if (not host.is_up) or host.is_currently_reconnecting():
return
host.set_down()
log.debug("Host %s has been marked down", host)
self.load_balancing_policy.on_down(host)
self.control_connection.on_down(host)
for session in self.sessions:
session.on_down(host)
for listener in self.listeners:
listener.on_down(host)
schedule = self.reconnection_policy.new_schedule()
# in order to not hold references to this Cluster open and prevent
@@ -491,28 +571,94 @@ class Cluster(object):
conn_factory = self._make_connection_factory(host)
reconnector = _HostReconnectionHandler(
host, conn_factory, self.scheduler, schedule,
host.get_and_set_reconnection_handler, new_handler=None)
host, conn_factory, is_host_addition, self.on_add, self.on_up,
self.scheduler, schedule, host.get_and_set_reconnection_handler,
new_handler=None)
old_reconnector = host.get_and_set_reconnection_handler(reconnector)
if old_reconnector:
log.debug("Old host reconnector found for %s, cancelling", host)
old_reconnector.cancel()
log.debug("Staring reconnector for host %s", host)
reconnector.start()
def on_add(self, host):
if self._is_shutdown:
return
log.debug("Adding new host %s", host)
self._prepare_all_queries(host)
self.load_balancing_policy.on_add(host)
self.control_connection.on_add(host)
futures_lock = Lock()
futures_results = []
futures = set()
def future_completed(future):
with futures_lock:
futures.discard(future)
try:
futures_results.append(future.result())
except Exception as exc:
futures_results.append(exc)
if futures:
return
# all futures have completed at this point
for exc in [f for f in futures_results if isinstance(f, Exception)]:
log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc)
return
if not all(futures_results):
log.debug("Connection pool could not be created, not marking node %s up:", host)
return
# mark the host as up and notify all listeners
host.set_up()
for listener in self.listeners:
listener.on_add(host)
# see if there are any pools to add or remove now that the host is marked up
for session in self.sessions:
session.update_created_pools()
for session in self.sessions:
future = session.add_or_renew_host(host, is_host_addition=True)
future.add_done_callback(future_completed)
def on_remove(self, host):
if self._is_shutdown:
return
log.debug("Removing host %s", host)
host.set_down()
self.load_balancing_policy.on_remove(host)
for session in self.sessions:
session.on_remove()
for listener in self.listeners:
listener.on_remove()
def signal_connection_failure(self, host, connection_exc, is_host_addition):
is_down = host.signal_connection_failure(connection_exc)
if is_down:
self.on_down(host, is_host_addition)
return is_down
def add_host(self, address, signal):
"""
Called when adding initial contact points and when the control
connection subsequently discovers a new node. Intended for internal
use only.
"""
log.info("Now considering host %s for new connections", address)
new_host = self.metadata.add_host(address)
if new_host and signal:
self._prepare_all_queries(new_host)
self.control_connection.on_add(new_host)
for session in self.sessions: # TODO need to copy/lock?
session.on_add(new_host)
log.info("New Cassandra host %s added", address)
self.on_add(new_host)
return new_host
@@ -521,11 +667,9 @@ class Cluster(object):
Called when the control connection observes that a node has left the
ring. Intended for internal use only.
"""
log.info("Host %s will no longer be considered for new connections", host)
if host and self.metadata.remove_host(host):
self.control_connection.on_remove(host)
for session in self.sessions:
session.on_remove(host)
log.info("Cassandra host %s removed", host)
self.on_remove(host)
def register_listener(self, listener):
"""
@@ -574,7 +718,8 @@ class Cluster(object):
try:
self.control_connection.wait_for_schema_agreement(connection)
except Exception:
pass
log.debug("Error waiting for schema agreement before preparing statements against host %s", host, exc_info=True)
# TODO: potentially error out the connection?
statements = self._prepared_statements.values()
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
@@ -587,7 +732,8 @@ class Cluster(object):
for statement in ks_statements:
message = PrepareMessage(query=statement.query_string)
try:
response = connection.wait_for_response(message)
# TODO: make this timeout configurable somehow?
response = connection.wait_for_response(message, timeout=1.0)
if (not isinstance(response, ResultMessage) or
response.kind != ResultMessage.KIND_PREPARED):
log.debug("Got unexpected response when preparing "
@@ -596,6 +742,8 @@ class Cluster(object):
log.exception("Error trying to prepare statement on "
"host %s", host)
connection.close()
log.debug("Done preparing all known prepared statements against host %s", host)
except Exception:
# log and ignore
log.exception("Error trying to prepare all statements on host %s", host)
@@ -657,7 +805,8 @@ class Session(object):
self._metrics = cluster.metrics
for host in hosts:
self.add_host(host)
future = self.add_or_renew_pool(host, is_host_addition=False)
future.result()
def execute(self, query, parameters=None, trace=False):
"""
@@ -838,71 +987,81 @@ class Session(object):
except TypeError:
pass
def add_host(self, host):
""" Internal """
def add_or_renew_pool(self, host, is_host_addition):
"""
For internal use only.
"""
distance = self._load_balancer.distance(host)
if distance == HostDistance.IGNORED:
return self._pools.get(host)
else:
return None
def run_add_or_renew_pool():
try:
new_pool = HostConnectionPool(host, distance, self)
except AuthenticationFailed as auth_exc:
conn_exc = ConnectionException(str(auth_exc), host=host)
host.monitor.signal_connection_failure(conn_exc)
return self._pools.get(host)
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
return False
except Exception as conn_exc:
host.monitor.signal_connection_failure(conn_exc)
return self._pools.get(host)
log.debug("Signaling connection failure during Session.add_host: %s", conn_exc)
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
return False
previous = self._pools.get(host)
self._pools[host] = new_pool
return previous
log.debug("Added pool for host %s to session", host)
if previous:
previous.shutdown()
def on_up(self, host):
"""
Called by the parent Cluster instance when a host's :class:`HealthMonitor`
marks it up. Only intended for internal use.
"""
previous_pool = self.add_host(host)
self._load_balancer.on_up(host)
if previous_pool:
previous_pool.shutdown()
return True
def on_down(self, host):
"""
Called by the parent Cluster instance when a host's :class:`HealthMonitor`
marks it down. Only intended for internal use.
"""
self._load_balancer.on_down(host)
return self.submit(run_add_or_renew_pool)
def remove_pool(self, host):
pool = self._pools.pop(host, None)
if pool:
pool.shutdown()
return self.submit(pool.shutdown)
else:
return None
def update_created_pools(self):
"""
When the set of live nodes change, the loadbalancer will change its
mind on host distances. It might change it on the node that came/left
but also on other nodes (for instance, if a node dies, another
previously ignored node may be now considered).
This method ensures that all hosts for which a pool should exist
have one, and hosts that shouldn't don't.
For internal use only.
"""
for host in self.cluster.metadata.all_hosts():
if not host.monitor.is_up:
continue
distance = self._load_balancer.distance(host)
if distance != HostDistance.IGNORED:
pool = self._pools.get(host)
if not pool:
self.add_host(host)
pool = self._pools.get(host)
if not pool:
if distance != HostDistance.IGNORED and host.is_up:
self.add_or_renew_pool(host, False)
elif distance != pool.host_distance:
# the distance has changed
if distance == HostDistance.IGNORED:
self.remove_pool(host)
else:
pool.host_distance = distance
def on_add(self, host):
""" Internal """
previous_pool = self.add_host(host)
self._load_balancer.on_add(host)
if previous_pool:
previous_pool.shutdown()
def on_down(self, host):
"""
Called by the parent Cluster instance when a node is marked down.
Only intended for internal use.
"""
future = self.remove_pool(host)
if future:
future.add_done_callback(lambda f: self.update_created_pools())
def on_remove(self, host):
""" Internal """
self._load_balancer.on_remove(host)
pool = self._pools.pop(host)
if pool:
pool.shutdown()
self.on_down(host)
def set_keyspace(self, keyspace):
"""
@@ -1031,8 +1190,8 @@ class ControlConnection(object):
return self._try_connect(host)
except ConnectionException as exc:
errors[host.address] = exc
host.monitor.signal_connection_failure(exc)
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
except Exception as exc:
errors[host.address] = exc
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
@@ -1242,14 +1401,16 @@ class ControlConnection(object):
# this is the first time we've seen the node
self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True)
else:
self._cluster.scheduler.schedule(1, host.monitor.set_up)
# this will be run by the scheduler
self._cluster.scheduler.schedule(1, self._cluster.on_up, host)
elif change_type == "DOWN":
# Note that there is a slight risk we can receive the event late and thus
# mark the host down even though we already had reconnected successfully.
# But it is unlikely, and don't have too much consequence since we'll try reconnecting
# right away, so we favor the detection to make the Host.is_up more accurate.
if host is not None:
self._cluster.scheduler.schedule(1, host.monitor.set_down)
# this will be run by the scheduler
self._cluster.on_down(host, is_host_addition=False)
def _handle_schema_change(self, event):
keyspace = event['keyspace'] or None
@@ -1294,10 +1455,11 @@ class ControlConnection(object):
rpc = row.get("peer")
peer = self._cluster.metadata.get_host(rpc)
if peer and peer.monitor.is_up:
if peer and peer.is_up:
versions.add(row.get("schema_version"))
if len(versions) == 1:
log.debug("[control connection] Schemas match")
return True
log.debug("[control connection] Schemas mismatched, trying again")
@@ -1307,14 +1469,15 @@ class ControlConnection(object):
return False
def _signal_error(self):
# try just signaling the host monitor, as this will trigger a reconnect
# try just signaling the cluster, as this will trigger a reconnect
# as part of marking the host down
if self._connection and self._connection.is_defunct:
host = self._cluster.metadata.get_host(self._connection.host)
# host may be None if it's already been removed, but that indicates
# that errors have already been reported, so we're fine
if host:
host.monitor.signal_connection_failure(self._connection.last_error)
self._cluster.signal_connection_failure(
host, self._connection.last_error, is_host_addition=False)
return
# if the connection is not defunct or the host already left, reconnect
@@ -1327,26 +1490,22 @@ class ControlConnection(object):
return bool(conn and conn.is_open)
def on_up(self, host):
log.debug("[control connection] Host %s is considered up", host)
self._balancing_policy.on_up(host)
pass
def on_down(self, host):
log.debug("[control connection] Host %s is considered down", host)
self._balancing_policy.on_down(host)
conn = self._connection
if conn and conn.host == host.address and \
self._reconnection_handler is None:
log.debug("[control connection] Control connection host (%s) is "
"considered down, starting reconnection", host)
# this will result in a task being submitted to the executor to reconnect
self.reconnect()
def on_add(self, host):
log.debug("[control connection] Adding host %r and refreshing topology", host)
self._balancing_policy.on_add(host)
self.refresh_node_list_and_token_map()
def on_remove(self, host):
log.debug("[control connection] Removing host %r and refreshing topology", host)
self._balancing_policy.on_remove(host)
self.refresh_node_list_and_token_map()
@@ -1659,13 +1818,13 @@ class ResponseFuture(object):
else:
self._set_final_exception(ConnectionException(
"Got unexpected response when preparing statement "
"on host %s: %s" % (self._host, response)))
"on host %s: %s" % (self._current_host, response)))
elif isinstance(response, ErrorMessage):
self._set_final_exception(response)
else:
self._set_final_exception(ConnectionException(
"Got unexpected response type when preparing "
"statement on host %s: %s" % (self._host, response)))
"statement on host %s: %s" % (self._current_host, response)))
def _set_final_result(self, response):
if self._metrics is not None:

View File

@@ -306,7 +306,6 @@ class Metadata(object):
else:
return None
new_host.monitor.register(cluster)
return new_host
def remove_host(self, host):

View File

@@ -42,7 +42,29 @@ class HostDistance(object):
"""
class LoadBalancingPolicy(object):
class HostStateListener(object):
def on_up(self, host):
""" Called when a node is marked up. """
raise NotImplementedError()
def on_down(self, host):
""" Called when a node is marked down. """
raise NotImplementedError()
def on_add(self, host):
"""
Called when a node is added to the cluster. The newly added node
should be considered up.
"""
raise NotImplementedError()
def on_remove(self, host):
""" Called when a node is removed from the cluster. """
raise NotImplementedError()
class LoadBalancingPolicy(HostStateListener):
"""
Load balancing policies are used to decide how to distribute
requests among all possible coordinator nodes in the cluster.
@@ -87,36 +109,6 @@ class LoadBalancingPolicy(object):
"""
raise NotImplementedError()
def on_up(self, host):
"""
Called when a :class:`~.pool.Host`'s :class:`~.HealthMonitor`
marks the node up.
"""
raise NotImplementedError()
def on_down(self, host):
"""
Called when a :class:`~.pool.Host`'s :class:`~.HealthMonitor`
marks the node down.
"""
raise NotImplementedError()
def on_add(self, host):
"""
Called when a :class:`.Cluster` instance is first created and
the initial contact points are added as well as when a new
:class:`~.pool.Host` is discovered in the cluster, which may
happen the first time the ring topology is examined or when
a new node joins the cluster.
"""
raise NotImplementedError()
def on_remove(self, host):
"""
Called when a :class:`~.pool.Host` leaves the cluster.
"""
raise NotImplementedError()
class RoundRobinPolicy(LoadBalancingPolicy):
"""
@@ -300,7 +292,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
else:
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
for replica in replicas:
if replica.monitor.is_up and \
if replica.is_up and \
child.distance(replica) == HostDistance.LOCAL:
yield replica

View File

@@ -4,7 +4,7 @@ Connection pooling and host management.
import logging
import time
from threading import Lock, RLock, Condition
from threading import RLock, Condition
import weakref
try:
from weakref import WeakSet
@@ -35,15 +35,18 @@ class Host(object):
The IP address or hostname of the node.
"""
monitor = None
conviction_policy = None
"""
A :class:`.HealthMonitor` instance that tracks whether this node is
up or down.
A class:`ConvictionPolicy` instance for determining when this node should
be marked up or down.
"""
is_up = None
_datacenter = None
_rack = None
_reconnection_handler = None
lock = None
def __init__(self, inet_address, conviction_policy_factory):
if inet_address is None:
@@ -52,9 +55,8 @@ class Host(object):
raise ValueError("conviction_policy_factory may not be None")
self.address = inet_address
self.monitor = HealthMonitor(conviction_policy_factory(self))
self._reconnection_lock = Lock()
self.conviction_policy = conviction_policy_factory(self)
self.lock = RLock()
@property
def datacenter(self):
@@ -75,12 +77,25 @@ class Host(object):
self._datacenter = datacenter
self._rack = rack
def set_up(self):
self.conviction_policy.reset()
self.is_up = True
def set_down(self):
self.is_up = False
def signal_connection_failure(self, connection_exc):
return self.conviction_policy.add_failure(connection_exc)
def is_currently_reconnecting(self):
return self._reconnection_handler is not None
def get_and_set_reconnection_handler(self, new_handler):
"""
Atomically replaces the reconnection handler for this
host. Intended for internal use only.
"""
with self._reconnection_lock:
with self.lock:
old = self._reconnection_handler
self._reconnection_handler = new_handler
return old
@@ -175,8 +190,11 @@ class _ReconnectionHandler(object):
class _HostReconnectionHandler(_ReconnectionHandler):
def __init__(self, host, connection_factory, *args, **kwargs):
def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs):
_ReconnectionHandler.__init__(self, *args, **kwargs)
self.is_host_addition = is_host_addition
self.on_add = on_add
self.on_up = on_up
self.host = host
self.connection_factory = connection_factory
@@ -184,85 +202,23 @@ class _HostReconnectionHandler(_ReconnectionHandler):
return self.connection_factory()
def on_reconnection(self, connection):
self.host.monitor.reset()
connection.close()
log.info("Successful reconnection to %s, marking node up", self.host)
if self.is_host_addition:
self.on_add(self.host)
else:
self.on_up(self.host)
def on_exception(self, exc, next_delay):
if isinstance(exc, AuthenticationFailed):
return False
else:
log.warn("Error attempting to reconnect to %s: %s", self.host, exc)
log.warn("Error attempting to reconnect to %s, scheduling retry in %f seconds: %s",
self.host, next_delay, exc)
log.debug("Reconnection error details", exc_info=True)
return True
class HealthMonitor(object):
"""
Monitors whether a particular host is marked as up or down.
This class is primarily intended for internal use, although
applications may find it useful to check whether a given node
is up or down.
"""
is_up = True
"""
A boolean representing the current state of the node.
"""
def __init__(self, conviction_policy):
self._conviction_policy = conviction_policy
self._host = conviction_policy.host
# self._listeners will hold, among other things, references to
# Cluster objects. To allow those to be GC'ed (and shutdown) even
# though we've implemented __del__, use weak references.
self._listeners = WeakSet()
self._lock = RLock()
def register(self, listener):
with self._lock:
self._listeners.add(listener)
def unregister(self, listener):
with self._lock:
self._listeners.remove(listener)
def set_up(self):
if self.is_up:
return
self._conviction_policy.reset()
log.info("Host %s is considered up", self._host)
with self._lock:
listeners = self._listeners.copy()
for listener in listeners:
listener.on_up(self._host)
self.is_up = True
def set_down(self):
if not self.is_up:
return
self.is_up = False
log.info("Host %s is considered down", self._host)
with self._lock:
listeners = self._listeners.copy()
for listener in listeners:
listener.on_down(self._host)
def reset(self):
return self.set_up()
def signal_connection_failure(self, connection_exc):
is_down = self._conviction_policy.add_failure(connection_exc)
if is_down:
self.set_down()
return is_down
_MAX_SIMULTANEOUS_CREATION = 1
_NEW_CONNECTION_GRACE_PERIOD = 5
@@ -295,6 +251,7 @@ class HostConnectionPool(object):
self._trash = set()
self.open_count = core_conns
log.debug("Finished initializing new connection pool for host %s", self.host)
def borrow_connection(self, timeout):
if self.is_shutdown:
@@ -395,7 +352,7 @@ class HostConnectionPool(object):
log.exception("Failed to add new connection to pool for host %s", self.host)
with self._lock:
self.open_count -= 1
if self.host.monitor.signal_connection_failure(exc):
if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False):
self.shutdown()
return False
except AuthenticationFailed:
@@ -448,7 +405,8 @@ class HostConnectionPool(object):
if connection.is_defunct or connection.is_closed:
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
"marking host %s as down", id(connection), self.host)
is_down = self.host.monitor.signal_connection_failure(connection.last_error)
is_down = self._session.cluster.signal_connection_failure(
self.host, connection.last_error, is_host_addition=False)
if is_down:
self.shutdown()
else:

View File

@@ -74,7 +74,7 @@ class ClusterTests(unittest.TestCase):
Ensure errors are not thrown when using non-default policies
"""
cluster = Cluster(
Cluster(
load_balancing_policy=RoundRobinPolicy(),
reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0),
default_retry_policy=RetryPolicy(),
@@ -177,28 +177,6 @@ class ClusterTests(unittest.TestCase):
self.assertIn("newkeyspace", cluster.metadata.keyspaces)
def test_on_down_and_up(self):
"""
Test on_down and on_up handling
"""
cluster = Cluster()
session = cluster.connect()
host = cluster.metadata.all_hosts()[0]
host.monitor.signal_connection_failure(None)
cluster.on_down(host)
self.assertNotIn(host, session._pools)
host_reconnector = host._reconnection_handler
self.assertNotEqual(None, host_reconnector)
host.monitor.is_up = True
cluster.on_up(host)
self.assertEqual(None, host._reconnection_handler)
self.assertTrue(host_reconnector._cancelled)
self.assertIn(host, session._pools)
def test_trace(self):
"""
Ensure trace can be requested for async and non-async queries

View File

@@ -11,7 +11,6 @@ from mock import patch, Mock
from cassandra.connection import (PROTOCOL_VERSION,
HEADER_DIRECTION_TO_CLIENT,
ProtocolError,
ConnectionException)
from cassandra.decoder import (write_stringmultimap, write_int, write_string,
@@ -87,7 +86,7 @@ class LibevConnectionTest(unittest.TestCase):
c.handle_write(None, None)
# read in a SupportedMessage response
header = self.make_header_prefix(SupportedMessage, version=0x04)
header = self.make_header_prefix(SupportedMessage, version=0xa4)
options = self.make_options_body()
c._socket.recv.return_value = self.make_msg(header, options)
c.handle_read(None, None)
@@ -95,7 +94,7 @@ class LibevConnectionTest(unittest.TestCase):
# make sure it errored correctly
self.assertTrue(c.is_defunct)
self.assertTrue(c.connected_event.is_set())
self.assertIsInstance(c.last_error, ProtocolError)
self.assertIsInstance(c.last_error, ConnectionException)
def test_error_message_on_startup(self, *args):
c = self.make_connection()

View File

@@ -23,6 +23,8 @@ class MockMetadata(object):
"192.168.1.1": Host("192.168.1.1", SimpleConvictionPolicy),
"192.168.1.2": Host("192.168.1.2", SimpleConvictionPolicy)
}
for host in self.hosts.values():
host.set_up()
self.cluster_name = None
self.partitioner = None
@@ -44,6 +46,7 @@ class MockCluster(object):
max_schema_agreement_wait = Cluster.max_schema_agreement_wait
load_balancing_policy = RoundRobinPolicy()
reconnection_policy = ConstantReconnectionPolicy(2)
down_host = None
def __init__(self):
self.metadata = MockMetadata()
@@ -60,6 +63,12 @@ class MockCluster(object):
def remove_host(self, host):
self.removed_hosts.append(host)
def on_up(self, host):
pass
def on_down(self, host, is_host_addition):
self.down_host = host
class MockConnection(object):
@@ -142,7 +151,7 @@ class ControlConnectionTest(unittest.TestCase):
# change the schema version on one of the existing entries
self.connection.peer_results[1][1][3] = 'c'
self.cluster.metadata.get_host('192.168.1.1').monitor.is_up = False
self.cluster.metadata.get_host('192.168.1.1').is_up = False
self.assertTrue(self.control_connection.wait_for_schema_agreement())
self.assertEqual(self.time.clock, 0)
@@ -156,7 +165,7 @@ class ControlConnectionTest(unittest.TestCase):
)
host = Host("0.0.0.0", SimpleConvictionPolicy)
self.cluster.metadata.hosts[PEER_IP] = host
host.monitor.is_up = False
host.is_up = False
# even though the new host has a different schema version, it's
# marked as down, so the control connection shouldn't care
@@ -164,7 +173,7 @@ class ControlConnectionTest(unittest.TestCase):
self.assertEqual(self.time.clock, 0)
# but once we mark it up, the control connection will care
host.monitor.is_up = True
host.is_up = True
self.assertFalse(self.control_connection.wait_for_schema_agreement())
self.assertGreaterEqual(self.time.clock, Cluster.max_schema_agreement_wait)
@@ -248,7 +257,7 @@ class ControlConnectionTest(unittest.TestCase):
}
self.control_connection._handle_status_change(event)
host = self.cluster.metadata.hosts['192.168.1.0']
self.cluster.scheduler.schedule.assert_called_with(ANY, host.monitor.set_up)
self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.on_up, host)
self.cluster.scheduler.schedule.reset_mock()
event = {
@@ -265,7 +274,7 @@ class ControlConnectionTest(unittest.TestCase):
}
self.control_connection._handle_status_change(event)
host = self.cluster.metadata.hosts['192.168.1.0']
self.cluster.scheduler.schedule.assert_called_with(ANY, host.monitor.set_down)
self.assertIs(host, self.cluster.down_host)
def test_handle_schema_change(self):

View File

@@ -8,7 +8,7 @@ from threading import Thread, Event
from cassandra.cluster import Session
from cassandra.connection import Connection, MAX_STREAM_PER_CONNECTION
from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable, HealthMonitor
from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable
from cassandra.policies import HostDistance, SimpleConvictionPolicy
@@ -158,7 +158,7 @@ class HostConnectionPoolTests(unittest.TestCase):
pool.borrow_connection(timeout=0.01)
conn.is_defunct = True
host.monitor.signal_connection_failure.return_value = False
session.cluster.signal_connection_failure.return_value = False
pool.return_connection(conn)
# the connection should be closed a new creation scheduled
@@ -168,7 +168,6 @@ class HostConnectionPoolTests(unittest.TestCase):
def test_return_defunct_connection_on_down_host(self):
host = Mock(spec=Host, address='ip1')
host.monitor = Mock(spec=HealthMonitor)
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False)
session.cluster.connection_factory.return_value = conn
@@ -178,11 +177,11 @@ class HostConnectionPoolTests(unittest.TestCase):
pool.borrow_connection(timeout=0.01)
conn.is_defunct = True
host.monitor.signal_connection_failure.return_value = True
session.cluster.signal_connection_failure.return_value = True
pool.return_connection(conn)
# the connection should be closed a new creation scheduled
host.monitor.signal_connection_failure.assert_called_once()
session.cluster.signal_connection_failure.assert_called_once()
conn.close.assert_called_once()
self.assertFalse(session.submit.called)
self.assertTrue(pool.is_shutdown)
@@ -198,7 +197,7 @@ class HostConnectionPoolTests(unittest.TestCase):
pool.borrow_connection(timeout=0.01)
conn.is_closed = True
host.monitor.signal_connection_failure.return_value = False
session.cluster.signal_connection_failure.return_value = False
pool.return_connection(conn)
# a new creation should be scheduled

View File

@@ -257,6 +257,8 @@ class TokenAwarePolicyTest(unittest.TestCase):
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
host.set_up()
def get_replicas(keyspace, packed_key):
index = struct.unpack('>i', packed_key)[0]
@@ -286,6 +288,8 @@ class TokenAwarePolicyTest(unittest.TestCase):
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
for host in hosts:
host.set_up()
for h in hosts[:2]:
h.set_location_info("dc1", "rack1")
for h in hosts[2:]:
@@ -362,7 +366,6 @@ class TokenAwarePolicyTest(unittest.TestCase):
distances = set([policy.distance(remote_host), policy.distance(second_remote_host)])
self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED]))
def test_status_updates(self):
"""
Same test as DCAwareRoundRobinPolicyTest.test_status_updates()
@@ -468,7 +471,7 @@ class ConstantReconnectionPolicyTest(unittest.TestCase):
max_attempts = -100
try:
policy = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts)
ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts)
self.fail('max_attempts should throw ValueError when negative')
except ValueError:
pass

View File

@@ -68,8 +68,8 @@ class ResponseFutureTests(unittest.TestCase):
kind=ResultMessage.KIND_SET_KEYSPACE,
results="keyspace1")
rf._set_result(result)
rf._set_keyspace_completed({})
self.assertEqual(None, rf.result())
self.assertEqual(session.keyspace, 'keyspace1')
def test_schema_change_result(self):
session = self.make_session()