Add test coverage, fixes for TokenAwarePolicy
This commit is contained in:
@@ -280,7 +280,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
|
|||||||
for host in child.make_query_plan(query):
|
for host in child.make_query_plan(query):
|
||||||
yield host
|
yield host
|
||||||
else:
|
else:
|
||||||
replicas = self.metadata.get_replicas(routing_key)
|
replicas = self._cluster_metadata.get_replicas(routing_key)
|
||||||
for replica in replicas:
|
for replica in replicas:
|
||||||
if replica.monitor.is_up and \
|
if replica.monitor.is_up and \
|
||||||
child.distance(replica) == HostDistance.LOCAL:
|
child.distance(replica) == HostDistance.LOCAL:
|
||||||
@@ -289,7 +289,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
|
|||||||
for host in child.make_query_plan(query):
|
for host in child.make_query_plan(query):
|
||||||
# skip if we've already listed this host
|
# skip if we've already listed this host
|
||||||
if host not in replicas or \
|
if host not in replicas or \
|
||||||
child.distance(replica) == HostDistance.REMOTE:
|
child.distance(host) == HostDistance.REMOTE:
|
||||||
yield host
|
yield host
|
||||||
|
|
||||||
def on_up(self, *args, **kwargs):
|
def on_up(self, *args, **kwargs):
|
||||||
|
@@ -92,7 +92,8 @@ class Host(object):
|
|||||||
return self.address
|
return self.address
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<%s: %s>" % (self.__class__.__name__, self.address)
|
dc = (" %s" % (self._datacenter,)) if self._datacenter else ""
|
||||||
|
return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc)
|
||||||
|
|
||||||
|
|
||||||
class _ReconnectionHandler(object):
|
class _ReconnectionHandler(object):
|
||||||
|
@@ -48,12 +48,12 @@ class Query(object):
|
|||||||
def _get_routing_key(self):
|
def _get_routing_key(self):
|
||||||
return self._routing_key
|
return self._routing_key
|
||||||
|
|
||||||
def _set_routing_key(self, key_components):
|
def _set_routing_key(self, key):
|
||||||
if len(key_components) == 1:
|
if isinstance(key, (list, tuple)):
|
||||||
self._routing_key = key_components[0]
|
|
||||||
else:
|
|
||||||
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
|
self._routing_key = "".join(struct.pack("HsB", len(component), component, 0)
|
||||||
for component in key_components)
|
for component in key)
|
||||||
|
else:
|
||||||
|
self._routing_key = key
|
||||||
|
|
||||||
def _del_routing_key(self):
|
def _del_routing_key(self):
|
||||||
self._routing_key = None
|
self._routing_key = None
|
||||||
@@ -66,7 +66,9 @@ class Query(object):
|
|||||||
The :attr:`~.TableMetadata.partition_key` portion of the primary key,
|
The :attr:`~.TableMetadata.partition_key` portion of the primary key,
|
||||||
which can be used to determine which nodes are replicas for the query.
|
which can be used to determine which nodes are replicas for the query.
|
||||||
|
|
||||||
When setting this attribute, a list or tuple *must* be used.
|
If the partition key is a composite, a list or tuple must be passed in.
|
||||||
|
Each key component should be in its packed (binary) format, so all
|
||||||
|
components should be strings.
|
||||||
""")
|
""")
|
||||||
|
|
||||||
class SimpleStatement(Query):
|
class SimpleStatement(Query):
|
||||||
|
@@ -1,12 +1,19 @@
|
|||||||
import unittest
|
from itertools import islice, cycle
|
||||||
|
from mock import Mock
|
||||||
|
import struct
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
import unittest
|
||||||
|
|
||||||
from cassandra import ConsistencyLevel
|
from cassandra import ConsistencyLevel
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
from cassandra.metadata import Metadata
|
||||||
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
||||||
SimpleConvictionPolicy, HostDistance,
|
TokenAwarePolicy, SimpleConvictionPolicy,
|
||||||
ExponentialReconnectionPolicy, RetryPolicy,
|
HostDistance, ExponentialReconnectionPolicy,
|
||||||
WriteType, DowngradingConsistencyRetryPolicy)
|
RetryPolicy, WriteType,
|
||||||
|
DowngradingConsistencyRetryPolicy)
|
||||||
from cassandra.pool import Host
|
from cassandra.pool import Host
|
||||||
|
from cassandra.query import Query
|
||||||
|
|
||||||
class TestRoundRobinPolicy(unittest.TestCase):
|
class TestRoundRobinPolicy(unittest.TestCase):
|
||||||
|
|
||||||
@@ -163,6 +170,72 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase):
|
|||||||
self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE)
|
self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenAwarePolicyTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_wrap_round_robin(self):
|
||||||
|
cluster = Mock(spec=Cluster)
|
||||||
|
cluster.metadata = Mock(spec=Metadata)
|
||||||
|
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
||||||
|
|
||||||
|
def get_replicas(packed_key):
|
||||||
|
index = struct.unpack('>i', packed_key)[0]
|
||||||
|
return list(islice(cycle(hosts), index, index + 2))
|
||||||
|
|
||||||
|
cluster.metadata.get_replicas.side_effect = get_replicas
|
||||||
|
|
||||||
|
policy = TokenAwarePolicy(RoundRobinPolicy())
|
||||||
|
policy.populate(cluster, hosts)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
query = Query(routing_key=struct.pack('>i', i))
|
||||||
|
qplan = list(policy.make_query_plan(query))
|
||||||
|
|
||||||
|
replicas = get_replicas(struct.pack('>i', i))
|
||||||
|
other = set(h for h in hosts if h not in replicas)
|
||||||
|
self.assertEquals(replicas, qplan[:2])
|
||||||
|
self.assertEquals(other, set(qplan[2:]))
|
||||||
|
|
||||||
|
def test_wrap_dc_aware(self):
|
||||||
|
cluster = Mock(spec=Cluster)
|
||||||
|
cluster.metadata = Mock(spec=Metadata)
|
||||||
|
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
|
||||||
|
for h in hosts[:2]:
|
||||||
|
h.set_location_info("dc1", "rack1")
|
||||||
|
for h in hosts[2:]:
|
||||||
|
h.set_location_info("dc2", "rack1")
|
||||||
|
|
||||||
|
def get_replicas(packed_key):
|
||||||
|
index = struct.unpack('>i', packed_key)[0]
|
||||||
|
# return one node from each DC
|
||||||
|
if index % 2 == 0:
|
||||||
|
return [hosts[0], hosts[2]]
|
||||||
|
else:
|
||||||
|
return [hosts[1], hosts[3]]
|
||||||
|
|
||||||
|
cluster.metadata.get_replicas.side_effect = get_replicas
|
||||||
|
|
||||||
|
policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1))
|
||||||
|
policy.populate(cluster, hosts)
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
query = Query(routing_key=struct.pack('>i', i))
|
||||||
|
qplan = list(policy.make_query_plan(query))
|
||||||
|
replicas = get_replicas(struct.pack('>i', i))
|
||||||
|
|
||||||
|
# first should be the only local replica
|
||||||
|
self.assertIn(qplan[0], replicas)
|
||||||
|
self.assertEquals(qplan[0].datacenter, "dc1")
|
||||||
|
|
||||||
|
# then the local non-replica
|
||||||
|
self.assertNotIn(qplan[1], replicas)
|
||||||
|
self.assertEquals(qplan[1].datacenter, "dc1")
|
||||||
|
|
||||||
|
# then one of the remotes (used_hosts_per_remote_dc is 1, so we
|
||||||
|
# shouldn't see two remotes)
|
||||||
|
self.assertEquals(qplan[2].datacenter, "dc2")
|
||||||
|
self.assertEquals(3, len(qplan))
|
||||||
|
|
||||||
|
|
||||||
class ExponentialReconnectionPolicyTest(unittest.TestCase):
|
class ExponentialReconnectionPolicyTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_bad_vals(self):
|
def test_bad_vals(self):
|
||||||
|
Reference in New Issue
Block a user