From 90663364f82966489856ffa0fbe03a1824bebb55 Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Mon, 1 Dec 2014 16:21:01 -0800 Subject: [PATCH] Add a reader/writer lock Taskflow has a reader/writer lock that is likely useful to other projects; and it seems better at home in this module. The class provides a way to create reader/writer locks where there may be many readers at the same time (but only one writer). It does not allow (currently) for privilege escalation (but this could be added with limited support in the future). Change-Id: Ie763ef92f31c34869d83a533bc8761b0fbd77217 --- oslo_concurrency/lockutils.py | 172 ++++++++++ oslo_concurrency/tests/unit/test_lockutils.py | 324 ++++++++++++++++++ test-requirements.txt | 1 + 3 files changed, 497 insertions(+) diff --git a/oslo_concurrency/lockutils.py b/oslo_concurrency/lockutils.py index e9fd4d7..1c389b1 100644 --- a/oslo_concurrency/lockutils.py +++ b/oslo_concurrency/lockutils.py @@ -13,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import contextlib import errno import functools @@ -492,6 +493,177 @@ def _lock_wrapper(argv): return ret_val +class ReaderWriterLock(object): + """A reader/writer lock. + + This lock allows for simultaneous readers to exist but only one writer + to exist for use-cases where it is useful to have such types of locks. + + Currently a reader can not escalate its read lock to a write lock and + a writer can not acquire a read lock while it owns or is waiting on + the write lock. + + In the future these restrictions may be relaxed. + + This can be eventually removed if http://bugs.python.org/issue8800 ever + gets accepted into the python standard threading library... + """ + WRITER = b'w' + READER = b'r' + + @staticmethod + def _fetch_current_thread_functor(): + # Until https://github.com/eventlet/eventlet/issues/172 is resolved + # or addressed we have to use complicated workaround to get a object + # that will not be recycled; the usage of threading.current_thread() + # doesn't appear to currently be monkey patched and therefore isn't + # reliable to use (and breaks badly when used as all threads share + # the same current_thread() object)... + try: + import eventlet + from eventlet import patcher + green_threaded = patcher.is_monkey_patched('thread') + except ImportError: + green_threaded = False + if green_threaded: + return lambda: eventlet.getcurrent() + else: + return lambda: threading.current_thread() + + def __init__(self): + self._writer = None + self._pending_writers = collections.deque() + self._readers = collections.defaultdict(int) + self._cond = threading.Condition() + self._current_thread = self._fetch_current_thread_functor() + + def _has_pending_writers(self): + """Returns if there are writers waiting to become the *one* writer. + + Internal usage only. + + :return: whether there are any pending writers + :rtype: boolean + """ + return bool(self._pending_writers) + + def _is_writer(self, check_pending=True): + """Returns if the caller is the active writer or a pending writer. + + Internal usage only. + + :param check_pending: checks the pending writes as well, if false then + only the current writer is checked (and not those + writers that may be in line). + + :return: whether the current thread is a active/pending writer + :rtype: boolean + """ + me = self._current_thread() + with self._cond: + if self._writer is not None and self._writer == me: + return True + if check_pending: + return me in self._pending_writers + else: + return False + + @property + def owner_type(self): + """Returns whether the lock is locked by a writer/reader/nobody. + + :return: constant defining what the active owners type is + :rtype: WRITER/READER/None + """ + with self._cond: + if self._writer is not None: + return self.WRITER + if self._readers: + return self.READER + return None + + def _is_reader(self): + """Returns if the caller is one of the readers. + + Internal usage only. + + :return: whether the current thread is a active/pending reader + :rtype: boolean + """ + me = self._current_thread() + with self._cond: + return me in self._readers + + @contextlib.contextmanager + def read_lock(self): + """Context manager that grants a read lock. + + Will wait until no active or pending writers. + + Raises a ``RuntimeError`` if an active or pending writer tries to + acquire a read lock as this is disallowed. + """ + me = self._current_thread() + if self._is_writer(): + raise RuntimeError("Writer %s can not acquire a read lock" + " while holding/waiting for the write lock" + % me) + with self._cond: + while self._writer is not None: + # An active writer; guess we have to wait. + self._cond.wait() + # No active writer; we are good to become a reader. + self._readers[me] += 1 + try: + yield self + finally: + # I am no longer a reader, remove *one* occurrence of myself. + # If the current thread acquired two read locks, then it will + # still have to remove that other read lock; this allows for + # basic reentrancy to be possible. + with self._cond: + claims = self._readers[me] + if claims == 1: + self._readers.pop(me) + else: + self._readers[me] = claims - 1 + if not self._readers: + self._cond.notify_all() + + @contextlib.contextmanager + def write_lock(self): + """Context manager that grants a write lock. + + Will wait until no active readers. Blocks readers after acquiring. + + Raises a ``RuntimeError`` if an active reader attempts to acquire a + writer lock as this is disallowed. + """ + me = self._current_thread() + if self._is_reader(): + raise RuntimeError("Reader %s to writer privilege" + " escalation not allowed" % me) + if self._is_writer(check_pending=False): + # Already the writer; this allows for basic reentrancy. + yield self + else: + with self._cond: + # Add ourself to the pending writes and wait until we are + # the one writer that can run (aka, when we are the first + # element in the pending writers). + self._pending_writers.append(me) + while (self._readers or self._writer is not None + or self._pending_writers[0] != me): + self._cond.wait() + self._writer = self._pending_writers.popleft() + try: + yield self + finally: + with self._cond: + self._writer = None + self._cond.notify_all() + + def main(): sys.exit(_lock_wrapper(sys.argv)) diff --git a/oslo_concurrency/tests/unit/test_lockutils.py b/oslo_concurrency/tests/unit/test_lockutils.py index f665ebd..f24cce3 100644 --- a/oslo_concurrency/tests/unit/test_lockutils.py +++ b/oslo_concurrency/tests/unit/test_lockutils.py @@ -29,6 +29,8 @@ from oslo.config import cfg from oslotest import base as test_base import six +from concurrent import futures + from oslo.config import fixture as config from oslo_concurrency.fixture import lockutils as fixtures from oslo_concurrency import lockutils @@ -513,6 +515,328 @@ class FileBasedLockingTestCase(test_base.BaseTestCase): self.assertEqual(f.read(), 'test') +class ReadWriteLockTest(test_base.BaseTestCase): + # This test works by sending up a bunch of threads and then running + # them all at once and having different threads either a read lock + # or a write lock; and sleeping for a period of time while using it. + # + # After the tests have completed the timings of each thread are checked + # to ensure that there are no *invalid* overlaps (a writer should never + # overlap with any readers, for example). + + # We will spend this amount of time doing some "fake" work. + WORK_TIMES = [(0.01 + x / 100.0) for x in range(0, 5)] + + # NOTE(harlowja): Sleep a little so time.time() can not be the same (which + # will cause false positives when our overlap detection code runs). If + # there are real overlaps then they will still exist. + NAPPY_TIME = 0.05 + + @staticmethod + def _find_overlaps(times, start, end): + """Counts num of overlaps between start and end in the given times.""" + overlaps = 0 + for (s, e) in times: + if s >= start and e <= end: + overlaps += 1 + return overlaps + + @classmethod + def _spawn_variation(cls, readers, writers, max_workers=None): + """Spawns the given number of readers and writers.""" + + start_stops = collections.deque() + lock = lockutils.ReaderWriterLock() + + def read_func(ident): + with lock.read_lock(): + # TODO(harlowja): sometime in the future use a monotonic clock + # here to avoid problems that can be caused by ntpd resyncing + # the clock while we are actively running. + enter_time = time.time() + time.sleep(cls.WORK_TIMES[ident % len(cls.WORK_TIMES)]) + exit_time = time.time() + start_stops.append((lock.READER, enter_time, exit_time)) + time.sleep(cls.NAPPY_TIME) + + def write_func(ident): + with lock.write_lock(): + enter_time = time.time() + time.sleep(cls.WORK_TIMES[ident % len(cls.WORK_TIMES)]) + exit_time = time.time() + start_stops.append((lock.WRITER, enter_time, exit_time)) + time.sleep(cls.NAPPY_TIME) + + if max_workers is None: + max_workers = max(0, readers) + max(0, writers) + if max_workers > 0: + with futures.ThreadPoolExecutor(max_workers=max_workers) as e: + count = 0 + for _i in range(0, readers): + e.submit(read_func, count) + count += 1 + for _i in range(0, writers): + e.submit(write_func, count) + count += 1 + + writer_times = [] + reader_times = [] + for (lock_type, start, stop) in list(start_stops): + if lock_type == lock.WRITER: + writer_times.append((start, stop)) + else: + reader_times.append((start, stop)) + return (writer_times, reader_times) + + def test_writer_abort(self): + # Ensures that the lock is released when the writer has an + # exception... + lock = lockutils.ReaderWriterLock() + self.assertFalse(lock.owner_type) + + def blow_up(): + with lock.write_lock(): + self.assertEqual(lock.WRITER, lock.owner_type) + raise RuntimeError("Broken") + + self.assertRaises(RuntimeError, blow_up) + self.assertFalse(lock.owner_type) + + def test_reader_abort(self): + lock = lockutils.ReaderWriterLock() + self.assertFalse(lock.owner_type) + + def blow_up(): + with lock.read_lock(): + self.assertEqual(lock.READER, lock.owner_type) + raise RuntimeError("Broken") + + self.assertRaises(RuntimeError, blow_up) + self.assertFalse(lock.owner_type) + + def test_double_reader_abort(self): + lock = lockutils.ReaderWriterLock() + activated = collections.deque() + + def double_bad_reader(): + with lock.read_lock(): + with lock.read_lock(): + raise RuntimeError("Broken") + + def happy_writer(): + with lock.write_lock(): + activated.append(lock.owner_type) + + # Submit a bunch of work to a pool, and then ensure that the correct + # number of writers eventually executed (every other thread will + # be a reader thread that will fail)... + max_workers = 8 + with futures.ThreadPoolExecutor(max_workers=max_workers) as e: + for i in range(0, max_workers): + if i % 2 == 0: + e.submit(double_bad_reader) + else: + e.submit(happy_writer) + + self.assertEqual(max_workers / 2, + len([a for a in activated + if a == lockutils.ReaderWriterLock.WRITER])) + + def test_double_reader_writer(self): + lock = lockutils.ReaderWriterLock() + activated = collections.deque() + active = threading.Event() + + def double_reader(): + with lock.read_lock(): + active.set() + # Wait for the writer thread to get into pending mode using a + # simple spin-loop... + while not lock._has_pending_writers(): + time.sleep(0.001) + with lock.read_lock(): + activated.append(lock.owner_type) + + def happy_writer(): + with lock.write_lock(): + activated.append(lock.owner_type) + + reader = threading.Thread(target=double_reader) + reader.daemon = True + reader.start() + + # Wait for the reader to become the active reader. + active.wait() + self.assertTrue(active.is_set()) + + # Start up the writer (the reader will wait until its going). + writer = threading.Thread(target=happy_writer) + writer.daemon = True + writer.start() + + # Ensure it went in the order we expected. + reader.join() + writer.join() + self.assertEqual(2, len(activated)) + self.assertEqual([lockutils.ReaderWriterLock.READER, + lockutils.ReaderWriterLock.WRITER], list(activated)) + + def test_reader_chaotic(self): + lock = lockutils.ReaderWriterLock() + activated = collections.deque() + + def chaotic_reader(blow_up): + with lock.read_lock(): + if blow_up: + raise RuntimeError("Broken") + else: + activated.append(lock.owner_type) + + def happy_writer(): + with lock.write_lock(): + activated.append(lock.owner_type) + + # Test that every 4th reader blows up and that we get the expected + # number of owners with this occuring. + max_workers = 8 + with futures.ThreadPoolExecutor(max_workers=max_workers) as e: + for i in range(0, max_workers): + if i % 2 == 0: + e.submit(chaotic_reader, blow_up=bool(i % 4 == 0)) + else: + e.submit(happy_writer) + + writers = [a for a in activated + if a == lockutils.ReaderWriterLock.WRITER] + readers = [a for a in activated + if a == lockutils.ReaderWriterLock.READER] + self.assertEqual(4, len(writers)) + self.assertEqual(2, len(readers)) + + def test_writer_chaotic(self): + lock = lockutils.ReaderWriterLock() + activated = collections.deque() + + def chaotic_writer(blow_up): + with lock.write_lock(): + if blow_up: + raise RuntimeError("Broken") + else: + activated.append(lock.owner_type) + + def happy_reader(): + with lock.read_lock(): + activated.append(lock.owner_type) + + # Test that every 4th reader blows up and that we get the expected + # number of owners with this occuring. + max_workers = 8 + with futures.ThreadPoolExecutor(max_workers=max_workers) as e: + for i in range(0, max_workers): + if i % 2 == 0: + e.submit(chaotic_writer, blow_up=bool(i % 4 == 0)) + else: + e.submit(happy_reader) + + writers = [a for a in activated + if a == lockutils.ReaderWriterLock.WRITER] + readers = [a for a in activated + if a == lockutils.ReaderWriterLock.READER] + self.assertEqual(2, len(writers)) + self.assertEqual(4, len(readers)) + + def test_single_reader_writer(self): + results = [] + lock = lockutils.ReaderWriterLock() + with lock.read_lock(): + self.assertTrue(lock._is_reader()) + self.assertEqual(0, len(results)) + with lock.write_lock(): + results.append(1) + self.assertTrue(lock._is_writer()) + with lock.read_lock(): + self.assertTrue(lock._is_reader()) + self.assertEqual(1, len(results)) + self.assertFalse(lock._is_reader()) + self.assertFalse(lock._is_writer()) + + def test_reader_to_writer(self): + lock = lockutils.ReaderWriterLock() + + def writer_func(): + with lock.write_lock(): + pass + + with lock.read_lock(): + self.assertRaises(RuntimeError, writer_func) + self.assertFalse(lock._is_writer()) + + self.assertFalse(lock._is_reader()) + self.assertFalse(lock._is_writer()) + + def test_writer_to_reader(self): + lock = lockutils.ReaderWriterLock() + + def reader_func(): + with lock.read_lock(): + pass + + with lock.write_lock(): + self.assertRaises(RuntimeError, reader_func) + self.assertFalse(lock._is_reader()) + + self.assertFalse(lock._is_reader()) + self.assertFalse(lock._is_writer()) + + def test_double_writer(self): + lock = lockutils.ReaderWriterLock() + with lock.write_lock(): + self.assertFalse(lock._is_reader()) + self.assertTrue(lock._is_writer()) + with lock.write_lock(): + self.assertTrue(lock._is_writer()) + self.assertTrue(lock._is_writer()) + + self.assertFalse(lock._is_reader()) + self.assertFalse(lock._is_writer()) + + def test_double_reader(self): + lock = lockutils.ReaderWriterLock() + with lock.read_lock(): + self.assertTrue(lock._is_reader()) + self.assertFalse(lock._is_writer()) + with lock.read_lock(): + self.assertTrue(lock._is_reader()) + self.assertTrue(lock._is_reader()) + + self.assertFalse(lock._is_reader()) + self.assertFalse(lock._is_writer()) + + def test_multi_reader_multi_writer(self): + writer_times, reader_times = self._spawn_variation(10, 10) + self.assertEqual(10, len(writer_times)) + self.assertEqual(10, len(reader_times)) + for (start, stop) in writer_times: + self.assertEqual(0, self._find_overlaps(reader_times, start, stop)) + self.assertEqual(1, self._find_overlaps(writer_times, start, stop)) + for (start, stop) in reader_times: + self.assertEqual(0, self._find_overlaps(writer_times, start, stop)) + + def test_multi_reader_single_writer(self): + writer_times, reader_times = self._spawn_variation(9, 1) + self.assertEqual(1, len(writer_times)) + self.assertEqual(9, len(reader_times)) + start, stop = writer_times[0] + self.assertEqual(0, self._find_overlaps(reader_times, start, stop)) + + def test_multi_writer(self): + writer_times, reader_times = self._spawn_variation(0, 10) + self.assertEqual(10, len(writer_times)) + self.assertEqual(0, len(reader_times)) + for (start, stop) in writer_times: + self.assertEqual(1, self._find_overlaps(writer_times, start, stop)) + + class LockutilsModuleTestCase(test_base.BaseTestCase): def setUp(self): diff --git a/test-requirements.txt b/test-requirements.txt index c424655..06a765a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,6 +5,7 @@ hacking>=0.9.1,<0.10 oslotest>=1.2.0 # Apache-2.0 coverage>=3.6 +futures>=2.1.6 # These are needed for docs generation oslosphinx>=2.2.0 # Apache-2.0