Added tests for PYTHON-761
This commit is contained in:
@@ -504,20 +504,16 @@ class HostFilterPolicy(LoadBalancingPolicy):
|
|||||||
self._predicate = predicate
|
self._predicate = predicate
|
||||||
|
|
||||||
def on_up(self, host, *args, **kwargs):
|
def on_up(self, host, *args, **kwargs):
|
||||||
if self.predicate(host):
|
return self._child_policy.on_up(host, *args, **kwargs)
|
||||||
return self._child_policy.on_up(host, *args, **kwargs)
|
|
||||||
|
|
||||||
def on_down(self, host, *args, **kwargs):
|
def on_down(self, host, *args, **kwargs):
|
||||||
if self.predicate(host):
|
return self._child_policy.on_down(host, *args, **kwargs)
|
||||||
return self._child_policy.on_down(host, *args, **kwargs)
|
|
||||||
|
|
||||||
def on_add(self, host, *args, **kwargs):
|
def on_add(self, host, *args, **kwargs):
|
||||||
if self.predicate(host):
|
return self._child_policy.on_add(host, *args, **kwargs)
|
||||||
return self._child_policy.on_add(host, *args, **kwargs)
|
|
||||||
|
|
||||||
def on_remove(self, host, *args, **kwargs):
|
def on_remove(self, host, *args, **kwargs):
|
||||||
if self.predicate(host):
|
return self._child_policy.on_remove(host, *args, **kwargs)
|
||||||
return self._child_policy.on_remove(host, *args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def predicate(self):
|
def predicate(self):
|
||||||
@@ -547,7 +543,7 @@ class HostFilterPolicy(LoadBalancingPolicy):
|
|||||||
def populate(self, cluster, hosts):
|
def populate(self, cluster, hosts):
|
||||||
self._child_policy.populate(
|
self._child_policy.populate(
|
||||||
cluster=cluster,
|
cluster=cluster,
|
||||||
hosts=[h for h in hosts if self.predicate(h)]
|
hosts=hosts
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_query_plan(self, working_keyspace=None, query=None):
|
def make_query_plan(self, working_keyspace=None, query=None):
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ except ImportError:
|
|||||||
import unittest # noqa
|
import unittest # noqa
|
||||||
|
|
||||||
from cassandra import OperationTimedOut
|
from cassandra import OperationTimedOut
|
||||||
from cassandra.cluster import ExecutionProfile
|
from cassandra.cluster import ExecutionProfile, Cluster
|
||||||
from cassandra.query import SimpleStatement
|
from cassandra.query import SimpleStatement
|
||||||
from cassandra.policies import ConstantSpeculativeExecutionPolicy, RoundRobinPolicy
|
from cassandra.policies import ConstantSpeculativeExecutionPolicy, HostFilterPolicy, RoundRobinPolicy, \
|
||||||
|
SimpleConvictionPolicy
|
||||||
from cassandra.connection import Connection
|
from cassandra.connection import Connection
|
||||||
|
from cassandra.pool import Host
|
||||||
|
|
||||||
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21
|
from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthancass21, PROTOCOL_VERSION
|
||||||
from tests import notwindows
|
from tests import notwindows
|
||||||
|
|
||||||
from mock import patch
|
from mock import patch
|
||||||
@@ -47,6 +49,40 @@ class BadRoundRobinPolicy(RoundRobinPolicy):
|
|||||||
return hosts
|
return hosts
|
||||||
|
|
||||||
|
|
||||||
|
class HostFilterPolicyTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_predicate_changes(self):
|
||||||
|
restrict = True
|
||||||
|
contact_point = "127.0.0.1"
|
||||||
|
|
||||||
|
single_host = {Host(contact_point, SimpleConvictionPolicy)}
|
||||||
|
all_hosts = {Host("127.0.0.{}".format(i), SimpleConvictionPolicy) for i in (1, 2, 3)}
|
||||||
|
|
||||||
|
predicate = lambda host: host.address == contact_point if restrict else True
|
||||||
|
cluster = Cluster((contact_point,), load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(),
|
||||||
|
predicate=predicate),
|
||||||
|
protocol_version=PROTOCOL_VERSION, topology_event_refresh_window=0,
|
||||||
|
status_event_refresh_window=0)
|
||||||
|
session = cluster.connect(wait_for_all_pools=True)
|
||||||
|
|
||||||
|
queried_hosts = set()
|
||||||
|
for _ in range(100):
|
||||||
|
response = session.execute("SELECT * from system.local")
|
||||||
|
queried_hosts.update(response.response_future.attempted_hosts)
|
||||||
|
|
||||||
|
self.assertEqual(queried_hosts, single_host)
|
||||||
|
|
||||||
|
restrict = False
|
||||||
|
session.update_created_pools()
|
||||||
|
|
||||||
|
queried_hosts = set()
|
||||||
|
for _ in range(100):
|
||||||
|
response = session.execute("SELECT * from system.local")
|
||||||
|
print(response.response_future.attempted_hosts)
|
||||||
|
queried_hosts.update(response.response_future.attempted_hosts)
|
||||||
|
self.assertEqual(queried_hosts, all_hosts)
|
||||||
|
|
||||||
|
|
||||||
# This doesn't work well with Windows clock granularity
|
# This doesn't work well with Windows clock granularity
|
||||||
@notwindows
|
@notwindows
|
||||||
class SpecExecTest(BasicSharedKeyspaceUnitTestCase):
|
class SpecExecTest(BasicSharedKeyspaceUnitTestCase):
|
||||||
|
|||||||
@@ -1402,13 +1402,13 @@ class HostFilterPolicyPopulateTest(unittest.TestCase):
|
|||||||
hosts=hosts
|
hosts=hosts
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_child_not_populated_with_filtered_hosts(self):
|
def test_child_is_populated_with_filtered_hosts(self):
|
||||||
hfp = HostFilterPolicy(
|
hfp = HostFilterPolicy(
|
||||||
child_policy=Mock(name='child_policy'),
|
child_policy=Mock(name='child_policy'),
|
||||||
predicate=lambda host: 'acceptme' in host
|
predicate=lambda host: False
|
||||||
)
|
)
|
||||||
mock_cluster, hosts = (Mock(name='cluster'),
|
mock_cluster, hosts = (Mock(name='cluster'),
|
||||||
['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1'])
|
['acceptme0', 'acceptme1'])
|
||||||
hfp.populate(mock_cluster, hosts)
|
hfp.populate(mock_cluster, hosts)
|
||||||
hfp._child_policy.populate.assert_called_once()
|
hfp._child_policy.populate.assert_called_once()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|||||||
Reference in New Issue
Block a user