Merge "Use reader/writer locks in storage"
This commit is contained in:
@@ -191,13 +191,12 @@ class ActionEngine(base.EngineBase):
|
|||||||
|
|
||||||
class SingleThreadedActionEngine(ActionEngine):
|
class SingleThreadedActionEngine(ActionEngine):
|
||||||
"""Engine that runs tasks in serial manner."""
|
"""Engine that runs tasks in serial manner."""
|
||||||
_storage_cls = t_storage.Storage
|
_storage_cls = t_storage.SingleThreadedStorage
|
||||||
|
|
||||||
|
|
||||||
class MultiThreadedActionEngine(ActionEngine):
|
class MultiThreadedActionEngine(ActionEngine):
|
||||||
"""Engine that runs tasks in parallel manner."""
|
"""Engine that runs tasks in parallel manner."""
|
||||||
|
_storage_cls = t_storage.MultiThreadedStorage
|
||||||
_storage_cls = t_storage.ThreadSafeStorage
|
|
||||||
|
|
||||||
def _task_executor_cls(self):
|
def _task_executor_cls(self):
|
||||||
return executor.ParallelTaskExecutor(self._executor)
|
return executor.ParallelTaskExecutor(self._executor)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -25,13 +26,15 @@ from taskflow import exceptions
|
|||||||
from taskflow.openstack.common import uuidutils
|
from taskflow.openstack.common import uuidutils
|
||||||
from taskflow.persistence import logbook
|
from taskflow.persistence import logbook
|
||||||
from taskflow import states
|
from taskflow import states
|
||||||
|
from taskflow.utils import lock_utils
|
||||||
from taskflow.utils import misc
|
from taskflow.utils import misc
|
||||||
from taskflow.utils import threading_utils as tu
|
from taskflow.utils import reflection
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
|
STATES_WITH_RESULTS = (states.SUCCESS, states.REVERTING, states.FAILURE)
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(abc.ABCMeta)
|
||||||
class Storage(object):
|
class Storage(object):
|
||||||
"""Interface between engines and logbook.
|
"""Interface between engines and logbook.
|
||||||
|
|
||||||
@@ -48,11 +51,14 @@ class Storage(object):
|
|||||||
self._reverse_mapping = {}
|
self._reverse_mapping = {}
|
||||||
self._backend = backend
|
self._backend = backend
|
||||||
self._flowdetail = flow_detail
|
self._flowdetail = flow_detail
|
||||||
|
self._lock = self._lock_cls()
|
||||||
|
|
||||||
# NOTE(imelnikov): failure serialization looses information,
|
# NOTE(imelnikov): failure serialization looses information,
|
||||||
# so we cache failures here, in task name -> misc.Failure mapping.
|
# so we cache failures here, in task name -> misc.Failure mapping.
|
||||||
self._failures = {}
|
self._failures = {}
|
||||||
self._reload_failures()
|
for td in self._flowdetail:
|
||||||
|
if td.failure is not None:
|
||||||
|
self._failures[td.name] = td.failure
|
||||||
|
|
||||||
self._task_name_to_uuid = dict((td.name, td.uuid)
|
self._task_name_to_uuid = dict((td.name, td.uuid)
|
||||||
for td in self._flowdetail)
|
for td in self._flowdetail)
|
||||||
@@ -66,11 +72,20 @@ class Storage(object):
|
|||||||
self._set_result_mapping(injector_td.name,
|
self._set_result_mapping(injector_td.name,
|
||||||
dict((name, name) for name in names))
|
dict((name, name) for name in names))
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def _lock_cls(self):
|
||||||
|
"""Lock class used to generate reader/writer locks for protecting
|
||||||
|
read/write access to the underlying storage backend and internally
|
||||||
|
mutating operations.
|
||||||
|
"""
|
||||||
|
|
||||||
def _with_connection(self, functor, *args, **kwargs):
|
def _with_connection(self, functor, *args, **kwargs):
|
||||||
# NOTE(harlowja): Activate the given function with a backend
|
# NOTE(harlowja): Activate the given function with a backend
|
||||||
# connection, if a backend is provided in the first place, otherwise
|
# connection, if a backend is provided in the first place, otherwise
|
||||||
# don't call the function.
|
# don't call the function.
|
||||||
if self._backend is None:
|
if self._backend is None:
|
||||||
|
LOG.debug("No backend provided, not calling functor '%s'",
|
||||||
|
reflection.get_callable_name(functor))
|
||||||
return
|
return
|
||||||
with contextlib.closing(self._backend.get_connection()) as conn:
|
with contextlib.closing(self._backend.get_connection()) as conn:
|
||||||
functor(conn, *args, **kwargs)
|
functor(conn, *args, **kwargs)
|
||||||
@@ -85,12 +100,13 @@ class Storage(object):
|
|||||||
Returns uuid for the task details corresponding to the task with
|
Returns uuid for the task details corresponding to the task with
|
||||||
given name.
|
given name.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock.write_lock():
|
||||||
task_id = self._task_name_to_uuid[task_name]
|
try:
|
||||||
except KeyError:
|
task_id = self._task_name_to_uuid[task_name]
|
||||||
task_id = uuidutils.generate_uuid()
|
except KeyError:
|
||||||
self._add_task(task_id, task_name, task_version)
|
task_id = uuidutils.generate_uuid()
|
||||||
self._set_result_mapping(task_name, result_mapping)
|
self._add_task(task_id, task_name, task_version)
|
||||||
|
self._set_result_mapping(task_name, result_mapping)
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
def _add_task(self, uuid, task_name, task_version=None):
|
def _add_task(self, uuid, task_name, task_version=None):
|
||||||
@@ -99,27 +115,29 @@ class Storage(object):
|
|||||||
Task becomes known to storage by that name and uuid.
|
Task becomes known to storage by that name and uuid.
|
||||||
Task state is set to PENDING.
|
Task state is set to PENDING.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def save_both(conn, td):
|
||||||
|
"""Saves the flow and the task detail with the same connection."""
|
||||||
|
self._save_flow_detail(conn)
|
||||||
|
self._save_task_detail(conn, td)
|
||||||
|
|
||||||
# TODO(imelnikov): check that task with same uuid or
|
# TODO(imelnikov): check that task with same uuid or
|
||||||
# task name does not exist.
|
# task name does not exist.
|
||||||
td = logbook.TaskDetail(name=task_name, uuid=uuid)
|
td = logbook.TaskDetail(name=task_name, uuid=uuid)
|
||||||
td.state = states.PENDING
|
td.state = states.PENDING
|
||||||
td.version = task_version
|
td.version = task_version
|
||||||
self._flowdetail.add(td)
|
self._flowdetail.add(td)
|
||||||
|
self._with_connection(save_both, td)
|
||||||
def save_both(conn):
|
|
||||||
"""Saves the flow and the task detail with the same connection."""
|
|
||||||
self._save_flow_detail(conn)
|
|
||||||
self._save_task_detail(conn, task_detail=td)
|
|
||||||
|
|
||||||
self._with_connection(save_both)
|
|
||||||
self._task_name_to_uuid[task_name] = uuid
|
self._task_name_to_uuid[task_name] = uuid
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flow_name(self):
|
def flow_name(self):
|
||||||
|
# This never changes (so no read locking needed).
|
||||||
return self._flowdetail.name
|
return self._flowdetail.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flow_uuid(self):
|
def flow_uuid(self):
|
||||||
|
# This never changes (so no read locking needed).
|
||||||
return self._flowdetail.uuid
|
return self._flowdetail.uuid
|
||||||
|
|
||||||
def _save_flow_detail(self, conn):
|
def _save_flow_detail(self, conn):
|
||||||
@@ -130,13 +148,9 @@ class Storage(object):
|
|||||||
|
|
||||||
def _taskdetail_by_name(self, task_name):
|
def _taskdetail_by_name(self, task_name):
|
||||||
try:
|
try:
|
||||||
td = self._flowdetail.find(self._task_name_to_uuid[task_name])
|
return self._flowdetail.find(self._task_name_to_uuid[task_name])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
td = None
|
|
||||||
|
|
||||||
if td is None:
|
|
||||||
raise exceptions.NotFound("Unknown task name: %s" % task_name)
|
raise exceptions.NotFound("Unknown task name: %s" % task_name)
|
||||||
return td
|
|
||||||
|
|
||||||
def _save_task_detail(self, conn, task_detail):
|
def _save_task_detail(self, conn, task_detail):
|
||||||
# NOTE(harlowja): we need to update our contained task detail if
|
# NOTE(harlowja): we need to update our contained task detail if
|
||||||
@@ -146,34 +160,39 @@ class Storage(object):
|
|||||||
|
|
||||||
def get_task_uuid(self, task_name):
|
def get_task_uuid(self, task_name):
|
||||||
"""Get task uuid by given name."""
|
"""Get task uuid by given name."""
|
||||||
td = self._taskdetail_by_name(task_name)
|
with self._lock.read_lock():
|
||||||
return td.uuid
|
td = self._taskdetail_by_name(task_name)
|
||||||
|
return td.uuid
|
||||||
|
|
||||||
def set_task_state(self, task_name, state):
|
def set_task_state(self, task_name, state):
|
||||||
"""Set task state."""
|
"""Set task state."""
|
||||||
td = self._taskdetail_by_name(task_name)
|
with self._lock.write_lock():
|
||||||
td.state = state
|
td = self._taskdetail_by_name(task_name)
|
||||||
self._with_connection(self._save_task_detail, task_detail=td)
|
td.state = state
|
||||||
|
self._with_connection(self._save_task_detail, td)
|
||||||
|
|
||||||
def get_task_state(self, task_name):
|
def get_task_state(self, task_name):
|
||||||
"""Get state of task with given name."""
|
"""Get state of task with given name."""
|
||||||
return self._taskdetail_by_name(task_name).state
|
with self._lock.read_lock():
|
||||||
|
td = self._taskdetail_by_name(task_name)
|
||||||
|
return td.state
|
||||||
|
|
||||||
def get_tasks_states(self, task_names):
|
def get_tasks_states(self, task_names):
|
||||||
return dict((name, self.get_task_state(name))
|
"""Gets all task states."""
|
||||||
for name in task_names)
|
with self._lock.read_lock():
|
||||||
|
return dict((name, self.get_task_state(name))
|
||||||
|
for name in task_names)
|
||||||
|
|
||||||
def update_task_metadata(self, task_name, update_with):
|
def update_task_metadata(self, task_name, update_with):
|
||||||
|
"""Updates a tasks metadata."""
|
||||||
if not update_with:
|
if not update_with:
|
||||||
return
|
return
|
||||||
# NOTE(harlowja): this is a read and then write, not in 1 transaction
|
with self._lock.write_lock():
|
||||||
# so it is entirely possible that we could write over another writes
|
td = self._taskdetail_by_name(task_name)
|
||||||
# metadata update. Maybe add some merging logic later?
|
if not td.meta:
|
||||||
td = self._taskdetail_by_name(task_name)
|
td.meta = {}
|
||||||
if not td.meta:
|
td.meta.update(update_with)
|
||||||
td.meta = {}
|
self._with_connection(self._save_task_detail, td)
|
||||||
td.meta.update(update_with)
|
|
||||||
self._with_connection(self._save_task_detail, task_detail=td)
|
|
||||||
|
|
||||||
def set_task_progress(self, task_name, progress, details=None):
|
def set_task_progress(self, task_name, progress, details=None):
|
||||||
"""Set task progress.
|
"""Set task progress.
|
||||||
@@ -204,10 +223,11 @@ class Storage(object):
|
|||||||
:param task_name: task name
|
:param task_name: task name
|
||||||
:returns: current task progress value
|
:returns: current task progress value
|
||||||
"""
|
"""
|
||||||
meta = self._taskdetail_by_name(task_name).meta
|
with self._lock.read_lock():
|
||||||
if not meta:
|
td = self._taskdetail_by_name(task_name)
|
||||||
return 0.0
|
if not td.meta:
|
||||||
return meta.get('progress', 0.0)
|
return 0.0
|
||||||
|
return td.meta.get('progress', 0.0)
|
||||||
|
|
||||||
def get_task_progress_details(self, task_name):
|
def get_task_progress_details(self, task_name):
|
||||||
"""Get progress details of task with given name.
|
"""Get progress details of task with given name.
|
||||||
@@ -216,10 +236,11 @@ class Storage(object):
|
|||||||
:returns: None if progress_details not defined, else progress_details
|
:returns: None if progress_details not defined, else progress_details
|
||||||
dict
|
dict
|
||||||
"""
|
"""
|
||||||
meta = self._taskdetail_by_name(task_name).meta
|
with self._lock.read_lock():
|
||||||
if not meta:
|
td = self._taskdetail_by_name(task_name)
|
||||||
return None
|
if not td.meta:
|
||||||
return meta.get('progress_details')
|
return None
|
||||||
|
return td.meta.get('progress_details')
|
||||||
|
|
||||||
def _check_all_results_provided(self, task_name, data):
|
def _check_all_results_provided(self, task_name, data):
|
||||||
"""Warn if task did not provide some of results.
|
"""Warn if task did not provide some of results.
|
||||||
@@ -228,69 +249,57 @@ class Storage(object):
|
|||||||
without all needed keys. It may also happen if task returns
|
without all needed keys. It may also happen if task returns
|
||||||
result of wrong type.
|
result of wrong type.
|
||||||
"""
|
"""
|
||||||
result_mapping = self._result_mappings.get(task_name, None)
|
result_mapping = self._result_mappings.get(task_name)
|
||||||
if result_mapping is None:
|
if not result_mapping:
|
||||||
return
|
return
|
||||||
for name, index in six.iteritems(result_mapping):
|
for name, index in six.iteritems(result_mapping):
|
||||||
try:
|
try:
|
||||||
misc.item_from(data, index, name=name)
|
misc.item_from(data, index, name=name)
|
||||||
except exceptions.NotFound:
|
except exceptions.NotFound:
|
||||||
LOG.warning("Task %s did not supply result "
|
LOG.warning("Task %s did not supply result "
|
||||||
"with index %r (name %s)",
|
"with index %r (name %s)", task_name, index, name)
|
||||||
task_name, index, name)
|
|
||||||
|
|
||||||
def save(self, task_name, data, state=states.SUCCESS):
|
def save(self, task_name, data, state=states.SUCCESS):
|
||||||
"""Put result for task with id 'uuid' to storage."""
|
"""Put result for task with id 'uuid' to storage."""
|
||||||
td = self._taskdetail_by_name(task_name)
|
with self._lock.write_lock():
|
||||||
td.state = state
|
td = self._taskdetail_by_name(task_name)
|
||||||
if state == states.FAILURE and isinstance(data, misc.Failure):
|
td.state = state
|
||||||
td.results = None
|
if state == states.FAILURE and isinstance(data, misc.Failure):
|
||||||
td.failure = data
|
td.results = None
|
||||||
self._failures[td.name] = data
|
td.failure = data
|
||||||
else:
|
self._failures[td.name] = data
|
||||||
td.results = data
|
else:
|
||||||
td.failure = None
|
td.results = data
|
||||||
self._check_all_results_provided(td.name, data)
|
td.failure = None
|
||||||
self._with_connection(self._save_task_detail, task_detail=td)
|
self._check_all_results_provided(td.name, data)
|
||||||
|
self._with_connection(self._save_task_detail, td)
|
||||||
def _cache_failure(self, name, fail):
|
|
||||||
"""Ensure that cache has matching failure for task with this name.
|
|
||||||
|
|
||||||
We leave cached version if it matches as it may contain more
|
|
||||||
information. Returns cached failure.
|
|
||||||
"""
|
|
||||||
cached = self._failures.get(name)
|
|
||||||
if fail.matches(cached):
|
|
||||||
return cached
|
|
||||||
self._failures[name] = fail
|
|
||||||
return fail
|
|
||||||
|
|
||||||
def _reload_failures(self):
|
|
||||||
"""Refresh failures cache."""
|
|
||||||
for td in self._flowdetail:
|
|
||||||
if td.failure is not None:
|
|
||||||
self._cache_failure(td.name, td.failure)
|
|
||||||
|
|
||||||
def get(self, task_name):
|
def get(self, task_name):
|
||||||
"""Get result for task with name 'task_name' to storage."""
|
"""Get result for task with name 'task_name' to storage."""
|
||||||
td = self._taskdetail_by_name(task_name)
|
with self._lock.read_lock():
|
||||||
if td.failure is not None:
|
td = self._taskdetail_by_name(task_name)
|
||||||
return self._cache_failure(td.name, td.failure)
|
if td.failure is not None:
|
||||||
if td.state not in STATES_WITH_RESULTS:
|
cached = self._failures.get(task_name)
|
||||||
raise exceptions.NotFound(
|
if td.failure.matches(cached):
|
||||||
"Result for task %s is not known" % task_name)
|
return cached
|
||||||
return td.results
|
return td.failure
|
||||||
|
if td.state not in STATES_WITH_RESULTS:
|
||||||
|
raise exceptions.NotFound("Result for task %s is not known"
|
||||||
|
% task_name)
|
||||||
|
return td.results
|
||||||
|
|
||||||
def get_failures(self):
|
def get_failures(self):
|
||||||
"""Get list of failures that happened with this flow.
|
"""Get list of failures that happened with this flow.
|
||||||
|
|
||||||
No order guaranteed.
|
No order guaranteed.
|
||||||
"""
|
"""
|
||||||
return self._failures.copy()
|
with self._lock.read_lock():
|
||||||
|
return self._failures.copy()
|
||||||
|
|
||||||
def has_failures(self):
|
def has_failures(self):
|
||||||
"""Returns True if there are failed tasks in the storage."""
|
"""Returns True if there are failed tasks in the storage."""
|
||||||
return bool(self._failures)
|
with self._lock.read_lock():
|
||||||
|
return bool(self._failures)
|
||||||
|
|
||||||
def _reset_task(self, td, state):
|
def _reset_task(self, td, state):
|
||||||
if td.name == self.injector_name:
|
if td.name == self.injector_name:
|
||||||
@@ -305,25 +314,28 @@ class Storage(object):
|
|||||||
|
|
||||||
def reset(self, task_name, state=states.PENDING):
|
def reset(self, task_name, state=states.PENDING):
|
||||||
"""Remove result for task with id 'uuid' from storage."""
|
"""Remove result for task with id 'uuid' from storage."""
|
||||||
td = self._taskdetail_by_name(task_name)
|
with self._lock.write_lock():
|
||||||
if self._reset_task(td, state):
|
td = self._taskdetail_by_name(task_name)
|
||||||
self._with_connection(self._save_task_detail, task_detail=td)
|
if self._reset_task(td, state):
|
||||||
|
self._with_connection(self._save_task_detail, td)
|
||||||
|
|
||||||
def reset_tasks(self):
|
def reset_tasks(self):
|
||||||
"""Reset all tasks to PENDING state, removing results.
|
"""Reset all tasks to PENDING state, removing results.
|
||||||
|
|
||||||
Returns list of (name, uuid) tuples for all tasks that were reset.
|
Returns list of (name, uuid) tuples for all tasks that were reset.
|
||||||
"""
|
"""
|
||||||
result = []
|
reset_results = []
|
||||||
|
|
||||||
def do_reset_all(connection):
|
def do_reset_all(connection):
|
||||||
for td in self._flowdetail:
|
for td in self._flowdetail:
|
||||||
if self._reset_task(td, states.PENDING):
|
if self._reset_task(td, states.PENDING):
|
||||||
self._save_task_detail(connection, td)
|
self._save_task_detail(connection, td)
|
||||||
result.append((td.name, td.uuid))
|
reset_results.append((td.name, td.uuid))
|
||||||
|
|
||||||
self._with_connection(do_reset_all)
|
with self._lock.write_lock():
|
||||||
return result
|
self._with_connection(do_reset_all)
|
||||||
|
|
||||||
|
return reset_results
|
||||||
|
|
||||||
def inject(self, pairs):
|
def inject(self, pairs):
|
||||||
"""Add values into storage.
|
"""Add values into storage.
|
||||||
@@ -331,20 +343,20 @@ class Storage(object):
|
|||||||
This method should be used to put flow parameters (requirements that
|
This method should be used to put flow parameters (requirements that
|
||||||
are not satisfied by any task in the flow) into storage.
|
are not satisfied by any task in the flow) into storage.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock.write_lock():
|
||||||
injector_td = self._taskdetail_by_name(self.injector_name)
|
try:
|
||||||
except exceptions.NotFound:
|
td = self._taskdetail_by_name(self.injector_name)
|
||||||
injector_uuid = uuidutils.generate_uuid()
|
except exceptions.NotFound:
|
||||||
self._add_task(injector_uuid, self.injector_name)
|
self._add_task(uuidutils.generate_uuid(), self.injector_name)
|
||||||
results = dict(pairs)
|
td = self._taskdetail_by_name(self.injector_name)
|
||||||
else:
|
td.results = dict(pairs)
|
||||||
results = injector_td.results.copy()
|
td.state = states.SUCCESS
|
||||||
results.update(pairs)
|
else:
|
||||||
|
td.results.update(pairs)
|
||||||
self.save(self.injector_name, results)
|
self._with_connection(self._save_task_detail, td)
|
||||||
names = six.iterkeys(results)
|
names = six.iterkeys(td.results)
|
||||||
self._set_result_mapping(self.injector_name,
|
self._set_result_mapping(self.injector_name,
|
||||||
dict((name, name) for name in names))
|
dict((name, name) for name in names))
|
||||||
|
|
||||||
def _set_result_mapping(self, task_name, mapping):
|
def _set_result_mapping(self, task_name, mapping):
|
||||||
"""Set mapping for naming task results.
|
"""Set mapping for naming task results.
|
||||||
@@ -378,50 +390,60 @@ class Storage(object):
|
|||||||
|
|
||||||
def fetch(self, name):
|
def fetch(self, name):
|
||||||
"""Fetch named task result."""
|
"""Fetch named task result."""
|
||||||
try:
|
with self._lock.read_lock():
|
||||||
indexes = self._reverse_mapping[name]
|
|
||||||
except KeyError:
|
|
||||||
raise exceptions.NotFound("Name %r is not mapped" % name)
|
|
||||||
# Return the first one that is found.
|
|
||||||
for task_name, index in reversed(indexes):
|
|
||||||
try:
|
try:
|
||||||
result = self.get(task_name)
|
indexes = self._reverse_mapping[name]
|
||||||
return misc.item_from(result, index, name=name)
|
except KeyError:
|
||||||
except exceptions.NotFound:
|
raise exceptions.NotFound("Name %r is not mapped" % name)
|
||||||
pass
|
# Return the first one that is found.
|
||||||
raise exceptions.NotFound("Unable to find result %r" % name)
|
for (task_name, index) in reversed(indexes):
|
||||||
|
try:
|
||||||
|
result = self.get(task_name)
|
||||||
|
return misc.item_from(result, index, name=name)
|
||||||
|
except exceptions.NotFound:
|
||||||
|
pass
|
||||||
|
raise exceptions.NotFound("Unable to find result %r" % name)
|
||||||
|
|
||||||
def fetch_all(self):
|
def fetch_all(self):
|
||||||
"""Fetch all named task results known so far.
|
"""Fetch all named task results known so far.
|
||||||
|
|
||||||
Should be used for debugging and testing purposes mostly.
|
Should be used for debugging and testing purposes mostly.
|
||||||
"""
|
"""
|
||||||
result = {}
|
with self._lock.read_lock():
|
||||||
for name in self._reverse_mapping:
|
results = {}
|
||||||
try:
|
for name in self._reverse_mapping:
|
||||||
result[name] = self.fetch(name)
|
try:
|
||||||
except exceptions.NotFound:
|
results[name] = self.fetch(name)
|
||||||
pass
|
except exceptions.NotFound:
|
||||||
return result
|
pass
|
||||||
|
return results
|
||||||
|
|
||||||
def fetch_mapped_args(self, args_mapping):
|
def fetch_mapped_args(self, args_mapping):
|
||||||
"""Fetch arguments for the task using arguments mapping."""
|
"""Fetch arguments for the task using arguments mapping."""
|
||||||
return dict((key, self.fetch(name))
|
with self._lock.read_lock():
|
||||||
for key, name in six.iteritems(args_mapping))
|
return dict((key, self.fetch(name))
|
||||||
|
for key, name in six.iteritems(args_mapping))
|
||||||
|
|
||||||
def set_flow_state(self, state):
|
def set_flow_state(self, state):
|
||||||
"""Set flowdetails state and save it."""
|
"""Set flow details state and save it."""
|
||||||
self._flowdetail.state = state
|
with self._lock.write_lock():
|
||||||
self._with_connection(self._save_flow_detail)
|
self._flowdetail.state = state
|
||||||
|
self._with_connection(self._save_flow_detail)
|
||||||
|
|
||||||
def get_flow_state(self):
|
def get_flow_state(self):
|
||||||
"""Set state from flowdetails."""
|
"""Get state from flow details."""
|
||||||
state = self._flowdetail.state
|
with self._lock.read_lock():
|
||||||
if state is None:
|
state = self._flowdetail.state
|
||||||
state = states.PENDING
|
if state is None:
|
||||||
return state
|
state = states.PENDING
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
@six.add_metaclass(tu.ThreadSafeMeta)
|
class MultiThreadedStorage(Storage):
|
||||||
class ThreadSafeStorage(Storage):
|
"""Storage that uses locks to protect against concurrent access."""
|
||||||
pass
|
_lock_cls = lock_utils.ReaderWriterLock
|
||||||
|
|
||||||
|
|
||||||
|
class SingleThreadedStorage(Storage):
|
||||||
|
"""Storage that uses dummy locks when you really don't need locks."""
|
||||||
|
_lock_cls = lock_utils.DummyReaderWriterLock
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import threading
|
||||||
|
|
||||||
import mock
|
import mock
|
||||||
|
|
||||||
@@ -35,10 +36,22 @@ class StorageTest(test.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(StorageTest, self).setUp()
|
super(StorageTest, self).setUp()
|
||||||
self.backend = impl_memory.MemoryBackend(conf={})
|
self.backend = impl_memory.MemoryBackend(conf={})
|
||||||
|
self.thread_count = 50
|
||||||
|
|
||||||
def _get_storage(self):
|
def _run_many_threads(self, threads):
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
def _get_storage(self, threaded=False):
|
||||||
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
||||||
return storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
if threaded:
|
||||||
|
return storage.MultiThreadedStorage(backend=self.backend,
|
||||||
|
flow_detail=flow_detail)
|
||||||
|
else:
|
||||||
|
return storage.SingleThreadedStorage(backend=self.backend,
|
||||||
|
flow_detail=flow_detail)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
super(StorageTest, self).tearDown()
|
super(StorageTest, self).tearDown()
|
||||||
@@ -48,14 +61,14 @@ class StorageTest(test.TestCase):
|
|||||||
|
|
||||||
def test_non_saving_storage(self):
|
def test_non_saving_storage(self):
|
||||||
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
||||||
s = storage.Storage(flow_detail=flow_detail) # no backend
|
s = storage.SingleThreadedStorage(flow_detail=flow_detail)
|
||||||
s.ensure_task('my_task')
|
s.ensure_task('my_task')
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
uuidutils.is_uuid_like(s.get_task_uuid('my_task')))
|
uuidutils.is_uuid_like(s.get_task_uuid('my_task')))
|
||||||
|
|
||||||
def test_flow_name_and_uuid(self):
|
def test_flow_name_and_uuid(self):
|
||||||
fd = logbook.FlowDetail(name='test-fd', uuid='aaaa')
|
fd = logbook.FlowDetail(name='test-fd', uuid='aaaa')
|
||||||
s = storage.Storage(flow_detail=fd)
|
s = storage.SingleThreadedStorage(flow_detail=fd)
|
||||||
self.assertEqual(s.flow_name, 'test-fd')
|
self.assertEqual(s.flow_name, 'test-fd')
|
||||||
self.assertEqual(s.flow_uuid, 'aaaa')
|
self.assertEqual(s.flow_uuid, 'aaaa')
|
||||||
|
|
||||||
@@ -79,7 +92,8 @@ class StorageTest(test.TestCase):
|
|||||||
|
|
||||||
def test_ensure_task_fd(self):
|
def test_ensure_task_fd(self):
|
||||||
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
_lb, flow_detail = p_utils.temporary_flow_detail(self.backend)
|
||||||
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
s = storage.SingleThreadedStorage(backend=self.backend,
|
||||||
|
flow_detail=flow_detail)
|
||||||
s.ensure_task('my task', '3.11')
|
s.ensure_task('my task', '3.11')
|
||||||
td = flow_detail.find(s.get_task_uuid('my task'))
|
td = flow_detail.find(s.get_task_uuid('my task'))
|
||||||
self.assertIsNotNone(td)
|
self.assertIsNotNone(td)
|
||||||
@@ -92,7 +106,8 @@ class StorageTest(test.TestCase):
|
|||||||
td = logbook.TaskDetail(name='my_task', uuid='42')
|
td = logbook.TaskDetail(name='my_task', uuid='42')
|
||||||
flow_detail.add(td)
|
flow_detail.add(td)
|
||||||
|
|
||||||
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
s = storage.SingleThreadedStorage(backend=self.backend,
|
||||||
|
flow_detail=flow_detail)
|
||||||
self.assertEqual('42', s.get_task_uuid('my_task'))
|
self.assertEqual('42', s.get_task_uuid('my_task'))
|
||||||
|
|
||||||
def test_ensure_existing_task(self):
|
def test_ensure_existing_task(self):
|
||||||
@@ -100,7 +115,8 @@ class StorageTest(test.TestCase):
|
|||||||
td = logbook.TaskDetail(name='my_task', uuid='42')
|
td = logbook.TaskDetail(name='my_task', uuid='42')
|
||||||
flow_detail.add(td)
|
flow_detail.add(td)
|
||||||
|
|
||||||
s = storage.Storage(backend=self.backend, flow_detail=flow_detail)
|
s = storage.SingleThreadedStorage(backend=self.backend,
|
||||||
|
flow_detail=flow_detail)
|
||||||
s.ensure_task('my_task')
|
s.ensure_task('my_task')
|
||||||
self.assertEqual('42', s.get_task_uuid('my_task'))
|
self.assertEqual('42', s.get_task_uuid('my_task'))
|
||||||
|
|
||||||
@@ -147,7 +163,8 @@ class StorageTest(test.TestCase):
|
|||||||
s.ensure_task('my task')
|
s.ensure_task('my task')
|
||||||
s.save('my task', fail, states.FAILURE)
|
s.save('my task', fail, states.FAILURE)
|
||||||
|
|
||||||
s2 = storage.Storage(backend=self.backend, flow_detail=s._flowdetail)
|
s2 = storage.SingleThreadedStorage(backend=self.backend,
|
||||||
|
flow_detail=s._flowdetail)
|
||||||
self.assertIs(s2.has_failures(), True)
|
self.assertIs(s2.has_failures(), True)
|
||||||
self.assertEqual(s2.get_failures(), {'my task': fail})
|
self.assertEqual(s2.get_failures(), {'my task': fail})
|
||||||
self.assertEqual(s2.get('my task'), fail)
|
self.assertEqual(s2.get('my task'), fail)
|
||||||
@@ -305,13 +322,71 @@ class StorageTest(test.TestCase):
|
|||||||
})
|
})
|
||||||
# imagine we are resuming, so we need to make new
|
# imagine we are resuming, so we need to make new
|
||||||
# storage from same flow details
|
# storage from same flow details
|
||||||
s2 = storage.Storage(s._flowdetail, backend=self.backend)
|
s2 = storage.SingleThreadedStorage(s._flowdetail, backend=self.backend)
|
||||||
# injected data should still be there:
|
# injected data should still be there:
|
||||||
self.assertEqual(s2.fetch_all(), {
|
self.assertEqual(s2.fetch_all(), {
|
||||||
'foo': 'bar',
|
'foo': 'bar',
|
||||||
'spam': 'eggs',
|
'spam': 'eggs',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_many_thread_ensure_same_task(self):
|
||||||
|
s = self._get_storage(threaded=True)
|
||||||
|
|
||||||
|
def ensure_my_task():
|
||||||
|
s.ensure_task('my_task', result_mapping={})
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(0, self.thread_count):
|
||||||
|
threads.append(threading.Thread(target=ensure_my_task))
|
||||||
|
self._run_many_threads(threads)
|
||||||
|
|
||||||
|
# Only one task should have been made, no more.
|
||||||
|
self.assertEqual(1, len(s._flowdetail))
|
||||||
|
|
||||||
|
def test_many_thread_one_reset(self):
|
||||||
|
s = self._get_storage(threaded=True)
|
||||||
|
s.ensure_task('a')
|
||||||
|
s.set_task_state('a', states.SUSPENDED)
|
||||||
|
s.ensure_task('b')
|
||||||
|
s.set_task_state('b', states.SUSPENDED)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
result_lock = threading.Lock()
|
||||||
|
|
||||||
|
def reset_all():
|
||||||
|
r = s.reset_tasks()
|
||||||
|
with result_lock:
|
||||||
|
results.append(r)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(0, self.thread_count):
|
||||||
|
threads.append(threading.Thread(target=reset_all))
|
||||||
|
|
||||||
|
self._run_many_threads(threads)
|
||||||
|
|
||||||
|
# Only one thread should have actually reset (not anymore)
|
||||||
|
results = [r for r in results if len(r)]
|
||||||
|
self.assertEqual(1, len(results))
|
||||||
|
self.assertEqual(['a', 'b'], sorted([a[0] for a in results[0]]))
|
||||||
|
|
||||||
|
def test_many_thread_inject(self):
|
||||||
|
s = self._get_storage(threaded=True)
|
||||||
|
|
||||||
|
def inject_values(values):
|
||||||
|
s.inject(values)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(0, self.thread_count):
|
||||||
|
values = {
|
||||||
|
str(i): str(i),
|
||||||
|
}
|
||||||
|
threads.append(threading.Thread(target=inject_values,
|
||||||
|
args=[values]))
|
||||||
|
|
||||||
|
self._run_many_threads(threads)
|
||||||
|
self.assertEqual(self.thread_count, len(s.fetch_all()))
|
||||||
|
self.assertEqual(1, len(s._flowdetail))
|
||||||
|
|
||||||
def test_fetch_meapped_args(self):
|
def test_fetch_meapped_args(self):
|
||||||
s = self._get_storage()
|
s = self._get_storage()
|
||||||
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
s.inject({'foo': 'bar', 'spam': 'eggs'})
|
||||||
@@ -357,7 +432,7 @@ class StorageTest(test.TestCase):
|
|||||||
fd.state = states.FAILURE
|
fd.state = states.FAILURE
|
||||||
with contextlib.closing(self.backend.get_connection()) as conn:
|
with contextlib.closing(self.backend.get_connection()) as conn:
|
||||||
fd.update(conn.update_flow_details(fd))
|
fd.update(conn.update_flow_details(fd))
|
||||||
s = storage.Storage(flow_detail=fd, backend=self.backend)
|
s = storage.SingleThreadedStorage(flow_detail=fd, backend=self.backend)
|
||||||
self.assertEqual(s.get_flow_state(), states.FAILURE)
|
self.assertEqual(s.get_flow_state(), states.FAILURE)
|
||||||
|
|
||||||
def test_set_and_get_flow_state(self):
|
def test_set_and_get_flow_state(self):
|
||||||
|
|||||||
235
taskflow/tests/unit/test_utils_lock_utils.py
Normal file
235
taskflow/tests/unit/test_utils_lock_utils.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||||
|
|
||||||
|
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import time
|
||||||
|
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
from taskflow import test
|
||||||
|
from taskflow.utils import lock_utils
|
||||||
|
|
||||||
|
|
||||||
|
def _find_overlaps(times, start, end):
|
||||||
|
overlaps = 0
|
||||||
|
for (s, e) in times:
|
||||||
|
if s >= start and e <= end:
|
||||||
|
overlaps += 1
|
||||||
|
return overlaps
|
||||||
|
|
||||||
|
|
||||||
|
def _spawn_variation(readers, writers, max_workers=None):
|
||||||
|
start_stops = collections.deque()
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
|
||||||
|
def read_func():
|
||||||
|
with lock.read_lock():
|
||||||
|
start_stops.append(('r', time.time(), time.time()))
|
||||||
|
|
||||||
|
def write_func():
|
||||||
|
with lock.write_lock():
|
||||||
|
start_stops.append(('w', time.time(), time.time()))
|
||||||
|
|
||||||
|
if max_workers is None:
|
||||||
|
max_workers = max(0, readers) + max(0, writers)
|
||||||
|
if max_workers > 0:
|
||||||
|
with futures.ThreadPoolExecutor(max_workers=max_workers) as e:
|
||||||
|
for i in range(0, readers):
|
||||||
|
e.submit(read_func)
|
||||||
|
for i in range(0, writers):
|
||||||
|
e.submit(write_func)
|
||||||
|
|
||||||
|
writer_times = []
|
||||||
|
reader_times = []
|
||||||
|
for (t, start, stop) in list(start_stops):
|
||||||
|
if t == 'w':
|
||||||
|
writer_times.append((start, stop))
|
||||||
|
else:
|
||||||
|
reader_times.append((start, stop))
|
||||||
|
return (writer_times, reader_times)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWriteLockTest(test.TestCase):
|
||||||
|
def test_writer_abort(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
self.assertFalse(lock.owner)
|
||||||
|
|
||||||
|
def blow_up():
|
||||||
|
with lock.write_lock():
|
||||||
|
self.assertEqual(lock.WRITER, lock.owner)
|
||||||
|
raise RuntimeError("Broken")
|
||||||
|
|
||||||
|
self.assertRaises(RuntimeError, blow_up)
|
||||||
|
self.assertFalse(lock.owner)
|
||||||
|
|
||||||
|
def test_reader_abort(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
self.assertFalse(lock.owner)
|
||||||
|
|
||||||
|
def blow_up():
|
||||||
|
with lock.read_lock():
|
||||||
|
self.assertEqual(lock.READER, lock.owner)
|
||||||
|
raise RuntimeError("Broken")
|
||||||
|
|
||||||
|
self.assertRaises(RuntimeError, blow_up)
|
||||||
|
self.assertFalse(lock.owner)
|
||||||
|
|
||||||
|
def test_reader_chaotic(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
activated = collections.deque()
|
||||||
|
|
||||||
|
def chaotic_reader(blow_up):
|
||||||
|
with lock.read_lock():
|
||||||
|
if blow_up:
|
||||||
|
raise RuntimeError("Broken")
|
||||||
|
else:
|
||||||
|
activated.append(lock.owner)
|
||||||
|
|
||||||
|
def happy_writer():
|
||||||
|
with lock.write_lock():
|
||||||
|
activated.append(lock.owner)
|
||||||
|
|
||||||
|
with futures.ThreadPoolExecutor(max_workers=20) as e:
|
||||||
|
for i in range(0, 20):
|
||||||
|
if i % 2 == 0:
|
||||||
|
e.submit(chaotic_reader, blow_up=bool(i % 4 == 0))
|
||||||
|
else:
|
||||||
|
e.submit(happy_writer)
|
||||||
|
|
||||||
|
writers = [a for a in activated if a == 'w']
|
||||||
|
readers = [a for a in activated if a == 'r']
|
||||||
|
self.assertEqual(10, len(writers))
|
||||||
|
self.assertEqual(5, len(readers))
|
||||||
|
|
||||||
|
def test_writer_chaotic(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
activated = collections.deque()
|
||||||
|
|
||||||
|
def chaotic_writer(blow_up):
|
||||||
|
with lock.write_lock():
|
||||||
|
if blow_up:
|
||||||
|
raise RuntimeError("Broken")
|
||||||
|
else:
|
||||||
|
activated.append(lock.owner)
|
||||||
|
|
||||||
|
def happy_reader():
|
||||||
|
with lock.read_lock():
|
||||||
|
activated.append(lock.owner)
|
||||||
|
|
||||||
|
with futures.ThreadPoolExecutor(max_workers=20) as e:
|
||||||
|
for i in range(0, 20):
|
||||||
|
if i % 2 == 0:
|
||||||
|
e.submit(chaotic_writer, blow_up=bool(i % 4 == 0))
|
||||||
|
else:
|
||||||
|
e.submit(happy_reader)
|
||||||
|
|
||||||
|
writers = [a for a in activated if a == 'w']
|
||||||
|
readers = [a for a in activated if a == 'r']
|
||||||
|
self.assertEqual(5, len(writers))
|
||||||
|
self.assertEqual(10, len(readers))
|
||||||
|
|
||||||
|
def test_single_reader_writer(self):
|
||||||
|
results = []
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
with lock.read_lock():
|
||||||
|
self.assertTrue(lock.is_reader())
|
||||||
|
self.assertEqual(0, len(results))
|
||||||
|
with lock.write_lock():
|
||||||
|
results.append(1)
|
||||||
|
self.assertTrue(lock.is_writer())
|
||||||
|
with lock.read_lock():
|
||||||
|
self.assertTrue(lock.is_reader())
|
||||||
|
self.assertEqual(1, len(results))
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
|
||||||
|
def test_reader_to_writer(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
|
||||||
|
def writer_func():
|
||||||
|
with lock.write_lock():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with lock.read_lock():
|
||||||
|
self.assertRaises(RuntimeError, writer_func)
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
|
||||||
|
def test_writer_to_reader(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
|
||||||
|
def reader_func():
|
||||||
|
with lock.read_lock():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with lock.write_lock():
|
||||||
|
self.assertRaises(RuntimeError, reader_func)
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
|
||||||
|
def test_double_writer(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
with lock.write_lock():
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
self.assertTrue(lock.is_writer())
|
||||||
|
with lock.write_lock():
|
||||||
|
self.assertTrue(lock.is_writer())
|
||||||
|
self.assertTrue(lock.is_writer())
|
||||||
|
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
|
||||||
|
def test_double_reader(self):
|
||||||
|
lock = lock_utils.ReaderWriterLock()
|
||||||
|
with lock.read_lock():
|
||||||
|
self.assertTrue(lock.is_reader())
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
with lock.read_lock():
|
||||||
|
self.assertTrue(lock.is_reader())
|
||||||
|
self.assertTrue(lock.is_reader())
|
||||||
|
|
||||||
|
self.assertFalse(lock.is_reader())
|
||||||
|
self.assertFalse(lock.is_writer())
|
||||||
|
|
||||||
|
def test_multi_reader_multi_writer(self):
|
||||||
|
writer_times, reader_times = _spawn_variation(10, 10)
|
||||||
|
self.assertEqual(10, len(writer_times))
|
||||||
|
self.assertEqual(10, len(reader_times))
|
||||||
|
for (start, stop) in writer_times:
|
||||||
|
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
|
||||||
|
self.assertEqual(1, _find_overlaps(writer_times, start, stop))
|
||||||
|
for (start, stop) in reader_times:
|
||||||
|
self.assertEqual(0, _find_overlaps(writer_times, start, stop))
|
||||||
|
|
||||||
|
def test_multi_reader_single_writer(self):
|
||||||
|
writer_times, reader_times = _spawn_variation(9, 1)
|
||||||
|
self.assertEqual(1, len(writer_times))
|
||||||
|
self.assertEqual(9, len(reader_times))
|
||||||
|
start, stop = writer_times[0]
|
||||||
|
self.assertEqual(0, _find_overlaps(reader_times, start, stop))
|
||||||
|
|
||||||
|
def test_multi_writer(self):
|
||||||
|
writer_times, reader_times = _spawn_variation(0, 10)
|
||||||
|
self.assertEqual(10, len(writer_times))
|
||||||
|
self.assertEqual(0, len(reader_times))
|
||||||
|
for (start, stop) in writer_times:
|
||||||
|
self.assertEqual(1, _find_overlaps(writer_times, start, stop))
|
||||||
@@ -21,6 +21,8 @@
|
|||||||
# pulls in oslo.cfg) and is reduced to only what taskflow currently wants to
|
# pulls in oslo.cfg) and is reduced to only what taskflow currently wants to
|
||||||
# use from that code.
|
# use from that code.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import contextlib
|
||||||
import errno
|
import errno
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -28,6 +30,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from taskflow.utils import misc
|
from taskflow.utils import misc
|
||||||
|
from taskflow.utils import threading_utils as tu
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,6 +65,163 @@ def locked(*args, **kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class ReaderWriterLock(object):
|
||||||
|
"""A reader/writer lock.
|
||||||
|
|
||||||
|
This lock allows for simultaneous readers to exist but only one writer
|
||||||
|
to exist for use-cases where it is useful to have such types of locks.
|
||||||
|
|
||||||
|
Currently a reader can not escalate its read lock to a write lock and
|
||||||
|
a writer can not acquire a read lock while it owns or is waiting on
|
||||||
|
the write lock.
|
||||||
|
|
||||||
|
In the future these restrictions may be relaxed.
|
||||||
|
"""
|
||||||
|
WRITER = 'w'
|
||||||
|
READER = 'r'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._writer = None
|
||||||
|
self._pending_writers = collections.deque()
|
||||||
|
self._readers = collections.deque()
|
||||||
|
self._cond = threading.Condition()
|
||||||
|
|
||||||
|
def is_writer(self, check_pending=True):
|
||||||
|
"""Returns if the caller is the active writer or a pending writer."""
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
me = tu.get_ident()
|
||||||
|
if self._writer is not None and self._writer == me:
|
||||||
|
return True
|
||||||
|
if check_pending:
|
||||||
|
return me in self._pending_writers
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def owner(self):
|
||||||
|
"""Returns whether the lock is locked by a writer or reader."""
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
if self._writer is not None:
|
||||||
|
return self.WRITER
|
||||||
|
if self._readers:
|
||||||
|
return self.READER
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
|
||||||
|
def is_reader(self):
|
||||||
|
"""Returns if the caller is one of the readers."""
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
return tu.get_ident() in self._readers
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def read_lock(self):
|
||||||
|
"""Grants a read lock.
|
||||||
|
|
||||||
|
Will wait until no active or pending writers.
|
||||||
|
|
||||||
|
Raises a RuntimeError if an active or pending writer tries to acquire
|
||||||
|
a read lock.
|
||||||
|
"""
|
||||||
|
me = tu.get_ident()
|
||||||
|
if self.is_writer():
|
||||||
|
raise RuntimeError("Writer %s can not acquire a read lock"
|
||||||
|
" while holding/waiting for the write lock"
|
||||||
|
% me)
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# No active or pending writers; we are good to become a reader.
|
||||||
|
if self._writer is None and len(self._pending_writers) == 0:
|
||||||
|
self._readers.append(me)
|
||||||
|
break
|
||||||
|
# Some writers; guess we have to wait.
|
||||||
|
self._cond.wait()
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
# I am no longer a reader, remove *one* occurrence of myself.
|
||||||
|
# If the current thread acquired two read locks, then it will
|
||||||
|
# still have to remove that other read lock; this allows for
|
||||||
|
# basic reentrancy to be possible.
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
self._readers.remove(me)
|
||||||
|
self._cond.notify_all()
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def write_lock(self):
|
||||||
|
"""Grants a write lock.
|
||||||
|
|
||||||
|
Will wait until no active readers. Blocks readers after acquiring.
|
||||||
|
|
||||||
|
Raises a RuntimeError if an active reader attempts to acquire a lock.
|
||||||
|
"""
|
||||||
|
me = tu.get_ident()
|
||||||
|
if self.is_reader():
|
||||||
|
raise RuntimeError("Reader %s to writer privilege"
|
||||||
|
" escalation not allowed" % me)
|
||||||
|
if self.is_writer(check_pending=False):
|
||||||
|
# Already the writer; this allows for basic reentrancy.
|
||||||
|
yield self
|
||||||
|
else:
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
self._pending_writers.append(me)
|
||||||
|
while True:
|
||||||
|
# No readers, and no active writer, am I next??
|
||||||
|
if len(self._readers) == 0 and self._writer is None:
|
||||||
|
if self._pending_writers[0] == me:
|
||||||
|
self._writer = self._pending_writers.popleft()
|
||||||
|
break
|
||||||
|
self._cond.wait()
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
self._cond.acquire()
|
||||||
|
try:
|
||||||
|
self._writer = None
|
||||||
|
self._cond.notify_all()
|
||||||
|
finally:
|
||||||
|
self._cond.release()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyReaderWriterLock(object):
|
||||||
|
"""A dummy reader/writer lock that doesn't lock anything but provides same
|
||||||
|
functions as a normal reader/writer lock class.
|
||||||
|
"""
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def write_lock(self):
|
||||||
|
yield self
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def read_lock(self):
|
||||||
|
yield self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def owner(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_reader(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_writer(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MultiLock(object):
|
class MultiLock(object):
|
||||||
"""A class which can attempt to obtain many locks at once and release
|
"""A class which can attempt to obtain many locks at once and release
|
||||||
said locks when exiting.
|
said locks when exiting.
|
||||||
|
|||||||
@@ -16,16 +16,15 @@
|
|||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import threading
|
|
||||||
import types
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from taskflow.utils import lock_utils
|
if six.PY2:
|
||||||
|
from thread import get_ident # noqa
|
||||||
LOG = logging.getLogger(__name__)
|
else:
|
||||||
|
# In python3+ the get_ident call moved (whhhy??)
|
||||||
|
from threading import get_ident # noqa
|
||||||
|
|
||||||
|
|
||||||
def get_optimal_thread_count():
|
def get_optimal_thread_count():
|
||||||
@@ -37,20 +36,3 @@ def get_optimal_thread_count():
|
|||||||
# just setup two threads since its hard to know what else we
|
# just setup two threads since its hard to know what else we
|
||||||
# should do in this situation.
|
# should do in this situation.
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
|
|
||||||
class ThreadSafeMeta(type):
|
|
||||||
"""Metaclass that adds locking to all pubic methods of a class."""
|
|
||||||
|
|
||||||
def __new__(cls, name, bases, attrs):
|
|
||||||
for attr_name, attr_value in six.iteritems(attrs):
|
|
||||||
if isinstance(attr_value, types.FunctionType):
|
|
||||||
if attr_name[0] != '_':
|
|
||||||
attrs[attr_name] = lock_utils.locked(attr_value)
|
|
||||||
return super(ThreadSafeMeta, cls).__new__(cls, name, bases, attrs)
|
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
|
||||||
instance = super(ThreadSafeMeta, cls).__call__(*args, **kwargs)
|
|
||||||
if not hasattr(instance, '_lock'):
|
|
||||||
instance._lock = threading.RLock()
|
|
||||||
return instance
|
|
||||||
|
|||||||
Reference in New Issue
Block a user