Merge "Two locks acquired from one coord must works"
This commit is contained in:
commit
3f64bf5d13
@ -15,7 +15,6 @@
|
||||
# under the License.
|
||||
import errno
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import tooz
|
||||
from tooz import coordination
|
||||
@ -94,8 +93,6 @@ else:
|
||||
class FileDriver(coordination.CoordinationDriver):
|
||||
"""A file based driver."""
|
||||
|
||||
LOCKS = weakref.WeakValueDictionary()
|
||||
|
||||
def __init__(self, member_id, parsed_url, options):
|
||||
"""Initialize the file driver."""
|
||||
super(FileDriver, self).__init__()
|
||||
@ -103,8 +100,7 @@ class FileDriver(coordination.CoordinationDriver):
|
||||
|
||||
def get_lock(self, name):
|
||||
path = os.path.abspath(os.path.join(self._lockdir, name.decode()))
|
||||
lock = LockClass(path)
|
||||
return self.LOCKS.setdefault(path, lock)
|
||||
return locking.SharedWeakLockHelper(self._lockdir, LockClass, path)
|
||||
|
||||
@staticmethod
|
||||
def watch_join_group(group_id, callback):
|
||||
|
@ -25,41 +25,38 @@ from tooz import utils
|
||||
class MySQLLock(locking.Lock):
|
||||
"""A MySQL based lock."""
|
||||
|
||||
def __init__(self, name, connection):
|
||||
def __init__(self, name, parsed_url, options):
|
||||
super(MySQLLock, self).__init__(name)
|
||||
self._conn = connection
|
||||
self._conn = MySQLDriver.get_connection(parsed_url, options)
|
||||
|
||||
def acquire(self, blocking=True):
|
||||
if blocking is False:
|
||||
def _acquire(retry=False):
|
||||
try:
|
||||
cur = self._conn.cursor()
|
||||
cur.execute("SELECT GET_LOCK(%s, 0);", self.name)
|
||||
# Can return NULL on error
|
||||
if cur.fetchone()[0] is 1:
|
||||
return True
|
||||
return False
|
||||
except pymysql.MySQLError as e:
|
||||
raise coordination.ToozError(utils.exception_message(e))
|
||||
else:
|
||||
def _acquire():
|
||||
try:
|
||||
cur = self._conn.cursor()
|
||||
with self._conn as cur:
|
||||
cur.execute("SELECT GET_LOCK(%s, 0);", self.name)
|
||||
# Can return NULL on error
|
||||
if cur.fetchone()[0] is 1:
|
||||
return True
|
||||
except pymysql.MySQLError as e:
|
||||
raise coordination.ToozError(utils.exception_message(e))
|
||||
except pymysql.MySQLError as e:
|
||||
raise coordination.ToozError(utils.exception_message(e))
|
||||
if retry:
|
||||
raise _retry.Retry
|
||||
else:
|
||||
return False
|
||||
|
||||
if blocking is False:
|
||||
return _acquire()
|
||||
else:
|
||||
kwargs = _retry.RETRYING_KWARGS.copy()
|
||||
if blocking is not True:
|
||||
kwargs['stop_max_delay'] = blocking
|
||||
return _retry.Retrying(**kwargs).call(_acquire)
|
||||
return _retry.Retrying(**kwargs).call(_acquire, retry=True)
|
||||
|
||||
def release(self):
|
||||
try:
|
||||
cur = self._conn.cursor()
|
||||
cur.execute("SELECT RELEASE_LOCK(%s);", self.name)
|
||||
return cur.fetchone()[0]
|
||||
with self._conn as cur:
|
||||
cur.execute("SELECT RELEASE_LOCK(%s);", self.name)
|
||||
return cur.fetchone()[0]
|
||||
except pymysql.MySQLError as e:
|
||||
raise coordination.ToozError(utils.exception_message(e))
|
||||
|
||||
@ -70,35 +67,20 @@ class MySQLDriver(coordination.CoordinationDriver):
|
||||
def __init__(self, member_id, parsed_url, options):
|
||||
"""Initialize the MySQL driver."""
|
||||
super(MySQLDriver, self).__init__()
|
||||
self._host = parsed_url.netloc
|
||||
self._port = parsed_url.port
|
||||
self._dbname = parsed_url.path[1:]
|
||||
self._username = parsed_url.username
|
||||
self._password = parsed_url.password
|
||||
self._unix_socket = options.get("unix_socket", [None])[-1]
|
||||
self._parsed_url = parsed_url
|
||||
self._options = options
|
||||
|
||||
def _start(self):
|
||||
try:
|
||||
if self._unix_socket:
|
||||
self._conn = pymysql.Connect(unix_socket=self._unix_socket,
|
||||
port=self._port,
|
||||
user=self._username,
|
||||
passwd=self._password,
|
||||
database=self._dbname)
|
||||
else:
|
||||
self._conn = pymysql.Connect(host=self._host,
|
||||
port=self._port,
|
||||
user=self._username,
|
||||
passwd=self._password,
|
||||
database=self._dbname)
|
||||
except pymysql.err.OperationalError as e:
|
||||
raise coordination.ToozConnectionError(utils.exception_message(e))
|
||||
self._conn = MySQLDriver.get_connection(self._parsed_url,
|
||||
self._options)
|
||||
|
||||
def _stop(self):
|
||||
self._conn.close()
|
||||
|
||||
def get_lock(self, name):
|
||||
return MySQLLock(name, self._conn)
|
||||
return locking.WeakLockHelper(
|
||||
self._parsed_url.geturl(),
|
||||
MySQLLock, name, self._parsed_url, self._options)
|
||||
|
||||
@staticmethod
|
||||
def watch_join_group(group_id, callback):
|
||||
@ -123,3 +105,28 @@ class MySQLDriver(coordination.CoordinationDriver):
|
||||
@staticmethod
|
||||
def unwatch_elected_as_leader(group_id, callback):
|
||||
raise tooz.NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def get_connection(parsed_url, options):
|
||||
host = parsed_url.netloc
|
||||
port = parsed_url.port
|
||||
dbname = parsed_url.path[1:]
|
||||
username = parsed_url.username
|
||||
password = parsed_url.password
|
||||
unix_socket = options.get("unix_socket", [None])[-1]
|
||||
|
||||
try:
|
||||
if unix_socket:
|
||||
return pymysql.Connect(unix_socket=unix_socket,
|
||||
port=port,
|
||||
user=username,
|
||||
passwd=password,
|
||||
database=dbname)
|
||||
else:
|
||||
return pymysql.Connect(host=host,
|
||||
port=port,
|
||||
user=username,
|
||||
passwd=password,
|
||||
database=dbname)
|
||||
except pymysql.err.OperationalError as e:
|
||||
raise coordination.ToozConnectionError(utils.exception_message(e))
|
||||
|
@ -87,9 +87,9 @@ def _translating_cursor(conn):
|
||||
class PostgresLock(locking.Lock):
|
||||
"""A PostgreSQL based lock."""
|
||||
|
||||
def __init__(self, name, connection):
|
||||
def __init__(self, name, parsed_url, options):
|
||||
super(PostgresLock, self).__init__(name)
|
||||
self._conn = connection
|
||||
self._conn = PostgresDriver.get_connection(parsed_url, options)
|
||||
h = hashlib.md5()
|
||||
h.update(name)
|
||||
if six.PY2:
|
||||
@ -130,27 +130,20 @@ class PostgresDriver(coordination.CoordinationDriver):
|
||||
def __init__(self, member_id, parsed_url, options):
|
||||
"""Initialize the PostgreSQL driver."""
|
||||
super(PostgresDriver, self).__init__()
|
||||
self._host = options.get("host", [None])[-1]
|
||||
self._port = parsed_url.port or options.get("port", [None])[-1]
|
||||
self._dbname = parsed_url.path[1:] or options.get("dbname", [None])[-1]
|
||||
self._username = parsed_url.username
|
||||
self._password = parsed_url.password
|
||||
self._parsed_url = parsed_url
|
||||
self._options = options
|
||||
|
||||
def _start(self):
|
||||
try:
|
||||
self._conn = psycopg2.connect(host=self._host,
|
||||
port=self._port,
|
||||
user=self._username,
|
||||
password=self._password,
|
||||
database=self._dbname)
|
||||
except psycopg2.Error as e:
|
||||
raise coordination.ToozConnectionError(_format_exception(e))
|
||||
self._conn = PostgresDriver.get_connection(self._parsed_url,
|
||||
self._options)
|
||||
|
||||
def _stop(self):
|
||||
self._conn.close()
|
||||
|
||||
def get_lock(self, name):
|
||||
return PostgresLock(name, self._conn)
|
||||
return locking.WeakLockHelper(
|
||||
self._parsed_url.geturl(),
|
||||
PostgresLock, name, self._parsed_url, self._options)
|
||||
|
||||
@staticmethod
|
||||
def watch_join_group(group_id, callback):
|
||||
@ -175,3 +168,20 @@ class PostgresDriver(coordination.CoordinationDriver):
|
||||
@staticmethod
|
||||
def unwatch_elected_as_leader(group_id, callback):
|
||||
raise tooz.NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def get_connection(parsed_url, options):
|
||||
host = options.get("host", [None])[-1]
|
||||
port = parsed_url.port or options.get("port", [None])[-1]
|
||||
dbname = parsed_url.path[1:] or options.get("dbname", [None])[-1]
|
||||
username = parsed_url.username
|
||||
password = parsed_url.password
|
||||
|
||||
try:
|
||||
return psycopg2.connect(host=host,
|
||||
port=port,
|
||||
user=username,
|
||||
password=password,
|
||||
database=dbname)
|
||||
except psycopg2.Error as e:
|
||||
raise coordination.ToozConnectionError(_format_exception(e))
|
||||
|
@ -16,6 +16,8 @@
|
||||
import abc
|
||||
|
||||
import six
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
@ -55,3 +57,67 @@ class Lock(object):
|
||||
:rtype: bool
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class SharedWeakLockHelper(Lock):
|
||||
"""Helper for lock that need to rely on a state in memory and
|
||||
be the same object across each coordinator.get_lock(...)
|
||||
"""
|
||||
|
||||
LOCKS_LOCK = threading.Lock()
|
||||
ACQUIRED_LOCKS = dict()
|
||||
RELEASED_LOCKS = weakref.WeakValueDictionary()
|
||||
|
||||
def __init__(self, namespace, lockclass, name, *args, **kwargs):
|
||||
super(SharedWeakLockHelper, self).__init__(name)
|
||||
self.acquired = False
|
||||
self._lock_key = "%s:%s" % (namespace, name)
|
||||
self._newlock = lambda: lockclass(
|
||||
self.name, *args, **kwargs)
|
||||
|
||||
def acquire(self, blocking=True):
|
||||
with self.LOCKS_LOCK:
|
||||
try:
|
||||
l = self.ACQUIRED_LOCKS[self._lock_key]
|
||||
except KeyError:
|
||||
l = self.RELEASED_LOCKS.setdefault(
|
||||
self._lock_key, self._newlock())
|
||||
|
||||
if l.acquire(blocking):
|
||||
with self.LOCKS_LOCK:
|
||||
self.RELEASED_LOCKS.pop(self._lock_key, None)
|
||||
self.ACQUIRED_LOCKS[self._lock_key] = l
|
||||
return True
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
with self.LOCKS_LOCK:
|
||||
l = self.ACQUIRED_LOCKS.pop(self._lock_key)
|
||||
l.release()
|
||||
self.RELEASED_LOCKS[self._lock_key] = l
|
||||
|
||||
|
||||
class WeakLockHelper(Lock):
|
||||
"""Helper for lock that need to rely on a state in memory and
|
||||
be a diffrent object across each coordinator.get_lock(...)
|
||||
"""
|
||||
|
||||
LOCKS_LOCK = threading.Lock()
|
||||
ACQUIRED_LOCKS = dict()
|
||||
|
||||
def __init__(self, namespace, lockclass, name, *args, **kwargs):
|
||||
super(WeakLockHelper, self).__init__(name)
|
||||
self._lock_key = "%s:%s" % (namespace, name)
|
||||
self._lock = lockclass(self.name, *args, **kwargs)
|
||||
|
||||
def acquire(self, blocking=True):
|
||||
if self._lock.acquire(blocking):
|
||||
with self.LOCKS_LOCK:
|
||||
self.ACQUIRED_LOCKS[self._lock_key] = self._lock
|
||||
return True
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
with self.LOCKS_LOCK:
|
||||
self._lock.release()
|
||||
self.ACQUIRED_LOCKS.pop(self._lock_key)
|
||||
|
@ -577,6 +577,31 @@ class TestAPI(testscenarios.TestWithScenarios,
|
||||
with lock:
|
||||
pass
|
||||
|
||||
def test_get_multiple_locks_with_same_coord(self):
|
||||
name = self._get_random_uuid()
|
||||
lock1 = self._coord.get_lock(name)
|
||||
lock2 = self._coord.get_lock(name)
|
||||
self.assertEqual(True, lock1.acquire())
|
||||
self.assertEqual(False, lock2.acquire(blocking=False))
|
||||
self.assertEqual(False,
|
||||
self._coord.get_lock(name).acquire(blocking=False))
|
||||
lock1.release()
|
||||
|
||||
def test_get_multiple_locks_with_same_coord_without_ref(self):
|
||||
# NOTE(sileht): weird test case who want a lock that can't be
|
||||
# released ? This tests is here to ensure that the
|
||||
# acquired first lock in not vanished by the gc and get accidentally
|
||||
# released.
|
||||
# This test ensures that the consumer application will stuck when it
|
||||
# looses the ref of a acquired lock instead of create a race.
|
||||
# Also, by its nature this tests don't cleanup the created
|
||||
# semaphore by the ipc:// driver, don't close opened files and
|
||||
# sql connections and that the desired behavior.
|
||||
name = self._get_random_uuid()
|
||||
self.assertEqual(True, self._coord.get_lock(name).acquire())
|
||||
self.assertEqual(False,
|
||||
self._coord.get_lock(name).acquire(blocking=False))
|
||||
|
||||
def test_get_lock_multiple_coords(self):
|
||||
member_id2 = self._get_random_uuid()
|
||||
client2 = tooz.coordination.get_coordinator(self.url,
|
||||
|
Loading…
x
Reference in New Issue
Block a user