diff --git a/cassandra/policies.py b/cassandra/policies.py index 7d7d1d90..e990dd87 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -158,7 +158,7 @@ class ExponentialReconnectionPolicy(object): self.max_delay = max_delay def new_schedule(self): - return (min(self.base_delay * (i ** 2), self.max_delay) for i in xrange(64)) + return (min(self.base_delay * (2 ** i), self.max_delay) for i in xrange(64)) class WriteType(object): diff --git a/tests/test_policies.py b/tests/test_policies.py index eeae6164..88ce7b8f 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -2,7 +2,8 @@ import unittest from threading import Thread from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy, - SimpleConvictionPolicy, HostDistance) + SimpleConvictionPolicy, HostDistance, + ExponentialReconnectionPolicy) from cassandra.pool import Host class TestRoundRobinPolicy(unittest.TestCase): @@ -107,8 +108,8 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase): second_remote_host = Host("ip3", SimpleConvictionPolicy) second_remote_host.set_location_info("dc2", "rack1") policy.populate(None, [host, remote_host, second_remote_host]) - self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) - self.assertEqual(policy.distance(second_remote_host), HostDistance.IGNORED) + distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) + self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) def test_status_updates(self): hosts = [Host(i, SimpleConvictionPolicy) for i in range(4)] @@ -137,3 +138,23 @@ class TestDCAwareRoundRobinPolicy(unittest.TestCase): # since we have hosts in dc9000, the distance shouldn't be IGNORED self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) + + +class ExponentialReconnectionPolicyTest(unittest.TestCase): + + def test_bad_vals(self): + self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1) + + def test_schedule(self): + policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100) + schedule = list(policy.new_schedule()) + self.assertEqual(len(schedule), 64) + for i, delay in enumerate(schedule): + if i == 0: + self.assertEqual(delay, 2) + elif i < 6: + self.assertEqual(delay, schedule[i - 1] * 2) + else: + self.assertEqual(delay, 100)