Added tests for PYTHON-761

This commit is contained in:
bjmb
2017-07-13 13:28:49 -04:00
parent 852d39e6c0
commit 186e4092ce
3 changed files with 47 additions and 15 deletions

View File

@@ -504,20 +504,16 @@ class HostFilterPolicy(LoadBalancingPolicy):
self._predicate = predicate
def on_up(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_up(host, *args, **kwargs)
return self._child_policy.on_up(host, *args, **kwargs)
def on_down(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_down(host, *args, **kwargs)
return self._child_policy.on_down(host, *args, **kwargs)
def on_add(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_add(host, *args, **kwargs)
return self._child_policy.on_add(host, *args, **kwargs)
def on_remove(self, host, *args, **kwargs):
if self.predicate(host):
return self._child_policy.on_remove(host, *args, **kwargs)
return self._child_policy.on_remove(host, *args, **kwargs)
@property
def predicate(self):
@@ -547,7 +543,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
def populate(self, cluster, hosts):
self._child_policy.populate(
cluster=cluster,
hosts=[h for h in hosts if self.predicate(h)]
hosts=hosts
)
def make_query_plan(self, working_keyspace=None, query=None):

View File

@@ -20,12 +20,14 @@ except ImportError:
import unittest # noqa
from cassandra import OperationTimedOut
from cassandra.cluster import ExecutionProfile
from cassandra.cluster import ExecutionProfile, Cluster
from cassandra.query import SimpleStatement
from cassandra.policies import ConstantSpeculativeExecutionPolicy, RoundRobinPolicy
from cassandra.policies import ConstantSpeculativeExecutionPolicy, HostFilterPolicy, RoundRobinPolicy, \
SimpleConvictionPolicy
from cassandra.connection import Connection
from cassandra.pool import Host
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21, PROTOCOL_VERSION
from tests import notwindows
from mock import patch
@@ -47,6 +49,40 @@ class BadRoundRobinPolicy(RoundRobinPolicy):
return hosts
class HostFilterPolicyTests(unittest.TestCase):
def test_predicate_changes(self):
restrict = True
contact_point = "127.0.0.1"
single_host = {Host(contact_point, SimpleConvictionPolicy)}
all_hosts = {Host("127.0.0.{}".format(i), SimpleConvictionPolicy) for i in (1, 2, 3)}
predicate = lambda host: host.address == contact_point if restrict else True
cluster = Cluster((contact_point,), load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(),
predicate=predicate),
protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0,
status_event_refresh_window=0)
session = cluster.connect(wait_for_all_pools=True)
queried_hosts = set()
for _ in range(100):
response = session.execute("SELECT * from system.local")
queried_hosts.update(response.response_future.attempted_hosts)
self.assertEqual(queried_hosts, single_host)
restrict = False
session.update_created_pools()
queried_hosts = set()
for _ in range(100):
response = session.execute("SELECT * from system.local")
print(response.response_future.attempted_hosts)
queried_hosts.update(response.response_future.attempted_hosts)
self.assertEqual(queried_hosts, all_hosts)
# This doesn't work well with Windows clock granularity
@notwindows
class SpecExecTest(BasicSharedKeyspaceUnitTestCase):

View File

@@ -1402,13 +1402,13 @@ class HostFilterPolicyPopulateTest(unittest.TestCase):
hosts=hosts
)
def test_child_not_populated_with_filtered_hosts(self):
def test_child_is_populated_with_filtered_hosts(self):
hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'),
predicate=lambda host: 'acceptme' in host
predicate=lambda host: False
)
mock_cluster, hosts = (Mock(name='cluster'),
['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1'])
['acceptme0', 'acceptme1'])
hfp.populate(mock_cluster, hosts)
hfp._child_policy.populate.assert_called_once()
self.assertEqual(