Add a reader/writer lock

Taskflow has a reader/writer lock that is likely useful
to other projects; and it seems better at home in this
module.

The class provides a way to create reader/writer locks
where there may be many readers at the same time (but
only one writer). It does not allow (currently) for privilege
escalation (but this could be added with limited support
in the future).

Change-Id: Ie763ef92f31c34869d83a533bc8761b0fbd77217
This commit is contained in:
Joshua Harlow 2014-12-01 16:21:01 -08:00
parent 54c84da50a
commit 90663364f8
3 changed files with 497 additions and 0 deletions

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import contextlib
import errno
import functools
@ -492,6 +493,177 @@ def _lock_wrapper(argv):
return ret_val
class ReaderWriterLock(object):
"""A reader/writer lock.
This lock allows for simultaneous readers to exist but only one writer
to exist for use-cases where it is useful to have such types of locks.
Currently a reader can not escalate its read lock to a write lock and
a writer can not acquire a read lock while it owns or is waiting on
the write lock.
In the future these restrictions may be relaxed.
This can be eventually removed if http://bugs.python.org/issue8800 ever
gets accepted into the python standard threading library...
"""
WRITER = b'w'
READER = b'r'
@staticmethod
def _fetch_current_thread_functor():
# Until https://github.com/eventlet/eventlet/issues/172 is resolved
# or addressed we have to use complicated workaround to get a object
# that will not be recycled; the usage of threading.current_thread()
# doesn't appear to currently be monkey patched and therefore isn't
# reliable to use (and breaks badly when used as all threads share
# the same current_thread() object)...
try:
import eventlet
from eventlet import patcher
green_threaded = patcher.is_monkey_patched('thread')
except ImportError:
green_threaded = False
if green_threaded:
return lambda: eventlet.getcurrent()
else:
return lambda: threading.current_thread()
def __init__(self):
self._writer = None
self._pending_writers = collections.deque()
self._readers = collections.defaultdict(int)
self._cond = threading.Condition()
self._current_thread = self._fetch_current_thread_functor()
def _has_pending_writers(self):
"""Returns if there are writers waiting to become the *one* writer.
Internal usage only.
:return: whether there are any pending writers
:rtype: boolean
"""
return bool(self._pending_writers)
def _is_writer(self, check_pending=True):
"""Returns if the caller is the active writer or a pending writer.
Internal usage only.
:param check_pending: checks the pending writes as well, if false then
only the current writer is checked (and not those
writers that may be in line).
:return: whether the current thread is a active/pending writer
:rtype: boolean
"""
me = self._current_thread()
with self._cond:
if self._writer is not None and self._writer == me:
return True
if check_pending:
return me in self._pending_writers
else:
return False
@property
def owner_type(self):
"""Returns whether the lock is locked by a writer/reader/nobody.
:return: constant defining what the active owners type is
:rtype: WRITER/READER/None
"""
with self._cond:
if self._writer is not None:
return self.WRITER
if self._readers:
return self.READER
return None
def _is_reader(self):
"""Returns if the caller is one of the readers.
Internal usage only.
:return: whether the current thread is a active/pending reader
:rtype: boolean
"""
me = self._current_thread()
with self._cond:
return me in self._readers
@contextlib.contextmanager
def read_lock(self):
"""Context manager that grants a read lock.
Will wait until no active or pending writers.
Raises a ``RuntimeError`` if an active or pending writer tries to
acquire a read lock as this is disallowed.
"""
me = self._current_thread()
if self._is_writer():
raise RuntimeError("Writer %s can not acquire a read lock"
" while holding/waiting for the write lock"
% me)
with self._cond:
while self._writer is not None:
# An active writer; guess we have to wait.
self._cond.wait()
# No active writer; we are good to become a reader.
self._readers[me] += 1
try:
yield self
finally:
# I am no longer a reader, remove *one* occurrence of myself.
# If the current thread acquired two read locks, then it will
# still have to remove that other read lock; this allows for
# basic reentrancy to be possible.
with self._cond:
claims = self._readers[me]
if claims == 1:
self._readers.pop(me)
else:
self._readers[me] = claims - 1
if not self._readers:
self._cond.notify_all()
@contextlib.contextmanager
def write_lock(self):
"""Context manager that grants a write lock.
Will wait until no active readers. Blocks readers after acquiring.
Raises a ``RuntimeError`` if an active reader attempts to acquire a
writer lock as this is disallowed.
"""
me = self._current_thread()
if self._is_reader():
raise RuntimeError("Reader %s to writer privilege"
" escalation not allowed" % me)
if self._is_writer(check_pending=False):
# Already the writer; this allows for basic reentrancy.
yield self
else:
with self._cond:
# Add ourself to the pending writes and wait until we are
# the one writer that can run (aka, when we are the first
# element in the pending writers).
self._pending_writers.append(me)
while (self._readers or self._writer is not None
or self._pending_writers[0] != me):
self._cond.wait()
self._writer = self._pending_writers.popleft()
try:
yield self
finally:
with self._cond:
self._writer = None
self._cond.notify_all()
def main():
sys.exit(_lock_wrapper(sys.argv))

View File

@ -29,6 +29,8 @@ from oslo.config import cfg
from oslotest import base as test_base
import six
from concurrent import futures
from oslo.config import fixture as config
from oslo_concurrency.fixture import lockutils as fixtures
from oslo_concurrency import lockutils
@ -513,6 +515,328 @@ class FileBasedLockingTestCase(test_base.BaseTestCase):
self.assertEqual(f.read(), 'test')
class ReadWriteLockTest(test_base.BaseTestCase):
# This test works by sending up a bunch of threads and then running
# them all at once and having different threads either a read lock
# or a write lock; and sleeping for a period of time while using it.
#
# After the tests have completed the timings of each thread are checked
# to ensure that there are no *invalid* overlaps (a writer should never
# overlap with any readers, for example).
# We will spend this amount of time doing some "fake" work.
WORK_TIMES = [(0.01 + x / 100.0) for x in range(0, 5)]
# NOTE(harlowja): Sleep a little so time.time() can not be the same (which
# will cause false positives when our overlap detection code runs). If
# there are real overlaps then they will still exist.
NAPPY_TIME = 0.05
@staticmethod
def _find_overlaps(times, start, end):
"""Counts num of overlaps between start and end in the given times."""
overlaps = 0
for (s, e) in times:
if s >= start and e <= end:
overlaps += 1
return overlaps
@classmethod
def _spawn_variation(cls, readers, writers, max_workers=None):
"""Spawns the given number of readers and writers."""
start_stops = collections.deque()
lock = lockutils.ReaderWriterLock()
def read_func(ident):
with lock.read_lock():
# TODO(harlowja): sometime in the future use a monotonic clock
# here to avoid problems that can be caused by ntpd resyncing
# the clock while we are actively running.
enter_time = time.time()
time.sleep(cls.WORK_TIMES[ident % len(cls.WORK_TIMES)])
exit_time = time.time()
start_stops.append((lock.READER, enter_time, exit_time))
time.sleep(cls.NAPPY_TIME)
def write_func(ident):
with lock.write_lock():
enter_time = time.time()
time.sleep(cls.WORK_TIMES[ident % len(cls.WORK_TIMES)])
exit_time = time.time()
start_stops.append((lock.WRITER, enter_time, exit_time))
time.sleep(cls.NAPPY_TIME)
if max_workers is None:
max_workers = max(0, readers) + max(0, writers)
if max_workers > 0:
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
count = 0
for _i in range(0, readers):
e.submit(read_func, count)
count += 1
for _i in range(0, writers):
e.submit(write_func, count)
count += 1
writer_times = []
reader_times = []
for (lock_type, start, stop) in list(start_stops):
if lock_type == lock.WRITER:
writer_times.append((start, stop))
else:
reader_times.append((start, stop))
return (writer_times, reader_times)
def test_writer_abort(self):
# Ensures that the lock is released when the writer has an
# exception...
lock = lockutils.ReaderWriterLock()
self.assertFalse(lock.owner_type)
def blow_up():
with lock.write_lock():
self.assertEqual(lock.WRITER, lock.owner_type)
raise RuntimeError("Broken")
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner_type)
def test_reader_abort(self):
lock = lockutils.ReaderWriterLock()
self.assertFalse(lock.owner_type)
def blow_up():
with lock.read_lock():
self.assertEqual(lock.READER, lock.owner_type)
raise RuntimeError("Broken")
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner_type)
def test_double_reader_abort(self):
lock = lockutils.ReaderWriterLock()
activated = collections.deque()
def double_bad_reader():
with lock.read_lock():
with lock.read_lock():
raise RuntimeError("Broken")
def happy_writer():
with lock.write_lock():
activated.append(lock.owner_type)
# Submit a bunch of work to a pool, and then ensure that the correct
# number of writers eventually executed (every other thread will
# be a reader thread that will fail)...
max_workers = 8
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
for i in range(0, max_workers):
if i % 2 == 0:
e.submit(double_bad_reader)
else:
e.submit(happy_writer)
self.assertEqual(max_workers / 2,
len([a for a in activated
if a == lockutils.ReaderWriterLock.WRITER]))
def test_double_reader_writer(self):
lock = lockutils.ReaderWriterLock()
activated = collections.deque()
active = threading.Event()
def double_reader():
with lock.read_lock():
active.set()
# Wait for the writer thread to get into pending mode using a
# simple spin-loop...
while not lock._has_pending_writers():
time.sleep(0.001)
with lock.read_lock():
activated.append(lock.owner_type)
def happy_writer():
with lock.write_lock():
activated.append(lock.owner_type)
reader = threading.Thread(target=double_reader)
reader.daemon = True
reader.start()
# Wait for the reader to become the active reader.
active.wait()
self.assertTrue(active.is_set())
# Start up the writer (the reader will wait until its going).
writer = threading.Thread(target=happy_writer)
writer.daemon = True
writer.start()
# Ensure it went in the order we expected.
reader.join()
writer.join()
self.assertEqual(2, len(activated))
self.assertEqual([lockutils.ReaderWriterLock.READER,
lockutils.ReaderWriterLock.WRITER], list(activated))
def test_reader_chaotic(self):
lock = lockutils.ReaderWriterLock()
activated = collections.deque()
def chaotic_reader(blow_up):
with lock.read_lock():
if blow_up:
raise RuntimeError("Broken")
else:
activated.append(lock.owner_type)
def happy_writer():
with lock.write_lock():
activated.append(lock.owner_type)
# Test that every 4th reader blows up and that we get the expected
# number of owners with this occuring.
max_workers = 8
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
for i in range(0, max_workers):
if i % 2 == 0:
e.submit(chaotic_reader, blow_up=bool(i % 4 == 0))
else:
e.submit(happy_writer)
writers = [a for a in activated
if a == lockutils.ReaderWriterLock.WRITER]
readers = [a for a in activated
if a == lockutils.ReaderWriterLock.READER]
self.assertEqual(4, len(writers))
self.assertEqual(2, len(readers))
def test_writer_chaotic(self):
lock = lockutils.ReaderWriterLock()
activated = collections.deque()
def chaotic_writer(blow_up):
with lock.write_lock():
if blow_up:
raise RuntimeError("Broken")
else:
activated.append(lock.owner_type)
def happy_reader():
with lock.read_lock():
activated.append(lock.owner_type)
# Test that every 4th reader blows up and that we get the expected
# number of owners with this occuring.
max_workers = 8
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
for i in range(0, max_workers):
if i % 2 == 0:
e.submit(chaotic_writer, blow_up=bool(i % 4 == 0))
else:
e.submit(happy_reader)
writers = [a for a in activated
if a == lockutils.ReaderWriterLock.WRITER]
readers = [a for a in activated
if a == lockutils.ReaderWriterLock.READER]
self.assertEqual(2, len(writers))
self.assertEqual(4, len(readers))
def test_single_reader_writer(self):
results = []
lock = lockutils.ReaderWriterLock()
with lock.read_lock():
self.assertTrue(lock._is_reader())
self.assertEqual(0, len(results))
with lock.write_lock():
results.append(1)
self.assertTrue(lock._is_writer())
with lock.read_lock():
self.assertTrue(lock._is_reader())
self.assertEqual(1, len(results))
self.assertFalse(lock._is_reader())
self.assertFalse(lock._is_writer())
def test_reader_to_writer(self):
lock = lockutils.ReaderWriterLock()
def writer_func():
with lock.write_lock():
pass
with lock.read_lock():
self.assertRaises(RuntimeError, writer_func)
self.assertFalse(lock._is_writer())
self.assertFalse(lock._is_reader())
self.assertFalse(lock._is_writer())
def test_writer_to_reader(self):
lock = lockutils.ReaderWriterLock()
def reader_func():
with lock.read_lock():
pass
with lock.write_lock():
self.assertRaises(RuntimeError, reader_func)
self.assertFalse(lock._is_reader())
self.assertFalse(lock._is_reader())
self.assertFalse(lock._is_writer())
def test_double_writer(self):
lock = lockutils.ReaderWriterLock()
with lock.write_lock():
self.assertFalse(lock._is_reader())
self.assertTrue(lock._is_writer())
with lock.write_lock():
self.assertTrue(lock._is_writer())
self.assertTrue(lock._is_writer())
self.assertFalse(lock._is_reader())
self.assertFalse(lock._is_writer())
def test_double_reader(self):
lock = lockutils.ReaderWriterLock()
with lock.read_lock():
self.assertTrue(lock._is_reader())
self.assertFalse(lock._is_writer())
with lock.read_lock():
self.assertTrue(lock._is_reader())
self.assertTrue(lock._is_reader())
self.assertFalse(lock._is_reader())
self.assertFalse(lock._is_writer())
def test_multi_reader_multi_writer(self):
writer_times, reader_times = self._spawn_variation(10, 10)
self.assertEqual(10, len(writer_times))
self.assertEqual(10, len(reader_times))
for (start, stop) in writer_times:
self.assertEqual(0, self._find_overlaps(reader_times, start, stop))
self.assertEqual(1, self._find_overlaps(writer_times, start, stop))
for (start, stop) in reader_times:
self.assertEqual(0, self._find_overlaps(writer_times, start, stop))
def test_multi_reader_single_writer(self):
writer_times, reader_times = self._spawn_variation(9, 1)
self.assertEqual(1, len(writer_times))
self.assertEqual(9, len(reader_times))
start, stop = writer_times[0]
self.assertEqual(0, self._find_overlaps(reader_times, start, stop))
def test_multi_writer(self):
writer_times, reader_times = self._spawn_variation(0, 10)
self.assertEqual(10, len(writer_times))
self.assertEqual(0, len(reader_times))
for (start, stop) in writer_times:
self.assertEqual(1, self._find_overlaps(writer_times, start, stop))
class LockutilsModuleTestCase(test_base.BaseTestCase):
def setUp(self):

View File

@ -5,6 +5,7 @@
hacking>=0.9.1,<0.10
oslotest>=1.2.0 # Apache-2.0
coverage>=3.6
futures>=2.1.6
# These are needed for docs generation
oslosphinx>=2.2.0 # Apache-2.0