From 32cdf827695c7e4506a68c32bff3f738053f3306 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Thu, 18 Jun 2015 17:28:05 -0700 Subject: [PATCH] Ensure lock(s) acquire/release returns boolean values Some of the returns were missing, so to ensure that the interface is consistent ensure that we always return true or false as needed, and add tests to ensure this keeps on being maintained. Change-Id: I877111c3a1d1d300881625c5900d4a825c5f1684 --- tooz/drivers/file.py | 3 +++ tooz/drivers/mysql.py | 3 +++ tooz/drivers/pgsql.py | 4 ++++ tooz/drivers/zookeeper.py | 10 +++++---- tooz/locking.py | 15 +++++++++++--- tooz/tests/test_coordination.py | 36 +++++++++++++++++++-------------- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/tooz/drivers/file.py b/tooz/drivers/file.py index 15d17abd..6d3730b1 100644 --- a/tooz/drivers/file.py +++ b/tooz/drivers/file.py @@ -101,6 +101,9 @@ class FileLock(locking.Lock): self._lock.release() self.acquired = False self._cond.notify_all() + return True + else: + return False def __del__(self): if self.acquired: diff --git a/tooz/drivers/mysql.py b/tooz/drivers/mysql.py index 82c22131..4ef3ea95 100644 --- a/tooz/drivers/mysql.py +++ b/tooz/drivers/mysql.py @@ -73,11 +73,14 @@ class MySQLLock(locking.Lock): return _lock() def release(self): + if not self.acquired: + return False try: with self._conn as cur: cur.execute("SELECT RELEASE_LOCK(%s);", self.name) cur.fetchone() self.acquired = False + return True except pymysql.MySQLError as e: coordination.raise_with_cause(coordination.ToozError, utils.exception_message(e), diff --git a/tooz/drivers/pgsql.py b/tooz/drivers/pgsql.py index 02624cac..0cea947b 100644 --- a/tooz/drivers/pgsql.py +++ b/tooz/drivers/pgsql.py @@ -103,6 +103,7 @@ class PostgresLock(locking.Lock): self.key = h.digest()[0:2] def acquire(self, blocking=True): + @_retry.retry(stop_max_delay=blocking) def _lock(): # NOTE(sileht) One the same session the lock is not exclusive @@ -134,11 +135,14 @@ class PostgresLock(locking.Lock): return _lock() def release(self): + if not self.acquired: + return False with _translating_cursor(self._conn) as cur: cur.execute("SELECT pg_advisory_unlock(%s, %s);", self.key) cur.fetchone() self.acquired = False + return True def __del__(self): if self.acquired: diff --git a/tooz/drivers/zookeeper.py b/tooz/drivers/zookeeper.py index 55e60525..38e07470 100644 --- a/tooz/drivers/zookeeper.py +++ b/tooz/drivers/zookeeper.py @@ -45,12 +45,10 @@ class ZooKeeperLock(locking.Lock): if blocking: raise _retry.Retry return False - if self._lock.acquire(blocking=bool(blocking), timeout=0): self.acquired = True return True - if blocking: raise _retry.Retry return False @@ -58,8 +56,12 @@ class ZooKeeperLock(locking.Lock): return _lock() def release(self): - self._lock.release() - self.acquired = False + if self.acquired: + self._lock.release() + self.acquired = False + return True + else: + return False class BaseZooKeeperDriver(coordination.CoordinationDriver): diff --git a/tooz/locking.py b/tooz/locking.py index 98be2ff0..cccc0b95 100644 --- a/tooz/locking.py +++ b/tooz/locking.py @@ -99,6 +99,15 @@ class SharedWeakLockHelper(Lock): def release(self): with self.LOCKS_LOCK: - l = self.ACQUIRED_LOCKS.pop(self._lock_key) - self.RELEASED_LOCKS[self._lock_key] = l - l.release() + try: + l = self.ACQUIRED_LOCKS.pop(self._lock_key) + except KeyError: + return False + else: + if l.release(): + self.RELEASED_LOCKS[self._lock_key] = l + return True + else: + # Put it back... + self.ACQUIRED_LOCKS[self._lock_key] = l + return False diff --git a/tooz/tests/test_coordination.py b/tooz/tests/test_coordination.py index 7f3c3fd6..c9be3b1a 100644 --- a/tooz/tests/test_coordination.py +++ b/tooz/tests/test_coordination.py @@ -588,8 +588,8 @@ class TestAPI(testscenarios.TestWithScenarios, def test_get_lock(self): lock = self._coord.get_lock(self._get_random_uuid()) - self.assertEqual(True, lock.acquire()) - lock.release() + self.assertTrue(lock.acquire()) + self.assertTrue(lock.release()) with lock: pass @@ -600,7 +600,7 @@ class TestAPI(testscenarios.TestWithScenarios, def thread(): self.assertTrue(lock.acquire()) - lock.release() + self.assertTrue(lock.release()) graceful_ending.set() t = threading.Thread(target=thread) @@ -668,23 +668,29 @@ class TestAPI(testscenarios.TestWithScenarios, lock1 = self._coord.get_lock(name) lock2 = self._coord.get_lock(name) with lock1: - self.assertEqual(False, lock2.acquire(blocking=False)) + self.assertFalse(lock2.acquire(blocking=False)) def test_get_lock_locked_twice(self): name = self._get_random_uuid() lock = self._coord.get_lock(name) with lock: - self.assertEqual(False, lock.acquire(blocking=False)) + self.assertFalse(lock.acquire(blocking=False)) def test_get_multiple_locks_with_same_coord(self): name = self._get_random_uuid() lock1 = self._coord.get_lock(name) lock2 = self._coord.get_lock(name) - self.assertEqual(True, lock1.acquire()) - self.assertEqual(False, lock2.acquire(blocking=False)) - self.assertEqual(False, - self._coord.get_lock(name).acquire(blocking=False)) - lock1.release() + self.assertTrue(lock1.acquire()) + self.assertFalse(lock2.acquire(blocking=False)) + self.assertFalse(self._coord.get_lock(name).acquire(blocking=False)) + self.assertTrue(lock1.release()) + + def test_ensure_acquire_release_return(self): + name = self._get_random_uuid() + lock1 = self._coord.get_lock(name) + self.assertTrue(lock1.acquire()) + self.assertTrue(lock1.release()) + self.assertFalse(lock1.release()) def test_get_lock_multiple_coords(self): member_id2 = self._get_random_uuid() @@ -694,13 +700,13 @@ class TestAPI(testscenarios.TestWithScenarios, lock_name = self._get_random_uuid() lock = self._coord.get_lock(lock_name) - self.assertEqual(True, lock.acquire()) + self.assertTrue(lock.acquire()) lock2 = client2.get_lock(lock_name) - self.assertEqual(False, lock2.acquire(blocking=False)) - lock.release() - self.assertEqual(True, lock2.acquire(blocking=True)) - lock2.release() + self.assertFalse(lock2.acquire(blocking=False)) + self.assertTrue(lock.release()) + self.assertTrue(lock2.acquire(blocking=True)) + self.assertTrue(lock2.release()) @staticmethod def _get_random_uuid():