Added integration tests around querying replicas

This commit is contained in:
bjmb
2017-07-12 21:07:08 -04:00
parent 852d39e6c0
commit cfffc7fb99

View File

@@ -29,7 +29,7 @@ from cassandra.cluster import Cluster, Session, NoHostAvailable, ExecutionProfil
from cassandra.concurrent import execute_concurrent from cassandra.concurrent import execute_concurrent
from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy,
RetryPolicy, SimpleConvictionPolicy, HostDistance, RetryPolicy, SimpleConvictionPolicy, HostDistance,
WhiteListRoundRobinPolicy, AddressTranslator) WhiteListRoundRobinPolicy, AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
from cassandra.pool import Host from cassandra.pool import Host
from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
@@ -974,6 +974,37 @@ class ClusterTests(unittest.TestCase):
else: else:
raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count))
def test_replicas_are_queried(self):
queried_hosts = set()
with Cluster(protocol_version=PROTOCOL_VERSION,
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) as cluster:
session = cluster.connect()
prepared = session.prepare("""SELECT * from test1rf.test WHERE k = ?""")
for i in range(100):
result = session.execute(prepared, (i,), trace=True)
queried_hosts = self._assert_replica_queried(result.get_query_trace(), only_replicas=True)
last_i = i
only_replica = queried_hosts.pop()
with Cluster(protocol_version=PROTOCOL_VERSION,
load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(),
predicate=lambda host: host.address != only_replica)) as cluster:
session = cluster.connect()
prepared = session.prepare("""SELECT * from test1rf.test WHERE k = ?""")
for _ in range(100):
result = session.execute(prepared, (last_i,), trace=True)
self._assert_replica_queried(result.get_query_trace(), only_replicas=False)
def _assert_replica_queried(self, trace, only_replicas=True):
queried_hosts = set()
for row in trace.events:
queried_hosts.add(row.source)
if only_replicas:
self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts))
else:
self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts))
return queried_hosts
class LocalHostAdressTranslator(AddressTranslator): class LocalHostAdressTranslator(AddressTranslator):