diff --git a/cassandra/policies.py b/cassandra/policies.py index 347907ec..5b325bfd 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -14,7 +14,7 @@ from itertools import islice, cycle, groupby, repeat import logging -from random import randint +from random import randint, shuffle from threading import Lock import socket @@ -320,13 +320,18 @@ class TokenAwarePolicy(LoadBalancingPolicy): If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. + + If :attr:`.shuffle_replicas` is truthy, :attr:`~.HostDistance.LOCAL` + replicas will be yielded in a random order, followed by the remaining + hosts in the order provided child policy's query plan. """ _child_policy = None _cluster_metadata = None - def __init__(self, child_policy): + def __init__(self, child_policy, shuffle_replicas=False): self._child_policy = child_policy + self.shuffle_replicas = shuffle_replicas def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata @@ -361,6 +366,8 @@ class TokenAwarePolicy(LoadBalancingPolicy): yield host else: replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) + if self.shuffle_replicas: + shuffle(replicas) for replica in replicas: if replica.is_up and \ child.distance(replica) == HostDistance.LOCAL: