diff --git a/tooz/coordination.py b/tooz/coordination.py index 45780a58..ab7bad68 100644 --- a/tooz/coordination.py +++ b/tooz/coordination.py @@ -436,6 +436,10 @@ class OperationTimedOut(ToozError): """Exception raised when an operation times out.""" +class LockAcquireFailed(ToozError): + """Exception raised when a lock acquire fails in a context manager.""" + + class GroupNotCreated(ToozError): """Exception raised when the caller request a group which does not exist. diff --git a/tooz/locking.py b/tooz/locking.py index 381fecb4..3f8b2331 100644 --- a/tooz/locking.py +++ b/tooz/locking.py @@ -19,6 +19,8 @@ import six import threading import weakref +from tooz import coordination + @six.add_metaclass(abc.ABCMeta) class Lock(object): @@ -32,7 +34,12 @@ class Lock(object): return self._name def __enter__(self): - self.acquire() + acquired = self.acquire() + if not acquired: + msg = u'Acquiring lock %s failed' % self.name + raise coordination.LockAcquireFailed(msg) + + return self def __exit__(self, exc_type, exc_val, exc_tb): self.release() diff --git a/tooz/tests/test_coordination.py b/tooz/tests/test_coordination.py index 12007a18..26745c61 100644 --- a/tooz/tests/test_coordination.py +++ b/tooz/tests/test_coordination.py @@ -21,6 +21,7 @@ import uuid from concurrent import futures import fixtures +import mock import testscenarios from testtools import matchers from testtools import testcase @@ -682,6 +683,22 @@ class TestAPI(testscenarios.TestWithScenarios, with lock1: self.assertFalse(lock2.acquire(blocking=False)) + def test_get_lock_context_fails(self): + name = self._get_random_uuid() + lock1 = self._coord.get_lock(name) + lock2 = self._coord.get_lock(name) + with mock.patch.object(lock2, 'acquire', return_value=False): + with lock1: + self.assertRaises( + tooz.coordination.LockAcquireFailed, + lock2.__enter__) + + def test_get_lock_context_check_value(self): + name = self._get_random_uuid() + lock = self._coord.get_lock(name) + with lock as returned_lock: + self.assertEqual(lock, returned_lock) + def test_get_lock_locked_twice(self): name = self._get_random_uuid() lock = self._coord.get_lock(name)