From 6ebbb24b5d1196f4422c35ed874e2902e97bf4fe Mon Sep 17 00:00:00 2001 From: bjmb Date: Wed, 12 Jul 2017 21:07:08 -0400 Subject: [PATCH] Added integration tests around querying replicas --- tests/integration/standard/test_cluster.py | 65 +++++++++++++++++++++- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 03ed4769..aa94549a 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -29,13 +29,14 @@ from cassandra.cluster import Cluster, Session, NoHostAvailable, ExecutionProfil from cassandra.concurrent import execute_concurrent from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance, - WhiteListRoundRobinPolicy, AddressTranslator) + WhiteListRoundRobinPolicy, AddressTranslator, TokenAwarePolicy, HostFilterPolicy) from cassandra.pool import Host from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory -from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, CASSANDRA_VERSION, DSE_VERSION, execute_until_pass, execute_with_long_wait_retry, get_node,\ - MockLoggingHandler, get_unsupported_lower_protocol, get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, CASSANDRA_VERSION, \ + execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ + get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP from tests.integration.util import assert_quiescent_pool_state import sys @@ -974,6 +975,64 @@ class ClusterTests(unittest.TestCase): else: raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) + def test_replicas_are_queried(self): + """ + Test that replicas are queried first for TokenAwarePolicy. A table with RF 1 + is created. All the queries should go to that replica when TokenAwarePolicy + is used. + Then using HostFilterPolicy the replica is excluded from the considered hosts. + By checking the trace we verify that there are no more replicas. + + @since 3.5 + @jira_ticket PYTHON-653 + @expected_result the replicas are queried for HostFilterPolicy + + @test_category metadata + """ + queried_hosts = set() + with Cluster(protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy())) as cluster: + session = cluster.connect() + session.execute(''' + CREATE TABLE test1rf.table_with_big_key ( + k1 int, + k2 int, + k3 int, + k4 int, + PRIMARY KEY((k1, k2, k3), k4))''') + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + for i in range(10): + result = session.execute(prepared, (i, i, i, i), trace=True) + queried_hosts = self._assert_replica_queried(result.get_query_trace(), only_replicas=True) + last_i = i + + only_replica = queried_hosts.pop() + available_hosts = [host for host in ["127.0.0.1", "127.0.0.2", "127.0.0.3"] if host != only_replica] + with Cluster(contact_points=available_hosts, + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), + predicate=lambda host: host.address != only_replica)) as cluster: + + session = cluster.connect() + prepared = session.prepare("""SELECT * from test1rf.table_with_big_key + WHERE k1 = ? AND k2 = ? AND k3 = ? AND k4 = ?""") + for _ in range(10): + result = session.execute(prepared, (last_i, last_i, last_i, last_i), trace=True) + self._assert_replica_queried(result.get_query_trace(), only_replicas=False) + + session.execute('''DROP TABLE test1rf.table_with_big_key''') + + def _assert_replica_queried(self, trace, only_replicas=True): + queried_hosts = set() + for row in trace.events: + queried_hosts.add(row.source) + if only_replicas: + self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts)) + else: + self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts)) + return queried_hosts + class LocalHostAdressTranslator(AddressTranslator):