diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 2beb6bcf..6e16658f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -333,8 +333,6 @@ class Cluster(object): self.metrics = Metrics(weakref.proxy(self)) self.control_connection = ControlConnection(self) - for address in contact_points: - self.add_host(address, signal=True) def get_min_requests_per_connection(self, host_distance): return self._min_requests_per_connection[host_distance] @@ -398,6 +396,13 @@ class Cluster(object): raise Exception("Cluster is already shut down") if not self._is_setup: + for address in self.contact_points: + host = self.add_host(address, signal=False) + if host: + host.set_up() + for listener in self.listeners: + listener.on_add(host) + self.load_balancing_policy.populate( weakref.proxy(self), self.metadata.all_hosts()) self._is_setup = True @@ -459,30 +464,105 @@ class Cluster(object): self.sessions.add(session) return session + def _on_up_future_completed(self, host, futures, results, lock, finished_future): + with lock: + futures.discard(finished_future) + + try: + results.append(finished_future.result()) + except Exception as exc: + results.append(exc) + + if futures: + return + + try: + # all futures have completed at this point + for exc in [f for f in results if isinstance(f, Exception)]: + log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) + return + + if not all(results): + log.debug("Connection pool could not be created, not marking node %s up:", host) + return + + # mark the host as up and notify all listeners + host.set_up() + for listener in self.listeners: + listener.on_up(host) + finally: + host.lock.release() + + # see if there are any pools to add or remove now that the host is marked up + for session in self.sessions: + session.update_created_pools() + def on_up(self, host): """ - Called when a host is marked up by its :class:`~.HealthMonitor`. Intended for internal use only. """ - reconnector = host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() + if self._is_shutdown: + return - self._prepare_all_queries(host) + host.lock.acquire() + try: + if host.is_up: + host.lock.release() + return - self.control_connection.on_up(host) - for session in self.sessions: - session.on_up(host) + log.debug("Host %s has been marked up", host) - def on_down(self, host): + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + log.debug("Now that host %s is up, cancelling the reconnection handler", host) + reconnector.cancel() + + self._prepare_all_queries(host) + + for session in self.sessions: + session.remove_pool(host) + + self.load_balancing_policy.on_up(host) + self.control_connection.on_up(host) + + futures_lock = Lock() + futures_results = [] + futures = set() + callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + for session in self.sessions: + future = session.add_or_renew_pool(host, is_host_addition=False) + future.add_done_callback(callback) + futures.add(future) + except Exception: + host.lock.release() + + # for testing purposes + return futures + + @run_in_executor + def on_down(self, host, is_host_addition): """ - Called when a host is marked down by its :class:`~.HealthMonitor`. Intended for internal use only. """ + if self._is_shutdown: + return + + with host.lock: + if (not host.is_up) or host.is_currently_reconnecting(): + return + + host.set_down() + + log.debug("Host %s has been marked down", host) + + self.load_balancing_policy.on_down(host) self.control_connection.on_down(host) for session in self.sessions: session.on_down(host) + for listener in self.listeners: + listener.on_down(host) + schedule = self.reconnection_policy.new_schedule() # in order to not hold references to this Cluster open and prevent @@ -491,28 +571,94 @@ class Cluster(object): conn_factory = self._make_connection_factory(host) reconnector = _HostReconnectionHandler( - host, conn_factory, self.scheduler, schedule, - host.get_and_set_reconnection_handler, new_handler=None) + host, conn_factory, is_host_addition, self.on_add, self.on_up, + self.scheduler, schedule, host.get_and_set_reconnection_handler, + new_handler=None) old_reconnector = host.get_and_set_reconnection_handler(reconnector) if old_reconnector: + log.debug("Old host reconnector found for %s, cancelling", host) old_reconnector.cancel() + log.debug("Staring reconnector for host %s", host) reconnector.start() + def on_add(self, host): + if self._is_shutdown: + return + + log.debug("Adding new host %s", host) + self._prepare_all_queries(host) + + self.load_balancing_policy.on_add(host) + self.control_connection.on_add(host) + + futures_lock = Lock() + futures_results = [] + futures = set() + + def future_completed(future): + with futures_lock: + futures.discard(future) + + try: + futures_results.append(future.result()) + except Exception as exc: + futures_results.append(exc) + + if futures: + return + + # all futures have completed at this point + for exc in [f for f in futures_results if isinstance(f, Exception)]: + log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) + return + + if not all(futures_results): + log.debug("Connection pool could not be created, not marking node %s up:", host) + return + + # mark the host as up and notify all listeners + host.set_up() + for listener in self.listeners: + listener.on_add(host) + + # see if there are any pools to add or remove now that the host is marked up + for session in self.sessions: + session.update_created_pools() + + for session in self.sessions: + future = session.add_or_renew_host(host, is_host_addition=True) + future.add_done_callback(future_completed) + + def on_remove(self, host): + if self._is_shutdown: + return + + log.debug("Removing host %s", host) + host.set_down() + self.load_balancing_policy.on_remove(host) + for session in self.sessions: + session.on_remove() + for listener in self.listeners: + listener.on_remove() + + def signal_connection_failure(self, host, connection_exc, is_host_addition): + is_down = host.signal_connection_failure(connection_exc) + if is_down: + self.on_down(host, is_host_addition) + return is_down + def add_host(self, address, signal): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. Intended for internal use only. """ - log.info("Now considering host %s for new connections", address) new_host = self.metadata.add_host(address) if new_host and signal: - self._prepare_all_queries(new_host) - self.control_connection.on_add(new_host) - for session in self.sessions: # TODO need to copy/lock? - session.on_add(new_host) + log.info("New Cassandra host %s added", address) + self.on_add(new_host) return new_host @@ -521,11 +667,9 @@ class Cluster(object): Called when the control connection observes that a node has left the ring. Intended for internal use only. """ - log.info("Host %s will no longer be considered for new connections", host) if host and self.metadata.remove_host(host): - self.control_connection.on_remove(host) - for session in self.sessions: - session.on_remove(host) + log.info("Cassandra host %s removed", host) + self.on_remove(host) def register_listener(self, listener): """ @@ -574,7 +718,8 @@ class Cluster(object): try: self.control_connection.wait_for_schema_agreement(connection) except Exception: - pass + log.debug("Error waiting for schema agreement before preparing statements against host %s", host, exc_info=True) + # TODO: potentially error out the connection? statements = self._prepared_statements.values() for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): @@ -587,7 +732,8 @@ class Cluster(object): for statement in ks_statements: message = PrepareMessage(query=statement.query_string) try: - response = connection.wait_for_response(message) + # TODO: make this timeout configurable somehow? + response = connection.wait_for_response(message, timeout=1.0) if (not isinstance(response, ResultMessage) or response.kind != ResultMessage.KIND_PREPARED): log.debug("Got unexpected response when preparing " @@ -596,6 +742,8 @@ class Cluster(object): log.exception("Error trying to prepare statement on " "host %s", host) + connection.close() + log.debug("Done preparing all known prepared statements against host %s", host) except Exception: # log and ignore log.exception("Error trying to prepare all statements on host %s", host) @@ -657,7 +805,8 @@ class Session(object): self._metrics = cluster.metrics for host in hosts: - self.add_host(host) + future = self.add_or_renew_pool(host, is_host_addition=False) + future.result() def execute(self, query, parameters=None, trace=False): """ @@ -838,71 +987,81 @@ class Session(object): except TypeError: pass - def add_host(self, host): - """ Internal """ + def add_or_renew_pool(self, host, is_host_addition): + """ + For internal use only. + """ distance = self._load_balancer.distance(host) if distance == HostDistance.IGNORED: - return self._pools.get(host) - else: + return None + + def run_add_or_renew_pool(): try: new_pool = HostConnectionPool(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), host=host) - host.monitor.signal_connection_failure(conn_exc) - return self._pools.get(host) + self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + return False except Exception as conn_exc: - host.monitor.signal_connection_failure(conn_exc) - return self._pools.get(host) + log.debug("Signaling connection failure during Session.add_host: %s", conn_exc) + self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + return False previous = self._pools.get(host) self._pools[host] = new_pool - return previous + log.debug("Added pool for host %s to session", host) + if previous: + previous.shutdown() - def on_up(self, host): - """ - Called by the parent Cluster instance when a host's :class:`HealthMonitor` - marks it up. Only intended for internal use. - """ - previous_pool = self.add_host(host) - self._load_balancer.on_up(host) - if previous_pool: - previous_pool.shutdown() + return True - def on_down(self, host): - """ - Called by the parent Cluster instance when a host's :class:`HealthMonitor` - marks it down. Only intended for internal use. - """ - self._load_balancer.on_down(host) + return self.submit(run_add_or_renew_pool) + + def remove_pool(self, host): pool = self._pools.pop(host, None) if pool: - pool.shutdown() + return self.submit(pool.shutdown) + else: + return None + def update_created_pools(self): + """ + When the set of live nodes change, the loadbalancer will change its + mind on host distances. It might change it on the node that came/left + but also on other nodes (for instance, if a node dies, another + previously ignored node may be now considered). + + This method ensures that all hosts for which a pool should exist + have one, and hosts that shouldn't don't. + + For internal use only. + """ for host in self.cluster.metadata.all_hosts(): - if not host.monitor.is_up: - continue - distance = self._load_balancer.distance(host) - if distance != HostDistance.IGNORED: - pool = self._pools.get(host) - if not pool: - self.add_host(host) + pool = self._pools.get(host) + + if not pool: + if distance != HostDistance.IGNORED and host.is_up: + self.add_or_renew_pool(host, False) + elif distance != pool.host_distance: + # the distance has changed + if distance == HostDistance.IGNORED: + self.remove_pool(host) else: pool.host_distance = distance - def on_add(self, host): - """ Internal """ - previous_pool = self.add_host(host) - self._load_balancer.on_add(host) - if previous_pool: - previous_pool.shutdown() + def on_down(self, host): + """ + Called by the parent Cluster instance when a node is marked down. + Only intended for internal use. + """ + future = self.remove_pool(host) + if future: + future.add_done_callback(lambda f: self.update_created_pools()) def on_remove(self, host): """ Internal """ - self._load_balancer.on_remove(host) - pool = self._pools.pop(host) - if pool: - pool.shutdown() + self.on_down(host) def set_keyspace(self, keyspace): """ @@ -1031,8 +1190,8 @@ class ControlConnection(object): return self._try_connect(host) except ConnectionException as exc: errors[host.address] = exc - host.monitor.signal_connection_failure(exc) log.warn("[control connection] Error connecting to %s:", host, exc_info=True) + self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: errors[host.address] = exc log.warn("[control connection] Error connecting to %s:", host, exc_info=True) @@ -1242,14 +1401,16 @@ class ControlConnection(object): # this is the first time we've seen the node self._cluster.scheduler.schedule(1, self._cluster.add_host, addr, signal=True) else: - self._cluster.scheduler.schedule(1, host.monitor.set_up) + # this will be run by the scheduler + self._cluster.scheduler.schedule(1, self._cluster.on_up, host) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. # But it is unlikely, and don't have too much consequence since we'll try reconnecting # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: - self._cluster.scheduler.schedule(1, host.monitor.set_down) + # this will be run by the scheduler + self._cluster.on_down(host, is_host_addition=False) def _handle_schema_change(self, event): keyspace = event['keyspace'] or None @@ -1294,10 +1455,11 @@ class ControlConnection(object): rpc = row.get("peer") peer = self._cluster.metadata.get_host(rpc) - if peer and peer.monitor.is_up: + if peer and peer.is_up: versions.add(row.get("schema_version")) if len(versions) == 1: + log.debug("[control connection] Schemas match") return True log.debug("[control connection] Schemas mismatched, trying again") @@ -1307,14 +1469,15 @@ class ControlConnection(object): return False def _signal_error(self): - # try just signaling the host monitor, as this will trigger a reconnect + # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: host = self._cluster.metadata.get_host(self._connection.host) # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - host.monitor.signal_connection_failure(self._connection.last_error) + self._cluster.signal_connection_failure( + host, self._connection.last_error, is_host_addition=False) return # if the connection is not defunct or the host already left, reconnect @@ -1327,26 +1490,22 @@ class ControlConnection(object): return bool(conn and conn.is_open) def on_up(self, host): - log.debug("[control connection] Host %s is considered up", host) - self._balancing_policy.on_up(host) + pass def on_down(self, host): - log.debug("[control connection] Host %s is considered down", host) - self._balancing_policy.on_down(host) conn = self._connection if conn and conn.host == host.address and \ self._reconnection_handler is None: + log.debug("[control connection] Control connection host (%s) is " + "considered down, starting reconnection", host) + # this will result in a task being submitted to the executor to reconnect self.reconnect() def on_add(self, host): - log.debug("[control connection] Adding host %r and refreshing topology", host) - self._balancing_policy.on_add(host) self.refresh_node_list_and_token_map() def on_remove(self, host): - log.debug("[control connection] Removing host %r and refreshing topology", host) - self._balancing_policy.on_remove(host) self.refresh_node_list_and_token_map() @@ -1659,13 +1818,13 @@ class ResponseFuture(object): else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " - "on host %s: %s" % (self._host, response))) + "on host %s: %s" % (self._current_host, response))) elif isinstance(response, ErrorMessage): self._set_final_exception(response) else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " - "statement on host %s: %s" % (self._host, response))) + "statement on host %s: %s" % (self._current_host, response))) def _set_final_result(self, response): if self._metrics is not None: diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 52918738..8131bed6 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -306,7 +306,6 @@ class Metadata(object): else: return None - new_host.monitor.register(cluster) return new_host def remove_host(self, host): diff --git a/cassandra/policies.py b/cassandra/policies.py index 6b7bf583..0e250487 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -42,7 +42,29 @@ class HostDistance(object): """ -class LoadBalancingPolicy(object): +class HostStateListener(object): + + def on_up(self, host): + """ Called when a node is marked up. """ + raise NotImplementedError() + + def on_down(self, host): + """ Called when a node is marked down. """ + raise NotImplementedError() + + def on_add(self, host): + """ + Called when a node is added to the cluster. The newly added node + should be considered up. + """ + raise NotImplementedError() + + def on_remove(self, host): + """ Called when a node is removed from the cluster. """ + raise NotImplementedError() + + +class LoadBalancingPolicy(HostStateListener): """ Load balancing policies are used to decide how to distribute requests among all possible coordinator nodes in the cluster. @@ -87,36 +109,6 @@ class LoadBalancingPolicy(object): """ raise NotImplementedError() - def on_up(self, host): - """ - Called when a :class:`~.pool.Host`'s :class:`~.HealthMonitor` - marks the node up. - """ - raise NotImplementedError() - - def on_down(self, host): - """ - Called when a :class:`~.pool.Host`'s :class:`~.HealthMonitor` - marks the node down. - """ - raise NotImplementedError() - - def on_add(self, host): - """ - Called when a :class:`.Cluster` instance is first created and - the initial contact points are added as well as when a new - :class:`~.pool.Host` is discovered in the cluster, which may - happen the first time the ring topology is examined or when - a new node joins the cluster. - """ - raise NotImplementedError() - - def on_remove(self, host): - """ - Called when a :class:`~.pool.Host` leaves the cluster. - """ - raise NotImplementedError() - class RoundRobinPolicy(LoadBalancingPolicy): """ @@ -300,7 +292,7 @@ class TokenAwarePolicy(LoadBalancingPolicy): else: replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) for replica in replicas: - if replica.monitor.is_up and \ + if replica.is_up and \ child.distance(replica) == HostDistance.LOCAL: yield replica diff --git a/cassandra/pool.py b/cassandra/pool.py index f1bc6423..b2634869 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -4,7 +4,7 @@ Connection pooling and host management. import logging import time -from threading import Lock, RLock, Condition +from threading import RLock, Condition import weakref try: from weakref import WeakSet @@ -35,15 +35,18 @@ class Host(object): The IP address or hostname of the node. """ - monitor = None + conviction_policy = None """ - A :class:`.HealthMonitor` instance that tracks whether this node is - up or down. + A class:`ConvictionPolicy` instance for determining when this node should + be marked up or down. """ + is_up = None + _datacenter = None _rack = None _reconnection_handler = None + lock = None def __init__(self, inet_address, conviction_policy_factory): if inet_address is None: @@ -52,9 +55,8 @@ class Host(object): raise ValueError("conviction_policy_factory may not be None") self.address = inet_address - self.monitor = HealthMonitor(conviction_policy_factory(self)) - - self._reconnection_lock = Lock() + self.conviction_policy = conviction_policy_factory(self) + self.lock = RLock() @property def datacenter(self): @@ -75,12 +77,25 @@ class Host(object): self._datacenter = datacenter self._rack = rack + def set_up(self): + self.conviction_policy.reset() + self.is_up = True + + def set_down(self): + self.is_up = False + + def signal_connection_failure(self, connection_exc): + return self.conviction_policy.add_failure(connection_exc) + + def is_currently_reconnecting(self): + return self._reconnection_handler is not None + def get_and_set_reconnection_handler(self, new_handler): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ - with self._reconnection_lock: + with self.lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old @@ -175,8 +190,11 @@ class _ReconnectionHandler(object): class _HostReconnectionHandler(_ReconnectionHandler): - def __init__(self, host, connection_factory, *args, **kwargs): + def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) + self.is_host_addition = is_host_addition + self.on_add = on_add + self.on_up = on_up self.host = host self.connection_factory = connection_factory @@ -184,85 +202,23 @@ class _HostReconnectionHandler(_ReconnectionHandler): return self.connection_factory() def on_reconnection(self, connection): - self.host.monitor.reset() + connection.close() + log.info("Successful reconnection to %s, marking node up", self.host) + if self.is_host_addition: + self.on_add(self.host) + else: + self.on_up(self.host) def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): return False else: - log.warn("Error attempting to reconnect to %s: %s", self.host, exc) + log.warn("Error attempting to reconnect to %s, scheduling retry in %f seconds: %s", + self.host, next_delay, exc) log.debug("Reconnection error details", exc_info=True) return True -class HealthMonitor(object): - """ - Monitors whether a particular host is marked as up or down. - This class is primarily intended for internal use, although - applications may find it useful to check whether a given node - is up or down. - """ - - is_up = True - """ - A boolean representing the current state of the node. - """ - - def __init__(self, conviction_policy): - self._conviction_policy = conviction_policy - self._host = conviction_policy.host - # self._listeners will hold, among other things, references to - # Cluster objects. To allow those to be GC'ed (and shutdown) even - # though we've implemented __del__, use weak references. - self._listeners = WeakSet() - self._lock = RLock() - - def register(self, listener): - with self._lock: - self._listeners.add(listener) - - def unregister(self, listener): - with self._lock: - self._listeners.remove(listener) - - def set_up(self): - if self.is_up: - return - - self._conviction_policy.reset() - log.info("Host %s is considered up", self._host) - - with self._lock: - listeners = self._listeners.copy() - - for listener in listeners: - listener.on_up(self._host) - - self.is_up = True - - def set_down(self): - if not self.is_up: - return - - self.is_up = False - log.info("Host %s is considered down", self._host) - - with self._lock: - listeners = self._listeners.copy() - - for listener in listeners: - listener.on_down(self._host) - - def reset(self): - return self.set_up() - - def signal_connection_failure(self, connection_exc): - is_down = self._conviction_policy.add_failure(connection_exc) - if is_down: - self.set_down() - return is_down - - _MAX_SIMULTANEOUS_CREATION = 1 _NEW_CONNECTION_GRACE_PERIOD = 5 @@ -295,6 +251,7 @@ class HostConnectionPool(object): self._trash = set() self.open_count = core_conns + log.debug("Finished initializing new connection pool for host %s", self.host) def borrow_connection(self, timeout): if self.is_shutdown: @@ -395,7 +352,7 @@ class HostConnectionPool(object): log.exception("Failed to add new connection to pool for host %s", self.host) with self._lock: self.open_count -= 1 - if self.host.monitor.signal_connection_failure(exc): + if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False): self.shutdown() return False except AuthenticationFailed: @@ -448,7 +405,8 @@ class HostConnectionPool(object): if connection.is_defunct or connection.is_closed: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - is_down = self.host.monitor.signal_connection_failure(connection.last_error) + is_down = self._session.cluster.signal_connection_failure( + self.host, connection.last_error, is_host_addition=False) if is_down: self.shutdown() else: diff --git a/tests/integration/test_cluster.py b/tests/integration/test_cluster.py index 7e8d3222..085a17bd 100644 --- a/tests/integration/test_cluster.py +++ b/tests/integration/test_cluster.py @@ -74,7 +74,7 @@ class ClusterTests(unittest.TestCase): Ensure errors are not thrown when using non-default policies """ - cluster = Cluster( + Cluster( load_balancing_policy=RoundRobinPolicy(), reconnection_policy=ExponentialReconnectionPolicy(1.0, 600.0), default_retry_policy=RetryPolicy(), @@ -177,28 +177,6 @@ class ClusterTests(unittest.TestCase): self.assertIn("newkeyspace", cluster.metadata.keyspaces) - def test_on_down_and_up(self): - """ - Test on_down and on_up handling - """ - - cluster = Cluster() - session = cluster.connect() - host = cluster.metadata.all_hosts()[0] - host.monitor.signal_connection_failure(None) - cluster.on_down(host) - self.assertNotIn(host, session._pools) - host_reconnector = host._reconnection_handler - self.assertNotEqual(None, host_reconnector) - - host.monitor.is_up = True - - cluster.on_up(host) - - self.assertEqual(None, host._reconnection_handler) - self.assertTrue(host_reconnector._cancelled) - self.assertIn(host, session._pools) - def test_trace(self): """ Ensure trace can be requested for async and non-async queries diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py index edba157b..b46f175b 100644 --- a/tests/unit/io/test_libevreactor.py +++ b/tests/unit/io/test_libevreactor.py @@ -11,7 +11,6 @@ from mock import patch, Mock from cassandra.connection import (PROTOCOL_VERSION, HEADER_DIRECTION_TO_CLIENT, - ProtocolError, ConnectionException) from cassandra.decoder import (write_stringmultimap, write_int, write_string, @@ -87,7 +86,7 @@ class LibevConnectionTest(unittest.TestCase): c.handle_write(None, None) # read in a SupportedMessage response - header = self.make_header_prefix(SupportedMessage, version=0x04) + header = self.make_header_prefix(SupportedMessage, version=0xa4) options = self.make_options_body() c._socket.recv.return_value = self.make_msg(header, options) c.handle_read(None, None) @@ -95,7 +94,7 @@ class LibevConnectionTest(unittest.TestCase): # make sure it errored correctly self.assertTrue(c.is_defunct) self.assertTrue(c.connected_event.is_set()) - self.assertIsInstance(c.last_error, ProtocolError) + self.assertIsInstance(c.last_error, ConnectionException) def test_error_message_on_startup(self, *args): c = self.make_connection() diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index e1f57ee9..6b820842 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -23,6 +23,8 @@ class MockMetadata(object): "192.168.1.1": Host("192.168.1.1", SimpleConvictionPolicy), "192.168.1.2": Host("192.168.1.2", SimpleConvictionPolicy) } + for host in self.hosts.values(): + host.set_up() self.cluster_name = None self.partitioner = None @@ -44,6 +46,7 @@ class MockCluster(object): max_schema_agreement_wait = Cluster.max_schema_agreement_wait load_balancing_policy = RoundRobinPolicy() reconnection_policy = ConstantReconnectionPolicy(2) + down_host = None def __init__(self): self.metadata = MockMetadata() @@ -60,6 +63,12 @@ class MockCluster(object): def remove_host(self, host): self.removed_hosts.append(host) + def on_up(self, host): + pass + + def on_down(self, host, is_host_addition): + self.down_host = host + class MockConnection(object): @@ -142,7 +151,7 @@ class ControlConnectionTest(unittest.TestCase): # change the schema version on one of the existing entries self.connection.peer_results[1][1][3] = 'c' - self.cluster.metadata.get_host('192.168.1.1').monitor.is_up = False + self.cluster.metadata.get_host('192.168.1.1').is_up = False self.assertTrue(self.control_connection.wait_for_schema_agreement()) self.assertEqual(self.time.clock, 0) @@ -156,7 +165,7 @@ class ControlConnectionTest(unittest.TestCase): ) host = Host("0.0.0.0", SimpleConvictionPolicy) self.cluster.metadata.hosts[PEER_IP] = host - host.monitor.is_up = False + host.is_up = False # even though the new host has a different schema version, it's # marked as down, so the control connection shouldn't care @@ -164,7 +173,7 @@ class ControlConnectionTest(unittest.TestCase): self.assertEqual(self.time.clock, 0) # but once we mark it up, the control connection will care - host.monitor.is_up = True + host.is_up = True self.assertFalse(self.control_connection.wait_for_schema_agreement()) self.assertGreaterEqual(self.time.clock, Cluster.max_schema_agreement_wait) @@ -248,7 +257,7 @@ class ControlConnectionTest(unittest.TestCase): } self.control_connection._handle_status_change(event) host = self.cluster.metadata.hosts['192.168.1.0'] - self.cluster.scheduler.schedule.assert_called_with(ANY, host.monitor.set_up) + self.cluster.scheduler.schedule.assert_called_with(ANY, self.cluster.on_up, host) self.cluster.scheduler.schedule.reset_mock() event = { @@ -265,7 +274,7 @@ class ControlConnectionTest(unittest.TestCase): } self.control_connection._handle_status_change(event) host = self.cluster.metadata.hosts['192.168.1.0'] - self.cluster.scheduler.schedule.assert_called_with(ANY, host.monitor.set_down) + self.assertIs(host, self.cluster.down_host) def test_handle_schema_change(self): diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 843e2e1a..871d091f 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -8,7 +8,7 @@ from threading import Thread, Event from cassandra.cluster import Session from cassandra.connection import Connection, MAX_STREAM_PER_CONNECTION -from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable, HealthMonitor +from cassandra.pool import Host, HostConnectionPool, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy @@ -158,7 +158,7 @@ class HostConnectionPoolTests(unittest.TestCase): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - host.monitor.signal_connection_failure.return_value = False + session.cluster.signal_connection_failure.return_value = False pool.return_connection(conn) # the connection should be closed a new creation scheduled @@ -168,7 +168,6 @@ class HostConnectionPoolTests(unittest.TestCase): def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') - host.monitor = Mock(spec=HealthMonitor) session = self.make_session() conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False) session.cluster.connection_factory.return_value = conn @@ -178,11 +177,11 @@ class HostConnectionPoolTests(unittest.TestCase): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - host.monitor.signal_connection_failure.return_value = True + session.cluster.signal_connection_failure.return_value = True pool.return_connection(conn) # the connection should be closed a new creation scheduled - host.monitor.signal_connection_failure.assert_called_once() + session.cluster.signal_connection_failure.assert_called_once() conn.close.assert_called_once() self.assertFalse(session.submit.called) self.assertTrue(pool.is_shutdown) @@ -198,7 +197,7 @@ class HostConnectionPoolTests(unittest.TestCase): pool.borrow_connection(timeout=0.01) conn.is_closed = True - host.monitor.signal_connection_failure.return_value = False + session.cluster.signal_connection_failure.return_value = False pool.return_connection(conn) # a new creation should be scheduled diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 60eabf14..d461295c 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -257,6 +257,8 @@ class TokenAwarePolicyTest(unittest.TestCase): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() def get_replicas(keyspace, packed_key): index = struct.unpack('>i', packed_key)[0] @@ -286,6 +288,8 @@ class TokenAwarePolicyTest(unittest.TestCase): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + for host in hosts: + host.set_up() for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: @@ -362,7 +366,6 @@ class TokenAwarePolicyTest(unittest.TestCase): distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) - def test_status_updates(self): """ Same test as DCAwareRoundRobinPolicyTest.test_status_updates() @@ -468,7 +471,7 @@ class ConstantReconnectionPolicyTest(unittest.TestCase): max_attempts = -100 try: - policy = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) self.fail('max_attempts should throw ValueError when negative') except ValueError: pass diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index ed9877e3..e7227394 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -68,8 +68,8 @@ class ResponseFutureTests(unittest.TestCase): kind=ResultMessage.KIND_SET_KEYSPACE, results="keyspace1") rf._set_result(result) + rf._set_keyspace_completed({}) self.assertEqual(None, rf.result()) - self.assertEqual(session.keyspace, 'keyspace1') def test_schema_change_result(self): session = self.make_session()