diff --git a/tooz/coordination.py b/tooz/coordination.py index ee033e7c..e6d7f5f5 100644 --- a/tooz/coordination.py +++ b/tooz/coordination.py @@ -267,6 +267,16 @@ class CoordinationDriver(object): """ raise NotImplementedError + @staticmethod + def get_lock(name): + """Return a distributed lock. + + :param name: The lock name that is used to identify it across all + nodes. + + """ + raise NotImplementedError + @staticmethod def heartbeat(): """Method to run once in a while to be sure that the member is not dead diff --git a/tooz/drivers/zookeeper.py b/tooz/drivers/zookeeper.py index 675af738..1ba581c3 100644 --- a/tooz/drivers/zookeeper.py +++ b/tooz/drivers/zookeeper.py @@ -24,6 +24,19 @@ import six from zake import fake_client from tooz import coordination +from tooz import locking + + +class ZooKeeperLock(locking.Lock): + def __init__(self, lock): + self._lock = lock + + def acquire(self, blocking=True, timeout=None): + return self._lock.acquire(blocking=blocking, + timeout=timeout) + + def release(self): + return self._lock.release() class BaseZooKeeperDriver(coordination.CoordinationDriver): @@ -317,6 +330,12 @@ class KazooDriver(BaseZooKeeperDriver): leader = None return ZooAsyncResult(None, lambda *args: leader) + def get_lock(self, name): + return ZooKeeperLock( + self._coord.Lock( + self.paths_join(b"/", self._TOOZ_NAMESPACE, b"locks", name), + self._member_id.decode('ascii'))) + def run_watchers(self): ret = [] while True: diff --git a/tooz/locking.py b/tooz/locking.py new file mode 100644 index 00000000..0d246e74 --- /dev/null +++ b/tooz/locking.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2014 eNovance Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class Lock(object): + def __enter__(self): + self.acquire() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + @abc.abstractmethod + def release(self): + pass + + @abc.abstractmethod + def acquire(self): + pass diff --git a/tooz/tests/test_coordination.py b/tooz/tests/test_coordination.py index 35312064..e54ebee3 100644 --- a/tooz/tests/test_coordination.py +++ b/tooz/tests/test_coordination.py @@ -457,6 +457,29 @@ class TestAPI(testscenarios.TestWithScenarios, self.assertEqual(self.group_id, self.event.group_id) + def test_get_lock(self): + lock = self._coord.get_lock(self._get_random_uuid()) + self.assertEqual(True, lock.acquire()) + lock.release() + with lock: + pass + + def test_get_lock_multiple_coords(self): + member_id2 = self._get_random_uuid() + client2 = tooz.coordination.get_coordinator(self.backend, + member_id2, + **self.kwargs) + client2.start() + + lock_name = self._get_random_uuid() + lock = self._coord.get_lock(lock_name) + self.assertEqual(True, lock.acquire()) + + lock2 = client2.get_lock(lock_name) + self.assertEqual(False, lock2.acquire(blocking=False)) + lock.release() + self.assertEqual(True, lock2.acquire(blocking=False)) + @staticmethod def _get_random_uuid(): return str(uuid.uuid4()).encode('ascii')