Make sure Lock.heartbeat() returns True/False

- add a test to be sure that no lock refresh is made on non-acquired locks
- add thread safety for etcd3 drivers on heartbeat() (like Redis)

Change-Id: I350ea10163d59a06712c22c4c5be4dfcb6885bf8
This commit is contained in:
Julien Danjou 2017-06-27 17:39:22 +02:00
parent c94b2a39b3
commit 7987f4455a
6 changed files with 49 additions and 19 deletions

View File

@ -188,12 +188,14 @@ class EtcdLock(locking.Lock):
poked = self.client.put(self._lock_url,
data={"ttl": self.ttl,
"prevExist": "true"}, make_url=False)
errorcode = poked.get("errorCode")
if errorcode:
LOG.warning("Unable to heartbeat by updating key '%s' with "
"extended expiry of %s seconds: %d, %s", self.name,
self.ttl, errorcode, poked.get("message"))
self._node = poked['node']
errorcode = poked.get("errorCode")
if not errorcode:
return True
LOG.warning("Unable to heartbeat by updating key '%s' with "
"extended expiry of %s seconds: %d, %s", self.name,
self.ttl, errorcode, poked.get("message"))
return False
class EtcdDriver(coordination.CoordinationDriver):

View File

@ -13,6 +13,7 @@
# under the License.
from __future__ import absolute_import
import threading
import etcd3
from etcd3 import exceptions as etcd3_exc
@ -61,6 +62,7 @@ class Etcd3Lock(locking.Lock):
super(Etcd3Lock, self).__init__(name)
self._coord = coord
self._lock = coord.client.lock(name.decode(), timeout)
self._exclusive_access = threading.Lock()
@_translate_failures
def acquire(self, blocking=True, shared=False):
@ -83,14 +85,19 @@ class Etcd3Lock(locking.Lock):
@_translate_failures
def release(self):
if self.acquired and self._lock.release():
self._coord._acquired_locks.discard(self)
return True
with self._exclusive_access:
if self.acquired and self._lock.release():
self._coord._acquired_locks.discard(self)
return True
return False
@_translate_failures
def heartbeat(self):
self._lock.refresh()
with self._exclusive_access:
if self.acquired:
self._lock.refresh()
return True
return False
class Etcd3Driver(coordination.CoordinationDriver):

View File

@ -14,6 +14,7 @@
from __future__ import absolute_import
import base64
import threading
import uuid
import etcd3gw
@ -68,6 +69,7 @@ class Etcd3Lock(locking.Lock):
self._key_b64 = base64.b64encode(self._key).decode("ascii")
self._uuid = base64.b64encode(uuid.uuid4().bytes).decode("ascii")
self._lease = self._coord.client.lease(self._timeout)
self._exclusive_access = threading.Lock()
@_translate_failures
def acquire(self, blocking=True, shared=False):
@ -126,11 +128,12 @@ class Etcd3Lock(locking.Lock):
}]
}
result = self._coord.client.transaction(txn)
success = result.get('succeeded', False)
if success:
self._coord._acquired_locks.remove(self)
return True
with self._exclusive_access:
result = self._coord.client.transaction(txn)
success = result.get('succeeded', False)
if success:
self._coord._acquired_locks.remove(self)
return True
return False
@_translate_failures
@ -140,9 +143,17 @@ class Etcd3Lock(locking.Lock):
return True
return False
@property
def acquired(self):
return self in self._coord._acquired_locks
@_translate_failures
def heartbeat(self):
self._lease.refresh()
with self._exclusive_access:
if self.acquired:
self._lease.refresh()
return True
return False
class Etcd3Driver(coordination.CoordinationDriver):

View File

@ -165,10 +165,12 @@ class MemcachedLock(locking.Lock):
poked = self.coord.client.touch(self.name,
expire=self.timeout,
noreply=False)
if not poked:
LOG.warning("Unable to heartbeat by updating key '%s' with "
"extended expiry of %s seconds", self.name,
self.timeout)
if poked:
return True
LOG.warning("Unable to heartbeat by updating key '%s' with "
"extended expiry of %s seconds", self.name,
self.timeout)
return False
@_translate_failures
def get_owner(self):

View File

@ -109,6 +109,8 @@ class RedisLock(locking.Lock):
if self.acquired:
with _translate_failures():
self._lock.extend(self._lock.timeout)
return True
return False
@property
def acquired(self):

View File

@ -712,6 +712,12 @@ class TestAPI(tests.TestWithCoordinator):
with lock:
pass
def test_heartbeat_lock_not_acquired(self):
lock = self._coord.get_lock(tests.get_random_uuid())
# Not all locks need heartbeat
if hasattr(lock, "heartbeat"):
self.assertFalse(lock.heartbeat())
def test_get_shared_lock(self):
lock = self._coord.get_lock(tests.get_random_uuid())
self.assertTrue(lock.acquire(shared=True))