Add tests for load balancing policies, related fixes
This commit is contained in:
@@ -59,35 +59,37 @@ class RoundRobinPolicy(LoadBalancingPolicy):
|
||||
class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
|
||||
|
||||
def __init__(self, local_dc, used_hosts_per_remote_dc=0):
|
||||
LoadBalancingPolicy.__init__(self)
|
||||
self.local_dc = local_dc
|
||||
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
|
||||
self._dc_live_hosts = {}
|
||||
self._lock = RLock()
|
||||
|
||||
def populate(self, cluster, hosts):
|
||||
for dc, hosts in groupby(hosts, lambda h: h.dc):
|
||||
self._dc_live_hosts[dc] = set(hosts)
|
||||
for dc, dc_hosts in groupby(hosts, lambda h: h.datacenter):
|
||||
self._dc_live_hosts[dc] = set(dc_hosts)
|
||||
|
||||
if len(hosts) == 1:
|
||||
# position is currently only used for local hosts
|
||||
local_live = self._dc_live_hosts.get(self.local_dc)
|
||||
if len(local_live) == 1:
|
||||
self._position = 0
|
||||
else:
|
||||
self._position = randint(0, len(hosts) - 1)
|
||||
self._position = randint(0, len(local_live) - 1)
|
||||
|
||||
def distance(self, host):
|
||||
if host.dc == self.local_dc:
|
||||
if host.datacenter == self.local_dc:
|
||||
return HostDistance.LOCAL
|
||||
|
||||
if not self.used_hosts_per_remote_dc:
|
||||
return HostDistance.IGNORE
|
||||
return HostDistance.IGNORED
|
||||
else:
|
||||
dc_hosts = self._dc_live_hosts.get(host.dc)
|
||||
dc_hosts = self._dc_live_hosts.get(host.datacenter)
|
||||
if not dc_hosts:
|
||||
return HostDistance.IGNORE
|
||||
return HostDistance.IGNORED
|
||||
|
||||
if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]:
|
||||
return HostDistance.REMOTE
|
||||
else:
|
||||
return HostDistance.IGNORE
|
||||
return HostDistance.IGNORED
|
||||
|
||||
def make_query_plan(self, query=None):
|
||||
with self._lock:
|
||||
@@ -96,7 +98,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
|
||||
|
||||
local_live = list(self._dc_live_hosts.get(self.local_dc))
|
||||
pos %= len(local_live)
|
||||
for host in islice(cycle(local_live, pos, pos + len(local_live))):
|
||||
for host in islice(cycle(local_live), pos, pos + len(local_live)):
|
||||
yield host
|
||||
|
||||
for dc, current_dc_hosts in self._dc_live_hosts.iteritems():
|
||||
@@ -107,16 +109,16 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy):
|
||||
yield host
|
||||
|
||||
def on_up(self, host):
|
||||
self._dc_live_hosts.setdefault(host.dc, set()).add(host)
|
||||
self._dc_live_hosts.setdefault(host.datacenter, set()).add(host)
|
||||
|
||||
def on_down(self, host):
|
||||
self._dc_live_hosts.setdefault(host.dc, set()).discard(host)
|
||||
self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host)
|
||||
|
||||
def on_add(self, host):
|
||||
self._dc_live_hosts.setdefault(host.dc, set()).add(host)
|
||||
self._dc_live_hosts.setdefault(host.datacenter, set()).add(host)
|
||||
|
||||
def on_remove(self, host):
|
||||
self._dc_live_hosts.setdefault(host.dc, set()).discard(host)
|
||||
self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host)
|
||||
|
||||
|
||||
class SimpleConvictionPolicy(object):
|
||||
@@ -236,7 +238,7 @@ class FallthroughRetryPolicy(RetryPolicy):
|
||||
if attempt_num != 0:
|
||||
return (self.RETHROW, None)
|
||||
elif write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER):
|
||||
return (self.IGNORE, None)
|
||||
return (self.IGNORED, None)
|
||||
elif write_type == WriteType.UNLOGGED_BATCH:
|
||||
return self._pick_consistency(received_responses)
|
||||
elif write_type == WriteType.BATCH_LOG:
|
||||
|
@@ -44,6 +44,14 @@ class Host(object):
|
||||
self._reconnection_handler = None
|
||||
self._reconnection_lock = Lock()
|
||||
|
||||
@property
|
||||
def datacenter(self):
|
||||
return self._datacenter
|
||||
|
||||
@property
|
||||
def rack(self):
|
||||
return self._rack
|
||||
|
||||
def set_location_info(self, datacenter, rack):
|
||||
self._datacenter = datacenter
|
||||
self._rack = rack
|
||||
|
139
tests/test_policies.py
Normal file
139
tests/test_policies.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import unittest
|
||||
from threading import Thread
|
||||
|
||||
from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
|
||||
SimpleConvictionPolicy, HostDistance)
|
||||
from cassandra.pool import Host
|
||||
|
||||
class TestRoundRobinPolicy(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
hosts = [0, 1, 2, 3]
|
||||
policy = RoundRobinPolicy()
|
||||
policy.populate(None, hosts)
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(sorted(qplan), hosts)
|
||||
|
||||
def test_multiple_query_plans(self):
|
||||
hosts = [0, 1, 2, 3]
|
||||
policy = RoundRobinPolicy()
|
||||
policy.populate(None, hosts)
|
||||
for i in xrange(20):
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(sorted(qplan), hosts)
|
||||
|
||||
def test_single_host(self):
|
||||
policy = RoundRobinPolicy()
|
||||
policy.populate(None, [0])
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(qplan, [0])
|
||||
|
||||
def test_status_updates(self):
|
||||
hosts = [0, 1, 2, 3]
|
||||
policy = RoundRobinPolicy()
|
||||
policy.populate(None, hosts)
|
||||
policy.on_down(0)
|
||||
policy.on_remove(1)
|
||||
policy.on_up(4)
|
||||
policy.on_add(5)
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(sorted(qplan), [2, 3, 4, 5])
|
||||
|
||||
def test_thread_safety(self):
|
||||
hosts = range(100)
|
||||
policy = RoundRobinPolicy()
|
||||
policy.populate(None, hosts)
|
||||
|
||||
def check_query_plan():
|
||||
for i in range(100):
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(sorted(qplan), hosts)
|
||||
|
||||
threads = [Thread(target=check_query_plan) for i in range(4)]
|
||||
map(lambda t: t.start(), threads)
|
||||
map(lambda t: t.join(), threads)
|
||||
|
||||
|
||||
class TestDCAwareRoundRobinPolicy(unittest.TestCase):
|
||||
|
||||
def test_no_remote(self):
|
||||
hosts = []
|
||||
for i in range(4):
|
||||
h = Host(i, SimpleConvictionPolicy)
|
||||
h.set_location_info("dc1", "rack1")
|
||||
hosts.append(h)
|
||||
|
||||
policy = DCAwareRoundRobinPolicy("dc1")
|
||||
policy.populate(None, hosts)
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(sorted(qplan), sorted(hosts))
|
||||
|
||||
def test_with_remotes(self):
|
||||
hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)]
|
||||
for h in hosts[:2]:
|
||||
h.set_location_info("dc1", "rack1")
|
||||
for h in hosts[2:]:
|
||||
h.set_location_info("dc2", "rack1")
|
||||
|
||||
policy = DCAwareRoundRobinPolicy("dc1")
|
||||
policy.populate(None, hosts)
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(set(qplan[:2]), set(h for h in hosts if h.datacenter == "dc1"))
|
||||
self.assertEqual(set(qplan[2:]), set(h for h in hosts if h.datacenter != "dc1"))
|
||||
|
||||
def test_get_distance(self):
|
||||
policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)
|
||||
host = Host("ip1", SimpleConvictionPolicy)
|
||||
host.set_location_info("dc1", "rack1")
|
||||
policy.populate(None, [host])
|
||||
|
||||
self.assertEqual(policy.distance(host), HostDistance.LOCAL)
|
||||
|
||||
# used_hosts_per_remote_dc is set to 0, so ignore it
|
||||
remote_host = Host("ip2", SimpleConvictionPolicy)
|
||||
remote_host.set_location_info("dc2", "rack1")
|
||||
self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED)
|
||||
|
||||
# dc2 isn't registered in the policy's live_hosts dict
|
||||
policy.used_hosts_per_remote_dc = 1
|
||||
self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED)
|
||||
|
||||
# make sure the policy has both dcs registered
|
||||
policy.populate(None, [host, remote_host])
|
||||
self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE)
|
||||
|
||||
# since used_hosts_per_remote_dc is set to 1, only the first
|
||||
# remote host in dc2 will be REMOTE, the rest are IGNORED
|
||||
second_remote_host = Host("ip3", SimpleConvictionPolicy)
|
||||
second_remote_host.set_location_info("dc2", "rack1")
|
||||
policy.populate(None, [host, remote_host, second_remote_host])
|
||||
self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE)
|
||||
self.assertEqual(policy.distance(second_remote_host), HostDistance.IGNORED)
|
||||
|
||||
def test_status_updates(self):
|
||||
hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)]
|
||||
for h in hosts[:2]:
|
||||
h.set_location_info("dc1", "rack1")
|
||||
for h in hosts[2:]:
|
||||
h.set_location_info("dc2", "rack1")
|
||||
|
||||
policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1)
|
||||
policy.populate(None, hosts)
|
||||
policy.on_down(hosts[0])
|
||||
policy.on_remove(hosts[2])
|
||||
|
||||
new_local_host = Host(4, SimpleConvictionPolicy)
|
||||
new_local_host.set_location_info("dc1", "rack1")
|
||||
policy.on_up(new_local_host)
|
||||
|
||||
new_remote_host = Host(5, SimpleConvictionPolicy)
|
||||
new_remote_host.set_location_info("dc9000", "rack1")
|
||||
policy.on_add(new_remote_host)
|
||||
|
||||
# we now have two local hosts and two remote hosts in separate dcs
|
||||
qplan = list(policy.make_query_plan())
|
||||
self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host]))
|
||||
self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host]))
|
||||
|
||||
# since we have hosts in dc9000, the distance shouldn't be IGNORED
|
||||
self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE)
|
Reference in New Issue
Block a user