Merge "Use reader/writer locks in storage"
This commit is contained in:
@@ -191,13 +191,12 @@ class ActionEngine(base.EngineBase):
|
||||
|
||||
class SingleThreadedActionEngine(ActionEngine):
|
||||
"""Engine that runs tasks in serial manner."""
|
||||
_storage_cls = t_storage.Storage
|
||||
_storage_cls = t_storage.SingleThreadedStorage
|
||||
|
||||
|
||||
class MultiThreadedActionEngine(ActionEngine):
|
||||
"""Engine that runs tasks in parallel manner."""
|
||||
|
||||
_storage_cls = t_storage.ThreadSafeStorage
|
||||
_storage_cls = t_storage.MultiThreadedStorage
|
||||
|
||||
def _task_executor_cls(self):
|
||||
return executor.ParallelTaskExecutor(self._executor)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
@@ -25,13 +26,15 @@ from taskflow import exceptions
|
||||
from taskflow.openstack.common import uuidutils
|
||||
from taskflow.persistence import logbook
|
||||
from taskflow import states
|
||||
from taskflow.utils import lock_utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import threading_utils as tu
|
||||
from taskflow.utils import reflection
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Storage(object):
|
||||
"""Interface between engines and logbook.
|
||||
|
||||
@@ -48,11 +51,14 @@ class Storage(object):
|
||||
self._reverse_mapping = {}
|
||||
self._backend = backend
|
||||
self._flowdetail = flow_detail
|
||||
self._lock = self._lock_cls()
|
||||
|
||||
# NOTE(imelnikov): failure serialization looses information,
|
||||
# so we cache failures here, in task name -> misc.Failure mapping.
|
||||
self._failures = {}
|
||||
self._reload_failures()
|
||||
for td in self._flowdetail:
|
||||
if td.failure is not None:
|
||||
self._failures[td.name] = td.failure
|
||||
|
||||
self._task_name_to_uuid = dict((td.name, td.uuid)
|
||||
for td in self._flowdetail)
|
||||
@@ -66,11 +72,20 @@ class Storage(object):
|
||||
self._set_result_mapping(injector_td.name,
|
||||
dict((name, name) for name in names))
|
||||
|
||||
@abc.abstractproperty
|
||||
def _lock_cls(self):
|
||||
"""Lock class used to generate reader/writer locks for protecting
|
||||
read/write access to the underlying storage backend and internally
|
||||
mutating operations.
|
||||
"""
|
||||
|
||||
def _with_connection(self, functor, *args, **kwargs):
|
||||
# NOTE(harlowja): Activate the given function with a backend
|
||||
# connection, if a backend is provided in the first place, otherwise
|
||||
# don't call the function.
|
||||
if self._backend is None:
|
||||
LOG.debug("No backend provided, not calling functor '%s'",
|
||||
reflection.get_callable_name(functor))
|
||||
return
|
||||
with contextlib.closing(self._backend.get_connection()) as conn:
|
||||
functor(conn, *args, **kwargs)
|
||||
@@ -85,6 +100,7 @@ class Storage(object):
|
||||
Returns uuid for the task details corresponding to the task with
|
||||
given name.
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
try:
|
||||
task_id = self._task_name_to_uuid[task_name]
|
||||
except KeyError:
|
||||
@@ -99,27 +115,29 @@ class Storage(object):
|
||||
Task becomes known to storage by that name and uuid.
|
||||
Task state is set to PENDING.
|
||||
"""
|
||||
|
||||
def save_both(conn, td):
|
||||
"""Saves the flow and the task detail with the same connection."""
|
||||
self._save_flow_detail(conn)
|
||||
self._save_task_detail(conn, td)
|
||||
|
||||
# TODO(imelnikov): check that task with same uuid or
|
||||
# task name does not exist.
|
||||
td = logbook.TaskDetail(name=task_name, uuid=uuid)
|
||||
td.state = states.PENDING
|
||||
td.version = task_version
|
||||
self._flowdetail.add(td)
|
||||
|
||||
def save_both(conn):
|
||||
"""Saves the flow and the task detail with the same connection."""
|
||||
self._save_flow_detail(conn)
|
||||
self._save_task_detail(conn, task_detail=td)
|
||||
|
||||
self._with_connection(save_both)
|
||||
self._with_connection(save_both, td)
|
||||
self._task_name_to_uuid[task_name] = uuid
|
||||
|
||||
@property
|
||||
def flow_name(self):
|
||||
# This never changes (so no read locking needed).
|
||||
return self._flowdetail.name
|
||||
|
||||
@property
|
||||
def flow_uuid(self):
|
||||
# This never changes (so no read locking needed).
|
||||
return self._flowdetail.uuid
|
||||
|
||||
def _save_flow_detail(self, conn):
|
||||
@@ -130,13 +148,9 @@ class Storage(object):
|
||||
|
||||
def _taskdetail_by_name(self, task_name):
|
||||
try:
|
||||
td = self._flowdetail.find(self._task_name_to_uuid[task_name])
|
||||
return self._flowdetail.find(self._task_name_to_uuid[task_name])
|
||||
except KeyError:
|
||||
td = None
|
||||
|
||||
if td is None:
|
||||
raise exceptions.NotFound("Unknown task name: %s" % task_name)
|
||||
return td
|
||||
|
||||
def _save_task_detail(self, conn, task_detail):
|
||||
# NOTE(harlowja): we need to update our contained task detail if
|
||||
@@ -146,34 +160,39 @@ class Storage(object):
|
||||
|
||||
def get_task_uuid(self, task_name):
|
||||
"""Get task uuid by given name."""
|
||||
with self._lock.read_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
return td.uuid
|
||||
|
||||
def set_task_state(self, task_name, state):
|
||||
"""Set task state."""
|
||||
with self._lock.write_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
td.state = state
|
||||
self._with_connection(self._save_task_detail, task_detail=td)
|
||||
self._with_connection(self._save_task_detail, td)
|
||||
|
||||
def get_task_state(self, task_name):
|
||||
"""Get state of task with given name."""
|
||||
return self._taskdetail_by_name(task_name).state
|
||||
with self._lock.read_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
return td.state
|
||||
|
||||
def get_tasks_states(self, task_names):
|
||||
"""Gets all task states."""
|
||||
with self._lock.read_lock():
|
||||
return dict((name, self.get_task_state(name))
|
||||
for name in task_names)
|
||||
|
||||
def update_task_metadata(self, task_name, update_with):
|
||||
"""Updates a tasks metadata."""
|
||||
if not update_with:
|
||||
return
|
||||
# NOTE(harlowja): this is a read and then write, not in 1 transaction
|
||||
# so it is entirely possible that we could write over another writes
|
||||
# metadata update. Maybe add some merging logic later?
|
||||
with self._lock.write_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
if not td.meta:
|
||||
td.meta = {}
|
||||
td.meta.update(update_with)
|
||||
self._with_connection(self._save_task_detail, task_detail=td)
|
||||
self._with_connection(self._save_task_detail, td)
|
||||
|
||||
def set_task_progress(self, task_name, progress, details=None):
|
||||
"""Set task progress.
|
||||
@@ -204,10 +223,11 @@ class Storage(object):
|
||||
:param task_name: task name
|
||||
:returns: current task progress value
|
||||
"""
|
||||
meta = self._taskdetail_by_name(task_name).meta
|
||||
if not meta:
|
||||
with self._lock.read_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
if not td.meta:
|
||||
return 0.0
|
||||
return meta.get('progress', 0.0)
|
||||
return td.meta.get('progress', 0.0)
|
||||
|
||||
def get_task_progress_details(self, task_name):
|
||||
"""Get progress details of task with given name.
|
||||
@@ -216,10 +236,11 @@ class Storage(object):
|
||||
:returns: None if progress_details not defined, else progress_details
|
||||
dict
|
||||
"""
|
||||
meta = self._taskdetail_by_name(task_name).meta
|
||||
if not meta:
|
||||
with self._lock.read_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
if not td.meta:
|
||||
return None
|
||||
return meta.get('progress_details')
|
||||
return td.meta.get('progress_details')
|
||||
|
||||
def _check_all_results_provided(self, task_name, data):
|
||||
"""Warn if task did not provide some of results.
|
||||
@@ -228,19 +249,19 @@ class Storage(object):
|
||||
without all needed keys. It may also happen if task returns
|
||||
result of wrong type.
|
||||
"""
|
||||
result_mapping = self._result_mappings.get(task_name, None)
|
||||
if result_mapping is None:
|
||||
result_mapping = self._result_mappings.get(task_name)
|
||||
if not result_mapping:
|
||||
return
|
||||
for name, index in six.iteritems(result_mapping):
|
||||
try:
|
||||
misc.item_from(data, index, name=name)
|
||||
except exceptions.NotFound:
|
||||
LOG.warning("Task %s did not supply result "
|
||||
"with index %r (name %s)",
|
||||
task_name, index, name)
|
||||
"with index %r (name %s)", task_name, index, name)
|
||||
|
||||
def save(self, task_name, data, state=states.SUCCESS):
|
||||
"""Put result for task with id 'uuid' to storage."""
|
||||
with self._lock.write_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
td.state = state
|
||||
if state == states.FAILURE and isinstance(data, misc.Failure):
|
||||
@@ -251,34 +272,20 @@ class Storage(object):
|
||||
td.results = data
|
||||
td.failure = None
|
||||
self._check_all_results_provided(td.name, data)
|
||||
self._with_connection(self._save_task_detail, task_detail=td)
|
||||
|
||||
def _cache_failure(self, name, fail):
|
||||
"""Ensure that cache has matching failure for task with this name.
|
||||
|
||||
We leave cached version if it matches as it may contain more
|
||||
information. Returns cached failure.
|
||||
"""
|
||||
cached = self._failures.get(name)
|
||||
if fail.matches(cached):
|
||||
return cached
|
||||
self._failures[name] = fail
|
||||
return fail
|
||||
|
||||
def _reload_failures(self):
|
||||
"""Refresh failures cache."""
|
||||
for td in self._flowdetail:
|
||||
if td.failure is not None:
|
||||
self._cache_failure(td.name, td.failure)
|
||||
self._with_connection(self._save_task_detail, td)
|
||||
|
||||
def get(self, task_name):
|
||||
"""Get result for task with name 'task_name' to storage."""
|
||||
with self._lock.read_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
if td.failure is not None:
|
||||
return self._cache_failure(td.name, td.failure)
|
||||
cached = self._failures.get(task_name)
|
||||
if td.failure.matches(cached):
|
||||
return cached
|
||||
return td.failure
|
||||
if td.state not in STATES_WITH_RESULTS:
|
||||
raise exceptions.NotFound(
|
||||
"Result for task %s is not known" % task_name)
|
||||
raise exceptions.NotFound("Result for task %s is not known"
|
||||
% task_name)
|
||||
return td.results
|
||||
|
||||
def get_failures(self):
|
||||
@@ -286,10 +293,12 @@ class Storage(object):
|
||||
|
||||
No order guaranteed.
|
||||
"""
|
||||
with self._lock.read_lock():
|
||||
return self._failures.copy()
|
||||
|
||||
def has_failures(self):
|
||||
"""Returns True if there are failed tasks in the storage."""
|
||||
with self._lock.read_lock():
|
||||
return bool(self._failures)
|
||||
|
||||
def _reset_task(self, td, state):
|
||||
@@ -305,25 +314,28 @@ class Storage(object):
|
||||
|
||||
def reset(self, task_name, state=states.PENDING):
|
||||
"""Remove result for task with id 'uuid' from storage."""
|
||||
with self._lock.write_lock():
|
||||
td = self._taskdetail_by_name(task_name)
|
||||
if self._reset_task(td, state):
|
||||
self._with_connection(self._save_task_detail, task_detail=td)
|
||||
self._with_connection(self._save_task_detail, td)
|
||||
|
||||
def reset_tasks(self):
|
||||
"""Reset all tasks to PENDING state, removing results.
|
||||
|
||||
Returns list of (name, uuid) tuples for all tasks that were reset.
|
||||
"""
|
||||
result = []
|
||||
reset_results = []
|
||||
|
||||
def do_reset_all(connection):
|
||||
for td in self._flowdetail:
|
||||
if self._reset_task(td, states.PENDING):
|
||||
self._save_task_detail(connection, td)
|
||||
result.append((td.name, td.uuid))
|
||||
reset_results.append((td.name, td.uuid))
|
||||
|
||||
with self._lock.write_lock():
|
||||
self._with_connection(do_reset_all)
|
||||
return result
|
||||
|
||||
return reset_results
|
||||
|
||||
def inject(self, pairs):
|
||||
"""Add values into storage.
|
||||
@@ -331,18 +343,18 @@ class Storage(object):
|
||||
This method should be used to put flow parameters (requirements that
|
||||
are not satisfied by any task in the flow) into storage.
|
||||
"""
|
||||
with self._lock.write_lock():
|
||||
try:
|
||||
injector_td = self._taskdetail_by_name(self.injector_name)
|
||||
td = self._taskdetail_by_name(self.injector_name)
|
||||
except exceptions.NotFound:
|
||||
injector_uuid = uuidutils.generate_uuid()
|
||||
self._add_task(injector_uuid, self.injector_name)
|
||||
results = dict(pairs)
|
||||
self._add_task(uuidutils.generate_uuid(), self.injector_name)
|
||||
td = self._taskdetail_by_name(self.injector_name)
|
||||
td.results = dict(pairs)
|
||||
td.state = states.SUCCESS
|
||||
else:
|
||||
results = injector_td.results.copy()
|
||||
results.update(pairs)
|
||||
|
||||
self.save(self.injector_name, results)
|
||||
names = six.iterkeys(results)
|
||||
td.results.update(pairs)
|
||||
self._with_connection(self._save_task_detail, td)
|
||||
names = six.iterkeys(td.results)
|
||||
self._set_result_mapping(self.injector_name,
|
||||
dict((name, name) for name in names))
|
||||
|
||||
@@ -378,12 +390,13 @@ class Storage(object):
|
||||
|
||||
def fetch(self, name):
|
||||
"""Fetch named task result."""
|
||||
with self._lock.read_lock():
|
||||
try:
|
||||
indexes = self._reverse_mapping[name]
|
||||
except KeyError:
|
||||
raise exceptions.NotFound("Name %r is not mapped" % name)
|
||||
# Return the first one that is found.
|
||||
for task_name, index in reversed(indexes):
|
||||
for (task_name, index) in reversed(indexes):
|
||||
try:
|
||||
result = self.get(task_name)
|
||||
return misc.item_from(result, index, name=name)
|
||||
@@ -396,32 +409,41 @@ class Storage(object):
|
||||
|
||||
Should be used for debugging and testing purposes mostly.
|
||||
"""
|
||||
result = {}
|
||||
with self._lock.read_lock():
|
||||
results = {}
|
||||
for name in self._reverse_mapping:
|
||||
try:
|
||||
result[name] = self.fetch(name)
|
||||
results[name] = self.fetch(name)
|
||||
except exceptions.NotFound:
|
||||
pass
|
||||
return result
|
||||
return results
|
||||
|
||||
def fetch_mapped_args(self, args_mapping):
|
||||
"""Fetch arguments for the task using arguments mapping."""
|
||||
with self._lock.read_lock():
|
||||
return dict((key, self.fetch(name))
|
||||
for key, name in six.iteritems(args_mapping))
|
||||
|
||||
def set_flow_state(self, state):
|
||||
"""Set flowdetails state and save it."""
|
||||
"""Set flow details state and save it."""
|
||||
with self._lock.write_lock():
|
||||
self._flowdetail.state = state
|
||||
self._with_connection(self._save_flow_detail)
|
||||
|
||||
def get_flow_state(self):
|
||||
"""Set state from flowdetails."""
|
||||
"""Get state from flow details."""
|
||||
with self._lock.read_lock():
|
||||
state = self._flowdetail.state
|
||||
if state is None:
|
||||
state = states.PENDING
|
||||
return state
|
||||
|
||||
|
||||
@six.add_metaclass(tu.ThreadSafeMeta)
|
||||
class ThreadSafeStorage(Storage):
|
||||
pass
|
||||
class MultiThreadedStorage(Storage):
|
||||
"""Storage that uses locks to protect against concurrent access."""
|
||||
_lock_cls = lock_utils.ReaderWriterLock
|
||||
|
||||
|
||||
class SingleThreadedStorage(Storage):
|
||||
"""Storage that uses dummy locks when you really don't need locks."""
|
||||
_lock_cls = lock_utils.DummyReaderWriterLock
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# under the License.
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
import mock
|
||||
|
||||
@@ -35,10 +36,22 @@ class StorageTest(test.TestCase):
|
||||
def setUp(self):
|
||||
super(StorageTest, self).setUp()
|
||||
self.backend = impl_memory.MemoryBackend(conf={})
|
||||
self.thread_count = 50
|
||||
|
||||
def _get_storage(self):
|
||||
def _run_many_threads(self, threads):
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
def _get_storage(self, threaded=False):
|
||||
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
||||
return storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
||||
if threaded:
|
||||
return storage.MultiThreadedStorage(backend=self.backend,
|
||||
flow_detail=flow_detail)
|
||||
else:
|
||||
return storage.SingleThreadedStorage(backend=self.backend,
|
||||
flow_detail=flow_detail)
|
||||
|
||||
def tearDown(self):
|
||||
super(StorageTest, self).tearDown()
|
||||
@@ -48,14 +61,14 @@ class StorageTest(test.TestCase):
|
||||
|
||||
def test_non_saving_storage(self):
|
||||
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
||||
s = storage.Storage(flow_detail=flow_detail) # no backend
|
||||
s = storage.SingleThreadedStorage(flow_detail=flow_detail)
|
||||
s.ensure_task('my_task')
|
||||
self.assertTrue(
|
||||
uuidutils.is_uuid_like(s.get_task_uuid('my_task')))
|
||||
|
||||
def test_flow_name_and_uuid(self):
|
||||
fd = logbook.FlowDetail(name='test-fd', uuid='aaaa')
|
||||
s = storage.Storage(flow_detail=fd)
|
||||
s = storage.SingleThreadedStorage(flow_detail=fd)
|
||||
self.assertEqual(s.flow_name, 'test-fd')
|
||||
self.assertEqual(s.flow_uuid, 'aaaa')
|
||||
|
||||
@@ -79,7 +92,8 @@ class StorageTest(test.TestCase):
|
||||
|
||||
def test_ensure_task_fd(self):
|
||||
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
||||
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
||||
s = storage.SingleThreadedStorage(backend=self.backend,
|
||||
flow_detail=flow_detail)
|
||||
s.ensure_task('my task', '3.11')
|
||||
td = flow_detail.find(s.get_task_uuid('my task'))
|
||||
self.assertIsNotNone(td)
|
||||
@@ -92,7 +106,8 @@ class StorageTest(test.TestCase):
|
||||
td = logbook.TaskDetail(name='my_task', uuid='42')
|
||||
flow_detail.add(td)
|
||||
|
||||
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
||||
s = storage.SingleThreadedStorage(backend=self.backend,
|
||||
flow_detail=flow_detail)
|
||||
self.assertEqual('42', s.get_task_uuid('my_task'))
|
||||
|
||||
def test_ensure_existing_task(self):
|
||||
@@ -100,7 +115,8 @@ class StorageTest(test.TestCase):
|
||||
td = logbook.TaskDetail(name='my_task', uuid='42')
|
||||
flow_detail.add(td)
|
||||
|
||||
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
||||
s = storage.SingleThreadedStorage(backend=self.backend,
|
||||
flow_detail=flow_detail)
|
||||
s.ensure_task('my_task')
|
||||
self.assertEqual('42', s.get_task_uuid('my_task'))
|
||||
|
||||
@@ -147,7 +163,8 @@ class StorageTest(test.TestCase):
|
||||
s.ensure_task('my task')
|
||||
s.save('my task', fail, states.FAILURE)
|
||||
|
||||
s2 = storage.Storage(backend=self.backend, flow_detail=s._flowdetail)
|
||||
s2 = storage.SingleThreadedStorage(backend=self.backend,
|
||||
flow_detail=s._flowdetail)
|
||||
self.assertIs(s2.has_failures(), True)
|
||||
self.assertEqual(s2.get_failures(), {'my task': fail})
|
||||
self.assertEqual(s2.get('my task'), fail)
|
||||
@@ -305,13 +322,71 @@ class StorageTest(test.TestCase):
|
||||
})
|
||||
# imagine we are resuming, so we need to make new
|
||||
# storage from same flow details
|
||||
s2 = storage.Storage(s._flowdetail, backend=self.backend)
|
||||
s2 = storage.SingleThreadedStorage(s._flowdetail, backend=self.backend)
|
||||
# injected data should still be there:
|
||||
self.assertEqual(s2.fetch_all(), {
|
||||
'foo': 'bar',
|
||||
'spam': 'eggs',
|
||||
})
|
||||
|
||||
def test_many_thread_ensure_same_task(self):
|
||||
s = self._get_storage(threaded=True)
|
||||
|
||||
def ensure_my_task():
|
||||
s.ensure_task('my_task', result_mapping={})
|
||||
|
||||
threads = []
|
||||
for i in range(0, self.thread_count):
|
||||
threads.append(threading.Thread(target=ensure_my_task))
|
||||
self._run_many_threads(threads)
|
||||
|
||||
# Only one task should have been made, no more.
|
||||
self.assertEqual(1, len(s._flowdetail))
|
||||
|
||||
def test_many_thread_one_reset(self):
|
||||
s = self._get_storage(threaded=True)
|
||||
s.ensure_task('a')
|
||||
s.set_task_state('a', states.SUSPENDED)
|
||||
s.ensure_task('b')
|
||||
s.set_task_state('b', states.SUSPENDED)
|
||||
|
||||
results = []
|
||||
result_lock = threading.Lock()
|
||||
|
||||
def reset_all():
|
||||
r = s.reset_tasks()
|
||||
with result_lock:
|
||||
results.append(r)
|
||||
|
||||
threads = []
|
||||
for i in range(0, self.thread_count):
|
||||
threads.append(threading.Thread(target=reset_all))
|
||||
|
||||
self._run_many_threads(threads)
|
||||
|
||||
# Only one thread should have actually reset (not anymore)
|
||||
results = [r for r in results if len(r)]
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertEqual(['a', 'b'], sorted([a[0] for a in results[0]]))
|
||||
|
||||
def test_many_thread_inject(self):
|
||||
s = self._get_storage(threaded=True)
|
||||
|
||||
def inject_values(values):
|
||||
s.inject(values)
|
||||
|
||||
threads = []
|
||||
for i in range(0, self.thread_count):
|
||||
values = {
|
||||
str(i): str(i),
|
||||
}
|
||||
threads.append(threading.Thread(target=inject_values,
|
||||
args=[values]))
|
||||
|
||||
self._run_many_threads(threads)
|
||||
self.assertEqual(self.thread_count, len(s.fetch_all()))
|
||||
self.assertEqual(1, len(s._flowdetail))
|
||||
|
||||
def test_fetch_meapped_args(self):
|
||||
s = self._get_storage()
|
||||
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||
@@ -357,7 +432,7 @@ class StorageTest(test.TestCase):
|
||||
fd.state = states.FAILURE
|
||||
with contextlib.closing(self.backend.get_connection()) as conn:
|
||||
fd.update(conn.update_flow_details(fd))
|
||||
s = storage.Storage(flow_detail=fd, backend=self.backend)
|
||||
s = storage.SingleThreadedStorage(flow_detail=fd, backend=self.backend)
|
||||
self.assertEqual(s.get_flow_state(), states.FAILURE)
|
||||
|
||||
def test_set_and_get_flow_state(self):
|
||||
|
||||
235
taskflow/tests/unit/test_utils_lock_utils.py
Normal file
235
taskflow/tests/unit/test_utils_lock_utils.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import time
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow import test
|
||||
from taskflow.utils import lock_utils
|
||||
|
||||
|
||||
def _find_overlaps(times, start, end):
|
||||
overlaps = 0
|
||||
for (s, e) in times:
|
||||
if s >= start and e <= end:
|
||||
overlaps += 1
|
||||
return overlaps
|
||||
|
||||
|
||||
def _spawn_variation(readers, writers, max_workers=None):
|
||||
start_stops = collections.deque()
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
|
||||
def read_func():
|
||||
with lock.read_lock():
|
||||
start_stops.append(('r', time.time(), time.time()))
|
||||
|
||||
def write_func():
|
||||
with lock.write_lock():
|
||||
start_stops.append(('w', time.time(), time.time()))
|
||||
|
||||
if max_workers is None:
|
||||
max_workers = max(0, readers) + max(0, writers)
|
||||
if max_workers > 0:
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
|
||||
for i in range(0, readers):
|
||||
e.submit(read_func)
|
||||
for i in range(0, writers):
|
||||
e.submit(write_func)
|
||||
|
||||
writer_times = []
|
||||
reader_times = []
|
||||
for (t, start, stop) in list(start_stops):
|
||||
if t == 'w':
|
||||
writer_times.append((start, stop))
|
||||
else:
|
||||
reader_times.append((start, stop))
|
||||
return (writer_times, reader_times)
|
||||
|
||||
|
||||
class ReadWriteLockTest(test.TestCase):
|
||||
def test_writer_abort(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
self.assertFalse(lock.owner)
|
||||
|
||||
def blow_up():
|
||||
with lock.write_lock():
|
||||
self.assertEqual(lock.WRITER, lock.owner)
|
||||
raise RuntimeError("Broken")
|
||||
|
||||
self.assertRaises(RuntimeError, blow_up)
|
||||
self.assertFalse(lock.owner)
|
||||
|
||||
def test_reader_abort(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
self.assertFalse(lock.owner)
|
||||
|
||||
def blow_up():
|
||||
with lock.read_lock():
|
||||
self.assertEqual(lock.READER, lock.owner)
|
||||
raise RuntimeError("Broken")
|
||||
|
||||
self.assertRaises(RuntimeError, blow_up)
|
||||
self.assertFalse(lock.owner)
|
||||
|
||||
def test_reader_chaotic(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
activated = collections.deque()
|
||||
|
||||
def chaotic_reader(blow_up):
|
||||
with lock.read_lock():
|
||||
if blow_up:
|
||||
raise RuntimeError("Broken")
|
||||
else:
|
||||
activated.append(lock.owner)
|
||||
|
||||
def happy_writer():
|
||||
with lock.write_lock():
|
||||
activated.append(lock.owner)
|
||||
|
||||
with futures.ThreadPoolExecutor(max_workers=20) as e:
|
||||
for i in range(0, 20):
|
||||
if i % 2 == 0:
|
||||
e.submit(chaotic_reader, blow_up=bool(i % 4 == 0))
|
||||
else:
|
||||
e.submit(happy_writer)
|
||||
|
||||
writers = [a for a in activated if a == 'w']
|
||||
readers = [a for a in activated if a == 'r']
|
||||
self.assertEqual(10, len(writers))
|
||||
self.assertEqual(5, len(readers))
|
||||
|
||||
def test_writer_chaotic(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
activated = collections.deque()
|
||||
|
||||
def chaotic_writer(blow_up):
|
||||
with lock.write_lock():
|
||||
if blow_up:
|
||||
raise RuntimeError("Broken")
|
||||
else:
|
||||
activated.append(lock.owner)
|
||||
|
||||
def happy_reader():
|
||||
with lock.read_lock():
|
||||
activated.append(lock.owner)
|
||||
|
||||
with futures.ThreadPoolExecutor(max_workers=20) as e:
|
||||
for i in range(0, 20):
|
||||
if i % 2 == 0:
|
||||
e.submit(chaotic_writer, blow_up=bool(i % 4 == 0))
|
||||
else:
|
||||
e.submit(happy_reader)
|
||||
|
||||
writers = [a for a in activated if a == 'w']
|
||||
readers = [a for a in activated if a == 'r']
|
||||
self.assertEqual(5, len(writers))
|
||||
self.assertEqual(10, len(readers))
|
||||
|
||||
def test_single_reader_writer(self):
|
||||
results = []
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
with lock.read_lock():
|
||||
self.assertTrue(lock.is_reader())
|
||||
self.assertEqual(0, len(results))
|
||||
with lock.write_lock():
|
||||
results.append(1)
|
||||
self.assertTrue(lock.is_writer())
|
||||
with lock.read_lock():
|
||||
self.assertTrue(lock.is_reader())
|
||||
self.assertEqual(1, len(results))
|
||||
self.assertFalse(lock.is_reader())
|
||||
self.assertFalse(lock.is_writer())
|
||||
|
||||
def test_reader_to_writer(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
|
||||
def writer_func():
|
||||
with lock.write_lock():
|
||||
pass
|
||||
|
||||
with lock.read_lock():
|
||||
self.assertRaises(RuntimeError, writer_func)
|
||||
self.assertFalse(lock.is_writer())
|
||||
|
||||
self.assertFalse(lock.is_reader())
|
||||
self.assertFalse(lock.is_writer())
|
||||
|
||||
def test_writer_to_reader(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
|
||||
def reader_func():
|
||||
with lock.read_lock():
|
||||
pass
|
||||
|
||||
with lock.write_lock():
|
||||
self.assertRaises(RuntimeError, reader_func)
|
||||
self.assertFalse(lock.is_reader())
|
||||
|
||||
self.assertFalse(lock.is_reader())
|
||||
self.assertFalse(lock.is_writer())
|
||||
|
||||
def test_double_writer(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
with lock.write_lock():
|
||||
self.assertFalse(lock.is_reader())
|
||||
self.assertTrue(lock.is_writer())
|
||||
with lock.write_lock():
|
||||
self.assertTrue(lock.is_writer())
|
||||
self.assertTrue(lock.is_writer())
|
||||
|
||||
self.assertFalse(lock.is_reader())
|
||||
self.assertFalse(lock.is_writer())
|
||||
|
||||
def test_double_reader(self):
|
||||
lock = lock_utils.ReaderWriterLock()
|
||||
with lock.read_lock():
|
||||
self.assertTrue(lock.is_reader())
|
||||
self.assertFalse(lock.is_writer())
|
||||
with lock.read_lock():
|
||||
self.assertTrue(lock.is_reader())
|
||||
self.assertTrue(lock.is_reader())
|
||||
|
||||
self.assertFalse(lock.is_reader())
|
||||
self.assertFalse(lock.is_writer())
|
||||
|
||||
def test_multi_reader_multi_writer(self):
|
||||
writer_times, reader_times = _spawn_variation(10, 10)
|
||||
self.assertEqual(10, len(writer_times))
|
||||
self.assertEqual(10, len(reader_times))
|
||||
for (start, stop) in writer_times:
|
||||
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
|
||||
self.assertEqual(1, _find_overlaps(writer_times, start, stop))
|
||||
for (start, stop) in reader_times:
|
||||
self.assertEqual(0, _find_overlaps(writer_times, start, stop))
|
||||
|
||||
def test_multi_reader_single_writer(self):
|
||||
writer_times, reader_times = _spawn_variation(9, 1)
|
||||
self.assertEqual(1, len(writer_times))
|
||||
self.assertEqual(9, len(reader_times))
|
||||
start, stop = writer_times[0]
|
||||
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
|
||||
|
||||
def test_multi_writer(self):
|
||||
writer_times, reader_times = _spawn_variation(0, 10)
|
||||
self.assertEqual(10, len(writer_times))
|
||||
self.assertEqual(0, len(reader_times))
|
||||
for (start, stop) in writer_times:
|
||||
self.assertEqual(1, _find_overlaps(writer_times, start, stop))
|
||||
@@ -21,6 +21,8 @@
|
||||
# pulls in oslo.cfg) and is reduced to only what taskflow currently wants to
|
||||
# use from that code.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
@@ -28,6 +30,7 @@ import threading
|
||||
import time
|
||||
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import threading_utils as tu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,6 +65,163 @@ def locked(*args, **kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
class ReaderWriterLock(object):
|
||||
"""A reader/writer lock.
|
||||
|
||||
This lock allows for simultaneous readers to exist but only one writer
|
||||
to exist for use-cases where it is useful to have such types of locks.
|
||||
|
||||
Currently a reader can not escalate its read lock to a write lock and
|
||||
a writer can not acquire a read lock while it owns or is waiting on
|
||||
the write lock.
|
||||
|
||||
In the future these restrictions may be relaxed.
|
||||
"""
|
||||
WRITER = 'w'
|
||||
READER = 'r'
|
||||
|
||||
def __init__(self):
|
||||
self._writer = None
|
||||
self._pending_writers = collections.deque()
|
||||
self._readers = collections.deque()
|
||||
self._cond = threading.Condition()
|
||||
|
||||
def is_writer(self, check_pending=True):
|
||||
"""Returns if the caller is the active writer or a pending writer."""
|
||||
self._cond.acquire()
|
||||
try:
|
||||
me = tu.get_ident()
|
||||
if self._writer is not None and self._writer == me:
|
||||
return True
|
||||
if check_pending:
|
||||
return me in self._pending_writers
|
||||
else:
|
||||
return False
|
||||
finally:
|
||||
self._cond.release()
|
||||
|
||||
@property
|
||||
def owner(self):
|
||||
"""Returns whether the lock is locked by a writer or reader."""
|
||||
self._cond.acquire()
|
||||
try:
|
||||
if self._writer is not None:
|
||||
return self.WRITER
|
||||
if self._readers:
|
||||
return self.READER
|
||||
return None
|
||||
finally:
|
||||
self._cond.release()
|
||||
|
||||
def is_reader(self):
|
||||
"""Returns if the caller is one of the readers."""
|
||||
self._cond.acquire()
|
||||
try:
|
||||
return tu.get_ident() in self._readers
|
||||
finally:
|
||||
self._cond.release()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def read_lock(self):
|
||||
"""Grants a read lock.
|
||||
|
||||
Will wait until no active or pending writers.
|
||||
|
||||
Raises a RuntimeError if an active or pending writer tries to acquire
|
||||
a read lock.
|
||||
"""
|
||||
me = tu.get_ident()
|
||||
if self.is_writer():
|
||||
raise RuntimeError("Writer %s can not acquire a read lock"
|
||||
" while holding/waiting for the write lock"
|
||||
% me)
|
||||
self._cond.acquire()
|
||||
try:
|
||||
while True:
|
||||
# No active or pending writers; we are good to become a reader.
|
||||
if self._writer is None and len(self._pending_writers) == 0:
|
||||
self._readers.append(me)
|
||||
break
|
||||
# Some writers; guess we have to wait.
|
||||
self._cond.wait()
|
||||
finally:
|
||||
self._cond.release()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# I am no longer a reader, remove *one* occurrence of myself.
|
||||
# If the current thread acquired two read locks, then it will
|
||||
# still have to remove that other read lock; this allows for
|
||||
# basic reentrancy to be possible.
|
||||
self._cond.acquire()
|
||||
try:
|
||||
self._readers.remove(me)
|
||||
self._cond.notify_all()
|
||||
finally:
|
||||
self._cond.release()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def write_lock(self):
|
||||
"""Grants a write lock.
|
||||
|
||||
Will wait until no active readers. Blocks readers after acquiring.
|
||||
|
||||
Raises a RuntimeError if an active reader attempts to acquire a lock.
|
||||
"""
|
||||
me = tu.get_ident()
|
||||
if self.is_reader():
|
||||
raise RuntimeError("Reader %s to writer privilege"
|
||||
" escalation not allowed" % me)
|
||||
if self.is_writer(check_pending=False):
|
||||
# Already the writer; this allows for basic reentrancy.
|
||||
yield self
|
||||
else:
|
||||
self._cond.acquire()
|
||||
try:
|
||||
self._pending_writers.append(me)
|
||||
while True:
|
||||
# No readers, and no active writer, am I next??
|
||||
if len(self._readers) == 0 and self._writer is None:
|
||||
if self._pending_writers[0] == me:
|
||||
self._writer = self._pending_writers.popleft()
|
||||
break
|
||||
self._cond.wait()
|
||||
finally:
|
||||
self._cond.release()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._cond.acquire()
|
||||
try:
|
||||
self._writer = None
|
||||
self._cond.notify_all()
|
||||
finally:
|
||||
self._cond.release()
|
||||
|
||||
|
||||
class DummyReaderWriterLock(object):
|
||||
"""A dummy reader/writer lock that doesn't lock anything but provides same
|
||||
functions as a normal reader/writer lock class.
|
||||
"""
|
||||
@contextlib.contextmanager
|
||||
def write_lock(self):
|
||||
yield self
|
||||
|
||||
@contextlib.contextmanager
|
||||
def read_lock(self):
|
||||
yield self
|
||||
|
||||
@property
|
||||
def owner(self):
|
||||
return None
|
||||
|
||||
def is_reader(self):
|
||||
return False
|
||||
|
||||
def is_writer(self):
|
||||
return False
|
||||
|
||||
|
||||
class MultiLock(object):
|
||||
"""A class which can attempt to obtain many locks at once and release
|
||||
said locks when exiting.
|
||||
|
||||
@@ -16,16 +16,15 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import threading
|
||||
import types
|
||||
|
||||
import six
|
||||
|
||||
from taskflow.utils import lock_utils
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
if six.PY2:
|
||||
from thread import get_ident # noqa
|
||||
else:
|
||||
# In python3+ the get_ident call moved (whhhy??)
|
||||
from threading import get_ident # noqa
|
||||
|
||||
|
||||
def get_optimal_thread_count():
|
||||
@@ -37,20 +36,3 @@ def get_optimal_thread_count():
|
||||
# just setup two threads since its hard to know what else we
|
||||
# should do in this situation.
|
||||
return 2
|
||||
|
||||
|
||||
class ThreadSafeMeta(type):
|
||||
"""Metaclass that adds locking to all pubic methods of a class."""
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
for attr_name, attr_value in six.iteritems(attrs):
|
||||
if isinstance(attr_value, types.FunctionType):
|
||||
if attr_name[0] != '_':
|
||||
attrs[attr_name] = lock_utils.locked(attr_value)
|
||||
return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs)
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs)
|
||||
if not hasattr(instance, '_lock'):
|
||||
instance._lock = threading.RLock()
|
||||
return instance
|
||||
|
||||
Reference in New Issue
Block a user