Merge "Use reader/writer locks in storage"

This commit is contained in:
Jenkins
2014-02-05 17:58:38 +00:00
committed by Gerrit Code Review
6 changed files with 647 additions and 174 deletions

View File

@@ -191,13 +191,12 @@ class ActionEngine(base.EngineBase):
class SingleThreadedActionEngine(ActionEngine): class SingleThreadedActionEngine(ActionEngine):
"""Engine that runs tasks in serial manner.""" """Engine that runs tasks in serial manner."""
_storage_cls = t_storage.Storage _storage_cls = t_storage.SingleThreadedStorage
class MultiThreadedActionEngine(ActionEngine): class MultiThreadedActionEngine(ActionEngine):
"""Engine that runs tasks in parallel manner.""" """Engine that runs tasks in parallel manner."""
_storage_cls = t_storage.MultiThreadedStorage
_storage_cls = t_storage.ThreadSafeStorage
def _task_executor_cls(self): def _task_executor_cls(self):
return executor.ParallelTaskExecutor(self._executor) return executor.ParallelTaskExecutor(self._executor)

View File

@@ -16,6 +16,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import abc
import contextlib import contextlib
import logging import logging
@@ -25,13 +26,15 @@ from taskflow import exceptions
from taskflow.openstack.common import uuidutils from taskflow.openstack.common import uuidutils
from taskflow.persistence import logbook from taskflow.persistence import logbook
from taskflow import states from taskflow import states
from taskflow.utils import lock_utils
from taskflow.utils import misc from taskflow.utils import misc
from taskflow.utils import threading_utils as tu from taskflow.utils import reflection
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE) STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
@six.add_metaclass(abc.ABCMeta)
class Storage(object): class Storage(object):
"""Interface between engines and logbook. """Interface between engines and logbook.
@@ -48,11 +51,14 @@ class Storage(object):
self._reverse_mapping = {} self._reverse_mapping = {}
self._backend = backend self._backend = backend
self._flowdetail = flow_detail self._flowdetail = flow_detail
self._lock = self._lock_cls()
# NOTE(imelnikov): failure serialization looses information, # NOTE(imelnikov): failure serialization looses information,
# so we cache failures here, in task name -> misc.Failure mapping. # so we cache failures here, in task name -> misc.Failure mapping.
self._failures = {} self._failures = {}
self._reload_failures() for td in self._flowdetail:
if td.failure is not None:
self._failures[td.name] = td.failure
self._task_name_to_uuid = dict((td.name, td.uuid) self._task_name_to_uuid = dict((td.name, td.uuid)
for td in self._flowdetail) for td in self._flowdetail)
@@ -66,11 +72,20 @@ class Storage(object):
self._set_result_mapping(injector_td.name, self._set_result_mapping(injector_td.name,
dict((name, name) for name in names)) dict((name, name) for name in names))
@abc.abstractproperty
def _lock_cls(self):
"""Lock class used to generate reader/writer locks for protecting
read/write access to the underlying storage backend and internally
mutating operations.
"""
def _with_connection(self, functor, *args, **kwargs): def _with_connection(self, functor, *args, **kwargs):
# NOTE(harlowja): Activate the given function with a backend # NOTE(harlowja): Activate the given function with a backend
# connection, if a backend is provided in the first place, otherwise # connection, if a backend is provided in the first place, otherwise
# don't call the function. # don't call the function.
if self._backend is None: if self._backend is None:
LOG.debug("No backend provided, not calling functor '%s'",
reflection.get_callable_name(functor))
return return
with contextlib.closing(self._backend.get_connection()) as conn: with contextlib.closing(self._backend.get_connection()) as conn:
functor(conn, *args, **kwargs) functor(conn, *args, **kwargs)
@@ -85,12 +100,13 @@ class Storage(object):
Returns uuid for the task details corresponding to the task with Returns uuid for the task details corresponding to the task with
given name. given name.
""" """
try: with self._lock.write_lock():
task_id = self._task_name_to_uuid[task_name] try:
except KeyError: task_id = self._task_name_to_uuid[task_name]
task_id = uuidutils.generate_uuid() except KeyError:
self._add_task(task_id, task_name, task_version) task_id = uuidutils.generate_uuid()
self._set_result_mapping(task_name, result_mapping) self._add_task(task_id, task_name, task_version)
self._set_result_mapping(task_name, result_mapping)
return task_id return task_id
def _add_task(self, uuid, task_name, task_version=None): def _add_task(self, uuid, task_name, task_version=None):
@@ -99,27 +115,29 @@ class Storage(object):
Task becomes known to storage by that name and uuid. Task becomes known to storage by that name and uuid.
Task state is set to PENDING. Task state is set to PENDING.
""" """
def save_both(conn, td):
"""Saves the flow and the task detail with the same connection."""
self._save_flow_detail(conn)
self._save_task_detail(conn, td)
# TODO(imelnikov): check that task with same uuid or # TODO(imelnikov): check that task with same uuid or
# task name does not exist. # task name does not exist.
td = logbook.TaskDetail(name=task_name, uuid=uuid) td = logbook.TaskDetail(name=task_name, uuid=uuid)
td.state = states.PENDING td.state = states.PENDING
td.version = task_version td.version = task_version
self._flowdetail.add(td) self._flowdetail.add(td)
self._with_connection(save_both, td)
def save_both(conn):
"""Saves the flow and the task detail with the same connection."""
self._save_flow_detail(conn)
self._save_task_detail(conn, task_detail=td)
self._with_connection(save_both)
self._task_name_to_uuid[task_name] = uuid self._task_name_to_uuid[task_name] = uuid
@property @property
def flow_name(self): def flow_name(self):
# This never changes (so no read locking needed).
return self._flowdetail.name return self._flowdetail.name
@property @property
def flow_uuid(self): def flow_uuid(self):
# This never changes (so no read locking needed).
return self._flowdetail.uuid return self._flowdetail.uuid
def _save_flow_detail(self, conn): def _save_flow_detail(self, conn):
@@ -130,13 +148,9 @@ class Storage(object):
def _taskdetail_by_name(self, task_name): def _taskdetail_by_name(self, task_name):
try: try:
td = self._flowdetail.find(self._task_name_to_uuid[task_name]) return self._flowdetail.find(self._task_name_to_uuid[task_name])
except KeyError: except KeyError:
td = None
if td is None:
raise exceptions.NotFound("Unknown task name: %s" % task_name) raise exceptions.NotFound("Unknown task name: %s" % task_name)
return td
def _save_task_detail(self, conn, task_detail): def _save_task_detail(self, conn, task_detail):
# NOTE(harlowja): we need to update our contained task detail if # NOTE(harlowja): we need to update our contained task detail if
@@ -146,34 +160,39 @@ class Storage(object):
def get_task_uuid(self, task_name): def get_task_uuid(self, task_name):
"""Get task uuid by given name.""" """Get task uuid by given name."""
td = self._taskdetail_by_name(task_name) with self._lock.read_lock():
return td.uuid td = self._taskdetail_by_name(task_name)
return td.uuid
def set_task_state(self, task_name, state): def set_task_state(self, task_name, state):
"""Set task state.""" """Set task state."""
td = self._taskdetail_by_name(task_name) with self._lock.write_lock():
td.state = state td = self._taskdetail_by_name(task_name)
self._with_connection(self._save_task_detail, task_detail=td) td.state = state
self._with_connection(self._save_task_detail, td)
def get_task_state(self, task_name): def get_task_state(self, task_name):
"""Get state of task with given name.""" """Get state of task with given name."""
return self._taskdetail_by_name(task_name).state with self._lock.read_lock():
td = self._taskdetail_by_name(task_name)
return td.state
def get_tasks_states(self, task_names): def get_tasks_states(self, task_names):
return dict((name, self.get_task_state(name)) """Gets all task states."""
for name in task_names) with self._lock.read_lock():
return dict((name, self.get_task_state(name))
for name in task_names)
def update_task_metadata(self, task_name, update_with): def update_task_metadata(self, task_name, update_with):
"""Updates a tasks metadata."""
if not update_with: if not update_with:
return return
# NOTE(harlowja): this is a read and then write, not in 1 transaction with self._lock.write_lock():
# so it is entirely possible that we could write over another writes td = self._taskdetail_by_name(task_name)
# metadata update. Maybe add some merging logic later? if not td.meta:
td = self._taskdetail_by_name(task_name) td.meta = {}
if not td.meta: td.meta.update(update_with)
td.meta = {} self._with_connection(self._save_task_detail, td)
td.meta.update(update_with)
self._with_connection(self._save_task_detail, task_detail=td)
def set_task_progress(self, task_name, progress, details=None): def set_task_progress(self, task_name, progress, details=None):
"""Set task progress. """Set task progress.
@@ -204,10 +223,11 @@ class Storage(object):
:param task_name: task name :param task_name: task name
:returns: current task progress value :returns: current task progress value
""" """
meta = self._taskdetail_by_name(task_name).meta with self._lock.read_lock():
if not meta: td = self._taskdetail_by_name(task_name)
return 0.0 if not td.meta:
return meta.get('progress', 0.0) return 0.0
return td.meta.get('progress', 0.0)
def get_task_progress_details(self, task_name): def get_task_progress_details(self, task_name):
"""Get progress details of task with given name. """Get progress details of task with given name.
@@ -216,10 +236,11 @@ class Storage(object):
:returns: None if progress_details not defined, else progress_details :returns: None if progress_details not defined, else progress_details
dict dict
""" """
meta = self._taskdetail_by_name(task_name).meta with self._lock.read_lock():
if not meta: td = self._taskdetail_by_name(task_name)
return None if not td.meta:
return meta.get('progress_details') return None
return td.meta.get('progress_details')
def _check_all_results_provided(self, task_name, data): def _check_all_results_provided(self, task_name, data):
"""Warn if task did not provide some of results. """Warn if task did not provide some of results.
@@ -228,69 +249,57 @@ class Storage(object):
without all needed keys. It may also happen if task returns without all needed keys. It may also happen if task returns
result of wrong type. result of wrong type.
""" """
result_mapping = self._result_mappings.get(task_name, None) result_mapping = self._result_mappings.get(task_name)
if result_mapping is None: if not result_mapping:
return return
for name, index in six.iteritems(result_mapping): for name, index in six.iteritems(result_mapping):
try: try:
misc.item_from(data, index, name=name) misc.item_from(data, index, name=name)
except exceptions.NotFound: except exceptions.NotFound:
LOG.warning("Task %s did not supply result " LOG.warning("Task %s did not supply result "
"with index %r (name %s)", "with index %r (name %s)", task_name, index, name)
task_name, index, name)
def save(self, task_name, data, state=states.SUCCESS): def save(self, task_name, data, state=states.SUCCESS):
"""Put result for task with id 'uuid' to storage.""" """Put result for task with id 'uuid' to storage."""
td = self._taskdetail_by_name(task_name) with self._lock.write_lock():
td.state = state td = self._taskdetail_by_name(task_name)
if state == states.FAILURE and isinstance(data, misc.Failure): td.state = state
td.results = None if state == states.FAILURE and isinstance(data, misc.Failure):
td.failure = data td.results = None
self._failures[td.name] = data td.failure = data
else: self._failures[td.name] = data
td.results = data else:
td.failure = None td.results = data
self._check_all_results_provided(td.name, data) td.failure = None
self._with_connection(self._save_task_detail, task_detail=td) self._check_all_results_provided(td.name, data)
self._with_connection(self._save_task_detail, td)
def _cache_failure(self, name, fail):
"""Ensure that cache has matching failure for task with this name.
We leave cached version if it matches as it may contain more
information. Returns cached failure.
"""
cached = self._failures.get(name)
if fail.matches(cached):
return cached
self._failures[name] = fail
return fail
def _reload_failures(self):
"""Refresh failures cache."""
for td in self._flowdetail:
if td.failure is not None:
self._cache_failure(td.name, td.failure)
def get(self, task_name): def get(self, task_name):
"""Get result for task with name 'task_name' to storage.""" """Get result for task with name 'task_name' to storage."""
td = self._taskdetail_by_name(task_name) with self._lock.read_lock():
if td.failure is not None: td = self._taskdetail_by_name(task_name)
return self._cache_failure(td.name, td.failure) if td.failure is not None:
if td.state not in STATES_WITH_RESULTS: cached = self._failures.get(task_name)
raise exceptions.NotFound( if td.failure.matches(cached):
"Result for task %s is not known" % task_name) return cached
return td.results return td.failure
if td.state not in STATES_WITH_RESULTS:
raise exceptions.NotFound("Result for task %s is not known"
% task_name)
return td.results
def get_failures(self): def get_failures(self):
"""Get list of failures that happened with this flow. """Get list of failures that happened with this flow.
No order guaranteed. No order guaranteed.
""" """
return self._failures.copy() with self._lock.read_lock():
return self._failures.copy()
def has_failures(self): def has_failures(self):
"""Returns True if there are failed tasks in the storage.""" """Returns True if there are failed tasks in the storage."""
return bool(self._failures) with self._lock.read_lock():
return bool(self._failures)
def _reset_task(self, td, state): def _reset_task(self, td, state):
if td.name == self.injector_name: if td.name == self.injector_name:
@@ -305,25 +314,28 @@ class Storage(object):
def reset(self, task_name, state=states.PENDING): def reset(self, task_name, state=states.PENDING):
"""Remove result for task with id 'uuid' from storage.""" """Remove result for task with id 'uuid' from storage."""
td = self._taskdetail_by_name(task_name) with self._lock.write_lock():
if self._reset_task(td, state): td = self._taskdetail_by_name(task_name)
self._with_connection(self._save_task_detail, task_detail=td) if self._reset_task(td, state):
self._with_connection(self._save_task_detail, td)
def reset_tasks(self): def reset_tasks(self):
"""Reset all tasks to PENDING state, removing results. """Reset all tasks to PENDING state, removing results.
Returns list of (name, uuid) tuples for all tasks that were reset. Returns list of (name, uuid) tuples for all tasks that were reset.
""" """
result = [] reset_results = []
def do_reset_all(connection): def do_reset_all(connection):
for td in self._flowdetail: for td in self._flowdetail:
if self._reset_task(td, states.PENDING): if self._reset_task(td, states.PENDING):
self._save_task_detail(connection, td) self._save_task_detail(connection, td)
result.append((td.name, td.uuid)) reset_results.append((td.name, td.uuid))
self._with_connection(do_reset_all) with self._lock.write_lock():
return result self._with_connection(do_reset_all)
return reset_results
def inject(self, pairs): def inject(self, pairs):
"""Add values into storage. """Add values into storage.
@@ -331,20 +343,20 @@ class Storage(object):
This method should be used to put flow parameters (requirements that This method should be used to put flow parameters (requirements that
are not satisfied by any task in the flow) into storage. are not satisfied by any task in the flow) into storage.
""" """
try: with self._lock.write_lock():
injector_td = self._taskdetail_by_name(self.injector_name) try:
except exceptions.NotFound: td = self._taskdetail_by_name(self.injector_name)
injector_uuid = uuidutils.generate_uuid() except exceptions.NotFound:
self._add_task(injector_uuid, self.injector_name) self._add_task(uuidutils.generate_uuid(), self.injector_name)
results = dict(pairs) td = self._taskdetail_by_name(self.injector_name)
else: td.results = dict(pairs)
results = injector_td.results.copy() td.state = states.SUCCESS
results.update(pairs) else:
td.results.update(pairs)
self.save(self.injector_name, results) self._with_connection(self._save_task_detail, td)
names = six.iterkeys(results) names = six.iterkeys(td.results)
self._set_result_mapping(self.injector_name, self._set_result_mapping(self.injector_name,
dict((name, name) for name in names)) dict((name, name) for name in names))
def _set_result_mapping(self, task_name, mapping): def _set_result_mapping(self, task_name, mapping):
"""Set mapping for naming task results. """Set mapping for naming task results.
@@ -378,50 +390,60 @@ class Storage(object):
def fetch(self, name): def fetch(self, name):
"""Fetch named task result.""" """Fetch named task result."""
try: with self._lock.read_lock():
indexes = self._reverse_mapping[name]
except KeyError:
raise exceptions.NotFound("Name %r is not mapped" % name)
# Return the first one that is found.
for task_name, index in reversed(indexes):
try: try:
result = self.get(task_name) indexes = self._reverse_mapping[name]
return misc.item_from(result, index, name=name) except KeyError:
except exceptions.NotFound: raise exceptions.NotFound("Name %r is not mapped" % name)
pass # Return the first one that is found.
raise exceptions.NotFound("Unable to find result %r" % name) for (task_name, index) in reversed(indexes):
try:
result = self.get(task_name)
return misc.item_from(result, index, name=name)
except exceptions.NotFound:
pass
raise exceptions.NotFound("Unable to find result %r" % name)
def fetch_all(self): def fetch_all(self):
"""Fetch all named task results known so far. """Fetch all named task results known so far.
Should be used for debugging and testing purposes mostly. Should be used for debugging and testing purposes mostly.
""" """
result = {} with self._lock.read_lock():
for name in self._reverse_mapping: results = {}
try: for name in self._reverse_mapping:
result[name] = self.fetch(name) try:
except exceptions.NotFound: results[name] = self.fetch(name)
pass except exceptions.NotFound:
return result pass
return results
def fetch_mapped_args(self, args_mapping): def fetch_mapped_args(self, args_mapping):
"""Fetch arguments for the task using arguments mapping.""" """Fetch arguments for the task using arguments mapping."""
return dict((key, self.fetch(name)) with self._lock.read_lock():
for key, name in six.iteritems(args_mapping)) return dict((key, self.fetch(name))
for key, name in six.iteritems(args_mapping))
def set_flow_state(self, state): def set_flow_state(self, state):
"""Set flowdetails state and save it.""" """Set flow details state and save it."""
self._flowdetail.state = state with self._lock.write_lock():
self._with_connection(self._save_flow_detail) self._flowdetail.state = state
self._with_connection(self._save_flow_detail)
def get_flow_state(self): def get_flow_state(self):
"""Set state from flowdetails.""" """Get state from flow details."""
state = self._flowdetail.state with self._lock.read_lock():
if state is None: state = self._flowdetail.state
state = states.PENDING if state is None:
return state state = states.PENDING
return state
@six.add_metaclass(tu.ThreadSafeMeta) class MultiThreadedStorage(Storage):
class ThreadSafeStorage(Storage): """Storage that uses locks to protect against concurrent access."""
pass _lock_cls = lock_utils.ReaderWriterLock
class SingleThreadedStorage(Storage):
"""Storage that uses dummy locks when you really don't need locks."""
_lock_cls = lock_utils.DummyReaderWriterLock

View File

@@ -17,6 +17,7 @@
# under the License. # under the License.
import contextlib import contextlib
import threading
import mock import mock
@@ -35,10 +36,22 @@ class StorageTest(test.TestCase):
def setUp(self): def setUp(self):
super(StorageTest, self).setUp() super(StorageTest, self).setUp()
self.backend = impl_memory.MemoryBackend(conf={}) self.backend = impl_memory.MemoryBackend(conf={})
self.thread_count = 50
def _get_storage(self): def _run_many_threads(self, threads):
for t in threads:
t.start()
for t in threads:
t.join()
def _get_storage(self, threaded=False):
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend) _lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
return storage.Storage(backend=self.backend, flow_detail=flow_detail) if threaded:
return storage.MultiThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
else:
return storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
def tearDown(self): def tearDown(self):
super(StorageTest, self).tearDown() super(StorageTest, self).tearDown()
@@ -48,14 +61,14 @@ class StorageTest(test.TestCase):
def test_non_saving_storage(self): def test_non_saving_storage(self):
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend) _lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
s = storage.Storage(flow_detail=flow_detail) # no backend s = storage.SingleThreadedStorage(flow_detail=flow_detail)
s.ensure_task('my_task') s.ensure_task('my_task')
self.assertTrue( self.assertTrue(
uuidutils.is_uuid_like(s.get_task_uuid('my_task'))) uuidutils.is_uuid_like(s.get_task_uuid('my_task')))
def test_flow_name_and_uuid(self): def test_flow_name_and_uuid(self):
fd = logbook.FlowDetail(name='test-fd', uuid='aaaa') fd = logbook.FlowDetail(name='test-fd', uuid='aaaa')
s = storage.Storage(flow_detail=fd) s = storage.SingleThreadedStorage(flow_detail=fd)
self.assertEqual(s.flow_name, 'test-fd') self.assertEqual(s.flow_name, 'test-fd')
self.assertEqual(s.flow_uuid, 'aaaa') self.assertEqual(s.flow_uuid, 'aaaa')
@@ -79,7 +92,8 @@ class StorageTest(test.TestCase):
def test_ensure_task_fd(self): def test_ensure_task_fd(self):
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend) _lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
s = storage.Storage(backend=self.backend, flow_detail=flow_detail) s = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
s.ensure_task('my task', '3.11') s.ensure_task('my task', '3.11')
td = flow_detail.find(s.get_task_uuid('my task')) td = flow_detail.find(s.get_task_uuid('my task'))
self.assertIsNotNone(td) self.assertIsNotNone(td)
@@ -92,7 +106,8 @@ class StorageTest(test.TestCase):
td = logbook.TaskDetail(name='my_task', uuid='42') td = logbook.TaskDetail(name='my_task', uuid='42')
flow_detail.add(td) flow_detail.add(td)
s = storage.Storage(backend=self.backend, flow_detail=flow_detail) s = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
self.assertEqual('42', s.get_task_uuid('my_task')) self.assertEqual('42', s.get_task_uuid('my_task'))
def test_ensure_existing_task(self): def test_ensure_existing_task(self):
@@ -100,7 +115,8 @@ class StorageTest(test.TestCase):
td = logbook.TaskDetail(name='my_task', uuid='42') td = logbook.TaskDetail(name='my_task', uuid='42')
flow_detail.add(td) flow_detail.add(td)
s = storage.Storage(backend=self.backend, flow_detail=flow_detail) s = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=flow_detail)
s.ensure_task('my_task') s.ensure_task('my_task')
self.assertEqual('42', s.get_task_uuid('my_task')) self.assertEqual('42', s.get_task_uuid('my_task'))
@@ -147,7 +163,8 @@ class StorageTest(test.TestCase):
s.ensure_task('my task') s.ensure_task('my task')
s.save('my task', fail, states.FAILURE) s.save('my task', fail, states.FAILURE)
s2 = storage.Storage(backend=self.backend, flow_detail=s._flowdetail) s2 = storage.SingleThreadedStorage(backend=self.backend,
flow_detail=s._flowdetail)
self.assertIs(s2.has_failures(), True) self.assertIs(s2.has_failures(), True)
self.assertEqual(s2.get_failures(), {'my task': fail}) self.assertEqual(s2.get_failures(), {'my task': fail})
self.assertEqual(s2.get('my task'), fail) self.assertEqual(s2.get('my task'), fail)
@@ -305,13 +322,71 @@ class StorageTest(test.TestCase):
}) })
# imagine we are resuming, so we need to make new # imagine we are resuming, so we need to make new
# storage from same flow details # storage from same flow details
s2 = storage.Storage(s._flowdetail, backend=self.backend) s2 = storage.SingleThreadedStorage(s._flowdetail, backend=self.backend)
# injected data should still be there: # injected data should still be there:
self.assertEqual(s2.fetch_all(), { self.assertEqual(s2.fetch_all(), {
'foo': 'bar', 'foo': 'bar',
'spam': 'eggs', 'spam': 'eggs',
}) })
def test_many_thread_ensure_same_task(self):
s = self._get_storage(threaded=True)
def ensure_my_task():
s.ensure_task('my_task', result_mapping={})
threads = []
for i in range(0, self.thread_count):
threads.append(threading.Thread(target=ensure_my_task))
self._run_many_threads(threads)
# Only one task should have been made, no more.
self.assertEqual(1, len(s._flowdetail))
def test_many_thread_one_reset(self):
s = self._get_storage(threaded=True)
s.ensure_task('a')
s.set_task_state('a', states.SUSPENDED)
s.ensure_task('b')
s.set_task_state('b', states.SUSPENDED)
results = []
result_lock = threading.Lock()
def reset_all():
r = s.reset_tasks()
with result_lock:
results.append(r)
threads = []
for i in range(0, self.thread_count):
threads.append(threading.Thread(target=reset_all))
self._run_many_threads(threads)
# Only one thread should have actually reset (not anymore)
results = [r for r in results if len(r)]
self.assertEqual(1, len(results))
self.assertEqual(['a', 'b'], sorted([a[0] for a in results[0]]))
def test_many_thread_inject(self):
s = self._get_storage(threaded=True)
def inject_values(values):
s.inject(values)
threads = []
for i in range(0, self.thread_count):
values = {
str(i): str(i),
}
threads.append(threading.Thread(target=inject_values,
args=[values]))
self._run_many_threads(threads)
self.assertEqual(self.thread_count, len(s.fetch_all()))
self.assertEqual(1, len(s._flowdetail))
def test_fetch_meapped_args(self): def test_fetch_meapped_args(self):
s = self._get_storage() s = self._get_storage()
s.inject({'foo': 'bar', 'spam': 'eggs'}) s.inject({'foo': 'bar', 'spam': 'eggs'})
@@ -357,7 +432,7 @@ class StorageTest(test.TestCase):
fd.state = states.FAILURE fd.state = states.FAILURE
with contextlib.closing(self.backend.get_connection()) as conn: with contextlib.closing(self.backend.get_connection()) as conn:
fd.update(conn.update_flow_details(fd)) fd.update(conn.update_flow_details(fd))
s = storage.Storage(flow_detail=fd, backend=self.backend) s = storage.SingleThreadedStorage(flow_detail=fd, backend=self.backend)
self.assertEqual(s.get_flow_state(), states.FAILURE) self.assertEqual(s.get_flow_state(), states.FAILURE)
def test_set_and_get_flow_state(self): def test_set_and_get_flow_state(self):

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import time
from concurrent import futures
from taskflow import test
from taskflow.utils import lock_utils
def _find_overlaps(times, start, end):
overlaps = 0
for (s, e) in times:
if s >= start and e <= end:
overlaps += 1
return overlaps
def _spawn_variation(readers, writers, max_workers=None):
start_stops = collections.deque()
lock = lock_utils.ReaderWriterLock()
def read_func():
with lock.read_lock():
start_stops.append(('r', time.time(), time.time()))
def write_func():
with lock.write_lock():
start_stops.append(('w', time.time(), time.time()))
if max_workers is None:
max_workers = max(0, readers) + max(0, writers)
if max_workers > 0:
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
for i in range(0, readers):
e.submit(read_func)
for i in range(0, writers):
e.submit(write_func)
writer_times = []
reader_times = []
for (t, start, stop) in list(start_stops):
if t == 'w':
writer_times.append((start, stop))
else:
reader_times.append((start, stop))
return (writer_times, reader_times)
class ReadWriteLockTest(test.TestCase):
def test_writer_abort(self):
lock = lock_utils.ReaderWriterLock()
self.assertFalse(lock.owner)
def blow_up():
with lock.write_lock():
self.assertEqual(lock.WRITER, lock.owner)
raise RuntimeError("Broken")
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner)
def test_reader_abort(self):
lock = lock_utils.ReaderWriterLock()
self.assertFalse(lock.owner)
def blow_up():
with lock.read_lock():
self.assertEqual(lock.READER, lock.owner)
raise RuntimeError("Broken")
self.assertRaises(RuntimeError, blow_up)
self.assertFalse(lock.owner)
def test_reader_chaotic(self):
lock = lock_utils.ReaderWriterLock()
activated = collections.deque()
def chaotic_reader(blow_up):
with lock.read_lock():
if blow_up:
raise RuntimeError("Broken")
else:
activated.append(lock.owner)
def happy_writer():
with lock.write_lock():
activated.append(lock.owner)
with futures.ThreadPoolExecutor(max_workers=20) as e:
for i in range(0, 20):
if i % 2 == 0:
e.submit(chaotic_reader, blow_up=bool(i % 4 == 0))
else:
e.submit(happy_writer)
writers = [a for a in activated if a == 'w']
readers = [a for a in activated if a == 'r']
self.assertEqual(10, len(writers))
self.assertEqual(5, len(readers))
def test_writer_chaotic(self):
lock = lock_utils.ReaderWriterLock()
activated = collections.deque()
def chaotic_writer(blow_up):
with lock.write_lock():
if blow_up:
raise RuntimeError("Broken")
else:
activated.append(lock.owner)
def happy_reader():
with lock.read_lock():
activated.append(lock.owner)
with futures.ThreadPoolExecutor(max_workers=20) as e:
for i in range(0, 20):
if i % 2 == 0:
e.submit(chaotic_writer, blow_up=bool(i % 4 == 0))
else:
e.submit(happy_reader)
writers = [a for a in activated if a == 'w']
readers = [a for a in activated if a == 'r']
self.assertEqual(5, len(writers))
self.assertEqual(10, len(readers))
def test_single_reader_writer(self):
results = []
lock = lock_utils.ReaderWriterLock()
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertEqual(0, len(results))
with lock.write_lock():
results.append(1)
self.assertTrue(lock.is_writer())
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertEqual(1, len(results))
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_reader_to_writer(self):
lock = lock_utils.ReaderWriterLock()
def writer_func():
with lock.write_lock():
pass
with lock.read_lock():
self.assertRaises(RuntimeError, writer_func)
self.assertFalse(lock.is_writer())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_writer_to_reader(self):
lock = lock_utils.ReaderWriterLock()
def reader_func():
with lock.read_lock():
pass
with lock.write_lock():
self.assertRaises(RuntimeError, reader_func)
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_double_writer(self):
lock = lock_utils.ReaderWriterLock()
with lock.write_lock():
self.assertFalse(lock.is_reader())
self.assertTrue(lock.is_writer())
with lock.write_lock():
self.assertTrue(lock.is_writer())
self.assertTrue(lock.is_writer())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_double_reader(self):
lock = lock_utils.ReaderWriterLock()
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertFalse(lock.is_writer())
with lock.read_lock():
self.assertTrue(lock.is_reader())
self.assertTrue(lock.is_reader())
self.assertFalse(lock.is_reader())
self.assertFalse(lock.is_writer())
def test_multi_reader_multi_writer(self):
writer_times, reader_times = _spawn_variation(10, 10)
self.assertEqual(10, len(writer_times))
self.assertEqual(10, len(reader_times))
for (start, stop) in writer_times:
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
self.assertEqual(1, _find_overlaps(writer_times, start, stop))
for (start, stop) in reader_times:
self.assertEqual(0, _find_overlaps(writer_times, start, stop))
def test_multi_reader_single_writer(self):
writer_times, reader_times = _spawn_variation(9, 1)
self.assertEqual(1, len(writer_times))
self.assertEqual(9, len(reader_times))
start, stop = writer_times[0]
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
def test_multi_writer(self):
writer_times, reader_times = _spawn_variation(0, 10)
self.assertEqual(10, len(writer_times))
self.assertEqual(0, len(reader_times))
for (start, stop) in writer_times:
self.assertEqual(1, _find_overlaps(writer_times, start, stop))

View File

@@ -21,6 +21,8 @@
# pulls in oslo.cfg) and is reduced to only what taskflow currently wants to # pulls in oslo.cfg) and is reduced to only what taskflow currently wants to
# use from that code. # use from that code.
import collections
import contextlib
import errno import errno
import logging import logging
import os import os
@@ -28,6 +30,7 @@ import threading
import time import time
from taskflow.utils import misc from taskflow.utils import misc
from taskflow.utils import threading_utils as tu
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -62,6 +65,163 @@ def locked(*args, **kwargs):
return decorator return decorator
class ReaderWriterLock(object):
"""A reader/writer lock.
This lock allows for simultaneous readers to exist but only one writer
to exist for use-cases where it is useful to have such types of locks.
Currently a reader can not escalate its read lock to a write lock and
a writer can not acquire a read lock while it owns or is waiting on
the write lock.
In the future these restrictions may be relaxed.
"""
WRITER = 'w'
READER = 'r'
def __init__(self):
self._writer = None
self._pending_writers = collections.deque()
self._readers = collections.deque()
self._cond = threading.Condition()
def is_writer(self, check_pending=True):
"""Returns if the caller is the active writer or a pending writer."""
self._cond.acquire()
try:
me = tu.get_ident()
if self._writer is not None and self._writer == me:
return True
if check_pending:
return me in self._pending_writers
else:
return False
finally:
self._cond.release()
@property
def owner(self):
"""Returns whether the lock is locked by a writer or reader."""
self._cond.acquire()
try:
if self._writer is not None:
return self.WRITER
if self._readers:
return self.READER
return None
finally:
self._cond.release()
def is_reader(self):
"""Returns if the caller is one of the readers."""
self._cond.acquire()
try:
return tu.get_ident() in self._readers
finally:
self._cond.release()
@contextlib.contextmanager
def read_lock(self):
"""Grants a read lock.
Will wait until no active or pending writers.
Raises a RuntimeError if an active or pending writer tries to acquire
a read lock.
"""
me = tu.get_ident()
if self.is_writer():
raise RuntimeError("Writer %s can not acquire a read lock"
" while holding/waiting for the write lock"
% me)
self._cond.acquire()
try:
while True:
# No active or pending writers; we are good to become a reader.
if self._writer is None and len(self._pending_writers) == 0:
self._readers.append(me)
break
# Some writers; guess we have to wait.
self._cond.wait()
finally:
self._cond.release()
try:
yield self
finally:
# I am no longer a reader, remove *one* occurrence of myself.
# If the current thread acquired two read locks, then it will
# still have to remove that other read lock; this allows for
# basic reentrancy to be possible.
self._cond.acquire()
try:
self._readers.remove(me)
self._cond.notify_all()
finally:
self._cond.release()
@contextlib.contextmanager
def write_lock(self):
"""Grants a write lock.
Will wait until no active readers. Blocks readers after acquiring.
Raises a RuntimeError if an active reader attempts to acquire a lock.
"""
me = tu.get_ident()
if self.is_reader():
raise RuntimeError("Reader %s to writer privilege"
" escalation not allowed" % me)
if self.is_writer(check_pending=False):
# Already the writer; this allows for basic reentrancy.
yield self
else:
self._cond.acquire()
try:
self._pending_writers.append(me)
while True:
# No readers, and no active writer, am I next??
if len(self._readers) == 0 and self._writer is None:
if self._pending_writers[0] == me:
self._writer = self._pending_writers.popleft()
break
self._cond.wait()
finally:
self._cond.release()
try:
yield self
finally:
self._cond.acquire()
try:
self._writer = None
self._cond.notify_all()
finally:
self._cond.release()
class DummyReaderWriterLock(object):
"""A dummy reader/writer lock that doesn't lock anything but provides same
functions as a normal reader/writer lock class.
"""
@contextlib.contextmanager
def write_lock(self):
yield self
@contextlib.contextmanager
def read_lock(self):
yield self
@property
def owner(self):
return None
def is_reader(self):
return False
def is_writer(self):
return False
class MultiLock(object): class MultiLock(object):
"""A class which can attempt to obtain many locks at once and release """A class which can attempt to obtain many locks at once and release
said locks when exiting. said locks when exiting.

View File

@@ -16,16 +16,15 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import logging
import multiprocessing import multiprocessing
import threading
import types
import six import six
from taskflow.utils import lock_utils if six.PY2:
from thread import get_ident # noqa
LOG = logging.getLogger(__name__) else:
# In python3+ the get_ident call moved (whhhy??)
from threading import get_ident # noqa
def get_optimal_thread_count(): def get_optimal_thread_count():
@@ -37,20 +36,3 @@ def get_optimal_thread_count():
# just setup two threads since its hard to know what else we # just setup two threads since its hard to know what else we
# should do in this situation. # should do in this situation.
return 2 return 2
class ThreadSafeMeta(type):
"""Metaclass that adds locking to all pubic methods of a class."""
def __new__(cls, name, bases, attrs):
for attr_name, attr_value in six.iteritems(attrs):
if isinstance(attr_value, types.FunctionType):
if attr_name[0] != '_':
attrs[attr_name] = lock_utils.locked(attr_value)
return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs)
def __call__(cls, *args, **kwargs):
instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs)
if not hasattr(instance, '_lock'):
instance._lock = threading.RLock()
return instance