Merge pull request #778 from datastax/python-host-filter-policy_PYTHON-761
Python host filter policy PYTHON-761
This commit is contained in:
		@@ -4,6 +4,7 @@
 | 
			
		||||
Features
 | 
			
		||||
--------
 | 
			
		||||
* Add idle_heartbeat_timeout cluster option to tune how long to wait for heartbeat responses. (PYTHON-762)
 | 
			
		||||
* Add HostFilterPolicy (PYTHON-761)
 | 
			
		||||
 | 
			
		||||
Bug Fixes
 | 
			
		||||
---------
 | 
			
		||||
@@ -20,6 +21,10 @@ Other
 | 
			
		||||
* Bump Cython dependency version to 0.25.2 (PYTHON-754)
 | 
			
		||||
* Fix DeprecationWarning when using lz4 (PYTHON-769)
 | 
			
		||||
 | 
			
		||||
Other
 | 
			
		||||
-----
 | 
			
		||||
* Deprecate WhiteListRoundRobinPolicy (PYTHON-759)
 | 
			
		||||
 | 
			
		||||
3.10.0
 | 
			
		||||
======
 | 
			
		||||
May 24, 2017
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,7 @@ import logging
 | 
			
		||||
from random import randint, shuffle
 | 
			
		||||
from threading import Lock
 | 
			
		||||
import socket
 | 
			
		||||
from warnings import warn
 | 
			
		||||
 | 
			
		||||
from cassandra import ConsistencyLevel, OperationTimedOut
 | 
			
		||||
 | 
			
		||||
@@ -396,6 +397,10 @@ class TokenAwarePolicy(LoadBalancingPolicy):
 | 
			
		||||
 | 
			
		||||
class WhiteListRoundRobinPolicy(RoundRobinPolicy):
 | 
			
		||||
    """
 | 
			
		||||
    |wlrrp| **is deprecated. It will be removed in 4.0.** It can effectively be
 | 
			
		||||
    reimplemented using :class:`.HostFilterPolicy`. For more information, see
 | 
			
		||||
    PYTHON-758_.
 | 
			
		||||
 | 
			
		||||
    A subclass of :class:`.RoundRobinPolicy` which evenly
 | 
			
		||||
    distributes queries across all nodes in the cluster,
 | 
			
		||||
    regardless of what datacenter the nodes may be in, but
 | 
			
		||||
@@ -405,12 +410,25 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
 | 
			
		||||
    https://datastax-oss.atlassian.net/browse/JAVA-145
 | 
			
		||||
    Where connection errors occur when connection
 | 
			
		||||
    attempts are made to private IP addresses remotely
 | 
			
		||||
 | 
			
		||||
    .. |wlrrp| raw:: html
 | 
			
		||||
 | 
			
		||||
       <b><code>WhiteListRoundRobinPolicy</code></b>
 | 
			
		||||
 | 
			
		||||
    .. _PYTHON-758: https://datastax-oss.atlassian.net/browse/PYTHON-758
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, hosts):
 | 
			
		||||
        """
 | 
			
		||||
        The `hosts` parameter should be a sequence of hosts to permit
 | 
			
		||||
        connections to.
 | 
			
		||||
        """
 | 
			
		||||
        msg = ('WhiteListRoundRobinPolicy is deprecated. '
 | 
			
		||||
               'It will be removed in 4.0. '
 | 
			
		||||
               'It can effectively be reimplemented using HostFilterPolicy.')
 | 
			
		||||
        warn(msg, DeprecationWarning)
 | 
			
		||||
        # DeprecationWarnings are silent by default so we also log the message
 | 
			
		||||
        log.warning(msg)
 | 
			
		||||
 | 
			
		||||
        self._allowed_hosts = hosts
 | 
			
		||||
        self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts
 | 
			
		||||
@@ -441,6 +459,116 @@ class WhiteListRoundRobinPolicy(RoundRobinPolicy):
 | 
			
		||||
            RoundRobinPolicy.on_add(self, host)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HostFilterPolicy(LoadBalancingPolicy):
 | 
			
		||||
    """
 | 
			
		||||
    A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
 | 
			
		||||
    and a single-argument predicate. This policy defers to the child policy for
 | 
			
		||||
    hosts where ``predicate(host)`` is truthy. Hosts for which
 | 
			
		||||
    ``predicate(host)`` is falsey will be considered :attr:`.IGNORED`, and will
 | 
			
		||||
    not be used in a query plan.
 | 
			
		||||
 | 
			
		||||
    This can be used in the cases where you need a whitelist or blacklist
 | 
			
		||||
    policy, e.g. to prepare for decommissioning nodes or for testing:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        def address_is_ignored(host):
 | 
			
		||||
            return host.address in [ignored_address0, ignored_address1]
 | 
			
		||||
 | 
			
		||||
        blacklist_filter_policy = HostFilterPolicy(
 | 
			
		||||
            child_policy=RoundRobinPolicy(),
 | 
			
		||||
            predicate=address_is_ignored
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        cluster = Cluster(
 | 
			
		||||
            primary_host,
 | 
			
		||||
            load_balancing_policy=blacklist_filter_policy,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    Please note that whitelist and blacklist policies are not recommended for
 | 
			
		||||
    general, day-to-day use. You probably want something like
 | 
			
		||||
    :class:`.DCAwareRoundRobinPolicy`, which prefers a local DC but has
 | 
			
		||||
    fallbacks, over a brute-force method like whitelisting or blacklisting.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, child_policy, predicate):
 | 
			
		||||
        """
 | 
			
		||||
        :param child_policy: an instantiated :class:`.LoadBalancingPolicy`
 | 
			
		||||
                             that this one will defer to.
 | 
			
		||||
        :param predicate: a one-parameter function that takes a :class:`.Host`.
 | 
			
		||||
                          If it returns a falsey value, the :class:`.Host` will
 | 
			
		||||
                          be :attr:`.IGNORED` and not returned in query plans.
 | 
			
		||||
        """
 | 
			
		||||
        super(HostFilterPolicy, self).__init__()
 | 
			
		||||
        self._child_policy = child_policy
 | 
			
		||||
        self._predicate = predicate
 | 
			
		||||
 | 
			
		||||
    def on_up(self, host, *args, **kwargs):
 | 
			
		||||
        if self.predicate(host):
 | 
			
		||||
            return self._child_policy.on_up(host, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def on_down(self, host, *args, **kwargs):
 | 
			
		||||
        if self.predicate(host):
 | 
			
		||||
            return self._child_policy.on_down(host, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def on_add(self, host, *args, **kwargs):
 | 
			
		||||
        if self.predicate(host):
 | 
			
		||||
            return self._child_policy.on_add(host, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def on_remove(self, host, *args, **kwargs):
 | 
			
		||||
        if self.predicate(host):
 | 
			
		||||
            return self._child_policy.on_remove(host, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def predicate(self):
 | 
			
		||||
        """
 | 
			
		||||
        A predicate, set on object initialization, that takes a :class:`.Host`
 | 
			
		||||
        and returns a value. If the value is falsy, the :class:`.Host` is
 | 
			
		||||
        :class:`~HostDistance.IGNORED`. If the value is truthy,
 | 
			
		||||
        :class:`.HostFilterPolicy` defers to the child policy to determine the
 | 
			
		||||
        host's distance.
 | 
			
		||||
 | 
			
		||||
        This is a read-only value set in ``__init__``, implemented as a
 | 
			
		||||
        ``property``.
 | 
			
		||||
        """
 | 
			
		||||
        return self._predicate
 | 
			
		||||
 | 
			
		||||
    def distance(self, host):
 | 
			
		||||
        """
 | 
			
		||||
        Checks if ``predicate(host)``, then returns
 | 
			
		||||
        :attr:`~HostDistance.IGNORED` if falsey, and defers to the child policy
 | 
			
		||||
        otherwise.
 | 
			
		||||
        """
 | 
			
		||||
        if self.predicate(host):
 | 
			
		||||
            return self._child_policy.distance(host)
 | 
			
		||||
        else:
 | 
			
		||||
            return HostDistance.IGNORED
 | 
			
		||||
 | 
			
		||||
    def populate(self, cluster, hosts):
 | 
			
		||||
        self._child_policy.populate(
 | 
			
		||||
            cluster=cluster,
 | 
			
		||||
            hosts=[h for h in hosts if self.predicate(h)]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def make_query_plan(self, working_keyspace=None, query=None):
 | 
			
		||||
        """
 | 
			
		||||
        Defers to the child policy's
 | 
			
		||||
        :meth:`.LoadBalancingPolicy.make_query_plan`. Since host changes (up,
 | 
			
		||||
        down, addition, and removal) have not been propagated to the child
 | 
			
		||||
        policy, the child policy will only ever return policies for which
 | 
			
		||||
        :meth:`.predicate(host)` was truthy when that change occurred.
 | 
			
		||||
        """
 | 
			
		||||
        child_qp = self._child_policy.make_query_plan(
 | 
			
		||||
            working_keyspace=working_keyspace, query=query
 | 
			
		||||
        )
 | 
			
		||||
        for host in child_qp:
 | 
			
		||||
            if self.predicate(host):
 | 
			
		||||
                yield host
 | 
			
		||||
 | 
			
		||||
    def check_supported(self):
 | 
			
		||||
        return self._child_policy.check_supported()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConvictionPolicy(object):
 | 
			
		||||
    """
 | 
			
		||||
    A policy which decides when hosts should be considered down
 | 
			
		||||
@@ -619,6 +747,7 @@ class WriteType(object):
 | 
			
		||||
    A lighweight-transaction write, such as "DELETE ... IF EXISTS".
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
WriteType.name_to_value = {
 | 
			
		||||
    'SIMPLE': WriteType.SIMPLE,
 | 
			
		||||
    'BATCH': WriteType.BATCH,
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,14 @@ Load Balancing
 | 
			
		||||
.. autoclass:: TokenAwarePolicy
 | 
			
		||||
   :members:
 | 
			
		||||
 | 
			
		||||
.. autoclass:: HostFilterPolicy
 | 
			
		||||
 | 
			
		||||
   # we document these methods manually so we can specify a param to predicate
 | 
			
		||||
 | 
			
		||||
   .. automethod:: predicate(host)
 | 
			
		||||
   .. automethod:: distance
 | 
			
		||||
   .. automethod:: make_query_plan
 | 
			
		||||
 | 
			
		||||
Translating Server Node Addresses
 | 
			
		||||
---------------------------------
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -12,15 +12,19 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import struct, time, logging, sys, traceback
 | 
			
		||||
import logging
 | 
			
		||||
import struct
 | 
			
		||||
import sys
 | 
			
		||||
import traceback
 | 
			
		||||
 | 
			
		||||
from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \
 | 
			
		||||
    WriteTimeout, WriteFailure
 | 
			
		||||
from cassandra.cluster import Cluster, NoHostAvailable, ExecutionProfile
 | 
			
		||||
from cassandra.cluster import Cluster, NoHostAvailable
 | 
			
		||||
from cassandra.concurrent import execute_concurrent_with_args
 | 
			
		||||
from cassandra.metadata import murmur3
 | 
			
		||||
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
 | 
			
		||||
                                TokenAwarePolicy, WhiteListRoundRobinPolicy)
 | 
			
		||||
                                TokenAwarePolicy, WhiteListRoundRobinPolicy,
 | 
			
		||||
                                HostFilterPolicy)
 | 
			
		||||
from cassandra.query import SimpleStatement
 | 
			
		||||
 | 
			
		||||
from tests.integration import use_singledc, use_multidc, remove_cluster, PROTOCOL_VERSION
 | 
			
		||||
@@ -105,7 +109,7 @@ class LoadBalancingPolicyTests(unittest.TestCase):
 | 
			
		||||
            query_string = 'SELECT * FROM %s.cf WHERE k = ?' % keyspace
 | 
			
		||||
            if not self.prepared or self.prepared.query_string != query_string:
 | 
			
		||||
                self.prepared = session.prepare(query_string)
 | 
			
		||||
                self.prepared.consistency_level=consistency_level
 | 
			
		||||
                self.prepared.consistency_level = consistency_level
 | 
			
		||||
            for i in range(count):
 | 
			
		||||
                tries = 0
 | 
			
		||||
                while True:
 | 
			
		||||
@@ -508,7 +512,7 @@ class LoadBalancingPolicyTests(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.coordinator_stats.reset_counts()
 | 
			
		||||
        stop(2)
 | 
			
		||||
        self._wait_for_nodes_down([2],cluster)
 | 
			
		||||
        self._wait_for_nodes_down([2], cluster)
 | 
			
		||||
 | 
			
		||||
        self._query(session, keyspace)
 | 
			
		||||
 | 
			
		||||
@@ -662,3 +666,37 @@ class LoadBalancingPolicyTests(unittest.TestCase):
 | 
			
		||||
            pass
 | 
			
		||||
        finally:
 | 
			
		||||
            cluster.shutdown()
 | 
			
		||||
 | 
			
		||||
    def test_black_list_with_host_filter_policy(self):
 | 
			
		||||
        use_singledc()
 | 
			
		||||
        keyspace = 'test_black_list_with_hfp'
 | 
			
		||||
        ignored_address = (IP_FORMAT % 2)
 | 
			
		||||
        hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=RoundRobinPolicy(),
 | 
			
		||||
            predicate=lambda host: host.address != ignored_address
 | 
			
		||||
        )
 | 
			
		||||
        cluster = Cluster(
 | 
			
		||||
            (IP_FORMAT % 1,),
 | 
			
		||||
            load_balancing_policy=hfp,
 | 
			
		||||
            protocol_version=PROTOCOL_VERSION,
 | 
			
		||||
            topology_event_refresh_window=0,
 | 
			
		||||
            status_event_refresh_window=0
 | 
			
		||||
        )
 | 
			
		||||
        self.addCleanup(cluster.shutdown)
 | 
			
		||||
        session = cluster.connect()
 | 
			
		||||
        self._wait_for_nodes_up([1, 2, 3])
 | 
			
		||||
 | 
			
		||||
        self.assertNotIn(ignored_address, [h.address for h in hfp.make_query_plan()])
 | 
			
		||||
 | 
			
		||||
        create_schema(cluster, session, keyspace)
 | 
			
		||||
        self._insert(session, keyspace)
 | 
			
		||||
        self._query(session, keyspace)
 | 
			
		||||
 | 
			
		||||
        self.coordinator_stats.assert_query_count_equals(self, 1, 6)
 | 
			
		||||
        self.coordinator_stats.assert_query_count_equals(self, 2, 0)
 | 
			
		||||
        self.coordinator_stats.assert_query_count_equals(self, 3, 6)
 | 
			
		||||
 | 
			
		||||
        # policy should not allow reconnecting to ignored host
 | 
			
		||||
        force_stop(2)
 | 
			
		||||
        self._wait_for_nodes_down([2])
 | 
			
		||||
        self.assertFalse(cluster.metadata._hosts[ignored_address].is_currently_reconnecting())
 | 
			
		||||
 
 | 
			
		||||
@@ -18,9 +18,10 @@ except ImportError:
 | 
			
		||||
    import unittest  # noqa
 | 
			
		||||
 | 
			
		||||
from itertools import islice, cycle
 | 
			
		||||
from mock import Mock, patch
 | 
			
		||||
from mock import Mock, patch, call
 | 
			
		||||
from random import randint
 | 
			
		||||
import six
 | 
			
		||||
from six.moves._thread import LockType
 | 
			
		||||
import sys
 | 
			
		||||
import struct
 | 
			
		||||
from threading import Thread
 | 
			
		||||
@@ -34,7 +35,7 @@ from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCA
 | 
			
		||||
                                RetryPolicy, WriteType,
 | 
			
		||||
                                DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy,
 | 
			
		||||
                                LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy,
 | 
			
		||||
                                IdentityTranslator, EC2MultiRegionTranslator)
 | 
			
		||||
                                IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy)
 | 
			
		||||
from cassandra.pool import Host
 | 
			
		||||
from cassandra.query import Statement
 | 
			
		||||
 | 
			
		||||
@@ -421,7 +422,6 @@ class DCAwareRoundRobinPolicyTest(unittest.TestCase):
 | 
			
		||||
        policy.on_up(hosts[2])
 | 
			
		||||
        policy.on_up(hosts[3])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        another_host = Host(5, SimpleConvictionPolicy)
 | 
			
		||||
        another_host.set_location_info("dc3", "rack1")
 | 
			
		||||
        new_host.set_location_info("dc3", "rack1")
 | 
			
		||||
@@ -884,7 +884,7 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase):
 | 
			
		||||
        self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0)
 | 
			
		||||
        self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1)
 | 
			
		||||
        self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1)
 | 
			
		||||
        self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2,-1)
 | 
			
		||||
        self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2, -1)
 | 
			
		||||
 | 
			
		||||
    def test_schedule_no_max(self):
 | 
			
		||||
        base_delay = 2.0
 | 
			
		||||
@@ -1232,11 +1232,27 @@ class WhiteListRoundRobinPolicyTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(policy.distance(host), HostDistance.LOCAL)
 | 
			
		||||
 | 
			
		||||
    def test_deprecated(self):
 | 
			
		||||
        import warnings
 | 
			
		||||
 | 
			
		||||
        warnings.resetwarnings()  # in case we've instantiated one before
 | 
			
		||||
 | 
			
		||||
        # set up warning filters to allow all, set up restore when this test is done
 | 
			
		||||
        filters_backup, warnings.filters = warnings.filters, []
 | 
			
		||||
        self.addCleanup(setattr, warnings, 'filters', filters_backup)
 | 
			
		||||
 | 
			
		||||
        with warnings.catch_warnings(record=True) as caught_warnings:
 | 
			
		||||
            WhiteListRoundRobinPolicy([])
 | 
			
		||||
            self.assertEqual(len(caught_warnings), 1)
 | 
			
		||||
            warning_message = caught_warnings[-1]
 | 
			
		||||
            self.assertEqual(warning_message.category, DeprecationWarning)
 | 
			
		||||
            self.assertIn('4.0', warning_message.message.args[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AddressTranslatorTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_identity_translator(self):
 | 
			
		||||
        it = IdentityTranslator()
 | 
			
		||||
        addr = '127.0.0.1'
 | 
			
		||||
        IdentityTranslator()
 | 
			
		||||
 | 
			
		||||
    @patch('socket.getfqdn', return_value='localhost')
 | 
			
		||||
    def test_ec2_multi_region_translator(self, *_):
 | 
			
		||||
@@ -1245,3 +1261,181 @@ class AddressTranslatorTest(unittest.TestCase):
 | 
			
		||||
        translated = ec2t.translate(addr)
 | 
			
		||||
        self.assertIsNot(translated, addr)  # verifies that the resolver path is followed
 | 
			
		||||
        self.assertEqual(translated, addr)  # and that it resolves to the same address
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HostFilterPolicyInitTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.child_policy, self.predicate = (Mock(name='child_policy'),
 | 
			
		||||
                                             Mock(name='predicate'))
 | 
			
		||||
 | 
			
		||||
    def _check_init(self, hfp):
 | 
			
		||||
        self.assertIs(hfp._child_policy, self.child_policy)
 | 
			
		||||
        self.assertIsInstance(hfp._hosts_lock, LockType)
 | 
			
		||||
 | 
			
		||||
        # we can't use a simple assertIs because we wrap the function
 | 
			
		||||
        arg0, arg1 = Mock(name='arg0'), Mock(name='arg1')
 | 
			
		||||
        hfp.predicate(arg0)
 | 
			
		||||
        hfp.predicate(arg1)
 | 
			
		||||
        self.predicate.assert_has_calls([call(arg0), call(arg1)])
 | 
			
		||||
 | 
			
		||||
    def test_init_arg_order(self):
 | 
			
		||||
        self._check_init(HostFilterPolicy(self.child_policy, self.predicate))
 | 
			
		||||
 | 
			
		||||
    def test_init_kwargs(self):
 | 
			
		||||
        self._check_init(HostFilterPolicy(
 | 
			
		||||
            predicate=self.predicate, child_policy=self.child_policy
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    def test_immutable_predicate(self):
 | 
			
		||||
        expected_message_regex = "can't set attribute"
 | 
			
		||||
        hfp = HostFilterPolicy(child_policy=Mock(name='child_policy'),
 | 
			
		||||
                               predicate=Mock(name='predicate'))
 | 
			
		||||
        with self.assertRaisesRegexp(AttributeError, expected_message_regex):
 | 
			
		||||
            hfp.predicate = object()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HostFilterPolicyDeferralTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.passthrough_hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=Mock(name='child_policy'),
 | 
			
		||||
            predicate=Mock(name='passthrough_predicate',
 | 
			
		||||
                           return_value=True)
 | 
			
		||||
        )
 | 
			
		||||
        self.filterall_hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=Mock(name='child_policy'),
 | 
			
		||||
            predicate=Mock(name='filterall_predicate',
 | 
			
		||||
                           return_value=False)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_host_triggered_method(self, policy, name):
 | 
			
		||||
        arg, kwarg = Mock(name='arg'), Mock(name='kwarg')
 | 
			
		||||
        expect_deferral = policy is self.passthrough_hfp
 | 
			
		||||
        method, child_policy_method = (getattr(policy, name),
 | 
			
		||||
                                       getattr(policy._child_policy, name))
 | 
			
		||||
 | 
			
		||||
        result = method(arg, kw=kwarg)
 | 
			
		||||
 | 
			
		||||
        if expect_deferral:
 | 
			
		||||
            # method calls the child policy's method...
 | 
			
		||||
            child_policy_method.assert_called_once_with(arg, kw=kwarg)
 | 
			
		||||
            # and returns its return value
 | 
			
		||||
            self.assertIs(result, child_policy_method.return_value)
 | 
			
		||||
        else:
 | 
			
		||||
            child_policy_method.assert_not_called()
 | 
			
		||||
 | 
			
		||||
    def test_defer_on_up_to_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.passthrough_hfp, 'on_up')
 | 
			
		||||
 | 
			
		||||
    def test_defer_on_down_to_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.passthrough_hfp, 'on_down')
 | 
			
		||||
 | 
			
		||||
    def test_defer_on_add_to_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.passthrough_hfp, 'on_add')
 | 
			
		||||
 | 
			
		||||
    def test_defer_on_remove_to_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.passthrough_hfp, 'on_remove')
 | 
			
		||||
 | 
			
		||||
    def test_filtered_host_on_up_doesnt_call_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.filterall_hfp, 'on_up')
 | 
			
		||||
 | 
			
		||||
    def test_filtered_host_on_down_doesnt_call_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.filterall_hfp, 'on_down')
 | 
			
		||||
 | 
			
		||||
    def test_filtered_host_on_add_doesnt_call_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.filterall_hfp, 'on_add')
 | 
			
		||||
 | 
			
		||||
    def test_filtered_host_on_remove_doesnt_call_child_policy(self):
 | 
			
		||||
        self._check_host_triggered_method(self.filterall_hfp, 'on_remove')
 | 
			
		||||
 | 
			
		||||
    def _check_check_supported_deferral(self, policy):
 | 
			
		||||
        policy.check_supported()
 | 
			
		||||
        policy._child_policy.check_supported.assert_called_once()
 | 
			
		||||
 | 
			
		||||
    def test_check_supported_defers_to_child(self):
 | 
			
		||||
        self._check_check_supported_deferral(self.passthrough_hfp)
 | 
			
		||||
 | 
			
		||||
    def test_check_supported_defers_to_child_when_predicate_filtered(self):
 | 
			
		||||
        self._check_check_supported_deferral(self.filterall_hfp)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HostFilterPolicyDistanceTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=Mock(name='child_policy', distance=Mock(name='distance')),
 | 
			
		||||
            predicate=lambda host: host.address == 'acceptme'
 | 
			
		||||
        )
 | 
			
		||||
        self.ignored_host = Host(inet_address='ignoreme', conviction_policy_factory=Mock())
 | 
			
		||||
        self.accepted_host = Host(inet_address='acceptme', conviction_policy_factory=Mock())
 | 
			
		||||
 | 
			
		||||
    def test_ignored_with_filter(self):
 | 
			
		||||
        self.assertEqual(self.hfp.distance(self.ignored_host),
 | 
			
		||||
                         HostDistance.IGNORED)
 | 
			
		||||
        self.assertNotEqual(self.hfp.distance(self.accepted_host),
 | 
			
		||||
                            HostDistance.IGNORED)
 | 
			
		||||
 | 
			
		||||
    def test_accepted_filter_defers_to_child_policy(self):
 | 
			
		||||
        self.hfp._child_policy.distance.side_effect = distances = Mock(), Mock()
 | 
			
		||||
 | 
			
		||||
        # getting the distance for an ignored host shouldn't affect subsequent results
 | 
			
		||||
        self.hfp.distance(self.ignored_host)
 | 
			
		||||
        # first call of _child_policy with count() side effect
 | 
			
		||||
        self.assertEqual(self.hfp.distance(self.accepted_host), distances[0])
 | 
			
		||||
        # second call of _child_policy with count() side effect
 | 
			
		||||
        self.assertEqual(self.hfp.distance(self.accepted_host), distances[1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HostFilterPolicyPopulateTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_populate_deferred_to_child(self):
 | 
			
		||||
        hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=Mock(name='child_policy'),
 | 
			
		||||
            predicate=lambda host: True
 | 
			
		||||
        )
 | 
			
		||||
        mock_cluster, hosts = (Mock(name='cluster'),
 | 
			
		||||
                               ['host1', 'host2', 'host3'])
 | 
			
		||||
        hfp.populate(mock_cluster, hosts)
 | 
			
		||||
        hfp._child_policy.populate.assert_called_once_with(
 | 
			
		||||
            cluster=mock_cluster,
 | 
			
		||||
            hosts=hosts
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_child_not_populated_with_filtered_hosts(self):
 | 
			
		||||
        hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=Mock(name='child_policy'),
 | 
			
		||||
            predicate=lambda host: 'acceptme' in host
 | 
			
		||||
        )
 | 
			
		||||
        mock_cluster, hosts = (Mock(name='cluster'),
 | 
			
		||||
                               ['acceptme0', 'ignoreme0', 'ignoreme1', 'acceptme1'])
 | 
			
		||||
        hfp.populate(mock_cluster, hosts)
 | 
			
		||||
        hfp._child_policy.populate.assert_called_once()
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            hfp._child_policy.populate.call_args[1]['hosts'],
 | 
			
		||||
            ['acceptme0', 'acceptme1']
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HostFilterPolicyQueryPlanTest(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_query_plan_deferred_to_child(self):
 | 
			
		||||
        child_policy = Mock(
 | 
			
		||||
            name='child_policy',
 | 
			
		||||
            make_query_plan=Mock(
 | 
			
		||||
                return_value=[object(), object(), object()]
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        hfp = HostFilterPolicy(
 | 
			
		||||
            child_policy=child_policy,
 | 
			
		||||
            predicate=lambda host: True
 | 
			
		||||
        )
 | 
			
		||||
        working_keyspace, query = (Mock(name='working_keyspace'),
 | 
			
		||||
                                   Mock(name='query'))
 | 
			
		||||
        qp = list(hfp.make_query_plan(working_keyspace=working_keyspace,
 | 
			
		||||
                                      query=query))
 | 
			
		||||
        hfp._child_policy.make_query_plan.assert_called_once_with(
 | 
			
		||||
            working_keyspace=working_keyspace,
 | 
			
		||||
            query=query
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(qp, hfp._child_policy.make_query_plan.return_value)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user