From 5160fbed147f38aabaa90ca3c01ecc3c8a262b96 Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Wed, 10 Apr 2013 14:11:38 -0500 Subject: [PATCH] Add tests for load balancing policies, related fixes --- cassandra/policies.py | 34 +++++----- cassandra/pool.py | 8 +++ tests/test_policies.py | 139 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 16 deletions(-) create mode 100644 tests/test_policies.py diff --git a/cassandra/policies.py b/cassandra/policies.py index d0a5f848..7d7d1d90 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -59,35 +59,37 @@ class RoundRobinPolicy(LoadBalancingPolicy): class DCAwareRoundRobinPolicy(LoadBalancingPolicy): def __init__(self, local_dc, used_hosts_per_remote_dc=0): - LoadBalancingPolicy.__init__(self) self.local_dc = local_dc self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} + self._lock = RLock() def populate(self, cluster, hosts): - for dc, hosts in groupby(hosts, lambda h: h.dc): - self._dc_live_hosts[dc] = set(hosts) + for dc, dc_hosts in groupby(hosts, lambda h: h.datacenter): + self._dc_live_hosts[dc] = set(dc_hosts) - if len(hosts) == 1: + # position is currently only used for local hosts + local_live = self._dc_live_hosts.get(self.local_dc) + if len(local_live) == 1: self._position = 0 else: - self._position = randint(0, len(hosts) - 1) + self._position = randint(0, len(local_live) - 1) def distance(self, host): - if host.dc == self.local_dc: + if host.datacenter == self.local_dc: return HostDistance.LOCAL if not self.used_hosts_per_remote_dc: - return HostDistance.IGNORE + return HostDistance.IGNORED else: - dc_hosts = self._dc_live_hosts.get(host.dc) + dc_hosts = self._dc_live_hosts.get(host.datacenter) if not dc_hosts: - return HostDistance.IGNORE + return HostDistance.IGNORED if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]: return HostDistance.REMOTE else: - return HostDistance.IGNORE + return HostDistance.IGNORED def make_query_plan(self, query=None): with self._lock: @@ -96,7 +98,7 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): local_live = list(self._dc_live_hosts.get(self.local_dc)) pos %= len(local_live) - for host in islice(cycle(local_live, pos, pos + len(local_live))): + for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host for dc, current_dc_hosts in self._dc_live_hosts.iteritems(): @@ -107,16 +109,16 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): yield host def on_up(self, host): - self._dc_live_hosts.setdefault(host.dc, set()).add(host) + self._dc_live_hosts.setdefault(host.datacenter, set()).add(host) def on_down(self, host): - self._dc_live_hosts.setdefault(host.dc, set()).discard(host) + self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host) def on_add(self, host): - self._dc_live_hosts.setdefault(host.dc, set()).add(host) + self._dc_live_hosts.setdefault(host.datacenter, set()).add(host) def on_remove(self, host): - self._dc_live_hosts.setdefault(host.dc, set()).discard(host) + self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host) class SimpleConvictionPolicy(object): @@ -236,7 +238,7 @@ class FallthroughRetryPolicy(RetryPolicy): if attempt_num != 0: return (self.RETHROW, None) elif write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): - return (self.IGNORE, None) + return (self.IGNORED, None) elif write_type == WriteType.UNLOGGED_BATCH: return self._pick_consistency(received_responses) elif write_type == WriteType.BATCH_LOG: diff --git a/cassandra/pool.py b/cassandra/pool.py index 1991fdab..e0719834 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -44,6 +44,14 @@ class Host(object): self._reconnection_handler = None self._reconnection_lock = Lock() + @property + def datacenter(self): + return self._datacenter + + @property + def rack(self): + return self._rack + def set_location_info(self, datacenter, rack): self._datacenter = datacenter self._rack = rack diff --git a/tests/test_policies.py b/tests/test_policies.py new file mode 100644 index 00000000..eeae6164 --- /dev/null +++ b/tests/test_policies.py @@ -0,0 +1,139 @@ +import unittest +from threading import Thread + +from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, + SimpleConvictionPolicy, HostDistance) +from cassandra.pool import Host + +class TestRoundRobinPolicy(unittest.TestCase): + + def test_basic(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), hosts) + + def test_multiple_query_plans(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + for i in xrange(20): + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), hosts) + + def test_single_host(self): + policy = RoundRobinPolicy() + policy.populate(None, [0]) + qplan = list(policy.make_query_plan()) + self.assertEqual(qplan, [0]) + + def test_status_updates(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + policy.on_down(0) + policy.on_remove(1) + policy.on_up(4) + policy.on_add(5) + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), [2, 3, 4, 5]) + + def test_thread_safety(self): + hosts = range(100) + policy = RoundRobinPolicy() + policy.populate(None, hosts) + + def check_query_plan(): + for i in range(100): + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), hosts) + + threads = [Thread(target=check_query_plan) for i in range(4)] + map(lambda t: t.start(), threads) + map(lambda t: t.join(), threads) + + +class TestDCAwareRoundRobinPolicy(unittest.TestCase): + + def test_no_remote(self): + hosts = [] + for i in range(4): + h = Host(i, SimpleConvictionPolicy) + h.set_location_info("dc1", "rack1") + hosts.append(h) + + policy = DCAwareRoundRobinPolicy("dc1") + policy.populate(None, hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(sorted(qplan), sorted(hosts)) + + def test_with_remotes(self): + hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = DCAwareRoundRobinPolicy("dc1") + policy.populate(None, hosts) + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan[:2]), set(h for h in hosts if h.datacenter == "dc1")) + self.assertEqual(set(qplan[2:]), set(h for h in hosts if h.datacenter != "dc1")) + + def test_get_distance(self): + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + host = Host("ip1", SimpleConvictionPolicy) + host.set_location_info("dc1", "rack1") + policy.populate(None, [host]) + + self.assertEqual(policy.distance(host), HostDistance.LOCAL) + + # used_hosts_per_remote_dc is set to 0, so ignore it + remote_host = Host("ip2", SimpleConvictionPolicy) + remote_host.set_location_info("dc2", "rack1") + self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + + # dc2 isn't registered in the policy's live_hosts dict + policy.used_hosts_per_remote_dc = 1 + self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + + # make sure the policy has both dcs registered + policy.populate(None, [host, remote_host]) + self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) + + # since used_hosts_per_remote_dc is set to 1, only the first + # remote host in dc2 will be REMOTE, the rest are IGNORED + second_remote_host = Host("ip3", SimpleConvictionPolicy) + second_remote_host.set_location_info("dc2", "rack1") + policy.populate(None, [host, remote_host, second_remote_host]) + self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) + self.assertEqual(policy.distance(second_remote_host), HostDistance.IGNORED) + + def test_status_updates(self): + hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy.populate(None, hosts) + policy.on_down(hosts[0]) + policy.on_remove(hosts[2]) + + new_local_host = Host(4, SimpleConvictionPolicy) + new_local_host.set_location_info("dc1", "rack1") + policy.on_up(new_local_host) + + new_remote_host = Host(5, SimpleConvictionPolicy) + new_remote_host.set_location_info("dc9000", "rack1") + policy.on_add(new_remote_host) + + # we now have two local hosts and two remote hosts in separate dcs + qplan = list(policy.make_query_plan()) + self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host])) + self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host])) + + # since we have hosts in dc9000, the distance shouldn't be IGNORED + self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE)