From 1f9bb019c780a49a2f5e580c495927812287d216 Mon Sep 17 00:00:00 2001 From: bjmb Date: Thu, 13 Jul 2017 08:25:04 -0400 Subject: [PATCH] Replaced WhiteHost policy for FilterHost policy in tests --- cassandra/policies.py | 17 ++---- tests/integration/long/test_failure_types.py | 6 +- tests/integration/standard/test_cluster.py | 55 ++++++++++++++----- tests/integration/standard/test_connection.py | 10 +++- tests/integration/standard/test_metrics.py | 6 +- 5 files changed, 62 insertions(+), 32 deletions(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index 1144f4f4..25565a2c 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -504,20 +504,16 @@ class HostFilterPolicy(LoadBalancingPolicy): self._predicate = predicate def on_up(self, host, *args, **kwargs): - if self.predicate(host): - return self._child_policy.on_up(host, *args, **kwargs) + return self._child_policy.on_up(host, *args, **kwargs) def on_down(self, host, *args, **kwargs): - if self.predicate(host): - return self._child_policy.on_down(host, *args, **kwargs) + return self._child_policy.on_down(host, *args, **kwargs) def on_add(self, host, *args, **kwargs): - if self.predicate(host): - return self._child_policy.on_add(host, *args, **kwargs) + return self._child_policy.on_add(host, *args, **kwargs) def on_remove(self, host, *args, **kwargs): - if self.predicate(host): - return self._child_policy.on_remove(host, *args, **kwargs) + return self._child_policy.on_remove(host, *args, **kwargs) @property def predicate(self): @@ -545,10 +541,7 @@ class HostFilterPolicy(LoadBalancingPolicy): return HostDistance.IGNORED def populate(self, cluster, hosts): - self._child_policy.populate( - cluster=cluster, - hosts=[h for h in hosts if self.predicate(h)] - ) + self._child_policy.populate(cluster=cluster, hosts=hosts) def make_query_plan(self, working_keyspace=None, query=None): """ diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py index c39903ec..a8f69397 100644 --- a/tests/integration/long/test_failure_types.py +++ b/tests/integration/long/test_failure_types.py @@ -17,7 +17,7 @@ import sys,logging, traceback, time, re from cassandra import (ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, ReadFailure, WriteFailure, FunctionFailure, ProtocolVersion) from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT -from cassandra.policies import WhiteListRoundRobinPolicy +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy from cassandra.concurrent import execute_concurrent_with_args from cassandra.query import SimpleStatement from tests.integration import use_singledc, PROTOCOL_VERSION, get_cluster, setup_keyspace, remove_cluster, get_node @@ -327,7 +327,9 @@ class TimeoutTimerTest(unittest.TestCase): # self.node1, self.node2, self.node3 = get_cluster().nodes.values() node1 = ExecutionProfile( - load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1']) + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) ) self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, execution_profiles={EXEC_PROFILE_DEFAULT: node1}) self.session = self.cluster.connect(wait_for_all_pools=True) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 03ed4769..7acae6a1 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -29,7 +29,7 @@ from cassandra.cluster import Cluster, Session, NoHostAvailable, ExecutionProfil from cassandra.concurrent import execute_concurrent from cassandra.policies import (RoundRobinPolicy, ExponentialReconnectionPolicy, RetryPolicy, SimpleConvictionPolicy, HostDistance, - WhiteListRoundRobinPolicy, AddressTranslator) + AddressTranslator, HostFilterPolicy) from cassandra.pool import Host from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory @@ -477,7 +477,10 @@ class ClusterTests(unittest.TestCase): def test_refresh_schema_no_wait(self): contact_points = [CASSANDRA_IP] cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10, - contact_points=contact_points, load_balancing_policy=WhiteListRoundRobinPolicy(contact_points)) + contact_points=contact_points, + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + )) session = cluster.connect() schema_ver = session.execute("SELECT schema_version FROM system.local WHERE key='local'")[0][0] @@ -618,7 +621,7 @@ class ClusterTests(unittest.TestCase): try: result = future.get_query_trace(-1.0) # In case the result has time to come back before this timeout due to a race condition - check_trace(result) + self.check_trace(result) except TraceUnavailable: break else: @@ -630,7 +633,7 @@ class ClusterTests(unittest.TestCase): try: result = future.get_query_trace(max_wait=120) # In case the result has been set check the trace - check_trace(result) + self.check_trace(result) except TraceUnavailable: break else: @@ -774,7 +777,11 @@ class ClusterTests(unittest.TestCase): @test_category config_profiles """ query = "select release_version from system.local" - node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy([CASSANDRA_IP])) + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + ) + ) with Cluster(execution_profiles={'node1': node1}) as cluster: session = cluster.connect(wait_for_all_pools=True) @@ -925,8 +932,16 @@ class ClusterTests(unittest.TestCase): @test_category config_profiles """ - node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) - node2 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.2'])) + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) + node2 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.2" + ) + ) with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1, 'node2': node2}) as cluster: session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() @@ -935,7 +950,11 @@ class ClusterTests(unittest.TestCase): self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2'))) # dynamically update pools on add - node3 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.3'])) + node3 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.3" + ) + ) cluster.add_execution_profile('node3', node3) pools = session.get_pool_state() self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2', '127.0.0.3'))) @@ -953,14 +972,22 @@ class ClusterTests(unittest.TestCase): """ max_retry_count = 10 for i in range(max_retry_count): - node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) + node1 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == "127.0.0.1" + ) + ) with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() self.assertGreater(len(cluster.metadata.all_hosts()), 2) self.assertEqual(set(h.address for h in pools), set(('127.0.0.1',))) - node2 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.2', '127.0.0.3'])) + node2 = ExecutionProfile( + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address in ["127.0.0.2", "127.0.0.3"] + ) + ) start = time.time() try: @@ -1030,7 +1057,9 @@ class TestAddressTranslation(unittest.TestCase): @local class ContextManagementTest(unittest.TestCase): - load_balancing_policy = WhiteListRoundRobinPolicy([CASSANDRA_IP]) + load_balancing_policy = HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address == CASSANDRA_IP + ) cluster_kwargs = {'execution_profiles': {EXEC_PROFILE_DEFAULT: ExecutionProfile(load_balancing_policy= load_balancing_policy)}, 'schema_metadata_enabled': False, @@ -1150,7 +1179,6 @@ class HostStateTest(unittest.TestCase): @local class DontPrepareOnIgnoredHostsTest(unittest.TestCase): - ignored_addresses = ['127.0.0.3'] ignore_node_3_policy = IgnoredHostPolicy(ignored_addresses) @@ -1189,7 +1217,8 @@ class DontPrepareOnIgnoredHostsTest(unittest.TestCase): @local class DuplicateRpcTest(unittest.TestCase): - load_balancing_policy = WhiteListRoundRobinPolicy(['127.0.0.1']) + load_balancing_policy = HostFilterPolicy(RoundRobinPolicy(), + lambda host: host.address == "127.0.0.1") def setUp(self): self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=self.load_balancing_policy) diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 4b9f73b5..4b0f8cab 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -29,7 +29,7 @@ from cassandra.cluster import NoHostAvailable, ConnectionShutdown, Cluster from cassandra.io.asyncorereactor import AsyncoreConnection from cassandra.protocol import QueryMessage from cassandra.connection import Connection -from cassandra.policies import WhiteListRoundRobinPolicy, HostStateListener +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, HostStateListener from cassandra.pool import HostConnectionPool from tests import is_monkey_patched, notwindows @@ -50,8 +50,12 @@ class ConnectionTimeoutTest(unittest.TestCase): def setUp(self): self.defaultInFlight = Connection.max_in_flight Connection.max_in_flight = 2 - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy= - WhiteListRoundRobinPolicy([CASSANDRA_IP])) + self.cluster = Cluster( + protocol_version=PROTOCOL_VERSION, + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), predicate=lambda host: host.address == CASSANDRA_IP + ) + ) self.session = self.cluster.connect() def tearDown(self): diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 38113255..90af9ee4 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -14,7 +14,7 @@ import time -from cassandra.policies import WhiteListRoundRobinPolicy, FallthroughRetryPolicy +from cassandra.policies import HostFilterPolicy, RoundRobinPolicy, FallthroughRetryPolicy try: import unittest2 as unittest @@ -39,7 +39,9 @@ class MetricsTests(unittest.TestCase): def setUp(self): contact_point = ['127.0.0.2'] self.cluster = Cluster(contact_points=contact_point, metrics_enabled=True, protocol_version=PROTOCOL_VERSION, - load_balancing_policy=WhiteListRoundRobinPolicy(contact_point), + load_balancing_policy=HostFilterPolicy( + RoundRobinPolicy(), lambda host: host.address in contact_point + ), default_retry_policy=FallthroughRetryPolicy()) self.session = self.cluster.connect("test3rf", wait_for_all_pools=True)