Add test coverage, fixes for TokenAwarePolicy
This commit is contained in:
@@ -280,7 +280,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
|
||||
for host in child.make_query_plan(query):
|
||||
yield host
|
||||
else:
|
||||
replicas = self.metadata.get_replicas(routing_key)
|
||||
replicas = self._cluster_metadata.get_replicas(routing_key)
|
||||
for replica in replicas:
|
||||
if replica.monitor.is_up and \
|
||||
child.distance(replica) == HostDistance.LOCAL:
|
||||
@@ -289,7 +289,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
|
||||
for host in child.make_query_plan(query):
|
||||
# skip if we've already listed this host
|
||||
if host not in replicas or \
|
||||
child.distance(replica) == HostDistance.REMOTE:
|
||||
child.distance(host) == HostDistance.REMOTE:
|
||||
yield host
|
||||
|
||||
def on_up(self, *args, **kwargs):
|
||||
|
@@ -92,7 +92,8 @@ class Host(object):
|
||||
return self.address
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s>" % (self.__class__.__name__, self.address)
|
||||
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
|
||||
return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc)
|
||||
|
||||
|
||||
class _ReconnectionHandler(object):
|
||||
|
@@ -48,12 +48,12 @@ class Query(object):
|
||||
def _get_routing_key(self):
|
||||
return self._routing_key
|
||||
|
||||
def _set_routing_key(self, key_components):
|
||||
if len(key_components) == 1:
|
||||
self._routing_key = key_components[0]
|
||||
else:
|
||||
def _set_routing_key(self, key):
|
||||
if isinstance(key, (list, tuple)):
|
||||
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
|
||||
for component in key_components)
|
||||
for component in key)
|
||||
else:
|
||||
self._routing_key = key
|
||||
|
||||
def _del_routing_key(self):
|
||||
self._routing_key = None
|
||||
@@ -66,7 +66,9 @@ class Query(object):
|
||||
The :attr:`~.TableMetadata.partition_key` portion of the primary key,
|
||||
which can be used to determine which nodes are replicas for the query.
|
||||
|
||||
When setting this attribute, a list or tuple *must* be used.
|
||||
If the partition key is a composite, a list or tuple must be passed in.
|
||||
Each key component should be in its packed (binary) format, so all
|
||||
components should be strings.
|
||||
""")
|
||||
|
||||
class SimpleStatement(Query):
|
||||
|
@@ -1,12 +1,19 @@
|
||||
import unittest
|
||||
from itertools import islice, cycle
|
||||
from mock import Mock
|
||||
import struct
|
||||
from threading import Thread
|
||||
import unittest
|
||||
|
||||
from cassandra import ConsistencyLevel
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.metadata import Metadata
|
||||
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
||||
SimpleConvictionPolicy, HostDistance,
|
||||
ExponentialReconnectionPolicy, RetryPolicy,
|
||||
WriteType, DowngradingConsistencyRetryPolicy)
|
||||
TokenAwarePolicy, SimpleConvictionPolicy,
|
||||
HostDistance, ExponentialReconnectionPolicy,
|
||||
RetryPolicy, WriteType,
|
||||
DowngradingConsistencyRetryPolicy)
|
||||
from cassandra.pool import Host
|
||||
from cassandra.query import Query
|
||||
|
||||
class TestRoundRobinPolicy(unittest.TestCase):
|
||||
|
||||
@@ -163,6 +170,72 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase):
|
||||
self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE)
|
||||
|
||||
|
||||
class TokenAwarePolicyTest(unittest.TestCase):
|
||||
|
||||
def test_wrap_round_robin(self):
|
||||
cluster = Mock(spec=Cluster)
|
||||
cluster.metadata = Mock(spec=Metadata)
|
||||
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
||||
|
||||
def get_replicas(packed_key):
|
||||
index = struct.unpack('>i', packed_key)[0]
|
||||
return list(islice(cycle(hosts), index, index + 2))
|
||||
|
||||
cluster.metadata.get_replicas.side_effect = get_replicas
|
||||
|
||||
policy = TokenAwarePolicy(RoundRobinPolicy())
|
||||
policy.populate(cluster, hosts)
|
||||
|
||||
for i in range(4):
|
||||
query = Query(routing_key=struct.pack('>i', i))
|
||||
qplan = list(policy.make_query_plan(query))
|
||||
|
||||
replicas = get_replicas(struct.pack('>i', i))
|
||||
other = set(h for h in hosts if h not in replicas)
|
||||
self.assertEquals(replicas, qplan[:2])
|
||||
self.assertEquals(other, set(qplan[2:]))
|
||||
|
||||
def test_wrap_dc_aware(self):
|
||||
cluster = Mock(spec=Cluster)
|
||||
cluster.metadata = Mock(spec=Metadata)
|
||||
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
||||
for h in hosts[:2]:
|
||||
h.set_location_info("dc1", "rack1")
|
||||
for h in hosts[2:]:
|
||||
h.set_location_info("dc2", "rack1")
|
||||
|
||||
def get_replicas(packed_key):
|
||||
index = struct.unpack('>i', packed_key)[0]
|
||||
# return one node from each DC
|
||||
if index % 2 == 0:
|
||||
return [hosts[0], hosts[2]]
|
||||
else:
|
||||
return [hosts[1], hosts[3]]
|
||||
|
||||
cluster.metadata.get_replicas.side_effect = get_replicas
|
||||
|
||||
policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1))
|
||||
policy.populate(cluster, hosts)
|
||||
|
||||
for i in range(4):
|
||||
query = Query(routing_key=struct.pack('>i', i))
|
||||
qplan = list(policy.make_query_plan(query))
|
||||
replicas = get_replicas(struct.pack('>i', i))
|
||||
|
||||
# first should be the only local replica
|
||||
self.assertIn(qplan[0], replicas)
|
||||
self.assertEquals(qplan[0].datacenter, "dc1")
|
||||
|
||||
# then the local non-replica
|
||||
self.assertNotIn(qplan[1], replicas)
|
||||
self.assertEquals(qplan[1].datacenter, "dc1")
|
||||
|
||||
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
|
||||
# shouldn't see two remotes)
|
||||
self.assertEquals(qplan[2].datacenter, "dc2")
|
||||
self.assertEquals(3, len(qplan))
|
||||
|
||||
|
||||
class ExponentialReconnectionPolicyTest(unittest.TestCase):
|
||||
|
||||
def test_bad_vals(self):
|
||||
|
Reference in New Issue
Block a user