Merge "Use reader/writer locks in storage"

This commit is contained in:
Jenkins
2014-02-05 17:58:38 +00:00
committed by Gerrit Code Review
6 changed files with 647 additions and 174 deletions

View File

@@ -191,13 +191,12 @@ class ActionEngine(base.EngineBase):
class SingleThreadedActionEngine(ActionEngine):
"""Engine that runs tasks in serial manner."""
_storage_cls = t_storage.Storage
_storage_cls = t_storage.SingleThreadedStorage
class MultiThreadedActionEngine(ActionEngine):
"""Engine that runs tasks in parallel manner."""
_storage_cls = t_storage.ThreadSafeStorage
_storage_cls = t_storage.MultiThreadedStorage
def _task_executor_cls(self):
return executor.ParallelTaskExecutor(self._executor)

View File

@@ -16,6 +16,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
import contextlib
import logging
@@ -25,13 +26,15 @@ from taskflow import exceptions
from taskflow.openstack.common import uuidutils
from taskflow.persistence import logbook
from taskflow import states
from taskflow.utils import lock_utils
from taskflow.utils import misc
from taskflow.utils import threading_utils as tu
from taskflow.utils import reflection
LOG = logging.getLogger(__name__)
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
@six.add_metaclass(abc.ABCMeta)
class Storage(object):
"""Interface between engines and logbook.
@@ -48,11 +51,14 @@ class Storage(object):
self._reverse_mapping = {}
self._backend = backend
self._flowdetail = flow_detail
self._lock = self._lock_cls()
# NOTE(imelnikov): failure serialization looses information,
# so we cache failures here, in task name -> misc.Failure mapping.
self._failures = {}
self._reload_failures()
for td in self._flowdetail:
if td.failure is not None:
self._failures[td.name] = td.failure
self._task_name_to_uuid = dict((td.name, td.uuid)
for td in self._flowdetail)
@@ -66,11 +72,20 @@ class Storage(object):
self._set_result_mapping(injector_td.name,
dict((name, name) for name in names))
@abc.abstractproperty
def _lock_cls(self):
"""Lock class used to generate reader/writer locks for protecting
read/write access to the underlying storage backend and internally
mutating operations.
"""
def _with_connection(self, functor, *args, **kwargs):
# NOTE(harlowja): Activate the given function with a backend
# connection, if a backend is provided in the first place, otherwise
# don't call the function.
if self._backend is None:
LOG.debug("No backend provided, not calling functor '%s'",
reflection.get_callable_name(functor))
return
with contextlib.closing(self._backend.get_connection()) as conn:
functor(conn, *args, **kwargs)
@@ -85,12 +100,13 @@ class Storage(object):
Returns uuid for the task details corresponding to the task with
given name.
"""
try:
task_id = self._task_name_to_uuid[task_name]
except KeyError:
task_id = uuidutils.generate_uuid()
self._add_task(task_id, task_name, task_version)
self._set_result_mapping(task_name, result_mapping)
with self._lock.write_lock():
try:
task_id = self._task_name_to_uuid[task_name]
except KeyError:
task_id = uuidutils.generate_uuid()
self._add_task(task_id, task_name, task_version)
self._set_result_mapping(task_name, result_mapping)
return task_id
def _add_task(self, uuid, task_name, task_version=None):
@@ -99,27 +115,29 @@ class Storage(object):
Task becomes known to storage by that name and uuid.
Task state is set to PENDING.
"""
def save_both(conn, td):
"""Saves the flow and the task detail with the same connection."""
self._save_flow_detail(conn)
self._save_task_detail(conn, td)
# TODO(imelnikov): check that task with same uuid or
# task name does not exist.
td = logbook.TaskDetail(name=task_name, uuid=uuid)
td.state = states.PENDING
td.version = task_version
self._flowdetail.add(td)
def save_both(conn):
"""Saves the flow and the task detail with the same connection."""
self._save_flow_detail(conn)
self._save_task_detail(conn, task_detail=td)
self._with_connection(save_both)
self._with_connection(save_both, td)
self._task_name_to_uuid[task_name] = uuid
@property
def flow_name(self):
# This never changes (so no read locking needed).
return self._flowdetail.name
@property
def flow_uuid(self):
# This never changes (so no read locking needed).
return self._flowdetail.uuid
def _save_flow_detail(self, conn):
@@ -130,13 +148,9 @@ class Storage(object):
def _taskdetail_by_name(self, task_name):
try:
td = self._flowdetail.find(self._task_name_to_uuid[task_name])
return self._flowdetail.find(self._task_name_to_uuid[task_name])
except KeyError:
td = None
if td is None:
raise exceptions.NotFound("Unknown task name: %s" % task_name)
return td
def _save_task_detail(self, conn, task_detail):
# NOTE(harlowja): we need to update our contained task detail if
@@ -146,34 +160,39 @@ class Storage(object):
def get_task_uuid(self, task_name):
"""Get task uuid by given name."""
td = self._taskdetail_by_name(task_name)
return td.uuid
with self._lock.read_lock():
td = self._taskdetail_by_name(task_name)
return td.uuid
def set_task_state(self, task_name, state):
"""Set task state."""
td = self._taskdetail_by_name(task_name)
td.state = state
self._with_connection(self._save_task_detail, task_detail=td)
with self._lock.write_lock():
td = self._taskdetail_by_name(task_name)
td.state = state
self._with_connection(self._save_task_detail, td)
def get_task_state(self, task_name):
"""Get state of task with given name."""
return self._taskdetail_by_name(task_name).state
with self._lock.read_lock():
td = self._taskdetail_by_name(task_name)
return td.state
def get_tasks_states(self, task_names):
return dict((name, self.get_task_state(name))
for name in task_names)
"""Gets all task states."""
with self._lock.read_lock():
return dict((name, self.get_task_state(name))
for name in task_names)
def update_task_metadata(self, task_name, update_with):
"""Updates a tasks metadata."""
if not update_with:
return
# NOTE(harlowja): this is a read and then write, not in 1 transaction
# so it is entirely possible that we could write over another writes
# metadata update. Maybe add some merging logic later?
td = self._taskdetail_by_name(task_name)
if not td.meta:
td.meta = {}
td.meta.update(update_with)
self._with_connection(self._save_task_detail, task_detail=td)
with self._lock.write_lock():
td = self._taskdetail_by_name(task_name)
if not td.meta:
td.meta = {}
td.meta.update(update_with)
self._with_connection(self._save_task_detail, td)
def set_task_progress(self, task_name, progress, details=None):
"""Set task progress.
@@ -204,10 +223,11 @@ class Storage(object):
:param task_name: task name
:returns: current task progress value
"""
meta = self._taskdetail_by_name(task_name).meta
if not meta:
return 0.0
return meta.get('progress', 0.0)
with self._lock.read_lock():
td = self._taskdetail_by_name(task_name)
if not td.meta:
return 0.0
return td.meta.get('progress', 0.0)
def get_task_progress_details(self, task_name):
"""Get progress details of task with given name.
@@ -216,10 +236,11 @@ class Storage(object):
:returns: None if progress_details not defined, else progress_details
dict
"""
meta = self._taskdetail_by_name(task_name).meta
if not meta:
return None
return meta.get('progress_details')
with self._lock.read_lock():
td = self._taskdetail_by_name(task_name)
if not td.meta:
return None
return td.meta.get('progress_details')
def _check_all_results_provided(self, task_name, data):
"""Warn if task did not provide some of results.
@@ -228,69 +249,57 @@ class Storage(object):
without all needed keys. It may also happen if task returns
result of wrong type.
"""
result_mapping = self._result_mappings.get(task_name, None)
if result_mapping is None:
result_mapping = self._result_mappings.get(task_name)
if not result_mapping:
return
for name, index in six.iteritems(result_mapping):
try:
misc.item_from(data, index, name=name)
except exceptions.NotFound:
LOG.warning("Task %s did not supply result "
"with index %r (name %s)",
task_name, index, name)
"with index %r (name %s)", task_name, index, name)
def save(self, task_name, data, state=states.SUCCESS):
"""Put result for task with id 'uuid' to storage."""
td = self._taskdetail_by_name(task_name)
td.state = state
if state == states.FAILURE and isinstance(data, misc.Failure):
td.results = None
td.failure = data
self._failures[td.name] = data
else:
td.results = data
td.failure = None
self._check_all_results_provided(td.name, data)
self._with_connection(self._save_task_detail, task_detail=td)
def _cache_failure(self, name, fail):
"""Ensure that cache has matching failure for task with this name.
We leave cached version if it matches as it may contain more
information. Returns cached failure.
"""
cached = self._failures.get(name)
if fail.matches(cached):
return cached
self._failures[name] = fail
return fail
def _reload_failures(self):
"""Refresh failures cache."""
for td in self._flowdetail:
if td.failure is not None:
self._cache_failure(td.name, td.failure)
with self._lock.write_lock():
td = self._taskdetail_by_name(task_name)
td.state = state
if state == states.FAILURE and isinstance(data, misc.Failure):
td.results = None
td.failure = data
self._failures[td.name] = data
else:
td.results = data
td.failure = None
self._check_all_results_provided(td.name, data)
self._with_connection(self._save_task_detail, td)
def get(self, task_name):
"""Get result for task with name 'task_name' to storage."""
td = self._taskdetail_by_name(task_name)
if td.failure is not None:
return self._cache_failure(td.name, td.failure)
if td.state not in STATES_WITH_RESULTS:
raise exceptions.NotFound(
"Result for task %s is not known" % task_name)
return td.results
with self._lock.read_lock():
td = self._taskdetail_by_name(task_name)
if td.failure is not None:
cached = self._failures.get(task_name)
if td.failure.matches(cached):
return cached
return td.failure
if td.state not in STATES_WITH_RESULTS:
raise exceptions.NotFound("Result for task %s is not known"
% task_name)
return td.results
def get_failures(self):
"""Get list of failures that happened with this flow.
No order guaranteed.
"""
return self._failures.copy()
with self._lock.read_lock():
return self._failures.copy()
def has_failures(self):
"""Returns True if there are failed tasks in the storage."""
return bool(self._failures)
with self._lock.read_lock():
return bool(self._failures)
def _reset_task(self, td, state):
if td.name == self.injector_name:
@@ -305,25 +314,28 @@ class Storage(object):
def reset(self, task_name, state=states.PENDING):
"""Remove result for task with id 'uuid' from storage."""
td = self._taskdetail_by_name(task_name)
if self._reset_task(td, state):
self._with_connection(self._save_task_detail, task_detail=td)
with self._lock.write_lock():
td = self._taskdetail_by_name(task_name)
if self._reset_task(td, state):
self._with_connection(self._save_task_detail, td)
def reset_tasks(self):
"""Reset all tasks to PENDING state, removing results.
Returns list of (name, uuid) tuples for all tasks that were reset.
"""
result = []
reset_results = []
def do_reset_all(connection):
for td in self._flowdetail:
if self._reset_task(td, states.PENDING):
self._save_task_detail(connection, td)
result.append((td.name, td.uuid))
reset_results.append((td.name, td.uuid))
self._with_connection(do_reset_all)
return result
with self._lock.write_lock():
self._with_connection(do_reset_all)
return reset_results
def inject(self, pairs):
"""Add values into storage.
@@ -331,20 +343,20 @@ class Storage(object):
This method should be used to put flow parameters (requirements that
are not satisfied by any task in the flow) into storage.
"""
try:
injector_td = self._taskdetail_by_name(self.injector_name)
except exceptions.NotFound:
injector_uuid = uuidutils.generate_uuid()
self._add_task(injector_uuid, self.injector_name)
results = dict(pairs)
else:
results = injector_td.results.copy()
results.update(pairs)
self.save(self.injector_name, results)
names = six.iterkeys(results)
self._set_result_mapping(self.injector_name,
dict((name, name) for name in names))
with self._lock.write_lock():
try:
td = self._taskdetail_by_name(self.injector_name)
except exceptions.NotFound:
self._add_task(uuidutils.generate_uuid(), self.injector_name)
td = self._taskdetail_by_name(self.injector_name)
td.results = dict(pairs)
td.state = states.SUCCESS
else:
td.results.update(pairs)
self._with_connection(self._save_task_detail, td)
names = six.iterkeys(td.results)
self._set_result_mapping(self.injector_name,
dict((name, name) for name in names))
def _set_result_mapping(self, task_name, mapping):
"""Set mapping for naming task results.
@@ -378,50 +390,60 @@ class Storage(object):
def fetch(self, name):
"""Fetch named task result."""
try:
indexes = self._reverse_mapping[name]
except KeyError:
raise exceptions.NotFound("Name %r is not mapped" % name)
# Return the first one that is found.
for task_name, index in reversed(indexes):
with self._lock.read_lock():
try:
result = self.get(task_name)
return misc.item_from(result, index, name=name)
except exceptions.NotFound:
pass
raise exceptions.NotFound("Unable to find result %r" % name)
indexes = self._reverse_mapping[name]
except KeyError:
raise exceptions.NotFound("Name %r is not mapped" % name)
# Return the first one that is found.
for (task_name, index) in reversed(indexes):
try:
result = self.get(task_name)
return misc.item_from(result, index, name=name)
except exceptions.NotFound:
pass
raise exceptions.NotFound("Unable to find result %r" % name)
def fetch_all(self):
"""Fetch all named task results known so far.
Should be used for debugging and testing purposes mostly.
"""
result = {}
for name in self._reverse_mapping:
try:
result[name] = self.fetch(name)
except exceptions.NotFound:
pass
return result
with self._lock.read_lock():
results = {}
for name in self._reverse_mapping:
try:
results[name] = self.fetch(name)
except exceptions.NotFound:
pass
return results
def fetch_mapped_args(self, args_mapping):
"""Fetch arguments for the task using arguments mapping."""
return dict((key, self.fetch(name))
for key, name in six.iteritems(args_mapping))
with self._lock.read_lock():
return dict((key, self.fetch(name))
for key, name in six.iteritems(args_mapping))
def set_flow_state(self, state):
"""Set flowdetails state and save it."""
self._flowdetail.state = state
self._with_connection(self._save_flow_detail)
"""Set flow details state and save it."""
with self._lock.write_lock():
self._flowdetail.state = state
self._with_connection(self._save_flow_detail)
def get_flow_state(self):
"""Set state from flowdetails."""
state = self._flowdetail.state
if state is None:
state = states.PENDING
return state
"""Get state from flow details."""
with self._lock.read_lock():
state = self._flowdetail.state
if state is None:
state = states.PENDING
return state
@six.add_metaclass(tu.ThreadSafeMeta)
class ThreadSafeStorage(Storage):
pass
class MultiThreadedStorage(Storage):
"""Storage that uses locks to protect against concurrent access."""
_lock_cls = lock_utils.ReaderWriterLock
class SingleThreadedStorage(Storage):
"""Storage that uses dummy locks when you really don't need locks."""
_lock_cls = lock_utils.DummyReaderWriterLock

View File

@@ -17,6 +17,7 @@
# under the License.
import contextlib
import threading
import mock
@@ -35,10 +36,22 @@ class StorageTest(test.TestCase):
def setUp(self):
super(StorageTest, self).setUp()
self.backend = impl_memory.MemoryBackend(conf={})
self.thread_count = 50
def _get_storage(self):
def _run_many_threads(self, threads):
for t in threads:
t.start()
for t in threads:
t.join()
def _get_storage(self, threaded=False):
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
return storage.Storage(backend=self.backend, flow_detail=flow_detail)
if threaded:
return storage.MultiThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
else:
return storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
def tearDown(self):
super(StorageTest, self).tearDown()
@@ -48,14 +61,14 @@ class StorageTest(test.TestCase):
def test_non_saving_storage(self):
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
s = storage.Storage(flow_detail=flow_detail) # no backend
s = storage.SingleThreadedStorage(flow_detail=flow_detail)
s.ensure_task('my_task')
self.assertTrue(
uuidutils.is_uuid_like(s.get_task_uuid('my_task')))
def test_flow_name_and_uuid(self):
fd = logbook.FlowDetail(name='test-fd', uuid='aaaa')
s = storage.Storage(flow_detail=fd)
s = storage.SingleThreadedStorage(flow_detail=fd)
self.assertEqual(s.flow_name, 'test-fd')
self.assertEqual(s.flow_uuid, 'aaaa')
@@ -79,7 +92,8 @@ class StorageTest(test.TestCase):
def test_ensure_task_fd(self):
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
s = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
s.ensure_task('my task', '3.11')
td = flow_detail.find(s.get_task_uuid('my task'))
self.assertIsNotNone(td)
@@ -92,7 +106,8 @@ class StorageTest(test.TestCase):
td = logbook.TaskDetail(name='my_task', uuid='42')
flow_detail.add(td)
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
s = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
self.assertEqual('42', s.get_task_uuid('my_task'))
def test_ensure_existing_task(self):
@@ -100,7 +115,8 @@ class StorageTest(test.TestCase):
td = logbook.TaskDetail(name='my_task', uuid='42')
flow_detail.add(td)
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
s = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
s.ensure_task('my_task')
self.assertEqual('42', s.get_task_uuid('my_task'))
@@ -147,7 +163,8 @@ class StorageTest(test.TestCase):
s.ensure_task('my task')
s.save('my task', fail, states.FAILURE)
s2 = storage.Storage(backend=self.backend, flow_detail=s._flowdetail)
s2 = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=s._flowdetail)
self.assertIs(s2.has_failures(), True)
self.assertEqual(s2.get_failures(), {'my task': fail})
self.assertEqual(s2.get('my task'), fail)
@@ -305,13 +322,71 @@ class StorageTest(test.TestCase):
})
# imagine we are resuming, so we need to make new
# storage from same flow details
s2 = storage.Storage(s._flowdetail, backend=self.backend)
s2 = storage.SingleThreadedStorage(s._flowdetail, backend=self.backend)
# injected data should still be there:
self.assertEqual(s2.fetch_all(), {
'foo': 'bar',
'spam': 'eggs',
})
def test_many_thread_ensure_same_task(self):
s = self._get_storage(threaded=True)
def ensure_my_task():
s.ensure_task('my_task', result_mapping={})
threads = []
for i in range(0, self.thread_count):
threads.append(threading.Thread(target=ensure_my_task))
self._run_many_threads(threads)
# Only one task should have been made, no more.
self.assertEqual(1, len(s._flowdetail))
def test_many_thread_one_reset(self):
s = self._get_storage(threaded=True)
s.ensure_task('a')
s.set_task_state('a', states.SUSPENDED)
s.ensure_task('b')
s.set_task_state('b', states.SUSPENDED)
results = []
result_lock = threading.Lock()
def reset_all():
r = s.reset_tasks()
with result_lock:
results.append(r)
threads = []
for i in range(0, self.thread_count):
threads.append(threading.Thread(target=reset_all))
self._run_many_threads(threads)
# Only one thread should have actually reset (not anymore)
results = [r for r in results if len(r)]
self.assertEqual(1, len(results))
self.assertEqual(['a', 'b'], sorted([a[0] for a in results[0]]))
def test_many_thread_inject(self):
s = self._get_storage(threaded=True)
def inject_values(values):
s.inject(values)
threads = []
for i in range(0, self.thread_count):
values = {
str(i): str(i),
}
threads.append(threading.Thread(target=inject_values,
args=[values]))
self._run_many_threads(threads)
self.assertEqual(self.thread_count, len(s.fetch_all()))
self.assertEqual(1, len(s._flowdetail))
def test_fetch_meapped_args(self):
s = self._get_storage()
s.inject({'foo': 'bar', 'spam': 'eggs'})
@@ -357,7 +432,7 @@ class StorageTest(test.TestCase):
fd.state = states.FAILURE
with contextlib.closing(self.backend.get_connection()) as conn:
fd.update(conn.update_flow_details(fd))
s = storage.Storage(flow_detail=fd, backend=self.backend)
s = storage.SingleThreadedStorage(flow_detail=fd, backend=self.backend)
self.assertEqual(s.get_flow_state(), states.FAILURE)
def test_set_and_get_flow_state(self):

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import time
from concurrent import futures
from taskflow import test
from taskflow.utils import lock_utils
def _find_overlaps(times, start, end):
overlaps = 0
for (s, e) in times:
if s >= start and e <= end:
overlaps += 1
return overlaps
def _spawn_variation(readers, writers, max_workers=None):
start_stops = collections.deque()
lock = lock_utils.ReaderWriterLock()
def read_func():
with lock.read_lock():
start_stops.append(('r', time.time(), time.time()))
def write_func():
with lock.write_lock():
start_stops.append(('w', time.time(), time.time()))
if max_workers is None:
max_workers = max(0, readers) + max(0, writers)
if max_workers > 0:
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
for i in range(0, readers):
e.submit(read_func)
for i in range(0, writers):
e.submit(write_func)
writer_times = []
reader_times = []
for (t, start, stop) in list(start_stops):
if t == 'w':
writer_times.append((start, stop))
else:
reader_times.append((start, stop))
return (writer_times, reader_times)
class ReadWriteLockTest(test.TestCase):
def test_writer_abort(self):
lock = lock_utils.ReaderWriterLock()
self.assertFalse(lock.owner)
def blow_up():
with lock.write_lock():
self.assertEqual(lock.WRITER, lock.owner)
raise RuntimeError("Broken")
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner)
def test_reader_abort(self):
lock = lock_utils.ReaderWriterLock()
self.assertFalse(lock.owner)
def blow_up():
with lock.read_lock():
self.assertEqual(lock.READER, lock.owner)
raise RuntimeError("Broken")
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner)
def test_reader_chaotic(self):
lock = lock_utils.ReaderWriterLock()
activated = collections.deque()
def chaotic_reader(blow_up):
with lock.read_lock():
if blow_up:
raise RuntimeError("Broken")
else:
activated.append(lock.owner)
def happy_writer():
with lock.write_lock():
activated.append(lock.owner)
with futures.ThreadPoolExecutor(max_workers=20) as e:
for i in range(0, 20):
if i % 2 == 0:
e.submit(chaotic_reader, blow_up=bool(i % 4 == 0))
else:
e.submit(happy_writer)
writers = [a for a in activated if a == 'w']
readers = [a for a in activated if a == 'r']
self.assertEqual(10, len(writers))
self.assertEqual(5, len(readers))
def test_writer_chaotic(self):
lock = lock_utils.ReaderWriterLock()
activated = collections.deque()
def chaotic_writer(blow_up):
with lock.write_lock():
if blow_up:
raise RuntimeError("Broken")
else:
activated.append(lock.owner)
def happy_reader():
with lock.read_lock():
activated.append(lock.owner)
with futures.ThreadPoolExecutor(max_workers=20) as e:
for i in range(0, 20):
if i % 2 == 0:
e.submit(chaotic_writer, blow_up=bool(i % 4 == 0))
else:
e.submit(happy_reader)
writers = [a for a in activated if a == 'w']
readers = [a for a in activated if a == 'r']
self.assertEqual(5, len(writers))
self.assertEqual(10, len(readers))
def test_single_reader_writer(self):
results = []
lock = lock_utils.ReaderWriterLock()
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertEqual(0, len(results))
with lock.write_lock():
results.append(1)
self.assertTrue(lock.is_writer())
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertEqual(1, len(results))
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_reader_to_writer(self):
lock = lock_utils.ReaderWriterLock()
def writer_func():
with lock.write_lock():
pass
with lock.read_lock():
self.assertRaises(RuntimeError, writer_func)
self.assertFalse(lock.is_writer())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_writer_to_reader(self):
lock = lock_utils.ReaderWriterLock()
def reader_func():
with lock.read_lock():
pass
with lock.write_lock():
self.assertRaises(RuntimeError, reader_func)
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_double_writer(self):
lock = lock_utils.ReaderWriterLock()
with lock.write_lock():
self.assertFalse(lock.is_reader())
self.assertTrue(lock.is_writer())
with lock.write_lock():
self.assertTrue(lock.is_writer())
self.assertTrue(lock.is_writer())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_double_reader(self):
lock = lock_utils.ReaderWriterLock()
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertFalse(lock.is_writer())
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertTrue(lock.is_reader())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_multi_reader_multi_writer(self):
writer_times, reader_times = _spawn_variation(10, 10)
self.assertEqual(10, len(writer_times))
self.assertEqual(10, len(reader_times))
for (start, stop) in writer_times:
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
self.assertEqual(1, _find_overlaps(writer_times, start, stop))
for (start, stop) in reader_times:
self.assertEqual(0, _find_overlaps(writer_times, start, stop))
def test_multi_reader_single_writer(self):
writer_times, reader_times = _spawn_variation(9, 1)
self.assertEqual(1, len(writer_times))
self.assertEqual(9, len(reader_times))
start, stop = writer_times[0]
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
def test_multi_writer(self):
writer_times, reader_times = _spawn_variation(0, 10)
self.assertEqual(10, len(writer_times))
self.assertEqual(0, len(reader_times))
for (start, stop) in writer_times:
self.assertEqual(1, _find_overlaps(writer_times, start, stop))

View File

@@ -21,6 +21,8 @@
# pulls in oslo.cfg) and is reduced to only what taskflow currently wants to
# use from that code.
import collections
import contextlib
import errno
import logging
import os
@@ -28,6 +30,7 @@ import threading
import time
from taskflow.utils import misc
from taskflow.utils import threading_utils as tu
LOG = logging.getLogger(__name__)
@@ -62,6 +65,163 @@ def locked(*args, **kwargs):
return decorator
class ReaderWriterLock(object):
"""A reader/writer lock.
This lock allows for simultaneous readers to exist but only one writer
to exist for use-cases where it is useful to have such types of locks.
Currently a reader can not escalate its read lock to a write lock and
a writer can not acquire a read lock while it owns or is waiting on
the write lock.
In the future these restrictions may be relaxed.
"""
WRITER = 'w'
READER = 'r'
def __init__(self):
self._writer = None
self._pending_writers = collections.deque()
self._readers = collections.deque()
self._cond = threading.Condition()
def is_writer(self, check_pending=True):
"""Returns if the caller is the active writer or a pending writer."""
self._cond.acquire()
try:
me = tu.get_ident()
if self._writer is not None and self._writer == me:
return True
if check_pending:
return me in self._pending_writers
else:
return False
finally:
self._cond.release()
@property
def owner(self):
"""Returns whether the lock is locked by a writer or reader."""
self._cond.acquire()
try:
if self._writer is not None:
return self.WRITER
if self._readers:
return self.READER
return None
finally:
self._cond.release()
def is_reader(self):
"""Returns if the caller is one of the readers."""
self._cond.acquire()
try:
return tu.get_ident() in self._readers
finally:
self._cond.release()
@contextlib.contextmanager
def read_lock(self):
"""Grants a read lock.
Will wait until no active or pending writers.
Raises a RuntimeError if an active or pending writer tries to acquire
a read lock.
"""
me = tu.get_ident()
if self.is_writer():
raise RuntimeError("Writer %s can not acquire a read lock"
" while holding/waiting for the write lock"
% me)
self._cond.acquire()
try:
while True:
# No active or pending writers; we are good to become a reader.
if self._writer is None and len(self._pending_writers) == 0:
self._readers.append(me)
break
# Some writers; guess we have to wait.
self._cond.wait()
finally:
self._cond.release()
try:
yield self
finally:
# I am no longer a reader, remove *one* occurrence of myself.
# If the current thread acquired two read locks, then it will
# still have to remove that other read lock; this allows for
# basic reentrancy to be possible.
self._cond.acquire()
try:
self._readers.remove(me)
self._cond.notify_all()
finally:
self._cond.release()
@contextlib.contextmanager
def write_lock(self):
"""Grants a write lock.
Will wait until no active readers. Blocks readers after acquiring.
Raises a RuntimeError if an active reader attempts to acquire a lock.
"""
me = tu.get_ident()
if self.is_reader():
raise RuntimeError("Reader %s to writer privilege"
" escalation not allowed" % me)
if self.is_writer(check_pending=False):
# Already the writer; this allows for basic reentrancy.
yield self
else:
self._cond.acquire()
try:
self._pending_writers.append(me)
while True:
# No readers, and no active writer, am I next??
if len(self._readers) == 0 and self._writer is None:
if self._pending_writers[0] == me:
self._writer = self._pending_writers.popleft()
break
self._cond.wait()
finally:
self._cond.release()
try:
yield self
finally:
self._cond.acquire()
try:
self._writer = None
self._cond.notify_all()
finally:
self._cond.release()
class DummyReaderWriterLock(object):
"""A dummy reader/writer lock that doesn't lock anything but provides same
functions as a normal reader/writer lock class.
"""
@contextlib.contextmanager
def write_lock(self):
yield self
@contextlib.contextmanager
def read_lock(self):
yield self
@property
def owner(self):
return None
def is_reader(self):
return False
def is_writer(self):
return False
class MultiLock(object):
"""A class which can attempt to obtain many locks at once and release
said locks when exiting.

View File

@@ -16,16 +16,15 @@
# License for the specific language governing permissions and limitations
# under the License.
import logging
import multiprocessing
import threading
import types
import six
from taskflow.utils import lock_utils
LOG = logging.getLogger(__name__)
if six.PY2:
from thread import get_ident # noqa
else:
# In python3+ the get_ident call moved (whhhy??)
from threading import get_ident # noqa
def get_optimal_thread_count():
@@ -37,20 +36,3 @@ def get_optimal_thread_count():
# just setup two threads since its hard to know what else we
# should do in this situation.
return 2
class ThreadSafeMeta(type):
"""Metaclass that adds locking to all pubic methods of a class."""
def __new__(cls, name, bases, attrs):
for attr_name, attr_value in six.iteritems(attrs):
if isinstance(attr_value, types.FunctionType):
if attr_name[0] != '_':
attrs[attr_name] = lock_utils.locked(attr_value)
return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs)
def __call__(cls, *args, **kwargs):
instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs)
if not hasattr(instance, '_lock'):
instance._lock = threading.RLock()
return instance