Added integration tests around querying replicas

This commit is contained in:
bjmb
2017-07-12 21:07:08 -04:00
parent 852d39e6c0
commit cfffc7fb99

View File

@@ -29,7 +29,7 @@ from cassandra.cluster import Cluster, Session, NoHostAvailable, ExecutionProfil
from cassandra.concurrent import execute_concurrent
from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy,
RetryPolicy, SimpleConvictionPolicy, HostDistance,
WhiteListRoundRobinPolicy, AddressTranslator)
WhiteListRoundRobinPolicy, AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
from cassandra.pool import Host
from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
@@ -974,6 +974,37 @@ class ClusterTests(unittest.TestCase):
else:
raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count))
def test_replicas_are_queried(self):
queried_hosts = set()
with Cluster(protocol_version=PROTOCOL_VERSION,
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) as cluster:
session = cluster.connect()
prepared = session.prepare("""SELECT * from test1rf.test WHERE k = ?""")
for i in range(100):
result = session.execute(prepared, (i,), trace=True)
queried_hosts = self._assert_replica_queried(result.get_query_trace(), only_replicas=True)
last_i = i
only_replica = queried_hosts.pop()
with Cluster(protocol_version=PROTOCOL_VERSION,
load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(),
predicate=lambda host: host.address != only_replica)) as cluster:
session = cluster.connect()
prepared = session.prepare("""SELECT * from test1rf.test WHERE k = ?""")
for _ in range(100):
result = session.execute(prepared, (last_i,), trace=True)
self._assert_replica_queried(result.get_query_trace(), only_replicas=False)
def _assert_replica_queried(self, trace, only_replicas=True):
queried_hosts = set()
for row in trace.events:
queried_hosts.add(row.source)
if only_replicas:
self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts))
else:
self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts))
return queried_hosts
class LocalHostAdressTranslator(AddressTranslator):