From f04aeb35c3cc924c99489e3cf11f883a2eab60cd Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Wed, 3 Jul 2013 14:16:54 -0500 Subject: [PATCH] Add test coverage, fixes for TokenAwarePolicy --- cassandra/policies.py | 4 +- cassandra/pool.py | 3 +- cassandra/query.py | 14 ++++--- tests/unit/test_policies.py | 81 +++++++++++++++++++++++++++++++++++-- 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index 8bb3280f..247b0315 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -280,7 +280,7 @@ class TokenAwarePolicy(LoadBalancingPolicy): for host in child.make_query_plan(query): yield host else: - replicas = self.metadata.get_replicas(routing_key) + replicas = self._cluster_metadata.get_replicas(routing_key) for replica in replicas: if replica.monitor.is_up and \ child.distance(replica) == HostDistance.LOCAL: @@ -289,7 +289,7 @@ class TokenAwarePolicy(LoadBalancingPolicy): for host in child.make_query_plan(query): # skip if we've already listed this host if host not in replicas or \ - child.distance(replica) == HostDistance.REMOTE: + child.distance(host) == HostDistance.REMOTE: yield host def on_up(self, *args, **kwargs): diff --git a/cassandra/pool.py b/cassandra/pool.py index dc8cbe86..e8026088 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -92,7 +92,8 @@ class Host(object): return self.address def __repr__(self): - return "<%s: %s>" % (self.__class__.__name__, self.address) + dc = (" %s" % (self._datacenter,)) if self._datacenter else "" + return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc) class _ReconnectionHandler(object): diff --git a/cassandra/query.py b/cassandra/query.py index bd912f5b..aad68739 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -48,12 +48,12 @@ class Query(object): def _get_routing_key(self): return self._routing_key - def _set_routing_key(self, key_components): - if len(key_components) == 1: - self._routing_key = key_components[0] - else: + def _set_routing_key(self, key): + if isinstance(key, (list, tuple)): self._routing_key = "".join(struct.pack("HsB", len(component), component, 0) - for component in key_components) + for component in key) + else: + self._routing_key = key def _del_routing_key(self): self._routing_key = None @@ -66,7 +66,9 @@ class Query(object): The :attr:`~.TableMetadata.partition_key` portion of the primary key, which can be used to determine which nodes are replicas for the query. - When setting this attribute, a list or tuple *must* be used. + If the partition key is a composite, a list or tuple must be passed in. + Each key component should be in its packed (binary) format, so all + components should be strings. """) class SimpleStatement(Query): diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index a87f881f..46b16b62 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -1,12 +1,19 @@ -import unittest +from itertools import islice, cycle +from mock import Mock +import struct from threading import Thread +import unittest from cassandra import ConsistencyLevel +from cassandra.cluster import Cluster +from cassandra.metadata import Metadata from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, - SimpleConvictionPolicy, HostDistance, - ExponentialReconnectionPolicy, RetryPolicy, - WriteType, DowngradingConsistencyRetryPolicy) + TokenAwarePolicy, SimpleConvictionPolicy, + HostDistance, ExponentialReconnectionPolicy, + RetryPolicy, WriteType, + DowngradingConsistencyRetryPolicy) from cassandra.pool import Host +from cassandra.query import Query class TestRoundRobinPolicy(unittest.TestCase): @@ -163,6 +170,72 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase): self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) +class TokenAwarePolicyTest(unittest.TestCase): + + def test_wrap_round_robin(self): + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + + def get_replicas(packed_key): + index = struct.unpack('>i', packed_key)[0] + return list(islice(cycle(hosts), index, index + 2)) + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy(RoundRobinPolicy()) + policy.populate(cluster, hosts) + + for i in range(4): + query = Query(routing_key=struct.pack('>i', i)) + qplan = list(policy.make_query_plan(query)) + + replicas = get_replicas(struct.pack('>i', i)) + other = set(h for h in hosts if h not in replicas) + self.assertEquals(replicas, qplan[:2]) + self.assertEquals(other, set(qplan[2:])) + + def test_wrap_dc_aware(self): + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + def get_replicas(packed_key): + index = struct.unpack('>i', packed_key)[0] + # return one node from each DC + if index % 2 == 0: + return [hosts[0], hosts[2]] + else: + return [hosts[1], hosts[3]] + + cluster.metadata.get_replicas.side_effect = get_replicas + + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1)) + policy.populate(cluster, hosts) + + for i in range(4): + query = Query(routing_key=struct.pack('>i', i)) + qplan = list(policy.make_query_plan(query)) + replicas = get_replicas(struct.pack('>i', i)) + + # first should be the only local replica + self.assertIn(qplan[0], replicas) + self.assertEquals(qplan[0].datacenter, "dc1") + + # then the local non-replica + self.assertNotIn(qplan[1], replicas) + self.assertEquals(qplan[1].datacenter, "dc1") + + # then one of the remotes (used_hosts_per_remote_dc is 1, so we + # shouldn't see two remotes) + self.assertEquals(qplan[2].datacenter, "dc2") + self.assertEquals(3, len(qplan)) + + class ExponentialReconnectionPolicyTest(unittest.TestCase): def test_bad_vals(self):