diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 03ed4769..12951a05 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -29,7 +29,7 @@ from cassandra.cluster import Cluster, Session, NoHostAvailable, ExecutionProfil from cassandra.concurrent import execute_concurrent from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance, - WhiteListRoundRobinPolicy, AddressTranslator) + WhiteListRoundRobinPolicy, AddressTranslator, TokenAwarePolicy, HostFilterPolicy) from cassandra.pool import Host from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory @@ -974,6 +974,37 @@ class ClusterTests(unittest.TestCase): else: raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) + def test_replicas_are_queried(self): + queried_hosts = set() + with Cluster(protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) as cluster: + session = cluster.connect() + prepared = session.prepare("""SELECT * from test1rf.test WHERE k = ?""") + for i in range(100): + result = session.execute(prepared, (i,), trace=True) + queried_hosts = self._assert_replica_queried(result.get_query_trace(), only_replicas=True) + last_i = i + + only_replica = queried_hosts.pop() + with Cluster(protocol_version=PROTOCOL_VERSION, + load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), + predicate=lambda host: host.address != only_replica)) as cluster: + session = cluster.connect() + prepared = session.prepare("""SELECT * from test1rf.test WHERE k = ?""") + for _ in range(100): + result = session.execute(prepared, (last_i,), trace=True) + self._assert_replica_queried(result.get_query_trace(), only_replicas=False) + + def _assert_replica_queried(self, trace, only_replicas=True): + queried_hosts = set() + for row in trace.events: + queried_hosts.add(row.source) + if only_replicas: + self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts)) + else: + self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts)) + return queried_hosts + class LocalHostAdressTranslator(AddressTranslator):