Merge "Robustify locking in MessageHandlingServer"

This commit is contained in:
Jenkins 2015-11-17 23:37:39 +00:00 committed by Gerrit Code Review
commit 9dbbc1f243
3 changed files with 585 additions and 92 deletions

View File

@ -116,29 +116,6 @@ def fetch_current_thread_functor():
return lambda: threading.current_thread()
class DummyCondition(object):
def acquire(self):
pass
def notify(self):
pass
def notify_all(self):
pass
def wait(self, timeout=None):
pass
def release(self):
pass
def __enter__(self):
self.acquire()
def __exit__(self, type, value, traceback):
self.release()
class DummyLock(object):
def acquire(self):
pass

View File

@ -23,20 +23,25 @@ __all__ = [
'ServerListenError',
]
import functools
import inspect
import logging
import threading
import traceback
from oslo_service import service
from oslo_utils import timeutils
from stevedore import driver
from oslo_messaging._drivers import base as driver_base
from oslo_messaging._i18n import _LW
from oslo_messaging import _utils
from oslo_messaging import exceptions
LOG = logging.getLogger(__name__)
# The default number of seconds of waiting after which we will emit a log
# message
DEFAULT_LOG_AFTER = 30
class MessagingServerError(exceptions.MessagingException):
"""Base class for all MessageHandlingServer exceptions."""
@ -62,7 +67,223 @@ class ServerListenError(MessagingServerError):
self.ex = ex
class MessageHandlingServer(service.ServiceBase):
class TaskTimeout(MessagingServerError):
"""Raised if we timed out waiting for a task to complete."""
class _OrderedTask(object):
"""A task which must be executed in a particular order.
A caller may wait for this task to complete by calling
`wait_for_completion`.
A caller may run this task with `run_once`, which will ensure that however
many times the task is called it only runs once. Simultaneous callers will
block until the running task completes, which means that any caller can be
sure that the task has completed after run_once returns.
"""
INIT = 0 # The task has not yet started
RUNNING = 1 # The task is running somewhere
COMPLETE = 2 # The task has run somewhere
def __init__(self, name):
"""Create a new _OrderedTask.
:param name: The name of this task. Used in log messages.
"""
super(_OrderedTask, self).__init__()
self._name = name
self._cond = threading.Condition()
self._state = self.INIT
def _wait(self, condition, msg, log_after, timeout_timer):
"""Wait while condition() is true. Write a log message if condition()
has not become false within `log_after` seconds. Raise TaskTimeout if
timeout_timer expires while waiting.
"""
log_timer = None
if log_after != 0:
log_timer = timeutils.StopWatch(duration=log_after)
log_timer.start()
while condition():
if log_timer is not None and log_timer.expired():
LOG.warn('Possible hang: %s' % msg)
LOG.debug(''.join(traceback.format_stack()))
# Only log once. After than we wait indefinitely without
# logging.
log_timer = None
if timeout_timer is not None and timeout_timer.expired():
raise TaskTimeout(msg)
timeouts = []
if log_timer is not None:
timeouts.append(log_timer.leftover())
if timeout_timer is not None:
timeouts.append(timeout_timer.leftover())
wait = None
if timeouts:
wait = min(timeouts)
self._cond.wait(wait)
@property
def complete(self):
return self._state == self.COMPLETE
def wait_for_completion(self, caller, log_after, timeout_timer):
"""Wait until this task has completed.
:param caller: The name of the task which is waiting.
:param log_after: Emit a log message if waiting longer than `log_after`
seconds.
:param timeout_timer: Raise TaskTimeout if StopWatch object
`timeout_timer` expires while waiting.
"""
with self._cond:
msg = '%s is waiting for %s to complete' % (caller, self._name)
self._wait(lambda: not self.complete,
msg, log_after, timeout_timer)
def run_once(self, fn, log_after, timeout_timer):
"""Run a task exactly once. If it is currently running in another
thread, wait for it to complete. If it has already run, return
immediately without running it again.
:param fn: The task to run. It must be a callable taking no arguments.
It may optionally return another callable, which also takes
no arguments, which will be executed after completion has
been signaled to other threads.
:param log_after: Emit a log message if waiting longer than `log_after`
seconds.
:param timeout_timer: Raise TaskTimeout if StopWatch object
`timeout_timer` expires while waiting.
"""
with self._cond:
if self._state == self.INIT:
self._state = self.RUNNING
# Note that nothing waits on RUNNING, so no need to notify
# We need to release the condition lock before calling out to
# prevent deadlocks. Reacquire it immediately afterwards.
self._cond.release()
try:
post_fn = fn()
finally:
self._cond.acquire()
self._state = self.COMPLETE
self._cond.notify_all()
if post_fn is not None:
# Release the condition lock before calling out to prevent
# deadlocks. Reacquire it immediately afterwards.
self._cond.release()
try:
post_fn()
finally:
self._cond.acquire()
elif self._state == self.RUNNING:
msg = ('%s is waiting for another thread to complete'
% self._name)
self._wait(lambda: self._state == self.RUNNING,
msg, log_after, timeout_timer)
class _OrderedTaskRunner(object):
"""Mixin for a class which executes ordered tasks."""
def __init__(self, *args, **kwargs):
super(_OrderedTaskRunner, self).__init__(*args, **kwargs)
# Get a list of methods on this object which have the _ordered
# attribute
self._tasks = [name
for (name, member) in inspect.getmembers(self)
if inspect.ismethod(member) and
getattr(member, '_ordered', False)]
self.reset_states()
self._reset_lock = threading.Lock()
def reset_states(self):
# Create new task states for tasks in reset
self._states = {task: _OrderedTask(task) for task in self._tasks}
@staticmethod
def decorate_ordered(fn, state, after, reset_after):
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
# If the reset_after state has already completed, reset state so
# we can run again.
# NOTE(mdbooth): This is ugly and requires external locking to be
# deterministic when using multiple threads. Consider a thread that
# does: server.stop(), server.wait(). If another thread causes a
# reset between stop() and wait(), this will not have the intended
# behaviour. It is safe without external locking, if the caller
# instantiates a new object.
with self._reset_lock:
if (reset_after is not None and
self._states[reset_after].complete):
self.reset_states()
# Store the states we started with in case the state wraps on us
# while we're sleeping. We must wait and run_once in the same
# epoch. If the epoch ended while we were sleeping, run_once will
# safely do nothing.
states = self._states
log_after = kwargs.pop('log_after', DEFAULT_LOG_AFTER)
timeout = kwargs.pop('timeout', None)
timeout_timer = None
if timeout is not None:
timeout_timer = timeutils.StopWatch(duration=timeout)
timeout_timer.start()
# Wait for the given preceding state to complete
if after is not None:
states[after].wait_for_completion(state,
log_after, timeout_timer)
# Run this state
states[state].run_once(lambda: fn(self, *args, **kwargs),
log_after, timeout_timer)
return wrapper
def ordered(after=None, reset_after=None):
"""A method which will be executed as an ordered task. The method will be
called exactly once, however many times it is called. If it is called
multiple times simultaneously it will only be called once, but all callers
will wait until execution is complete.
If `after` is given, this method will not run until `after` has completed.
If `reset_after` is given and the target method has completed, allow this
task to run again by resetting all task states.
:param after: Optionally, the name of another `ordered` method. Wait for
the completion of `after` before executing this method.
:param reset_after: Optionally, the name of another `ordered` method. Reset
all states when calling this method if `reset_after`
has completed.
"""
def _ordered(fn):
# Set an attribute on the method so we can find it later
setattr(fn, '_ordered', True)
state = fn.__name__
return _OrderedTaskRunner.decorate_ordered(fn, state, after,
reset_after)
return _ordered
class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
"""Server for handling messages.
Connect a transport to a dispatcher that knows how to process the
@ -94,29 +315,20 @@ class MessageHandlingServer(service.ServiceBase):
self.dispatcher = dispatcher
self.executor = executor
# NOTE(sileht): we use a lock to protect the state change of the
# server, we don't want to call stop until the transport driver
# is fully started. Except for the blocking executor that have
# start() that doesn't return
if self.executor != "blocking":
self._state_cond = threading.Condition()
self._dummy_cond = False
else:
self._state_cond = _utils.DummyCondition()
self._dummy_cond = True
try:
mgr = driver.DriverManager('oslo.messaging.executors',
self.executor)
except RuntimeError as ex:
raise ExecutorLoadFailure(self.executor, ex)
else:
self._executor_cls = mgr.driver
self._executor_obj = None
self._running = False
self._executor_cls = mgr.driver
self._executor_obj = None
self._started = False
super(MessageHandlingServer, self).__init__()
@ordered(reset_after='stop')
def start(self):
"""Start handling incoming messages.
@ -130,25 +342,39 @@ class MessageHandlingServer(service.ServiceBase):
registering a callback with an event loop. Similarly, the executor may
choose to dispatch messages in a new thread, coroutine or simply the
current thread.
:param log_after: Emit a log message if waiting longer than `log_after`
seconds to run this task. If set to zero, no log
message will be emitted. Defaults to 30 seconds.
:type log_after: int
:param timeout: Raise `TaskTimeout` if the task has to wait longer than
`timeout` seconds before executing.
:type timeout: int
"""
if self._executor_obj is not None:
return
with self._state_cond:
if self._executor_obj is not None:
return
try:
listener = self.dispatcher._listen(self.transport)
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
self._executor_obj = self._executor_cls(self.conf, listener,
self.dispatcher)
self._executor_obj.start()
self._running = True
self._state_cond.notify_all()
# Warn that restarting will be deprecated
if self._started:
LOG.warn('Restarting a MessageHandlingServer is inherently racy. '
'It is deprecated, and will become a noop in a future '
'release of oslo.messaging. If you need to restart '
'MessageHandlingServer you should instantiate a new '
'object.')
self._started = True
try:
listener = self.dispatcher._listen(self.transport)
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
executor = self._executor_cls(self.conf, listener, self.dispatcher)
executor.start()
self._executor_obj = executor
if self.executor == 'blocking':
self._executor_obj.execute()
# N.B. This will be executed unlocked and unordered, so
# we can't rely on the value of self._executor_obj when this runs.
# We explicitly pass the local variable.
return lambda: executor.execute()
@ordered(after='start')
def stop(self):
"""Stop handling incoming messages.
@ -156,13 +382,18 @@ class MessageHandlingServer(service.ServiceBase):
the server. However, the server may still be in the process of handling
some messages, and underlying driver resources associated to this
server are still in use. See 'wait' for more details.
"""
with self._state_cond:
if self._executor_obj is not None:
self._running = False
self._executor_obj.stop()
self._state_cond.notify_all()
:param log_after: Emit a log message if waiting longer than `log_after`
seconds to run this task. If set to zero, no log
message will be emitted. Defaults to 30 seconds.
:type log_after: int
:param timeout: Raise `TaskTimeout` if the task has to wait longer than
`timeout` seconds before executing.
:type timeout: int
"""
self._executor_obj.stop()
@ordered(after='stop')
def wait(self):
"""Wait for message processing to complete.
@ -172,38 +403,21 @@ class MessageHandlingServer(service.ServiceBase):
Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections).
:param log_after: Emit a log message if waiting longer than `log_after`
seconds to run this task. If set to zero, no log
message will be emitted. Defaults to 30 seconds.
:type log_after: int
:param timeout: Raise `TaskTimeout` if the task has to wait longer than
`timeout` seconds before executing.
:type timeout: int
"""
with self._state_cond:
if self._running:
LOG.warn(_LW("wait() should be called after stop() as it "
"waits for existing messages to finish "
"processing"))
w = timeutils.StopWatch()
w.start()
while self._running:
# NOTE(harlowja): 1.0 seconds was mostly chosen at
# random, but it seems like a reasonable value to
# use to avoid spamming the logs with to much
# information.
self._state_cond.wait(1.0)
if self._running and not self._dummy_cond:
LOG.warn(
_LW("wait() should have been called"
" after stop() as wait() waits for existing"
" messages to finish processing, it has"
" been %0.2f seconds and stop() still has"
" not been called"), w.elapsed())
executor = self._executor_obj
try:
self._executor_obj.wait()
finally:
# Close listener connection after processing all messages
self._executor_obj.listener.cleanup()
self._executor_obj = None
if executor is not None:
# We are the lucky calling thread to wait on the executor to
# actually finish.
try:
executor.wait()
finally:
# Close listener connection after processing all messages
executor.listener.cleanup()
executor = None
def reset(self):
"""Reset service.

View File

@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import eventlet
import time
import threading
from oslo_config import cfg
@ -20,6 +22,7 @@ import testscenarios
import mock
import oslo_messaging
from oslo_messaging import server as server_module
from oslo_messaging.tests import utils as test_utils
load_tests = testscenarios.load_tests_apply_scenarios
@ -528,3 +531,302 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
TestMultipleServers.generate_scenarios()
class TestServerLocking(test_utils.BaseTestCase):
def setUp(self):
super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts())
def _logmethod(name):
def method(self):
with self._lock:
self._calls.append(name)
return method
executors = []
class FakeExecutor(object):
def __init__(self, *args, **kwargs):
self._lock = threading.Lock()
self._calls = []
self.listener = mock.MagicMock()
executors.append(self)
start = _logmethod('start')
stop = _logmethod('stop')
wait = _logmethod('wait')
execute = _logmethod('execute')
self.executors = executors
self.server = oslo_messaging.MessageHandlingServer(mock.Mock(),
mock.Mock())
self.server._executor_cls = FakeExecutor
def test_start_stop_wait(self):
# Test a simple execution of start, stop, wait in order
thread = eventlet.spawn(self.server.start)
self.server.stop()
self.server.wait()
self.assertEqual(len(self.executors), 1)
executor = self.executors[0]
self.assertEqual(executor._calls,
['start', 'execute', 'stop', 'wait'])
self.assertTrue(executor.listener.cleanup.called)
def test_reversed_order(self):
# Test that if we call wait, stop, start, these will be correctly
# reordered
wait = eventlet.spawn(self.server.wait)
# This is non-deterministic, but there's not a great deal we can do
# about that
eventlet.sleep(0)
stop = eventlet.spawn(self.server.stop)
eventlet.sleep(0)
start = eventlet.spawn(self.server.start)
self.server.wait()
self.assertEqual(len(self.executors), 1)
executor = self.executors[0]
self.assertEqual(executor._calls,
['start', 'execute', 'stop', 'wait'])
def test_wait_for_running_task(self):
# Test that if 2 threads call a method simultaneously, both will wait,
# but only 1 will call the underlying executor method.
start_event = threading.Event()
finish_event = threading.Event()
running_event = threading.Event()
done_event = threading.Event()
runner = [None]
class SteppingFakeExecutor(self.server._executor_cls):
def start(self):
# Tell the test which thread won the race
runner[0] = eventlet.getcurrent()
running_event.set()
start_event.wait()
super(SteppingFakeExecutor, self).start()
done_event.set()
finish_event.wait()
self.server._executor_cls = SteppingFakeExecutor
start1 = eventlet.spawn(self.server.start)
start2 = eventlet.spawn(self.server.start)
# Wait until one of the threads starts running
running_event.wait()
runner = runner[0]
waiter = start2 if runner == start1 else start2
waiter_finished = threading.Event()
waiter.link(lambda _: waiter_finished.set())
# At this point, runner is running start(), and waiter() is waiting for
# it to complete. runner has not yet logged anything.
self.assertEqual(1, len(self.executors))
executor = self.executors[0]
self.assertEqual(executor._calls, [])
self.assertFalse(waiter_finished.is_set())
# Let the runner log the call
start_event.set()
done_event.wait()
# We haven't signalled completion yet, so execute shouldn't have run
self.assertEqual(executor._calls, ['start'])
self.assertFalse(waiter_finished.is_set())
# Let the runner complete
finish_event.set()
waiter.wait()
runner.wait()
# Check that both threads have finished, start was only called once,
# and execute ran
self.assertTrue(waiter_finished.is_set())
self.assertEqual(executor._calls, ['start', 'execute'])
def test_start_stop_wait_stop_wait(self):
# Test that we behave correctly when calling stop/wait more than once.
# Subsequent calls should be noops.
self.server.start()
self.server.stop()
self.server.wait()
self.server.stop()
self.server.wait()
self.assertEqual(len(self.executors), 1)
executor = self.executors[0]
self.assertEqual(executor._calls,
['start', 'execute', 'stop', 'wait'])
self.assertTrue(executor.listener.cleanup.called)
def test_state_wrapping(self):
# Test that we behave correctly if a thread waits, and the server state
# has wrapped when it it next scheduled
# Ensure that if 2 threads wait for the completion of 'start', the
# first will wait until complete_event is signalled, but the second
# will continue
complete_event = threading.Event()
complete_waiting_callback = threading.Event()
start_state = self.server._states['start']
old_wait_for_completion = start_state.wait_for_completion
waited = [False]
def new_wait_for_completion(*args, **kwargs):
if not waited[0]:
waited[0] = True
complete_waiting_callback.set()
complete_event.wait()
old_wait_for_completion(*args, **kwargs)
start_state.wait_for_completion = new_wait_for_completion
# thread1 will wait for start to complete until we signal it
thread1 = eventlet.spawn(self.server.stop)
thread1_finished = threading.Event()
thread1.link(lambda _: thread1_finished.set())
self.server.start()
complete_waiting_callback.wait()
# The server should have started, but stop should not have been called
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute'])
self.assertFalse(thread1_finished.is_set())
self.server.stop()
self.server.wait()
# We should have gone through all the states, and thread1 should still
# be waiting
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
'stop', 'wait'])
self.assertFalse(thread1_finished.is_set())
# Start again
self.server.start()
# We should now record 2 executors
self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
'stop', 'wait'])
self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
self.assertFalse(thread1_finished.is_set())
# Allow thread1 to complete
complete_event.set()
thread1_finished.wait()
# thread1 should now have finished, and stop should not have been
# called again on either the first or second executor
self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
'stop', 'wait'])
self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
self.assertTrue(thread1_finished.is_set())
@mock.patch.object(server_module, 'DEFAULT_LOG_AFTER', 1)
@mock.patch.object(server_module, 'LOG')
def test_logging(self, mock_log):
# Test that we generate a log message if we wait longer than
# DEFAULT_LOG_AFTER
log_event = threading.Event()
mock_log.warn.side_effect = lambda _: log_event.set()
# Call stop without calling start. We should log a wait after 1 second
thread = eventlet.spawn(self.server.stop)
log_event.wait()
# Redundant given that we already waited, but it's nice to assert
self.assertTrue(mock_log.warn.called)
thread.kill()
@mock.patch.object(server_module, 'LOG')
def test_logging_explicit_wait(self, mock_log):
# Test that we generate a log message if we wait longer than
# the number of seconds passed to log_after
log_event = threading.Event()
mock_log.warn.side_effect = lambda _: log_event.set()
# Call stop without calling start. We should log a wait after 1 second
thread = eventlet.spawn(self.server.stop, log_after=1)
log_event.wait()
# Redundant given that we already waited, but it's nice to assert
self.assertTrue(mock_log.warn.called)
thread.kill()
@mock.patch.object(server_module, 'LOG')
def test_logging_with_timeout(self, mock_log):
# Test that we log a message after log_after seconds if we've also
# specified an absolute timeout
log_event = threading.Event()
mock_log.warn.side_effect = lambda _: log_event.set()
# Call stop without calling start. We should log a wait after 1 second
thread = eventlet.spawn(self.server.stop, log_after=1, timeout=2)
log_event.wait()
# Redundant given that we already waited, but it's nice to assert
self.assertTrue(mock_log.warn.called)
thread.kill()
def test_timeout_wait(self):
# Test that we will eventually timeout when passing the timeout option
# if a preceding condition is not satisfied.
self.assertRaises(server_module.TaskTimeout,
self.server.stop, timeout=1)
def test_timeout_running(self):
# Test that we will eventually timeout if we're waiting for another
# thread to complete this task
# Start the server, which will also instantiate an executor
self.server.start()
stop_called = threading.Event()
# Patch the executor's stop method to be very slow
def slow_stop():
stop_called.set()
eventlet.sleep(10)
self.executors[0].stop = slow_stop
# Call stop in a new thread
thread = eventlet.spawn(self.server.stop)
# Wait until the thread is in the slow stop method
stop_called.wait()
# Call stop again in the main thread with a timeout
self.assertRaises(server_module.TaskTimeout,
self.server.stop, timeout=1)
thread.kill()
@mock.patch.object(server_module, 'LOG')
def test_log_after_zero(self, mock_log):
# Test that we do not log a message after DEFAULT_LOG_AFTER if the
# caller gave log_after=1
# Call stop without calling start.
self.assertRaises(server_module.TaskTimeout,
self.server.stop, log_after=0, timeout=2)
# We timed out. Ensure we didn't log anything.
self.assertFalse(mock_log.warn.called)