From e9a319d7d317514a9b4d95e2ebc8247a3c41058d Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Mon, 27 Jan 2014 11:35:04 -0800 Subject: [PATCH] Use reader/writer locks in storage Switch to using a reader/writer lock scheme to protect against simultaneous storage mutations, typically seen when running in a multi-threaded mode. For the single-threaded mode provide a dummy reader/writer lock which will mimic the locking api but not actually lock anything. Closes-Bug: 1273146 Change-Id: I954f542d9ab34b693e8da71c9fc913f823e869ba --- taskflow/engines/action_engine/engine.py | 5 +- taskflow/storage.py | 298 ++++++++++--------- taskflow/tests/unit/test_storage.py | 95 +++++- taskflow/tests/unit/test_utils_lock_utils.py | 235 +++++++++++++++ taskflow/utils/lock_utils.py | 160 ++++++++++ taskflow/utils/threading_utils.py | 28 +- 6 files changed, 647 insertions(+), 174 deletions(-) create mode 100644 taskflow/tests/unit/test_utils_lock_utils.py diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 73dc8a76..77cb7fd5 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -191,13 +191,12 @@ class ActionEngine(base.EngineBase): class SingleThreadedActionEngine(ActionEngine): """Engine that runs tasks in serial manner.""" - _storage_cls = t_storage.Storage + _storage_cls = t_storage.SingleThreadedStorage class MultiThreadedActionEngine(ActionEngine): """Engine that runs tasks in parallel manner.""" - - _storage_cls = t_storage.ThreadSafeStorage + _storage_cls = t_storage.MultiThreadedStorage def _task_executor_cls(self): return executor.ParallelTaskExecutor(self._executor) diff --git a/taskflow/storage.py b/taskflow/storage.py index 8e71ee6e..bea79072 100644 --- a/taskflow/storage.py +++ b/taskflow/storage.py @@ -16,6 +16,7 @@ # License for the specific language governing permissions and limitations # under the License. +import abc import contextlib import logging @@ -25,13 +26,15 @@ from taskflow import exceptions from taskflow.openstack.common import uuidutils from taskflow.persistence import logbook from taskflow import states +from taskflow.utils import lock_utils from taskflow.utils import misc -from taskflow.utils import threading_utils as tu +from taskflow.utils import reflection LOG = logging.getLogger(__name__) STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE) +@six.add_metaclass(abc.ABCMeta) class Storage(object): """Interface between engines and logbook. @@ -48,11 +51,14 @@ class Storage(object): self._reverse_mapping = {} self._backend = backend self._flowdetail = flow_detail + self._lock = self._lock_cls() # NOTE(imelnikov): failure serialization looses information, # so we cache failures here, in task name -> misc.Failure mapping. self._failures = {} - self._reload_failures() + for td in self._flowdetail: + if td.failure is not None: + self._failures[td.name] = td.failure self._task_name_to_uuid = dict((td.name, td.uuid) for td in self._flowdetail) @@ -66,11 +72,20 @@ class Storage(object): self._set_result_mapping(injector_td.name, dict((name, name) for name in names)) + @abc.abstractproperty + def _lock_cls(self): + """Lock class used to generate reader/writer locks for protecting + read/write access to the underlying storage backend and internally + mutating operations. + """ + def _with_connection(self, functor, *args, **kwargs): # NOTE(harlowja): Activate the given function with a backend # connection, if a backend is provided in the first place, otherwise # don't call the function. if self._backend is None: + LOG.debug("No backend provided, not calling functor '%s'", + reflection.get_callable_name(functor)) return with contextlib.closing(self._backend.get_connection()) as conn: functor(conn, *args, **kwargs) @@ -85,12 +100,13 @@ class Storage(object): Returns uuid for the task details corresponding to the task with given name. """ - try: - task_id = self._task_name_to_uuid[task_name] - except KeyError: - task_id = uuidutils.generate_uuid() - self._add_task(task_id, task_name, task_version) - self._set_result_mapping(task_name, result_mapping) + with self._lock.write_lock(): + try: + task_id = self._task_name_to_uuid[task_name] + except KeyError: + task_id = uuidutils.generate_uuid() + self._add_task(task_id, task_name, task_version) + self._set_result_mapping(task_name, result_mapping) return task_id def _add_task(self, uuid, task_name, task_version=None): @@ -99,27 +115,29 @@ class Storage(object): Task becomes known to storage by that name and uuid. Task state is set to PENDING. """ + + def save_both(conn, td): + """Saves the flow and the task detail with the same connection.""" + self._save_flow_detail(conn) + self._save_task_detail(conn, td) + # TODO(imelnikov): check that task with same uuid or # task name does not exist. td = logbook.TaskDetail(name=task_name, uuid=uuid) td.state = states.PENDING td.version = task_version self._flowdetail.add(td) - - def save_both(conn): - """Saves the flow and the task detail with the same connection.""" - self._save_flow_detail(conn) - self._save_task_detail(conn, task_detail=td) - - self._with_connection(save_both) + self._with_connection(save_both, td) self._task_name_to_uuid[task_name] = uuid @property def flow_name(self): + # This never changes (so no read locking needed). return self._flowdetail.name @property def flow_uuid(self): + # This never changes (so no read locking needed). return self._flowdetail.uuid def _save_flow_detail(self, conn): @@ -130,13 +148,9 @@ class Storage(object): def _taskdetail_by_name(self, task_name): try: - td = self._flowdetail.find(self._task_name_to_uuid[task_name]) + return self._flowdetail.find(self._task_name_to_uuid[task_name]) except KeyError: - td = None - - if td is None: raise exceptions.NotFound("Unknown task name: %s" % task_name) - return td def _save_task_detail(self, conn, task_detail): # NOTE(harlowja): we need to update our contained task detail if @@ -146,34 +160,39 @@ class Storage(object): def get_task_uuid(self, task_name): """Get task uuid by given name.""" - td = self._taskdetail_by_name(task_name) - return td.uuid + with self._lock.read_lock(): + td = self._taskdetail_by_name(task_name) + return td.uuid def set_task_state(self, task_name, state): """Set task state.""" - td = self._taskdetail_by_name(task_name) - td.state = state - self._with_connection(self._save_task_detail, task_detail=td) + with self._lock.write_lock(): + td = self._taskdetail_by_name(task_name) + td.state = state + self._with_connection(self._save_task_detail, td) def get_task_state(self, task_name): """Get state of task with given name.""" - return self._taskdetail_by_name(task_name).state + with self._lock.read_lock(): + td = self._taskdetail_by_name(task_name) + return td.state def get_tasks_states(self, task_names): - return dict((name, self.get_task_state(name)) - for name in task_names) + """Gets all task states.""" + with self._lock.read_lock(): + return dict((name, self.get_task_state(name)) + for name in task_names) def update_task_metadata(self, task_name, update_with): + """Updates a tasks metadata.""" if not update_with: return - # NOTE(harlowja): this is a read and then write, not in 1 transaction - # so it is entirely possible that we could write over another writes - # metadata update. Maybe add some merging logic later? - td = self._taskdetail_by_name(task_name) - if not td.meta: - td.meta = {} - td.meta.update(update_with) - self._with_connection(self._save_task_detail, task_detail=td) + with self._lock.write_lock(): + td = self._taskdetail_by_name(task_name) + if not td.meta: + td.meta = {} + td.meta.update(update_with) + self._with_connection(self._save_task_detail, td) def set_task_progress(self, task_name, progress, details=None): """Set task progress. @@ -204,10 +223,11 @@ class Storage(object): :param task_name: task name :returns: current task progress value """ - meta = self._taskdetail_by_name(task_name).meta - if not meta: - return 0.0 - return meta.get('progress', 0.0) + with self._lock.read_lock(): + td = self._taskdetail_by_name(task_name) + if not td.meta: + return 0.0 + return td.meta.get('progress', 0.0) def get_task_progress_details(self, task_name): """Get progress details of task with given name. @@ -216,10 +236,11 @@ class Storage(object): :returns: None if progress_details not defined, else progress_details dict """ - meta = self._taskdetail_by_name(task_name).meta - if not meta: - return None - return meta.get('progress_details') + with self._lock.read_lock(): + td = self._taskdetail_by_name(task_name) + if not td.meta: + return None + return td.meta.get('progress_details') def _check_all_results_provided(self, task_name, data): """Warn if task did not provide some of results. @@ -228,69 +249,57 @@ class Storage(object): without all needed keys. It may also happen if task returns result of wrong type. """ - result_mapping = self._result_mappings.get(task_name, None) - if result_mapping is None: + result_mapping = self._result_mappings.get(task_name) + if not result_mapping: return for name, index in six.iteritems(result_mapping): try: misc.item_from(data, index, name=name) except exceptions.NotFound: LOG.warning("Task %s did not supply result " - "with index %r (name %s)", - task_name, index, name) + "with index %r (name %s)", task_name, index, name) def save(self, task_name, data, state=states.SUCCESS): """Put result for task with id 'uuid' to storage.""" - td = self._taskdetail_by_name(task_name) - td.state = state - if state == states.FAILURE and isinstance(data, misc.Failure): - td.results = None - td.failure = data - self._failures[td.name] = data - else: - td.results = data - td.failure = None - self._check_all_results_provided(td.name, data) - self._with_connection(self._save_task_detail, task_detail=td) - - def _cache_failure(self, name, fail): - """Ensure that cache has matching failure for task with this name. - - We leave cached version if it matches as it may contain more - information. Returns cached failure. - """ - cached = self._failures.get(name) - if fail.matches(cached): - return cached - self._failures[name] = fail - return fail - - def _reload_failures(self): - """Refresh failures cache.""" - for td in self._flowdetail: - if td.failure is not None: - self._cache_failure(td.name, td.failure) + with self._lock.write_lock(): + td = self._taskdetail_by_name(task_name) + td.state = state + if state == states.FAILURE and isinstance(data, misc.Failure): + td.results = None + td.failure = data + self._failures[td.name] = data + else: + td.results = data + td.failure = None + self._check_all_results_provided(td.name, data) + self._with_connection(self._save_task_detail, td) def get(self, task_name): """Get result for task with name 'task_name' to storage.""" - td = self._taskdetail_by_name(task_name) - if td.failure is not None: - return self._cache_failure(td.name, td.failure) - if td.state not in STATES_WITH_RESULTS: - raise exceptions.NotFound( - "Result for task %s is not known" % task_name) - return td.results + with self._lock.read_lock(): + td = self._taskdetail_by_name(task_name) + if td.failure is not None: + cached = self._failures.get(task_name) + if td.failure.matches(cached): + return cached + return td.failure + if td.state not in STATES_WITH_RESULTS: + raise exceptions.NotFound("Result for task %s is not known" + % task_name) + return td.results def get_failures(self): """Get list of failures that happened with this flow. No order guaranteed. """ - return self._failures.copy() + with self._lock.read_lock(): + return self._failures.copy() def has_failures(self): """Returns True if there are failed tasks in the storage.""" - return bool(self._failures) + with self._lock.read_lock(): + return bool(self._failures) def _reset_task(self, td, state): if td.name == self.injector_name: @@ -305,25 +314,28 @@ class Storage(object): def reset(self, task_name, state=states.PENDING): """Remove result for task with id 'uuid' from storage.""" - td = self._taskdetail_by_name(task_name) - if self._reset_task(td, state): - self._with_connection(self._save_task_detail, task_detail=td) + with self._lock.write_lock(): + td = self._taskdetail_by_name(task_name) + if self._reset_task(td, state): + self._with_connection(self._save_task_detail, td) def reset_tasks(self): """Reset all tasks to PENDING state, removing results. Returns list of (name, uuid) tuples for all tasks that were reset. """ - result = [] + reset_results = [] def do_reset_all(connection): for td in self._flowdetail: if self._reset_task(td, states.PENDING): self._save_task_detail(connection, td) - result.append((td.name, td.uuid)) + reset_results.append((td.name, td.uuid)) - self._with_connection(do_reset_all) - return result + with self._lock.write_lock(): + self._with_connection(do_reset_all) + + return reset_results def inject(self, pairs): """Add values into storage. @@ -331,20 +343,20 @@ class Storage(object): This method should be used to put flow parameters (requirements that are not satisfied by any task in the flow) into storage. """ - try: - injector_td = self._taskdetail_by_name(self.injector_name) - except exceptions.NotFound: - injector_uuid = uuidutils.generate_uuid() - self._add_task(injector_uuid, self.injector_name) - results = dict(pairs) - else: - results = injector_td.results.copy() - results.update(pairs) - - self.save(self.injector_name, results) - names = six.iterkeys(results) - self._set_result_mapping(self.injector_name, - dict((name, name) for name in names)) + with self._lock.write_lock(): + try: + td = self._taskdetail_by_name(self.injector_name) + except exceptions.NotFound: + self._add_task(uuidutils.generate_uuid(), self.injector_name) + td = self._taskdetail_by_name(self.injector_name) + td.results = dict(pairs) + td.state = states.SUCCESS + else: + td.results.update(pairs) + self._with_connection(self._save_task_detail, td) + names = six.iterkeys(td.results) + self._set_result_mapping(self.injector_name, + dict((name, name) for name in names)) def _set_result_mapping(self, task_name, mapping): """Set mapping for naming task results. @@ -378,50 +390,60 @@ class Storage(object): def fetch(self, name): """Fetch named task result.""" - try: - indexes = self._reverse_mapping[name] - except KeyError: - raise exceptions.NotFound("Name %r is not mapped" % name) - # Return the first one that is found. - for task_name, index in reversed(indexes): + with self._lock.read_lock(): try: - result = self.get(task_name) - return misc.item_from(result, index, name=name) - except exceptions.NotFound: - pass - raise exceptions.NotFound("Unable to find result %r" % name) + indexes = self._reverse_mapping[name] + except KeyError: + raise exceptions.NotFound("Name %r is not mapped" % name) + # Return the first one that is found. + for (task_name, index) in reversed(indexes): + try: + result = self.get(task_name) + return misc.item_from(result, index, name=name) + except exceptions.NotFound: + pass + raise exceptions.NotFound("Unable to find result %r" % name) def fetch_all(self): """Fetch all named task results known so far. Should be used for debugging and testing purposes mostly. """ - result = {} - for name in self._reverse_mapping: - try: - result[name] = self.fetch(name) - except exceptions.NotFound: - pass - return result + with self._lock.read_lock(): + results = {} + for name in self._reverse_mapping: + try: + results[name] = self.fetch(name) + except exceptions.NotFound: + pass + return results def fetch_mapped_args(self, args_mapping): """Fetch arguments for the task using arguments mapping.""" - return dict((key, self.fetch(name)) - for key, name in six.iteritems(args_mapping)) + with self._lock.read_lock(): + return dict((key, self.fetch(name)) + for key, name in six.iteritems(args_mapping)) def set_flow_state(self, state): - """Set flowdetails state and save it.""" - self._flowdetail.state = state - self._with_connection(self._save_flow_detail) + """Set flow details state and save it.""" + with self._lock.write_lock(): + self._flowdetail.state = state + self._with_connection(self._save_flow_detail) def get_flow_state(self): - """Set state from flowdetails.""" - state = self._flowdetail.state - if state is None: - state = states.PENDING - return state + """Get state from flow details.""" + with self._lock.read_lock(): + state = self._flowdetail.state + if state is None: + state = states.PENDING + return state -@six.add_metaclass(tu.ThreadSafeMeta) -class ThreadSafeStorage(Storage): - pass +class MultiThreadedStorage(Storage): + """Storage that uses locks to protect against concurrent access.""" + _lock_cls = lock_utils.ReaderWriterLock + + +class SingleThreadedStorage(Storage): + """Storage that uses dummy locks when you really don't need locks.""" + _lock_cls = lock_utils.DummyReaderWriterLock diff --git a/taskflow/tests/unit/test_storage.py b/taskflow/tests/unit/test_storage.py index db814b96..c3494f24 100644 --- a/taskflow/tests/unit/test_storage.py +++ b/taskflow/tests/unit/test_storage.py @@ -17,6 +17,7 @@ # under the License. import contextlib +import threading import mock @@ -35,10 +36,22 @@ class StorageTest(test.TestCase): def setUp(self): super(StorageTest, self).setUp() self.backend = impl_memory.MemoryBackend(conf={}) + self.thread_count = 50 - def _get_storage(self): + def _run_many_threads(self, threads): + for t in threads: + t.start() + for t in threads: + t.join() + + def _get_storage(self, threaded=False): _lb, flow_detail = p_utils.temporary_flow_detail(self.backend) - return storage.Storage(backend=self.backend, flow_detail=flow_detail) + if threaded: + return storage.MultiThreadedStorage(backend=self.backend, + flow_detail=flow_detail) + else: + return storage.SingleThreadedStorage(backend=self.backend, + flow_detail=flow_detail) def tearDown(self): super(StorageTest, self).tearDown() @@ -48,14 +61,14 @@ class StorageTest(test.TestCase): def test_non_saving_storage(self): _lb, flow_detail = p_utils.temporary_flow_detail(self.backend) - s = storage.Storage(flow_detail=flow_detail) # no backend + s = storage.SingleThreadedStorage(flow_detail=flow_detail) s.ensure_task('my_task') self.assertTrue( uuidutils.is_uuid_like(s.get_task_uuid('my_task'))) def test_flow_name_and_uuid(self): fd = logbook.FlowDetail(name='test-fd', uuid='aaaa') - s = storage.Storage(flow_detail=fd) + s = storage.SingleThreadedStorage(flow_detail=fd) self.assertEqual(s.flow_name, 'test-fd') self.assertEqual(s.flow_uuid, 'aaaa') @@ -79,7 +92,8 @@ class StorageTest(test.TestCase): def test_ensure_task_fd(self): _lb, flow_detail = p_utils.temporary_flow_detail(self.backend) - s = storage.Storage(backend=self.backend, flow_detail=flow_detail) + s = storage.SingleThreadedStorage(backend=self.backend, + flow_detail=flow_detail) s.ensure_task('my task', '3.11') td = flow_detail.find(s.get_task_uuid('my task')) self.assertIsNotNone(td) @@ -92,7 +106,8 @@ class StorageTest(test.TestCase): td = logbook.TaskDetail(name='my_task', uuid='42') flow_detail.add(td) - s = storage.Storage(backend=self.backend, flow_detail=flow_detail) + s = storage.SingleThreadedStorage(backend=self.backend, + flow_detail=flow_detail) self.assertEqual('42', s.get_task_uuid('my_task')) def test_ensure_existing_task(self): @@ -100,7 +115,8 @@ class StorageTest(test.TestCase): td = logbook.TaskDetail(name='my_task', uuid='42') flow_detail.add(td) - s = storage.Storage(backend=self.backend, flow_detail=flow_detail) + s = storage.SingleThreadedStorage(backend=self.backend, + flow_detail=flow_detail) s.ensure_task('my_task') self.assertEqual('42', s.get_task_uuid('my_task')) @@ -147,7 +163,8 @@ class StorageTest(test.TestCase): s.ensure_task('my task') s.save('my task', fail, states.FAILURE) - s2 = storage.Storage(backend=self.backend, flow_detail=s._flowdetail) + s2 = storage.SingleThreadedStorage(backend=self.backend, + flow_detail=s._flowdetail) self.assertIs(s2.has_failures(), True) self.assertEqual(s2.get_failures(), {'my task': fail}) self.assertEqual(s2.get('my task'), fail) @@ -305,13 +322,71 @@ class StorageTest(test.TestCase): }) # imagine we are resuming, so we need to make new # storage from same flow details - s2 = storage.Storage(s._flowdetail, backend=self.backend) + s2 = storage.SingleThreadedStorage(s._flowdetail, backend=self.backend) # injected data should still be there: self.assertEqual(s2.fetch_all(), { 'foo': 'bar', 'spam': 'eggs', }) + def test_many_thread_ensure_same_task(self): + s = self._get_storage(threaded=True) + + def ensure_my_task(): + s.ensure_task('my_task', result_mapping={}) + + threads = [] + for i in range(0, self.thread_count): + threads.append(threading.Thread(target=ensure_my_task)) + self._run_many_threads(threads) + + # Only one task should have been made, no more. + self.assertEqual(1, len(s._flowdetail)) + + def test_many_thread_one_reset(self): + s = self._get_storage(threaded=True) + s.ensure_task('a') + s.set_task_state('a', states.SUSPENDED) + s.ensure_task('b') + s.set_task_state('b', states.SUSPENDED) + + results = [] + result_lock = threading.Lock() + + def reset_all(): + r = s.reset_tasks() + with result_lock: + results.append(r) + + threads = [] + for i in range(0, self.thread_count): + threads.append(threading.Thread(target=reset_all)) + + self._run_many_threads(threads) + + # Only one thread should have actually reset (not anymore) + results = [r for r in results if len(r)] + self.assertEqual(1, len(results)) + self.assertEqual(['a', 'b'], sorted([a[0] for a in results[0]])) + + def test_many_thread_inject(self): + s = self._get_storage(threaded=True) + + def inject_values(values): + s.inject(values) + + threads = [] + for i in range(0, self.thread_count): + values = { + str(i): str(i), + } + threads.append(threading.Thread(target=inject_values, + args=[values])) + + self._run_many_threads(threads) + self.assertEqual(self.thread_count, len(s.fetch_all())) + self.assertEqual(1, len(s._flowdetail)) + def test_fetch_meapped_args(self): s = self._get_storage() s.inject({'foo': 'bar', 'spam': 'eggs'}) @@ -357,7 +432,7 @@ class StorageTest(test.TestCase): fd.state = states.FAILURE with contextlib.closing(self.backend.get_connection()) as conn: fd.update(conn.update_flow_details(fd)) - s = storage.Storage(flow_detail=fd, backend=self.backend) + s = storage.SingleThreadedStorage(flow_detail=fd, backend=self.backend) self.assertEqual(s.get_flow_state(), states.FAILURE) def test_set_and_get_flow_state(self): diff --git a/taskflow/tests/unit/test_utils_lock_utils.py b/taskflow/tests/unit/test_utils_lock_utils.py new file mode 100644 index 00000000..6846b13a --- /dev/null +++ b/taskflow/tests/unit/test_utils_lock_utils.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import time + +from concurrent import futures + +from taskflow import test +from taskflow.utils import lock_utils + + +def _find_overlaps(times, start, end): + overlaps = 0 + for (s, e) in times: + if s >= start and e <= end: + overlaps += 1 + return overlaps + + +def _spawn_variation(readers, writers, max_workers=None): + start_stops = collections.deque() + lock = lock_utils.ReaderWriterLock() + + def read_func(): + with lock.read_lock(): + start_stops.append(('r', time.time(), time.time())) + + def write_func(): + with lock.write_lock(): + start_stops.append(('w', time.time(), time.time())) + + if max_workers is None: + max_workers = max(0, readers) + max(0, writers) + if max_workers > 0: + with futures.ThreadPoolExecutor(max_workers=max_workers) as e: + for i in range(0, readers): + e.submit(read_func) + for i in range(0, writers): + e.submit(write_func) + + writer_times = [] + reader_times = [] + for (t, start, stop) in list(start_stops): + if t == 'w': + writer_times.append((start, stop)) + else: + reader_times.append((start, stop)) + return (writer_times, reader_times) + + +class ReadWriteLockTest(test.TestCase): + def test_writer_abort(self): + lock = lock_utils.ReaderWriterLock() + self.assertFalse(lock.owner) + + def blow_up(): + with lock.write_lock(): + self.assertEqual(lock.WRITER, lock.owner) + raise RuntimeError("Broken") + + self.assertRaises(RuntimeError, blow_up) + self.assertFalse(lock.owner) + + def test_reader_abort(self): + lock = lock_utils.ReaderWriterLock() + self.assertFalse(lock.owner) + + def blow_up(): + with lock.read_lock(): + self.assertEqual(lock.READER, lock.owner) + raise RuntimeError("Broken") + + self.assertRaises(RuntimeError, blow_up) + self.assertFalse(lock.owner) + + def test_reader_chaotic(self): + lock = lock_utils.ReaderWriterLock() + activated = collections.deque() + + def chaotic_reader(blow_up): + with lock.read_lock(): + if blow_up: + raise RuntimeError("Broken") + else: + activated.append(lock.owner) + + def happy_writer(): + with lock.write_lock(): + activated.append(lock.owner) + + with futures.ThreadPoolExecutor(max_workers=20) as e: + for i in range(0, 20): + if i % 2 == 0: + e.submit(chaotic_reader, blow_up=bool(i % 4 == 0)) + else: + e.submit(happy_writer) + + writers = [a for a in activated if a == 'w'] + readers = [a for a in activated if a == 'r'] + self.assertEqual(10, len(writers)) + self.assertEqual(5, len(readers)) + + def test_writer_chaotic(self): + lock = lock_utils.ReaderWriterLock() + activated = collections.deque() + + def chaotic_writer(blow_up): + with lock.write_lock(): + if blow_up: + raise RuntimeError("Broken") + else: + activated.append(lock.owner) + + def happy_reader(): + with lock.read_lock(): + activated.append(lock.owner) + + with futures.ThreadPoolExecutor(max_workers=20) as e: + for i in range(0, 20): + if i % 2 == 0: + e.submit(chaotic_writer, blow_up=bool(i % 4 == 0)) + else: + e.submit(happy_reader) + + writers = [a for a in activated if a == 'w'] + readers = [a for a in activated if a == 'r'] + self.assertEqual(5, len(writers)) + self.assertEqual(10, len(readers)) + + def test_single_reader_writer(self): + results = [] + lock = lock_utils.ReaderWriterLock() + with lock.read_lock(): + self.assertTrue(lock.is_reader()) + self.assertEqual(0, len(results)) + with lock.write_lock(): + results.append(1) + self.assertTrue(lock.is_writer()) + with lock.read_lock(): + self.assertTrue(lock.is_reader()) + self.assertEqual(1, len(results)) + self.assertFalse(lock.is_reader()) + self.assertFalse(lock.is_writer()) + + def test_reader_to_writer(self): + lock = lock_utils.ReaderWriterLock() + + def writer_func(): + with lock.write_lock(): + pass + + with lock.read_lock(): + self.assertRaises(RuntimeError, writer_func) + self.assertFalse(lock.is_writer()) + + self.assertFalse(lock.is_reader()) + self.assertFalse(lock.is_writer()) + + def test_writer_to_reader(self): + lock = lock_utils.ReaderWriterLock() + + def reader_func(): + with lock.read_lock(): + pass + + with lock.write_lock(): + self.assertRaises(RuntimeError, reader_func) + self.assertFalse(lock.is_reader()) + + self.assertFalse(lock.is_reader()) + self.assertFalse(lock.is_writer()) + + def test_double_writer(self): + lock = lock_utils.ReaderWriterLock() + with lock.write_lock(): + self.assertFalse(lock.is_reader()) + self.assertTrue(lock.is_writer()) + with lock.write_lock(): + self.assertTrue(lock.is_writer()) + self.assertTrue(lock.is_writer()) + + self.assertFalse(lock.is_reader()) + self.assertFalse(lock.is_writer()) + + def test_double_reader(self): + lock = lock_utils.ReaderWriterLock() + with lock.read_lock(): + self.assertTrue(lock.is_reader()) + self.assertFalse(lock.is_writer()) + with lock.read_lock(): + self.assertTrue(lock.is_reader()) + self.assertTrue(lock.is_reader()) + + self.assertFalse(lock.is_reader()) + self.assertFalse(lock.is_writer()) + + def test_multi_reader_multi_writer(self): + writer_times, reader_times = _spawn_variation(10, 10) + self.assertEqual(10, len(writer_times)) + self.assertEqual(10, len(reader_times)) + for (start, stop) in writer_times: + self.assertEqual(0, _find_overlaps(reader_times, start, stop)) + self.assertEqual(1, _find_overlaps(writer_times, start, stop)) + for (start, stop) in reader_times: + self.assertEqual(0, _find_overlaps(writer_times, start, stop)) + + def test_multi_reader_single_writer(self): + writer_times, reader_times = _spawn_variation(9, 1) + self.assertEqual(1, len(writer_times)) + self.assertEqual(9, len(reader_times)) + start, stop = writer_times[0] + self.assertEqual(0, _find_overlaps(reader_times, start, stop)) + + def test_multi_writer(self): + writer_times, reader_times = _spawn_variation(0, 10) + self.assertEqual(10, len(writer_times)) + self.assertEqual(0, len(reader_times)) + for (start, stop) in writer_times: + self.assertEqual(1, _find_overlaps(writer_times, start, stop)) diff --git a/taskflow/utils/lock_utils.py b/taskflow/utils/lock_utils.py index 7f227983..dc987df9 100644 --- a/taskflow/utils/lock_utils.py +++ b/taskflow/utils/lock_utils.py @@ -21,6 +21,8 @@ # pulls in oslo.cfg) and is reduced to only what taskflow currently wants to # use from that code. +import collections +import contextlib import errno import logging import os @@ -28,6 +30,7 @@ import threading import time from taskflow.utils import misc +from taskflow.utils import threading_utils as tu LOG = logging.getLogger(__name__) @@ -62,6 +65,163 @@ def locked(*args, **kwargs): return decorator +class ReaderWriterLock(object): + """A reader/writer lock. + + This lock allows for simultaneous readers to exist but only one writer + to exist for use-cases where it is useful to have such types of locks. + + Currently a reader can not escalate its read lock to a write lock and + a writer can not acquire a read lock while it owns or is waiting on + the write lock. + + In the future these restrictions may be relaxed. + """ + WRITER = 'w' + READER = 'r' + + def __init__(self): + self._writer = None + self._pending_writers = collections.deque() + self._readers = collections.deque() + self._cond = threading.Condition() + + def is_writer(self, check_pending=True): + """Returns if the caller is the active writer or a pending writer.""" + self._cond.acquire() + try: + me = tu.get_ident() + if self._writer is not None and self._writer == me: + return True + if check_pending: + return me in self._pending_writers + else: + return False + finally: + self._cond.release() + + @property + def owner(self): + """Returns whether the lock is locked by a writer or reader.""" + self._cond.acquire() + try: + if self._writer is not None: + return self.WRITER + if self._readers: + return self.READER + return None + finally: + self._cond.release() + + def is_reader(self): + """Returns if the caller is one of the readers.""" + self._cond.acquire() + try: + return tu.get_ident() in self._readers + finally: + self._cond.release() + + @contextlib.contextmanager + def read_lock(self): + """Grants a read lock. + + Will wait until no active or pending writers. + + Raises a RuntimeError if an active or pending writer tries to acquire + a read lock. + """ + me = tu.get_ident() + if self.is_writer(): + raise RuntimeError("Writer %s can not acquire a read lock" + " while holding/waiting for the write lock" + % me) + self._cond.acquire() + try: + while True: + # No active or pending writers; we are good to become a reader. + if self._writer is None and len(self._pending_writers) == 0: + self._readers.append(me) + break + # Some writers; guess we have to wait. + self._cond.wait() + finally: + self._cond.release() + try: + yield self + finally: + # I am no longer a reader, remove *one* occurrence of myself. + # If the current thread acquired two read locks, then it will + # still have to remove that other read lock; this allows for + # basic reentrancy to be possible. + self._cond.acquire() + try: + self._readers.remove(me) + self._cond.notify_all() + finally: + self._cond.release() + + @contextlib.contextmanager + def write_lock(self): + """Grants a write lock. + + Will wait until no active readers. Blocks readers after acquiring. + + Raises a RuntimeError if an active reader attempts to acquire a lock. + """ + me = tu.get_ident() + if self.is_reader(): + raise RuntimeError("Reader %s to writer privilege" + " escalation not allowed" % me) + if self.is_writer(check_pending=False): + # Already the writer; this allows for basic reentrancy. + yield self + else: + self._cond.acquire() + try: + self._pending_writers.append(me) + while True: + # No readers, and no active writer, am I next?? + if len(self._readers) == 0 and self._writer is None: + if self._pending_writers[0] == me: + self._writer = self._pending_writers.popleft() + break + self._cond.wait() + finally: + self._cond.release() + try: + yield self + finally: + self._cond.acquire() + try: + self._writer = None + self._cond.notify_all() + finally: + self._cond.release() + + +class DummyReaderWriterLock(object): + """A dummy reader/writer lock that doesn't lock anything but provides same + functions as a normal reader/writer lock class. + """ + @contextlib.contextmanager + def write_lock(self): + yield self + + @contextlib.contextmanager + def read_lock(self): + yield self + + @property + def owner(self): + return None + + def is_reader(self): + return False + + def is_writer(self): + return False + + class MultiLock(object): """A class which can attempt to obtain many locks at once and release said locks when exiting. diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py index b47fb701..0542633d 100644 --- a/taskflow/utils/threading_utils.py +++ b/taskflow/utils/threading_utils.py @@ -16,16 +16,15 @@ # License for the specific language governing permissions and limitations # under the License. -import logging import multiprocessing -import threading -import types import six -from taskflow.utils import lock_utils - -LOG = logging.getLogger(__name__) +if six.PY2: + from thread import get_ident # noqa +else: + # In python3+ the get_ident call moved (whhhy??) + from threading import get_ident # noqa def get_optimal_thread_count(): @@ -37,20 +36,3 @@ def get_optimal_thread_count(): # just setup two threads since its hard to know what else we # should do in this situation. return 2 - - -class ThreadSafeMeta(type): - """Metaclass that adds locking to all pubic methods of a class.""" - - def __new__(cls, name, bases, attrs): - for attr_name, attr_value in six.iteritems(attrs): - if isinstance(attr_value, types.FunctionType): - if attr_name[0] != '_': - attrs[attr_name] = lock_utils.locked(attr_value) - return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs) - - def __call__(cls, *args, **kwargs): - instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs) - if not hasattr(instance, '_lock'): - instance._lock = threading.RLock() - return instance