More complete token-aware routing

The TokenAwarePolicy now takes into account what keyspace the
query is being run against so that replication settings can
be taken into account.  This means that all replicas for
the queried row can be used instead of just the first replica.
This commit is contained in:
Tyler Hobbs
2013-09-18 17:02:43 -05:00
parent f3d75e75cf
commit 54699cbc04
6 changed files with 105 additions and 49 deletions

View File

@@ -930,10 +930,10 @@ class ControlConnection(object):
except ConnectionException as exc: except ConnectionException as exc:
errors[host.address] = exc errors[host.address] = exc
host.monitor.signal_connection_failure(exc) host.monitor.signal_connection_failure(exc)
log.warn("[control connection] Error connecting to %s: %s", host, exc) log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
except Exception as exc: except Exception as exc:
errors[host.address] = exc errors[host.address] = exc
log.warn("[control connection] Error connecting to %s: %s", host, exc) log.warn("[control connection] Error connecting to %s:", host, exc_info=True)
raise NoHostAvailable("Unable to connect to any servers", errors) raise NoHostAvailable("Unable to connect to any servers", errors)
@@ -1354,7 +1354,8 @@ class ResponseFuture(object):
# convert the list/generator/etc to an iterator so that subsequent # convert the list/generator/etc to an iterator so that subsequent
# calls to send_request (which retries may do) will resume where # calls to send_request (which retries may do) will resume where
# they last left off # they last left off
self.query_plan = iter(session._load_balancer.make_query_plan(query)) self.query_plan = iter(session._load_balancer.make_query_plan(
session.keyspace, query))
self._event = Event() self._event = Event()
self._errors = {} self._errors = {}

View File

@@ -263,18 +263,20 @@ class Metadata(object):
self.token_map = None self.token_map = None
return return
tokens_to_hosts = defaultdict(set) token_to_primary_replica = {}
ring = [] ring = []
for host, token_strings in token_map.iteritems(): for host, token_strings in token_map.iteritems():
for token_string in token_strings: for token_string in token_strings:
token = token_class(token_string) token = token_class(token_string)
ring.append(token) ring.append(token)
tokens_to_hosts[token].add(host) token_to_primary_replica[token] = host
ring = sorted(ring) all_tokens = sorted(ring)
self.token_map = TokenMap(token_class, tokens_to_hosts, ring) self.token_map = TokenMap(
token_class, token_to_primary_replica, all_tokens,
self.keyspaces.values())
def get_replicas(self, key): def get_replicas(self, keyspace, key):
""" """
Returns a list of :class:`.Host` instances that are replicas for a given Returns a list of :class:`.Host` instances that are replicas for a given
partition key. partition key.
@@ -283,7 +285,7 @@ class Metadata(object):
if not t: if not t:
return [] return []
try: try:
return t.get_replicas(t.token_class.from_key(key)) return t.get_replicas(keyspace, t.token_class.from_key(key))
except NoMurmur3: except NoMurmur3:
return [] return []
@@ -344,14 +346,15 @@ class SimpleStrategy(ReplicationStrategy):
replication_factor = None replication_factor = None
def __init__(self, replication_factor): def __init__(self, replication_factor):
self.replication_factor = replication_factor self.replication_factor = int(replication_factor)
def make_token_replica_map(self, token_to_primary_replica, ring): def make_token_replica_map(self, token_to_primary_replica, ring):
replica_map = {} replica_map = {}
for i in range(len(ring)): for i in range(len(ring)):
j, hosts = 0, set() j, hosts = 0, set()
while len(hosts) < self.replication_factor and j < len(ring): while len(hosts) < self.replication_factor and j < len(ring):
hosts.add(token_to_primary_replica[ring[(i + j) % len(ring)]]) token = ring[(i + j) % len(ring)]
hosts.add(token_to_primary_replica[token])
j += 1 j += 1
replica_map[ring[i]] = hosts replica_map[ring[i]] = hosts
@@ -400,6 +403,7 @@ class NetworkTopologyStrategy(ReplicationStrategy):
ret += ", '%s': '%d'" % (dc, repl_factor) ret += ", '%s': '%d'" % (dc, repl_factor)
return ret + "}" return ret + "}"
class KeyspaceMetadata(object): class KeyspaceMetadata(object):
""" """
A representation of the schema for a single keyspace. A representation of the schema for a single keyspace.
@@ -708,9 +712,10 @@ class TokenMap(object):
A subclass of :class:`.Token`, depending on what partitioner the cluster uses. A subclass of :class:`.Token`, depending on what partitioner the cluster uses.
""" """
tokens_to_hosts = None tokens_to_hosts_by_ks = None
""" """
A map of :class:`.Token` objects to :class:`.Host` objects. A map of keyspace names to a nested map of :class:`.Token` objects to
sets of :class:`.Host` objects.
""" """
ring = None ring = None
@@ -718,25 +723,39 @@ class TokenMap(object):
An ordered list of :class:`.Token` instances in the ring. An ordered list of :class:`.Token` instances in the ring.
""" """
def __init__(self, token_class, tokens_to_hosts, ring): def __init__(self, token_class, token_to_primary_replica, all_tokens, keyspaces):
self.token_class = token_class self.token_class = token_class
self.tokens_to_hosts = tokens_to_hosts self.ring = all_tokens
self.ring = ring
def get_replicas(self, token): self.tokens_to_hosts_by_ks = {}
for ks_metadata in keyspaces:
strategy = ks_metadata.replication_strategy
if strategy is None:
token_to_hosts = defaultdict(set)
for token, host in token_to_primary_replica.items():
token_to_hosts[token].add(host)
self.tokens_to_hosts_by_ks[ks_metadata.name] = token_to_hosts
else:
self.tokens_to_hosts_by_ks[ks_metadata.name] = \
strategy.make_token_replica_map(
token_to_primary_replica, all_tokens)
def get_replicas(self, keyspace, token):
""" """
Get :class:`.Host` instances representing all of the replica nodes Get a set of :class:`.Host` instances representing all of the
for a given :class:`.Token`. replica nodes for a given :class:`.Token`.
""" """
# TODO depending on keyspace and replication strategy options, tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None)
# return full set of replicas if tokens_to_hosts is None:
return set()
point = bisect_left(self.ring, token) point = bisect_left(self.ring, token)
if point == 0 and token != self.ring[0]: if point == 0 and token != self.ring[0]:
return self.tokens_to_hosts[self.ring[-1]] return tokens_to_hosts[self.ring[-1]]
elif point == len(self.ring): elif point == len(self.ring):
return self.tokens_to_hosts[self.ring[0]] return tokens_to_hosts[self.ring[0]]
else: else:
return self.tokens_to_hosts[self.ring[point]] return tokens_to_hosts[self.ring[point]]
class Token(object): class Token(object):
@@ -785,6 +804,10 @@ class Murmur3Token(Token):
""" `token` should be an int or string representing the token """ """ `token` should be an int or string representing the token """
self.value = int(token) self.value = int(token)
def __repr__(self):
return "<Murmur3Token: %r" % (self.value,)
__str__ = __repr__
class MD5Token(Token): class MD5Token(Token):
""" """
@@ -799,6 +822,10 @@ class MD5Token(Token):
""" `token` should be an int or string representing the token """ """ `token` should be an int or string representing the token """
self.value = int(token) self.value = int(token)
def __repr__(self):
return "<MD5Token: %d" % (self.value,)
__str__ = __repr__
class BytesToken(Token): class BytesToken(Token):
""" """
@@ -812,3 +839,7 @@ class BytesToken(Token):
"Tokens for ByteOrderedPartitioner should be strings (got %s)" "Tokens for ByteOrderedPartitioner should be strings (got %s)"
% (type(token_string),)) % (type(token_string),))
self.value = token_string self.value = token_string
def __repr__(self):
return "<BytesToken: %r" % (self.value,)
__str__ = __repr__

View File

@@ -71,7 +71,7 @@ class LoadBalancingPolicy(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def make_query_plan(self, query=None): def make_query_plan(self, working_keyspace=None, query=None):
""" """
Given a :class:`~.query.Query` instance, return a iterable Given a :class:`~.query.Query` instance, return a iterable
of :class:`.Host` instances which should be queried in that of :class:`.Host` instances which should be queried in that
@@ -80,6 +80,10 @@ class LoadBalancingPolicy(object):
Note that the `query` argument may be :const:`None` when preparing Note that the `query` argument may be :const:`None` when preparing
statements. statements.
`working_keyspace` should be the string name of the current keyspace,
as set through :meth:`.Session.set_keyspace()` or with a ``USE``
statement.
""" """
raise NotImplementedError() raise NotImplementedError()
@@ -133,7 +137,7 @@ class RoundRobinPolicy(LoadBalancingPolicy):
def distance(self, host): def distance(self, host):
return HostDistance.LOCAL return HostDistance.LOCAL
def make_query_plan(self, query=None): def make_query_plan(self, working_keyspace=None, query=None):
# not thread-safe, but we don't care much about lost increments # not thread-safe, but we don't care much about lost increments
# for the purposes of load balancing # for the purposes of load balancing
pos = self._position pos = self._position
@@ -219,7 +223,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
else: else:
return HostDistance.IGNORED return HostDistance.IGNORED
def make_query_plan(self, query=None): def make_query_plan(self, working_keyspace=None, query=None):
# not thread-safe, but we don't care much about lost increments # not thread-safe, but we don't care much about lost increments
# for the purposes of load balancing # for the purposes of load balancing
pos = self._position pos = self._position
@@ -278,24 +282,29 @@ class TokenAwarePolicy(LoadBalancingPolicy):
def distance(self, *args, **kwargs): def distance(self, *args, **kwargs):
return self.child_policy.distance(*args, **kwargs) return self.child_policy.distance(*args, **kwargs)
def make_query_plan(self, query=None): def make_query_plan(self, working_keyspace=None, query=None):
if query and query.keyspace:
keyspace = query.keyspace
else:
keyspace = working_keyspace
child = self.child_policy child = self.child_policy
if query is None: if query is None:
for host in child.make_query_plan(query): for host in child.make_query_plan(keyspace, query):
yield host yield host
else: else:
routing_key = query.routing_key routing_key = query.routing_key
if routing_key is None: if routing_key is None:
for host in child.make_query_plan(query): for host in child.make_query_plan(keyspace, query):
yield host yield host
else: else:
replicas = self._cluster_metadata.get_replicas(routing_key) replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
for replica in replicas: for replica in replicas:
if replica.monitor.is_up and \ if replica.monitor.is_up and \
child.distance(replica) == HostDistance.LOCAL: child.distance(replica) == HostDistance.LOCAL:
yield replica yield replica
for host in child.make_query_plan(query): for host in child.make_query_plan(keyspace, query):
# skip if we've already listed this host # skip if we've already listed this host
if host not in replicas or \ if host not in replicas or \
child.distance(host) == HostDistance.REMOTE: child.distance(host) == HostDistance.REMOTE:

View File

@@ -38,9 +38,16 @@ class Query(object):
The :class:`.ConsistencyLevel` to be used for this operation. Defaults The :class:`.ConsistencyLevel` to be used for this operation. Defaults
to :attr:`.ConsistencyLevel.ONE`. to :attr:`.ConsistencyLevel.ONE`.
""" """
keyspace = None
"""
The string name of the keyspace this query acts on.
"""
_routing_key = None _routing_key = None
def __init__(self, retry_policy=None, tracing_enabled=False, consistency_level=ConsistencyLevel.ONE, routing_key=None): def __init__(self, retry_policy=None, tracing_enabled=False,
consistency_level=ConsistencyLevel.ONE, routing_key=None):
self.retry_policy = retry_policy self.retry_policy = retry_policy
self.tracing_enabled = tracing_enabled self.tracing_enabled = tracing_enabled
self.consistency_level = consistency_level self.consistency_level = consistency_level
@@ -245,6 +252,14 @@ class BoundStatement(Query):
return self._routing_key return self._routing_key
@property
def keyspace(self):
meta = self.prepared_statement.column_metadata
if meta:
return meta[0][0]
else:
return None
class ValueSequence(object): class ValueSequence(object):
""" """

View File

@@ -1,10 +1,10 @@
try: try:
import unittest2 as unittest import unittest2 as unittest
except ImportError: except ImportError:
import unittest import unittest # noqa
from cassandra.cluster import Cluster from cassandra.cluster import Cluster
from cassandra.metadata import TableMetadata, Token, MD5Token, TokenMap from cassandra.metadata import KeyspaceMetadata, TableMetadata, Token, MD5Token, TokenMap
from cassandra.policies import SimpleConvictionPolicy from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host from cassandra.pool import Host
@@ -120,8 +120,8 @@ class SchemaMetadataTest(unittest.TestCase):
self.assertEqual(ksmeta.name, self.ksname) self.assertEqual(ksmeta.name, self.ksname)
self.assertTrue(ksmeta.durable_writes) self.assertTrue(ksmeta.durable_writes)
self.assertTrue(ksmeta.replication['class'].endswith('SimpleStrategy')) self.assertEqual(ksmeta.replication_strategy.name, 'SimpleStrategy')
self.assertEqual(ksmeta.replication['replication_factor'], '1') self.assertEqual(ksmeta.replication_strategy.replication_factor, 1)
self.assertTrue(self.cfname in ksmeta.tables) self.assertTrue(self.cfname in ksmeta.tables)
tablemeta = ksmeta.tables[self.cfname] tablemeta = ksmeta.tables[self.cfname]
@@ -286,27 +286,27 @@ class TokenMetadataTest(unittest.TestCase):
tmap = cluster.metadata.token_map tmap = cluster.metadata.token_map
self.assertTrue(issubclass(tmap.token_class, Token)) self.assertTrue(issubclass(tmap.token_class, Token))
self.assertEqual(expected_node_count, len(tmap.ring)) self.assertEqual(expected_node_count, len(tmap.ring))
self.assertEqual(expected_node_count, len(tmap.tokens_to_hosts))
cluster.shutdown() cluster.shutdown()
def test_getting_replicas(self): def test_getting_replicas(self):
tokens = [MD5Token(str(i)) for i in range(0, (2 ** 127 - 1), 2 ** 125)] tokens = [MD5Token(str(i)) for i in range(0, (2 ** 127 - 1), 2 ** 125)]
hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))] hosts = [Host("ip%d" % i, SimpleConvictionPolicy) for i in range(len(tokens))]
tokens_to_hosts = dict((t, set([h])) for t, h in zip(tokens, hosts)) token_to_primary_replica = dict(zip(tokens, hosts))
token_map = TokenMap(MD5Token, tokens_to_hosts, tokens) keyspace = KeyspaceMetadata("ks", True, "SimpleStrategy", {"replication_factor": "1"})
token_map = TokenMap(MD5Token, token_to_primary_replica, tokens, [keyspace])
# tokens match node tokens exactly # tokens match node tokens exactly
for token, expected_host in zip(tokens, hosts): for token, expected_host in zip(tokens, hosts):
replicas = token_map.get_replicas(token) replicas = token_map.get_replicas("ks", token)
self.assertEqual(replicas, set([expected_host])) self.assertEqual(replicas, set([expected_host]))
# shift the tokens back by one # shift the tokens back by one
for token, expected_host in zip(tokens[1:], hosts[1:]): for token, expected_host in zip(tokens[1:], hosts[1:]):
replicas = token_map.get_replicas(MD5Token(str(token.value - 1))) replicas = token_map.get_replicas("ks", MD5Token(str(token.value - 1)))
self.assertEqual(replicas, set([expected_host])) self.assertEqual(replicas, set([expected_host]))
# shift the tokens forward by one # shift the tokens forward by one
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
replicas = token_map.get_replicas(MD5Token(str(token.value + 1))) replicas = token_map.get_replicas("ks", MD5Token(str(token.value + 1)))
expected_host = hosts[(i + 1) % len(hosts)] expected_host = hosts[(i + 1) % len(hosts)]
self.assertEqual(replicas, set([expected_host])) self.assertEqual(replicas, set([expected_host]))

View File

@@ -181,7 +181,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
cluster.metadata = Mock(spec=Metadata) cluster.metadata = Mock(spec=Metadata)
hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)] hosts = [Host(str(i), SimpleConvictionPolicy) for i in range(4)]
def get_replicas(packed_key): def get_replicas(keyspace, packed_key):
index = struct.unpack('>i', packed_key)[0] index = struct.unpack('>i', packed_key)[0]
return list(islice(cycle(hosts), index, index + 2)) return list(islice(cycle(hosts), index, index + 2))
@@ -192,9 +192,9 @@ class TokenAwarePolicyTest(unittest.TestCase):
for i in range(4): for i in range(4):
query = Query(routing_key=struct.pack('>i', i)) query = Query(routing_key=struct.pack('>i', i))
qplan = list(policy.make_query_plan(query)) qplan = list(policy.make_query_plan(None, query))
replicas = get_replicas(struct.pack('>i', i)) replicas = get_replicas(None, struct.pack('>i', i))
other = set(h for h in hosts if h not in replicas) other = set(h for h in hosts if h not in replicas)
self.assertEquals(replicas, qplan[:2]) self.assertEquals(replicas, qplan[:2])
self.assertEquals(other, set(qplan[2:])) self.assertEquals(other, set(qplan[2:]))
@@ -208,7 +208,7 @@ class TokenAwarePolicyTest(unittest.TestCase):
for h in hosts[2:]: for h in hosts[2:]:
h.set_location_info("dc2", "rack1") h.set_location_info("dc2", "rack1")
def get_replicas(packed_key): def get_replicas(keyspace, packed_key):
index = struct.unpack('>i', packed_key)[0] index = struct.unpack('>i', packed_key)[0]
# return one node from each DC # return one node from each DC
if index % 2 == 0: if index % 2 == 0:
@@ -223,8 +223,8 @@ class TokenAwarePolicyTest(unittest.TestCase):
for i in range(4): for i in range(4):
query = Query(routing_key=struct.pack('>i', i)) query = Query(routing_key=struct.pack('>i', i))
qplan = list(policy.make_query_plan(query)) qplan = list(policy.make_query_plan(None, query))
replicas = get_replicas(struct.pack('>i', i)) replicas = get_replicas(None, struct.pack('>i', i))
# first should be the only local replica # first should be the only local replica
self.assertIn(qplan[0], replicas) self.assertIn(qplan[0], replicas)