Added tests for PYTHON-761

This commit is contained in:
bjmb
2017-07-13 13:28:49 -04:00
parent 852d39e6c0
commit 186e4092ce
3 changed files with 47 additions and 15 deletions

View File

@@ -504,20 +504,16 @@ class HostFilterPolicy(LoadBalancingPolicy):
self._predicate = predicate self._predicate = predicate
def on_up(self, host, *args, **kwargs): def on_up(self, host, *args, **kwargs):
if self.predicate(host): return self._child_policy.on_up(host, *args, **kwargs)
return self._child_policy.on_up(host, *args, **kwargs)
def on_down(self, host, *args, **kwargs): def on_down(self, host, *args, **kwargs):
if self.predicate(host): return self._child_policy.on_down(host, *args, **kwargs)
return self._child_policy.on_down(host, *args, **kwargs)
def on_add(self, host, *args, **kwargs): def on_add(self, host, *args, **kwargs):
if self.predicate(host): return self._child_policy.on_add(host, *args, **kwargs)
return self._child_policy.on_add(host, *args, **kwargs)
def on_remove(self, host, *args, **kwargs): def on_remove(self, host, *args, **kwargs):
if self.predicate(host): return self._child_policy.on_remove(host, *args, **kwargs)
return self._child_policy.on_remove(host, *args, **kwargs)
@property @property
def predicate(self): def predicate(self):
@@ -547,7 +543,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
def populate(self, cluster, hosts): def populate(self, cluster, hosts):
self._child_policy.populate( self._child_policy.populate(
cluster=cluster, cluster=cluster,
hosts=[h for h in hosts if self.predicate(h)] hosts=hosts
) )
def make_query_plan(self, working_keyspace=None, query=None): def make_query_plan(self, working_keyspace=None, query=None):

View File

@@ -20,12 +20,14 @@ except ImportError:
import unittest # noqa import unittest # noqa
from cassandra import OperationTimedOut from cassandra import OperationTimedOut
from cassandra.cluster import ExecutionProfile from cassandra.cluster import ExecutionProfile, Cluster
from cassandra.query import SimpleStatement from cassandra.query import SimpleStatement
from cassandra.policies import ConstantSpeculativeExecutionPolicy, RoundRobinPolicy from cassandra.policies import ConstantSpeculativeExecutionPolicy, HostFilterPolicy, RoundRobinPolicy, \
SimpleConvictionPolicy
from cassandra.connection import Connection from cassandra.connection import Connection
from cassandra.pool import Host
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21 from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21, PROTOCOL_VERSION
from tests import notwindows from tests import notwindows
from mock import patch from mock import patch
@@ -47,6 +49,40 @@ class BadRoundRobinPolicy(RoundRobinPolicy):
return hosts return hosts
class HostFilterPolicyTests(unittest.TestCase):
def test_predicate_changes(self):
restrict = True
contact_point = "127.0.0.1"
single_host = {Host(contact_point, SimpleConvictionPolicy)}
all_hosts = {Host("127.0.0.{}".format(i), SimpleConvictionPolicy) for i in (1, 2, 3)}
predicate = lambda host: host.address == contact_point if restrict else True
cluster = Cluster((contact_point,), load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(),
predicate=predicate),
protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0,
status_event_refresh_window=0)
session = cluster.connect(wait_for_all_pools=True)
queried_hosts = set()
for _ in range(100):
response = session.execute("SELECT * from system.local")
queried_hosts.update(response.response_future.attempted_hosts)
self.assertEqual(queried_hosts, single_host)
restrict = False
session.update_created_pools()
queried_hosts = set()
for _ in range(100):
response = session.execute("SELECT * from system.local")
print(response.response_future.attempted_hosts)
queried_hosts.update(response.response_future.attempted_hosts)
self.assertEqual(queried_hosts, all_hosts)
# This doesn't work well with Windows clock granularity # This doesn't work well with Windows clock granularity
@notwindows @notwindows
class SpecExecTest(BasicSharedKeyspaceUnitTestCase): class SpecExecTest(BasicSharedKeyspaceUnitTestCase):

View File

@@ -1402,13 +1402,13 @@ class HostFilterPolicyPopulateTest(unittest.TestCase):
hosts=hosts hosts=hosts
) )
def test_child_not_populated_with_filtered_hosts(self): def test_child_is_populated_with_filtered_hosts(self):
hfp = HostFilterPolicy( hfp = HostFilterPolicy(
child_policy=Mock(name='child_policy'), child_policy=Mock(name='child_policy'),
predicate=lambda host: 'acceptme' in host predicate=lambda host: False
) )
mock_cluster, hosts = (Mock(name='cluster'), mock_cluster, hosts = (Mock(name='cluster'),
['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1']) ['acceptme0', 'acceptme1'])
hfp.populate(mock_cluster, hosts) hfp.populate(mock_cluster, hosts)
hfp._child_policy.populate.assert_called_once() hfp._child_policy.populate.assert_called_once()
self.assertEqual( self.assertEqual(