More complete token-aware routing

The TokenAwarePolicy now takes into account what keyspace the
query is being run against so that replication settings can
be taken into account.  This means that all replicas for
the queried row can be used instead of just the first replica.
This commit is contained in:
Tyler Hobbs
2013-09-18 17:02:43 -05:00
parent f3d75e75cf
commit 54699cbc04
6 changed files with 105 additions and 49 deletions

View File

@@ -930,10 +930,10 @@ class ControlConnection(object):
except ConnectionException as exc:
errors[host.address] = exc
host.monitor.signal_connection_failure(exc)
log.warn("[control connection] Error connecting to %s: %s", host, exc)
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
except Exception as exc:
errors[host.address] = exc
log.warn("[control connection] Error connecting to %s: %s", host, exc)
log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
raise NoHostAvailable("Unable to connect to any servers", errors)
@@ -1354,7 +1354,8 @@ class ResponseFuture(object):
# convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where
# they last left off
self.query_plan = iter(session._load_balancer.make_query_plan(query))
self.query_plan = iter(session._load_balancer.make_query_plan(
session.keyspace, query))
self._event = Event()
self._errors = {}

View File

@@ -263,18 +263,20 @@ class Metadata(object):
self.token_map = None
return
tokens_to_hosts = defaultdict(set)
token_to_primary_replica = {}
ring = []
for host, token_strings in token_map.iteritems():
for token_string in token_strings:
token = token_class(token_string)
ring.append(token)
tokens_to_hosts[token].add(host)
token_to_primary_replica[token] = host
ring = sorted(ring)
self.token_map = TokenMap(token_class, tokens_to_hosts, ring)
all_tokens = sorted(ring)
self.token_map = TokenMap(
token_class, token_to_primary_replica, all_tokens,
self.keyspaces.values())
def get_replicas(self, key):
def get_replicas(self, keyspace, key):
"""
Returns a list of :class:`.Host` instances that are replicas for a given
partition key.
@@ -283,7 +285,7 @@ class Metadata(object):
if not t:
return []
try:
return t.get_replicas(t.token_class.from_key(key))
return t.get_replicas(keyspace, t.token_class.from_key(key))
except NoMurmur3:
return []
@@ -344,14 +346,15 @@ class SimpleStrategy(ReplicationStrategy):
replication_factor = None
def __init__(self, replication_factor):
self.replication_factor = replication_factor
self.replication_factor = int(replication_factor)
def make_token_replica_map(self, token_to_primary_replica, ring):
replica_map = {}
for i in range(len(ring)):
j, hosts = 0, set()
while len(hosts) < self.replication_factor and j < len(ring):
hosts.add(token_to_primary_replica[ring[(i + j) % len(ring)]])
token = ring[(i + j) % len(ring)]
hosts.add(token_to_primary_replica[token])
j += 1
replica_map[ring[i]] = hosts
@@ -400,6 +403,7 @@ class NetworkTopologyStrategy(ReplicationStrategy):
ret += ", '%s': '%d'" % (dc, repl_factor)
return ret + "}"
class KeyspaceMetadata(object):
"""
A representation of the schema for a single keyspace.
@@ -708,9 +712,10 @@ class TokenMap(object):
A subclass of :class:`.Token`, depending on what partitioner the cluster uses.
"""
tokens_to_hosts = None
tokens_to_hosts_by_ks = None
"""
A map of :class:`.Token` objects to :class:`.Host` objects.
A map of keyspace names to a nested map of :class:`.Token` objects to
sets of :class:`.Host` objects.
"""
ring = None
@@ -718,25 +723,39 @@ class TokenMap(object):
An ordered list of :class:`.Token` instances in the ring.
"""
def __init__(self, token_class, tokens_to_hosts, ring):
def __init__(self, token_class, token_to_primary_replica, all_tokens, keyspaces):
self.token_class = token_class
self.tokens_to_hosts = tokens_to_hosts
self.ring = ring
self.ring = all_tokens
def get_replicas(self, token):
self.tokens_to_hosts_by_ks = {}
for ks_metadata in keyspaces:
strategy = ks_metadata.replication_strategy
if strategy is None:
token_to_hosts = defaultdict(set)
for token, host in token_to_primary_replica.items():
token_to_hosts[token].add(host)
self.tokens_to_hosts_by_ks[ks_metadata.name] = token_to_hosts
else:
self.tokens_to_hosts_by_ks[ks_metadata.name] = \
strategy.make_token_replica_map(
token_to_primary_replica, all_tokens)
def get_replicas(self, keyspace, token):
"""
Get :class:`.Host` instances representing all of the replica nodes
for a given :class:`.Token`.
Get a set of :class:`.Host` instances representing all of the
replica nodes for a given :class:`.Token`.
"""
# TODO depending on keyspace and replication strategy options,
# return full set of replicas
tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None)
if tokens_to_hosts is None:
return set()
point = bisect_left(self.ring, token)
if point == 0 and token != self.ring[0]:
return self.tokens_to_hosts[self.ring[-1]]
return tokens_to_hosts[self.ring[-1]]
elif point == len(self.ring):
return self.tokens_to_hosts[self.ring[0]]
return tokens_to_hosts[self.ring[0]]
else:
return self.tokens_to_hosts[self.ring[point]]
return tokens_to_hosts[self.ring[point]]
class Token(object):
@@ -785,6 +804,10 @@ class Murmur3Token(Token):
""" `token` should be an int or string representing the token """
self.value = int(token)
def __repr__(self):
return "<Murmur3Token: %r" % (self.value,)
__str__ = __repr__
class MD5Token(Token):
"""
@@ -799,6 +822,10 @@ class MD5Token(Token):
""" `token` should be an int or string representing the token """
self.value = int(token)
def __repr__(self):
return "<MD5Token: %d" % (self.value,)
__str__ = __repr__
class BytesToken(Token):
"""
@@ -812,3 +839,7 @@ class BytesToken(Token):
"Tokens for ByteOrderedPartitioner should be strings (got %s)"
% (type(token_string),))
self.value = token_string
def __repr__(self):
return "<BytesToken: %r" % (self.value,)
__str__ = __repr__

View File

@@ -71,7 +71,7 @@ class LoadBalancingPolicy(object):
"""
raise NotImplementedError()
def make_query_plan(self, query=None):
def make_query_plan(self, working_keyspace=None, query=None):
"""
Given a :class:`~.query.Query` instance, return a iterable
of :class:`.Host` instances which should be queried in that
@@ -80,6 +80,10 @@ class LoadBalancingPolicy(object):
Note that the `query` argument may be :const:`None` when preparing
statements.
`working_keyspace` should be the string name of the current keyspace,
as set through :meth:`.Session.set_keyspace()` or with a ``USE``
statement.
"""
raise NotImplementedError()
@@ -133,7 +137,7 @@ class RoundRobinPolicy(LoadBalancingPolicy):
def distance(self, host):
return HostDistance.LOCAL
def make_query_plan(self, query=None):
def make_query_plan(self, working_keyspace=None, query=None):
# not thread-safe, but we don't care much about lost increments
# for the purposes of load balancing
pos = self._position
@@ -219,7 +223,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
else:
return HostDistance.IGNORED
def make_query_plan(self, query=None):
def make_query_plan(self, working_keyspace=None, query=None):
# not thread-safe, but we don't care much about lost increments
# for the purposes of load balancing
pos = self._position
@@ -278,24 +282,29 @@ class TokenAwarePolicy(LoadBalancingPolicy):
def distance(self, *args, **kwargs):
return self.child_policy.distance(*args, **kwargs)
def make_query_plan(self, query=None):
def make_query_plan(self, working_keyspace=None, query=None):
if query and query.keyspace:
keyspace = query.keyspace
else:
keyspace = working_keyspace
child = self.child_policy
if query is None:
for host in child.make_query_plan(query):
for host in child.make_query_plan(keyspace, query):
yield host
else:
routing_key = query.routing_key
if routing_key is None:
for host in child.make_query_plan(query):
for host in child.make_query_plan(keyspace, query):
yield host
else:
replicas = self._cluster_metadata.get_replicas(routing_key)
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
for replica in replicas:
if replica.monitor.is_up and \
child.distance(replica) == HostDistance.LOCAL:
yield replica
for host in child.make_query_plan(query):
for host in child.make_query_plan(keyspace, query):
# skip if we've already listed this host
if host not in replicas or \
child.distance(host) == HostDistance.REMOTE:

View File

@@ -38,9 +38,16 @@ class Query(object):
The :class:`.ConsistencyLevel` to be used for this operation. Defaults
to :attr:`.ConsistencyLevel.ONE`.
"""
keyspace = None
"""
The string name of the keyspace this query acts on.
"""
_routing_key = None
def __init__(self, retry_policy=None, tracing_enabled=False, consistency_level=ConsistencyLevel.ONE, routing_key=None):
def __init__(self, retry_policy=None, tracing_enabled=False,
consistency_level=ConsistencyLevel.ONE, routing_key=None):
self.retry_policy = retry_policy
self.tracing_enabled = tracing_enabled
self.consistency_level = consistency_level
@@ -245,6 +252,14 @@ class BoundStatement(Query):
return self._routing_key
@property
def keyspace(self):
meta = self.prepared_statement.column_metadata
if meta:
return meta[0][0]
else:
return None
class ValueSequence(object):
"""

View File

@@ -1,10 +1,10 @@
try:
import unittest2 as unittest
except ImportError:
import unittest
import unittest # noqa
from cassandra.cluster import Cluster
from cassandra.metadata import TableMetadata, Token, MD5Token, TokenMap
from cassandra.metadata import KeyspaceMetadata, TableMetadata, Token, MD5Token, TokenMap
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
@@ -120,8 +120,8 @@ class SchemaMetadataTest(unittest.TestCase):
self.assertEqual(ksmeta.name, self.ksname)
self.assertTrue(ksmeta.durable_writes)
self.assertTrue(ksmeta.replication['class'].endswith('SimpleStrategy'))
self.assertEqual(ksmeta.replication['replication_factor'], '1')
self.assertEqual(ksmeta.replication_strategy.name, 'SimpleStrategy')
self.assertEqual(ksmeta.replication_strategy.replication_factor, 1)
self.assertTrue(self.cfname in ksmeta.tables)
tablemeta = ksmeta.tables[self.cfname]
@@ -286,27 +286,27 @@ class TokenMetadataTest(unittest.TestCase):
tmap = cluster.metadata.token_map
self.assertTrue(issubclass(tmap.token_class, Token))
self.assertEqual(expected_node_count, len(tmap.ring))
self.assertEqual(expected_node_count, len(tmap.tokens_to_hosts))
cluster.shutdown()
def test_getting_replicas(self):
tokens = [MD5Token(str(i)) for i in range(0, (2 ** 127 - 1), 2 ** 125)]
hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))]
tokens_to_hosts = dict((t, set([h])) for t, h in zip(tokens, hosts))
token_map = TokenMap(MD5Token, tokens_to_hosts, tokens)
token_to_primary_replica = dict(zip(tokens, hosts))
keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"})
token_map = TokenMap(MD5Token, token_to_primary_replica, tokens, [keyspace])
# tokens match node tokens exactly
for token, expected_host in zip(tokens, hosts):
replicas = token_map.get_replicas(token)
replicas = token_map.get_replicas("ks", token)
self.assertEqual(replicas, set([expected_host]))
# shift the tokens back by one
for token, expected_host in zip(tokens[1:], hosts[1:]):
replicas = token_map.get_replicas(MD5Token(str(token.value - 1)))
replicas = token_map.get_replicas("ks", MD5Token(str(token.value - 1)))
self.assertEqual(replicas, set([expected_host]))
# shift the tokens forward by one
for i, token in enumerate(tokens):
replicas = token_map.get_replicas(MD5Token(str(token.value + 1)))
replicas = token_map.get_replicas("ks", MD5Token(str(token.value + 1)))
expected_host = hosts[(i + 1) % len(hosts)]
self.assertEqual(replicas, set([expected_host]))

View File

@@ -181,7 +181,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
cluster.metadata = Mock(spec=Metadata)
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
def get_replicas(packed_key):
def get_replicas(keyspace, packed_key):
index = struct.unpack('>i', packed_key)[0]
return list(islice(cycle(hosts), index, index + 2))
@@ -192,9 +192,9 @@ class TokenAwarePolicyTest(unittest.TestCase):
for i in range(4):
query = Query(routing_key=struct.pack('>i', i))
qplan = list(policy.make_query_plan(query))
qplan = list(policy.make_query_plan(None, query))
replicas = get_replicas(struct.pack('>i', i))
replicas = get_replicas(None, struct.pack('>i', i))
other = set(h for h in hosts if h not in replicas)
self.assertEquals(replicas, qplan[:2])
self.assertEquals(other, set(qplan[2:]))
@@ -208,7 +208,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
for h in hosts[2:]:
h.set_location_info("dc2", "rack1")
def get_replicas(packed_key):
def get_replicas(keyspace, packed_key):
index = struct.unpack('>i', packed_key)[0]
# return one node from each DC
if index % 2 == 0:
@@ -223,8 +223,8 @@ class TokenAwarePolicyTest(unittest.TestCase):
for i in range(4):
query = Query(routing_key=struct.pack('>i', i))
qplan = list(policy.make_query_plan(query))
replicas = get_replicas(struct.pack('>i', i))
qplan = list(policy.make_query_plan(None, query))
replicas = get_replicas(None, struct.pack('>i', i))
# first should be the only local replica
self.assertIn(qplan[0], replicas)