From 816ea504c92b141b34759faaf814698499fa0c75 Mon Sep 17 00:00:00 2001
From: Pavlo Shchelokovskyy <shchelokovskyy@gmail.com>
Date: Thu, 10 Jun 2021 16:32:32 +0300
Subject: [PATCH] Enable retries in redis driver

this is followup to Iaab5ce609c0dcf7085f5dd43efbd37eb4b88f17b

actually retry for specified number of retries instead of raising
error on first ConnectionError

Change-Id: Ibca3f568b65dfea252da4b67f6d5105ba7f1ecb1
(cherry picked from commit 47c4d56e446a49726e7145c6cbc3bb7620a431f7)
(cherry picked from commit 11526e594c7e7cef8b5a4a4220f100d76a418036)
(cherry picked from commit c1ae1649d22b356f2bbb2293cb421eb155fdee14)
---
 ...edis-connect-retries-c9adfc81eb06a4ab.yaml |  5 ++
 tooz/drivers/redis.py                         | 72 +++++++++----------
 2 files changed, 41 insertions(+), 36 deletions(-)
 create mode 100644 releasenotes/notes/redis-connect-retries-c9adfc81eb06a4ab.yaml

diff --git a/releasenotes/notes/redis-connect-retries-c9adfc81eb06a4ab.yaml b/releasenotes/notes/redis-connect-retries-c9adfc81eb06a4ab.yaml
new file mode 100644
index 00000000..bf613079
--- /dev/null
+++ b/releasenotes/notes/redis-connect-retries-c9adfc81eb06a4ab.yaml
@@ -0,0 +1,5 @@
+---
+features:
+  - |
+    Redis driver retries actions for up to 15 times when met with error
+    connecting to Redis.
diff --git a/tooz/drivers/redis.py b/tooz/drivers/redis.py
index d4e68048..3105bc87 100644
--- a/tooz/drivers/redis.py
+++ b/tooz/drivers/redis.py
@@ -39,7 +39,7 @@ from tooz import utils
 LOG = logging.getLogger(__name__)
 
 
-def _handle_failures(func=None, n_tries=15):
+def _handle_failures(n_tries=15):
 
     """Translates common redis exceptions into tooz exceptions.
 
@@ -48,37 +48,37 @@ def _handle_failures(func=None, n_tries=15):
     :param func: the function to act on
     :param n_tries: the number of retries
     """
+    def inner(func):
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            ntries = n_tries
+            while ntries:
+                try:
+                    return func(*args, **kwargs)
+                except exceptions.ConnectionError as e:
+                    # retry ntries times and then raise a connection error
+                    ntries -= 1
+                    if not ntries:
+                        LOG.debug(
+                            "Redis connection error, "
+                            "retry limit has been reached, aborting - %s", e
+                        )
+                        utils.raise_with_cause(
+                            coordination.ToozConnectionError,
+                            encodeutils.exception_to_unicode(e),
+                            cause=e)
+                    LOG.debug("Redis connection error, will retry - %s", e)
 
-    if func is None:
-        return functools.partial(
-                _handle_failures,
-                n_tries=n_tries
-                )
-
-    @functools.wraps(func)
-    def wrapper(*args, **kwargs):
-        ntries = n_tries
-        while ntries > 1:
-            try:
-                return func(*args, **kwargs)
-            except exceptions.ConnectionError as e:
-                # retry ntries times and then raise a connection error
-                ntries -= 1
-                if ntries >= 1:
+                except (exceptions.TimeoutError) as e:
                     utils.raise_with_cause(coordination.ToozConnectionError,
                                            encodeutils.exception_to_unicode(e),
                                            cause=e)
-
-            except (exceptions.TimeoutError) as e:
-                utils.raise_with_cause(coordination.ToozConnectionError,
-                                       encodeutils.exception_to_unicode(e),
-                                       cause=e)
-            except exceptions.RedisError as e:
-                utils.raise_with_cause(tooz.ToozError,
-                                       encodeutils.exception_to_unicode(e),
-                                       cause=e)
-        return func(*args, **kwargs)
-    return wrapper
+                except exceptions.RedisError as e:
+                    utils.raise_with_cause(tooz.ToozError,
+                                           encodeutils.exception_to_unicode(e),
+                                           cause=e)
+        return wrapper
+    return inner
 
 
 class RedisLock(locking.Lock):
@@ -94,7 +94,7 @@ class RedisLock(locking.Lock):
         self._coord = coord
         self._client = client
 
-    @_handle_failures
+    @_handle_failures()
     def is_still_owner(self):
         lock_tok = self._lock.local.token
         if not lock_tok:
@@ -102,11 +102,11 @@ class RedisLock(locking.Lock):
         owner_tok = self._client.get(self.name)
         return owner_tok == lock_tok
 
-    @_handle_failures
+    @_handle_failures()
     def break_(self):
         return bool(self._client.delete(self.name))
 
-    @_handle_failures
+    @_handle_failures()
     def acquire(self, blocking=True, shared=False):
         if shared:
             raise tooz.NotImplemented
@@ -118,7 +118,7 @@ class RedisLock(locking.Lock):
                 self._coord._acquired_locks.add(self)
         return acquired
 
-    @_handle_failures
+    @_handle_failures()
     def release(self):
         with self._exclusive_access:
             try:
@@ -130,7 +130,7 @@ class RedisLock(locking.Lock):
                 self._coord._acquired_locks.discard(self)
             return True
 
-    @_handle_failures
+    @_handle_failures()
     def heartbeat(self):
         with self._exclusive_access:
             if self.acquired:
@@ -464,7 +464,7 @@ return 1
             return master_client
         return redis.StrictRedis(**kwargs)
 
-    @_handle_failures
+    @_handle_failures()
     def _start(self):
         super(RedisDriver, self)._start()
         try:
@@ -537,7 +537,7 @@ return 1
     def _decode_group_id(self, group_id):
         return utils.to_binary(group_id, encoding=self._encoding)
 
-    @_handle_failures
+    @_handle_failures()
     def heartbeat(self):
         beat_id = self._encode_beat_id(self._member_id)
         expiry_ms = max(0, int(self.membership_timeout * 1000.0))
@@ -552,7 +552,7 @@ return 1
                             exc_info=True)
         return min(self.lock_timeout, self.membership_timeout)
 
-    @_handle_failures
+    @_handle_failures()
     def _stop(self):
         while self._acquired_locks:
             lock = self._acquired_locks.pop()