Added tests for PYTHON-761
This commit is contained in:
@@ -504,20 +504,16 @@ class HostFilterPolicy(LoadBalancingPolicy):
|
||||
self._predicate = predicate
|
||||
|
||||
def on_up(self, host, *args, **kwargs):
|
||||
if self.predicate(host):
|
||||
return self._child_policy.on_up(host, *args, **kwargs)
|
||||
return self._child_policy.on_up(host, *args, **kwargs)
|
||||
|
||||
def on_down(self, host, *args, **kwargs):
|
||||
if self.predicate(host):
|
||||
return self._child_policy.on_down(host, *args, **kwargs)
|
||||
return self._child_policy.on_down(host, *args, **kwargs)
|
||||
|
||||
def on_add(self, host, *args, **kwargs):
|
||||
if self.predicate(host):
|
||||
return self._child_policy.on_add(host, *args, **kwargs)
|
||||
return self._child_policy.on_add(host, *args, **kwargs)
|
||||
|
||||
def on_remove(self, host, *args, **kwargs):
|
||||
if self.predicate(host):
|
||||
return self._child_policy.on_remove(host, *args, **kwargs)
|
||||
return self._child_policy.on_remove(host, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def predicate(self):
|
||||
@@ -547,7 +543,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
|
||||
def populate(self, cluster, hosts):
|
||||
self._child_policy.populate(
|
||||
cluster=cluster,
|
||||
hosts=[h for h in hosts if self.predicate(h)]
|
||||
hosts=hosts
|
||||
)
|
||||
|
||||
def make_query_plan(self, working_keyspace=None, query=None):
|
||||
|
@@ -20,12 +20,14 @@ except ImportError:
|
||||
import unittest # noqa
|
||||
|
||||
from cassandra import OperationTimedOut
|
||||
from cassandra.cluster import ExecutionProfile
|
||||
from cassandra.cluster import ExecutionProfile, Cluster
|
||||
from cassandra.query import SimpleStatement
|
||||
from cassandra.policies import ConstantSpeculativeExecutionPolicy, RoundRobinPolicy
|
||||
from cassandra.policies import ConstantSpeculativeExecutionPolicy, HostFilterPolicy, RoundRobinPolicy, \
|
||||
SimpleConvictionPolicy
|
||||
from cassandra.connection import Connection
|
||||
from cassandra.pool import Host
|
||||
|
||||
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21
|
||||
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21, PROTOCOL_VERSION
|
||||
from tests import notwindows
|
||||
|
||||
from mock import patch
|
||||
@@ -47,6 +49,40 @@ class BadRoundRobinPolicy(RoundRobinPolicy):
|
||||
return hosts
|
||||
|
||||
|
||||
class HostFilterPolicyTests(unittest.TestCase):
|
||||
|
||||
def test_predicate_changes(self):
|
||||
restrict = True
|
||||
contact_point = "127.0.0.1"
|
||||
|
||||
single_host = {Host(contact_point, SimpleConvictionPolicy)}
|
||||
all_hosts = {Host("127.0.0.{}".format(i), SimpleConvictionPolicy) for i in (1, 2, 3)}
|
||||
|
||||
predicate = lambda host: host.address == contact_point if restrict else True
|
||||
cluster = Cluster((contact_point,), load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(),
|
||||
predicate=predicate),
|
||||
protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0,
|
||||
status_event_refresh_window=0)
|
||||
session = cluster.connect(wait_for_all_pools=True)
|
||||
|
||||
queried_hosts = set()
|
||||
for _ in range(100):
|
||||
response = session.execute("SELECT * from system.local")
|
||||
queried_hosts.update(response.response_future.attempted_hosts)
|
||||
|
||||
self.assertEqual(queried_hosts, single_host)
|
||||
|
||||
restrict = False
|
||||
session.update_created_pools()
|
||||
|
||||
queried_hosts = set()
|
||||
for _ in range(100):
|
||||
response = session.execute("SELECT * from system.local")
|
||||
print(response.response_future.attempted_hosts)
|
||||
queried_hosts.update(response.response_future.attempted_hosts)
|
||||
self.assertEqual(queried_hosts, all_hosts)
|
||||
|
||||
|
||||
# This doesn't work well with Windows clock granularity
|
||||
@notwindows
|
||||
class SpecExecTest(BasicSharedKeyspaceUnitTestCase):
|
||||
|
@@ -1402,13 +1402,13 @@ class HostFilterPolicyPopulateTest(unittest.TestCase):
|
||||
hosts=hosts
|
||||
)
|
||||
|
||||
def test_child_not_populated_with_filtered_hosts(self):
|
||||
def test_child_is_populated_with_filtered_hosts(self):
|
||||
hfp = HostFilterPolicy(
|
||||
child_policy=Mock(name='child_policy'),
|
||||
predicate=lambda host: 'acceptme' in host
|
||||
predicate=lambda host: False
|
||||
)
|
||||
mock_cluster, hosts = (Mock(name='cluster'),
|
||||
['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1'])
|
||||
['acceptme0', 'acceptme1'])
|
||||
hfp.populate(mock_cluster, hosts)
|
||||
hfp._child_policy.populate.assert_called_once()
|
||||
self.assertEqual(
|
||||
|
Reference in New Issue
Block a user