From bcfa8123c165143ae3296085cac8f149cec9ea07 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 1 Jan 2016 13:19:53 +0800 Subject: [PATCH] improve ExponentialReconnectionPolicy,now can custom max attempts time --- cassandra/policies.py | 15 +++++++++++++-- tests/unit/test_policies.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/cassandra/policies.py b/cassandra/policies.py index 8132f8ab..3d4b4a4b 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -528,10 +528,14 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy): a set maximum delay. """ - def __init__(self, base_delay, max_delay): + def __init__(self, base_delay, max_delay, max_attempts=64): """ `base_delay` and `max_delay` should be in floating point units of seconds. + + `max_attempts` should be a total number of attempts to be made before + giving up, or :const:`None` to continue reconnection attempts forever. + The default is 64. """ if base_delay < 0 or max_delay < 0: raise ValueError("Delays may not be negative") @@ -539,11 +543,18 @@ class ExponentialReconnectionPolicy(ReconnectionPolicy): if max_delay < base_delay: raise ValueError("Max delay must be greater than base delay") + if max_attempts is not None and max_attempts < 0: + raise ValueError("max_attempts must not be negative") + self.base_delay = base_delay self.max_delay = max_delay + self.max_attempts = max_attempts def new_schedule(self): - return (min(self.base_delay * (2 ** i), self.max_delay) for i in range(64)) + i=0 + while self.max_attempts == None or i < self.max_attempts: + yield min(self.base_delay * (2 ** i), self.max_delay) + i += 1 class WriteType(object): diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 97e24455..2ca63491 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -814,9 +814,18 @@ class ExponentialReconnectionPolicyTest(unittest.TestCase): self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1) self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1) + self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2,-1) def test_schedule(self): - policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100) + policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100, max_attempts=None) + i=0; + for delay in policy.new_schedule(): + i += 1 + if i > 10000: + break; + self.assertEqual(i, 10001) + + policy = ExponentialReconnectionPolicy(base_delay=2, max_delay=100, max_attempts=64) schedule = list(policy.new_schedule()) self.assertEqual(len(schedule), 64) for i, delay in enumerate(schedule):