Add test coverage, fixes for TokenAwarePolicy

This commit is contained in:
Tyler Hobbs
2013-07-03 14:16:54 -05:00
parent 08a2370157
commit f04aeb35c3
4 changed files with 89 additions and 13 deletions

View File

@@ -280,7 +280,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
for host in child.make_query_plan(query):
yield host
else:
replicas = self.metadata.get_replicas(routing_key)
replicas = self._cluster_metadata.get_replicas(routing_key)
for replica in replicas:
if replica.monitor.is_up and \
child.distance(replica) == HostDistance.LOCAL:
@@ -289,7 +289,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
for host in child.make_query_plan(query):
# skip if we've already listed this host
if host not in replicas or \
child.distance(replica) == HostDistance.REMOTE:
child.distance(host) == HostDistance.REMOTE:
yield host
def on_up(self, *args, **kwargs):

View File

@@ -92,7 +92,8 @@ class Host(object):
return self.address
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.address)
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc)
class _ReconnectionHandler(object):

View File

@@ -48,12 +48,12 @@ class Query(object):
def _get_routing_key(self):
return self._routing_key
def _set_routing_key(self, key_components):
if len(key_components) == 1:
self._routing_key = key_components[0]
else:
def _set_routing_key(self, key):
if isinstance(key, (list, tuple)):
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
for component in key_components)
for component in key)
else:
self._routing_key = key
def _del_routing_key(self):
self._routing_key = None
@@ -66,7 +66,9 @@ class Query(object):
The :attr:`~.TableMetadata.partition_key` portion of the primary key,
which can be used to determine which nodes are replicas for the query.
When setting this attribute, a list or tuple *must* be used.
If the partition key is a composite, a list or tuple must be passed in.
Each key component should be in its packed (binary) format, so all
components should be strings.
""")
class SimpleStatement(Query):

View File

@@ -1,12 +1,19 @@
import unittest
from itertools import islice, cycle
from mock import Mock
import struct
from threading import Thread
import unittest
from cassandra import ConsistencyLevel
from cassandra.cluster import Cluster
from cassandra.metadata import Metadata
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
SimpleConvictionPolicy, HostDistance,
ExponentialReconnectionPolicy, RetryPolicy,
WriteType, DowngradingConsistencyRetryPolicy)
TokenAwarePolicy, SimpleConvictionPolicy,
HostDistance, ExponentialReconnectionPolicy,
RetryPolicy, WriteType,
DowngradingConsistencyRetryPolicy)
from cassandra.pool import Host
from cassandra.query import Query
class TestRoundRobinPolicy(unittest.TestCase):
@@ -163,6 +170,72 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase):
self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE)
class TokenAwarePolicyTest(unittest.TestCase):
def test_wrap_round_robin(self):
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
def get_replicas(packed_key):
index = struct.unpack('>i', packed_key)[0]
return list(islice(cycle(hosts), index, index + 2))
cluster.metadata.get_replicas.side_effect = get_replicas
policy = TokenAwarePolicy(RoundRobinPolicy())
policy.populate(cluster, hosts)
for i in range(4):
query = Query(routing_key=struct.pack('>i', i))
qplan = list(policy.make_query_plan(query))
replicas = get_replicas(struct.pack('>i', i))
other = set(h for h in hosts if h not in replicas)
self.assertEquals(replicas, qplan[:2])
self.assertEquals(other, set(qplan[2:]))
def test_wrap_dc_aware(self):
cluster = Mock(spec=Cluster)
cluster.metadata = Mock(spec=Metadata)
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
for h in hosts[:2]:
h.set_location_info("dc1", "rack1")
for h in hosts[2:]:
h.set_location_info("dc2", "rack1")
def get_replicas(packed_key):
index = struct.unpack('>i', packed_key)[0]
# return one node from each DC
if index % 2 == 0:
return [hosts[0], hosts[2]]
else:
return [hosts[1], hosts[3]]
cluster.metadata.get_replicas.side_effect = get_replicas
policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1))
policy.populate(cluster, hosts)
for i in range(4):
query = Query(routing_key=struct.pack('>i', i))
qplan = list(policy.make_query_plan(query))
replicas = get_replicas(struct.pack('>i', i))
# first should be the only local replica
self.assertIn(qplan[0], replicas)
self.assertEquals(qplan[0].datacenter, "dc1")
# then the local non-replica
self.assertNotIn(qplan[1], replicas)
self.assertEquals(qplan[1].datacenter, "dc1")
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
# shouldn't see two remotes)
self.assertEquals(qplan[2].datacenter, "dc2")
self.assertEquals(3, len(qplan))
class ExponentialReconnectionPolicyTest(unittest.TestCase):
def test_bad_vals(self):