Merge "Robustify locking in MessageHandlingServer"

This commit is contained in:
Jenkins 2015-10-28 07:00:59 +00:00 committed by Gerrit Code Review
commit d100988993
3 changed files with 402 additions and 91 deletions

View File

@ -116,29 +116,6 @@ def fetch_current_thread_functor():
return lambda: threading.current_thread()
class DummyCondition(object):
def acquire(self):
pass
def notify(self):
pass
def notify_all(self):
pass
def wait(self, timeout=None):
pass
def release(self):
pass
def __enter__(self):
self.acquire()
def __exit__(self, type, value, traceback):
self.release()
class DummyLock(object):
def acquire(self):
pass

View File

@ -23,16 +23,17 @@ __all__ = [
'ServerListenError',
]
import functools
import inspect
import logging
import threading
import traceback
from oslo_service import service
from oslo_utils import timeutils
from stevedore import driver
from oslo_messaging._drivers import base as driver_base
from oslo_messaging._i18n import _LW
from oslo_messaging import _utils
from oslo_messaging import exceptions
LOG = logging.getLogger(__name__)
@ -62,7 +63,170 @@ class ServerListenError(MessagingServerError):
self.ex = ex
class MessageHandlingServer(service.ServiceBase):
class _OrderedTask(object):
"""A task which must be executed in a particular order.
A caller may wait for this task to complete by calling
`wait_for_completion`.
A caller may run this task with `run_once`, which will ensure that however
many times the task is called it only runs once. Simultaneous callers will
block until the running task completes, which means that any caller can be
sure that the task has completed after run_once returns.
"""
INIT = 0 # The task has not yet started
RUNNING = 1 # The task is running somewhere
COMPLETE = 2 # The task has run somewhere
# We generate a log message if we wait for a lock longer than
# LOG_AFTER_WAIT_SECS seconds
LOG_AFTER_WAIT_SECS = 30
def __init__(self, name):
"""Create a new _OrderedTask.
:param name: The name of this task. Used in log messages.
"""
super(_OrderedTask, self).__init__()
self._name = name
self._cond = threading.Condition()
self._state = self.INIT
def _wait(self, condition, warn_msg):
"""Wait while condition() is true. Write a log message if condition()
has not become false within LOG_AFTER_WAIT_SECS.
"""
with timeutils.StopWatch(duration=self.LOG_AFTER_WAIT_SECS) as w:
logged = False
while condition():
wait = None if logged else w.leftover()
self._cond.wait(wait)
if not logged and w.expired():
LOG.warn(warn_msg)
LOG.debug(''.join(traceback.format_stack()))
# Only log once. After than we wait indefinitely without
# logging.
logged = True
def wait_for_completion(self, caller):
"""Wait until this task has completed.
:param caller: The name of the task which is waiting.
"""
with self._cond:
self._wait(lambda: self._state != self.COMPLETE,
'%s has been waiting for %s to complete for longer '
'than %i seconds'
% (caller, self._name, self.LOG_AFTER_WAIT_SECS))
def run_once(self, fn):
"""Run a task exactly once. If it is currently running in another
thread, wait for it to complete. If it has already run, return
immediately without running it again.
:param fn: The task to run. It must be a callable taking no arguments.
It may optionally return another callable, which also takes
no arguments, which will be executed after completion has
been signaled to other threads.
"""
with self._cond:
if self._state == self.INIT:
self._state = self.RUNNING
# Note that nothing waits on RUNNING, so no need to notify
# We need to release the condition lock before calling out to
# prevent deadlocks. Reacquire it immediately afterwards.
self._cond.release()
try:
post_fn = fn()
finally:
self._cond.acquire()
self._state = self.COMPLETE
self._cond.notify_all()
if post_fn is not None:
# Release the condition lock before calling out to prevent
# deadlocks. Reacquire it immediately afterwards.
self._cond.release()
try:
post_fn()
finally:
self._cond.acquire()
elif self._state == self.RUNNING:
self._wait(lambda: self._state == self.RUNNING,
'%s has been waiting on another thread to complete '
'for longer than %i seconds'
% (self._name, self.LOG_AFTER_WAIT_SECS))
class _OrderedTaskRunner(object):
"""Mixin for a class which executes ordered tasks."""
def __init__(self, *args, **kwargs):
super(_OrderedTaskRunner, self).__init__(*args, **kwargs)
# Get a list of methods on this object which have the _ordered
# attribute
self._tasks = [name
for (name, member) in inspect.getmembers(self)
if inspect.ismethod(member) and
getattr(member, '_ordered', False)]
self.init_task_states()
def init_task_states(self):
# Note that we don't need to lock this. Once created, the _states dict
# is immutable. Get and set are (individually) atomic operations in
# Python, and we only set after the dict is fully created.
self._states = {task: _OrderedTask(task) for task in self._tasks}
@staticmethod
def decorate_ordered(fn, state, after):
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
# Store the states we started with in case the state wraps on us
# while we're sleeping. We must wait and run_once in the same
# epoch. If the epoch ended while we were sleeping, run_once will
# safely do nothing.
states = self._states
# Wait for the given preceding state to complete
if after is not None:
states[after].wait_for_completion(state)
# Run this state
states[state].run_once(lambda: fn(self, *args, **kwargs))
return wrapper
def ordered(after=None):
"""A method which will be executed as an ordered task. The method will be
called exactly once, however many times it is called. If it is called
multiple times simultaneously it will only be called once, but all callers
will wait until execution is complete.
If `after` is given, this method will not run until `after` has completed.
:param after: Optionally, another method decorated with `ordered`. Wait for
the completion of `after` before executing this method.
"""
if after is not None:
after = after.__name__
def _ordered(fn):
# Set an attribute on the method so we can find it later
setattr(fn, '_ordered', True)
state = fn.__name__
return _OrderedTaskRunner.decorate_ordered(fn, state, after)
return _ordered
class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
"""Server for handling messages.
Connect a transport to a dispatcher that knows how to process the
@ -94,29 +258,18 @@ class MessageHandlingServer(service.ServiceBase):
self.dispatcher = dispatcher
self.executor = executor
# NOTE(sileht): we use a lock to protect the state change of the
# server, we don't want to call stop until the transport driver
# is fully started. Except for the blocking executor that have
# start() that doesn't return
if self.executor != "blocking":
self._state_cond = threading.Condition()
self._dummy_cond = False
else:
self._state_cond = _utils.DummyCondition()
self._dummy_cond = True
try:
mgr = driver.DriverManager('oslo.messaging.executors',
self.executor)
except RuntimeError as ex:
raise ExecutorLoadFailure(self.executor, ex)
else:
self._executor_cls = mgr.driver
self._executor_obj = None
self._running = False
super(MessageHandlingServer, self).__init__()
@ordered()
def start(self):
"""Start handling incoming messages.
@ -131,24 +284,21 @@ class MessageHandlingServer(service.ServiceBase):
choose to dispatch messages in a new thread, coroutine or simply the
current thread.
"""
if self._executor_obj is not None:
return
with self._state_cond:
if self._executor_obj is not None:
return
try:
listener = self.dispatcher._listen(self.transport)
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
self._executor_obj = self._executor_cls(self.conf, listener,
self.dispatcher)
self._executor_obj.start()
self._running = True
self._state_cond.notify_all()
executor = self._executor_cls(self.conf, listener, self.dispatcher)
executor.start()
self._executor_obj = executor
if self.executor == 'blocking':
self._executor_obj.execute()
# N.B. This will be executed unlocked and unordered, so
# we can't rely on the value of self._executor_obj when this runs.
# We explicitly pass the local variable.
return lambda: executor.execute()
@ordered(after=start)
def stop(self):
"""Stop handling incoming messages.
@ -157,12 +307,9 @@ class MessageHandlingServer(service.ServiceBase):
some messages, and underlying driver resources associated to this
server are still in use. See 'wait' for more details.
"""
with self._state_cond:
if self._executor_obj is not None:
self._running = False
self._executor_obj.stop()
self._state_cond.notify_all()
@ordered(after=stop)
def wait(self):
"""Wait for message processing to complete.
@ -173,37 +320,14 @@ class MessageHandlingServer(service.ServiceBase):
Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections).
"""
with self._state_cond:
if self._running:
LOG.warn(_LW("wait() should be called after stop() as it "
"waits for existing messages to finish "
"processing"))
w = timeutils.StopWatch()
w.start()
while self._running:
# NOTE(harlowja): 1.0 seconds was mostly chosen at
# random, but it seems like a reasonable value to
# use to avoid spamming the logs with to much
# information.
self._state_cond.wait(1.0)
if self._running and not self._dummy_cond:
LOG.warn(
_LW("wait() should have been called"
" after stop() as wait() waits for existing"
" messages to finish processing, it has"
" been %0.2f seconds and stop() still has"
" not been called"), w.elapsed())
executor = self._executor_obj
self._executor_obj = None
if executor is not None:
# We are the lucky calling thread to wait on the executor to
# actually finish.
try:
executor.wait()
self._executor_obj.wait()
finally:
# Close listener connection after processing all messages
executor.listener.cleanup()
executor = None
self._executor_obj.listener.cleanup()
self._executor_obj = None
self.init_task_states()
def reset(self):
"""Reset service.

View File

@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import eventlet
import time
import threading
from oslo_config import cfg
@ -20,6 +22,7 @@ import testscenarios
import mock
import oslo_messaging
from oslo_messaging import server as server_module
from oslo_messaging.tests import utils as test_utils
load_tests = testscenarios.load_tests_apply_scenarios
@ -528,3 +531,210 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
TestMultipleServers.generate_scenarios()
class TestServerLocking(test_utils.BaseTestCase):
def setUp(self):
super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts())
def _logmethod(name):
def method(self):
with self._lock:
self._calls.append(name)
return method
executors = []
class FakeExecutor(object):
def __init__(self, *args, **kwargs):
self._lock = threading.Lock()
self._calls = []
self.listener = mock.MagicMock()
executors.append(self)
start = _logmethod('start')
stop = _logmethod('stop')
wait = _logmethod('wait')
execute = _logmethod('execute')
self.executors = executors
self.server = oslo_messaging.MessageHandlingServer(mock.Mock(),
mock.Mock())
self.server._executor_cls = FakeExecutor
def test_start_stop_wait(self):
# Test a simple execution of start, stop, wait in order
thread = eventlet.spawn(self.server.start)
self.server.stop()
self.server.wait()
self.assertEqual(len(self.executors), 1)
executor = self.executors[0]
self.assertEqual(executor._calls,
['start', 'execute', 'stop', 'wait'])
self.assertTrue(executor.listener.cleanup.called)
def test_reversed_order(self):
# Test that if we call wait, stop, start, these will be correctly
# reordered
wait = eventlet.spawn(self.server.wait)
# This is non-deterministic, but there's not a great deal we can do
# about that
eventlet.sleep(0)
stop = eventlet.spawn(self.server.stop)
eventlet.sleep(0)
start = eventlet.spawn(self.server.start)
self.server.wait()
self.assertEqual(len(self.executors), 1)
executor = self.executors[0]
self.assertEqual(executor._calls,
['start', 'execute', 'stop', 'wait'])
def test_wait_for_running_task(self):
# Test that if 2 threads call a method simultaneously, both will wait,
# but only 1 will call the underlying executor method.
start_event = threading.Event()
finish_event = threading.Event()
running_event = threading.Event()
done_event = threading.Event()
runner = [None]
class SteppingFakeExecutor(self.server._executor_cls):
def start(self):
# Tell the test which thread won the race
runner[0] = eventlet.getcurrent()
running_event.set()
start_event.wait()
super(SteppingFakeExecutor, self).start()
done_event.set()
finish_event.wait()
self.server._executor_cls = SteppingFakeExecutor
start1 = eventlet.spawn(self.server.start)
start2 = eventlet.spawn(self.server.start)
# Wait until one of the threads starts running
running_event.wait()
runner = runner[0]
waiter = start2 if runner == start1 else start2
waiter_finished = threading.Event()
waiter.link(lambda _: waiter_finished.set())
# At this point, runner is running start(), and waiter() is waiting for
# it to complete. runner has not yet logged anything.
self.assertEqual(1, len(self.executors))
executor = self.executors[0]
self.assertEqual(executor._calls, [])
self.assertFalse(waiter_finished.is_set())
# Let the runner log the call
start_event.set()
done_event.wait()
# We haven't signalled completion yet, so execute shouldn't have run
self.assertEqual(executor._calls, ['start'])
self.assertFalse(waiter_finished.is_set())
# Let the runner complete
finish_event.set()
waiter.wait()
runner.wait()
# Check that both threads have finished, start was only called once,
# and execute ran
self.assertTrue(waiter_finished.is_set())
self.assertEqual(executor._calls, ['start', 'execute'])
def test_state_wrapping(self):
# Test that we behave correctly if a thread waits, and the server state
# has wrapped when it it next scheduled
# Ensure that if 2 threads wait for the completion of 'start', the
# first will wait until complete_event is signalled, but the second
# will continue
complete_event = threading.Event()
complete_waiting_callback = threading.Event()
start_state = self.server._states['start']
old_wait_for_completion = start_state.wait_for_completion
waited = [False]
def new_wait_for_completion(*args, **kwargs):
if not waited[0]:
waited[0] = True
complete_waiting_callback.set()
complete_event.wait()
old_wait_for_completion(*args, **kwargs)
start_state.wait_for_completion = new_wait_for_completion
# thread1 will wait for start to complete until we signal it
thread1 = eventlet.spawn(self.server.stop)
thread1_finished = threading.Event()
thread1.link(lambda _: thread1_finished.set())
self.server.start()
complete_waiting_callback.wait()
# The server should have started, but stop should not have been called
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute'])
self.assertFalse(thread1_finished.is_set())
self.server.stop()
self.server.wait()
# We should have gone through all the states, and thread1 should still
# be waiting
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
'stop', 'wait'])
self.assertFalse(thread1_finished.is_set())
# Start again
self.server.start()
# We should now record 2 executors
self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
'stop', 'wait'])
self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
self.assertFalse(thread1_finished.is_set())
# Allow thread1 to complete
complete_event.set()
thread1_finished.wait()
# thread1 should now have finished, and stop should not have been
# called again on either the first or second executor
self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
'stop', 'wait'])
self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
self.assertTrue(thread1_finished.is_set())
@mock.patch.object(server_module._OrderedTask,
'LOG_AFTER_WAIT_SECS', 1)
@mock.patch.object(server_module, 'LOG')
def test_timeout_logging(self, mock_log):
# Test that we generate a log message if we wait longer than
# LOG_AFTER_WAIT_SECS
log_event = threading.Event()
mock_log.warn.side_effect = lambda _: log_event.set()
# Call stop without calling start. We should log a wait after 1 second
thread = eventlet.spawn(self.server.stop)
log_event.wait()
# Redundant given that we already waited, but it's nice to assert
self.assertTrue(mock_log.warn.called)
thread.kill()