Merge "Robustify locking in MessageHandlingServer"
This commit is contained in:
commit
d100988993
@ -116,29 +116,6 @@ def fetch_current_thread_functor():
|
||||
return lambda: threading.current_thread()
|
||||
|
||||
|
||||
class DummyCondition(object):
|
||||
def acquire(self):
|
||||
pass
|
||||
|
||||
def notify(self):
|
||||
pass
|
||||
|
||||
def notify_all(self):
|
||||
pass
|
||||
|
||||
def wait(self, timeout=None):
|
||||
pass
|
||||
|
||||
def release(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
self.acquire()
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.release()
|
||||
|
||||
|
||||
class DummyLock(object):
|
||||
def acquire(self):
|
||||
pass
|
||||
|
@ -23,16 +23,17 @@ __all__ = [
|
||||
'ServerListenError',
|
||||
]
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from oslo_service import service
|
||||
from oslo_utils import timeutils
|
||||
from stevedore import driver
|
||||
|
||||
from oslo_messaging._drivers import base as driver_base
|
||||
from oslo_messaging._i18n import _LW
|
||||
from oslo_messaging import _utils
|
||||
from oslo_messaging import exceptions
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -62,7 +63,170 @@ class ServerListenError(MessagingServerError):
|
||||
self.ex = ex
|
||||
|
||||
|
||||
class MessageHandlingServer(service.ServiceBase):
|
||||
class _OrderedTask(object):
|
||||
"""A task which must be executed in a particular order.
|
||||
|
||||
A caller may wait for this task to complete by calling
|
||||
`wait_for_completion`.
|
||||
|
||||
A caller may run this task with `run_once`, which will ensure that however
|
||||
many times the task is called it only runs once. Simultaneous callers will
|
||||
block until the running task completes, which means that any caller can be
|
||||
sure that the task has completed after run_once returns.
|
||||
"""
|
||||
|
||||
INIT = 0 # The task has not yet started
|
||||
RUNNING = 1 # The task is running somewhere
|
||||
COMPLETE = 2 # The task has run somewhere
|
||||
|
||||
# We generate a log message if we wait for a lock longer than
|
||||
# LOG_AFTER_WAIT_SECS seconds
|
||||
LOG_AFTER_WAIT_SECS = 30
|
||||
|
||||
def __init__(self, name):
|
||||
"""Create a new _OrderedTask.
|
||||
|
||||
:param name: The name of this task. Used in log messages.
|
||||
"""
|
||||
|
||||
super(_OrderedTask, self).__init__()
|
||||
|
||||
self._name = name
|
||||
self._cond = threading.Condition()
|
||||
self._state = self.INIT
|
||||
|
||||
def _wait(self, condition, warn_msg):
|
||||
"""Wait while condition() is true. Write a log message if condition()
|
||||
has not become false within LOG_AFTER_WAIT_SECS.
|
||||
"""
|
||||
with timeutils.StopWatch(duration=self.LOG_AFTER_WAIT_SECS) as w:
|
||||
logged = False
|
||||
while condition():
|
||||
wait = None if logged else w.leftover()
|
||||
self._cond.wait(wait)
|
||||
|
||||
if not logged and w.expired():
|
||||
LOG.warn(warn_msg)
|
||||
LOG.debug(''.join(traceback.format_stack()))
|
||||
# Only log once. After than we wait indefinitely without
|
||||
# logging.
|
||||
logged = True
|
||||
|
||||
def wait_for_completion(self, caller):
|
||||
"""Wait until this task has completed.
|
||||
|
||||
:param caller: The name of the task which is waiting.
|
||||
"""
|
||||
with self._cond:
|
||||
self._wait(lambda: self._state != self.COMPLETE,
|
||||
'%s has been waiting for %s to complete for longer '
|
||||
'than %i seconds'
|
||||
% (caller, self._name, self.LOG_AFTER_WAIT_SECS))
|
||||
|
||||
def run_once(self, fn):
|
||||
"""Run a task exactly once. If it is currently running in another
|
||||
thread, wait for it to complete. If it has already run, return
|
||||
immediately without running it again.
|
||||
|
||||
:param fn: The task to run. It must be a callable taking no arguments.
|
||||
It may optionally return another callable, which also takes
|
||||
no arguments, which will be executed after completion has
|
||||
been signaled to other threads.
|
||||
"""
|
||||
with self._cond:
|
||||
if self._state == self.INIT:
|
||||
self._state = self.RUNNING
|
||||
# Note that nothing waits on RUNNING, so no need to notify
|
||||
|
||||
# We need to release the condition lock before calling out to
|
||||
# prevent deadlocks. Reacquire it immediately afterwards.
|
||||
self._cond.release()
|
||||
try:
|
||||
post_fn = fn()
|
||||
finally:
|
||||
self._cond.acquire()
|
||||
self._state = self.COMPLETE
|
||||
self._cond.notify_all()
|
||||
|
||||
if post_fn is not None:
|
||||
# Release the condition lock before calling out to prevent
|
||||
# deadlocks. Reacquire it immediately afterwards.
|
||||
self._cond.release()
|
||||
try:
|
||||
post_fn()
|
||||
finally:
|
||||
self._cond.acquire()
|
||||
elif self._state == self.RUNNING:
|
||||
self._wait(lambda: self._state == self.RUNNING,
|
||||
'%s has been waiting on another thread to complete '
|
||||
'for longer than %i seconds'
|
||||
% (self._name, self.LOG_AFTER_WAIT_SECS))
|
||||
|
||||
|
||||
class _OrderedTaskRunner(object):
|
||||
"""Mixin for a class which executes ordered tasks."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(_OrderedTaskRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
# Get a list of methods on this object which have the _ordered
|
||||
# attribute
|
||||
self._tasks = [name
|
||||
for (name, member) in inspect.getmembers(self)
|
||||
if inspect.ismethod(member) and
|
||||
getattr(member, '_ordered', False)]
|
||||
self.init_task_states()
|
||||
|
||||
def init_task_states(self):
|
||||
# Note that we don't need to lock this. Once created, the _states dict
|
||||
# is immutable. Get and set are (individually) atomic operations in
|
||||
# Python, and we only set after the dict is fully created.
|
||||
self._states = {task: _OrderedTask(task) for task in self._tasks}
|
||||
|
||||
@staticmethod
|
||||
def decorate_ordered(fn, state, after):
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# Store the states we started with in case the state wraps on us
|
||||
# while we're sleeping. We must wait and run_once in the same
|
||||
# epoch. If the epoch ended while we were sleeping, run_once will
|
||||
# safely do nothing.
|
||||
states = self._states
|
||||
|
||||
# Wait for the given preceding state to complete
|
||||
if after is not None:
|
||||
states[after].wait_for_completion(state)
|
||||
|
||||
# Run this state
|
||||
states[state].run_once(lambda: fn(self, *args, **kwargs))
|
||||
return wrapper
|
||||
|
||||
|
||||
def ordered(after=None):
|
||||
"""A method which will be executed as an ordered task. The method will be
|
||||
called exactly once, however many times it is called. If it is called
|
||||
multiple times simultaneously it will only be called once, but all callers
|
||||
will wait until execution is complete.
|
||||
|
||||
If `after` is given, this method will not run until `after` has completed.
|
||||
|
||||
:param after: Optionally, another method decorated with `ordered`. Wait for
|
||||
the completion of `after` before executing this method.
|
||||
"""
|
||||
if after is not None:
|
||||
after = after.__name__
|
||||
|
||||
def _ordered(fn):
|
||||
# Set an attribute on the method so we can find it later
|
||||
setattr(fn, '_ordered', True)
|
||||
state = fn.__name__
|
||||
|
||||
return _OrderedTaskRunner.decorate_ordered(fn, state, after)
|
||||
return _ordered
|
||||
|
||||
|
||||
class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
|
||||
"""Server for handling messages.
|
||||
|
||||
Connect a transport to a dispatcher that knows how to process the
|
||||
@ -94,29 +258,18 @@ class MessageHandlingServer(service.ServiceBase):
|
||||
self.dispatcher = dispatcher
|
||||
self.executor = executor
|
||||
|
||||
# NOTE(sileht): we use a lock to protect the state change of the
|
||||
# server, we don't want to call stop until the transport driver
|
||||
# is fully started. Except for the blocking executor that have
|
||||
# start() that doesn't return
|
||||
if self.executor != "blocking":
|
||||
self._state_cond = threading.Condition()
|
||||
self._dummy_cond = False
|
||||
else:
|
||||
self._state_cond = _utils.DummyCondition()
|
||||
self._dummy_cond = True
|
||||
|
||||
try:
|
||||
mgr = driver.DriverManager('oslo.messaging.executors',
|
||||
self.executor)
|
||||
except RuntimeError as ex:
|
||||
raise ExecutorLoadFailure(self.executor, ex)
|
||||
else:
|
||||
|
||||
self._executor_cls = mgr.driver
|
||||
self._executor_obj = None
|
||||
self._running = False
|
||||
|
||||
super(MessageHandlingServer, self).__init__()
|
||||
|
||||
@ordered()
|
||||
def start(self):
|
||||
"""Start handling incoming messages.
|
||||
|
||||
@ -131,24 +284,21 @@ class MessageHandlingServer(service.ServiceBase):
|
||||
choose to dispatch messages in a new thread, coroutine or simply the
|
||||
current thread.
|
||||
"""
|
||||
if self._executor_obj is not None:
|
||||
return
|
||||
with self._state_cond:
|
||||
if self._executor_obj is not None:
|
||||
return
|
||||
try:
|
||||
listener = self.dispatcher._listen(self.transport)
|
||||
except driver_base.TransportDriverError as ex:
|
||||
raise ServerListenError(self.target, ex)
|
||||
self._executor_obj = self._executor_cls(self.conf, listener,
|
||||
self.dispatcher)
|
||||
self._executor_obj.start()
|
||||
self._running = True
|
||||
self._state_cond.notify_all()
|
||||
executor = self._executor_cls(self.conf, listener, self.dispatcher)
|
||||
executor.start()
|
||||
self._executor_obj = executor
|
||||
|
||||
if self.executor == 'blocking':
|
||||
self._executor_obj.execute()
|
||||
# N.B. This will be executed unlocked and unordered, so
|
||||
# we can't rely on the value of self._executor_obj when this runs.
|
||||
# We explicitly pass the local variable.
|
||||
return lambda: executor.execute()
|
||||
|
||||
@ordered(after=start)
|
||||
def stop(self):
|
||||
"""Stop handling incoming messages.
|
||||
|
||||
@ -157,12 +307,9 @@ class MessageHandlingServer(service.ServiceBase):
|
||||
some messages, and underlying driver resources associated to this
|
||||
server are still in use. See 'wait' for more details.
|
||||
"""
|
||||
with self._state_cond:
|
||||
if self._executor_obj is not None:
|
||||
self._running = False
|
||||
self._executor_obj.stop()
|
||||
self._state_cond.notify_all()
|
||||
|
||||
@ordered(after=stop)
|
||||
def wait(self):
|
||||
"""Wait for message processing to complete.
|
||||
|
||||
@ -173,37 +320,14 @@ class MessageHandlingServer(service.ServiceBase):
|
||||
Once it's finished, the underlying driver resources associated to this
|
||||
server are released (like closing useless network connections).
|
||||
"""
|
||||
with self._state_cond:
|
||||
if self._running:
|
||||
LOG.warn(_LW("wait() should be called after stop() as it "
|
||||
"waits for existing messages to finish "
|
||||
"processing"))
|
||||
w = timeutils.StopWatch()
|
||||
w.start()
|
||||
while self._running:
|
||||
# NOTE(harlowja): 1.0 seconds was mostly chosen at
|
||||
# random, but it seems like a reasonable value to
|
||||
# use to avoid spamming the logs with to much
|
||||
# information.
|
||||
self._state_cond.wait(1.0)
|
||||
if self._running and not self._dummy_cond:
|
||||
LOG.warn(
|
||||
_LW("wait() should have been called"
|
||||
" after stop() as wait() waits for existing"
|
||||
" messages to finish processing, it has"
|
||||
" been %0.2f seconds and stop() still has"
|
||||
" not been called"), w.elapsed())
|
||||
executor = self._executor_obj
|
||||
self._executor_obj = None
|
||||
if executor is not None:
|
||||
# We are the lucky calling thread to wait on the executor to
|
||||
# actually finish.
|
||||
try:
|
||||
executor.wait()
|
||||
self._executor_obj.wait()
|
||||
finally:
|
||||
# Close listener connection after processing all messages
|
||||
executor.listener.cleanup()
|
||||
executor = None
|
||||
self._executor_obj.listener.cleanup()
|
||||
self._executor_obj = None
|
||||
|
||||
self.init_task_states()
|
||||
|
||||
def reset(self):
|
||||
"""Reset service.
|
||||
|
@ -13,6 +13,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import eventlet
|
||||
import time
|
||||
import threading
|
||||
|
||||
from oslo_config import cfg
|
||||
@ -20,6 +22,7 @@ import testscenarios
|
||||
|
||||
import mock
|
||||
import oslo_messaging
|
||||
from oslo_messaging import server as server_module
|
||||
from oslo_messaging.tests import utils as test_utils
|
||||
|
||||
load_tests = testscenarios.load_tests_apply_scenarios
|
||||
@ -528,3 +531,210 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
|
||||
|
||||
|
||||
TestMultipleServers.generate_scenarios()
|
||||
|
||||
class TestServerLocking(test_utils.BaseTestCase):
|
||||
def setUp(self):
|
||||
super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts())
|
||||
|
||||
def _logmethod(name):
|
||||
def method(self):
|
||||
with self._lock:
|
||||
self._calls.append(name)
|
||||
return method
|
||||
|
||||
executors = []
|
||||
class FakeExecutor(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._lock = threading.Lock()
|
||||
self._calls = []
|
||||
self.listener = mock.MagicMock()
|
||||
executors.append(self)
|
||||
|
||||
start = _logmethod('start')
|
||||
stop = _logmethod('stop')
|
||||
wait = _logmethod('wait')
|
||||
execute = _logmethod('execute')
|
||||
self.executors = executors
|
||||
|
||||
self.server = oslo_messaging.MessageHandlingServer(mock.Mock(),
|
||||
mock.Mock())
|
||||
self.server._executor_cls = FakeExecutor
|
||||
|
||||
def test_start_stop_wait(self):
|
||||
# Test a simple execution of start, stop, wait in order
|
||||
|
||||
thread = eventlet.spawn(self.server.start)
|
||||
self.server.stop()
|
||||
self.server.wait()
|
||||
|
||||
self.assertEqual(len(self.executors), 1)
|
||||
executor = self.executors[0]
|
||||
self.assertEqual(executor._calls,
|
||||
['start', 'execute', 'stop', 'wait'])
|
||||
self.assertTrue(executor.listener.cleanup.called)
|
||||
|
||||
def test_reversed_order(self):
|
||||
# Test that if we call wait, stop, start, these will be correctly
|
||||
# reordered
|
||||
|
||||
wait = eventlet.spawn(self.server.wait)
|
||||
# This is non-deterministic, but there's not a great deal we can do
|
||||
# about that
|
||||
eventlet.sleep(0)
|
||||
|
||||
stop = eventlet.spawn(self.server.stop)
|
||||
eventlet.sleep(0)
|
||||
|
||||
start = eventlet.spawn(self.server.start)
|
||||
|
||||
self.server.wait()
|
||||
|
||||
self.assertEqual(len(self.executors), 1)
|
||||
executor = self.executors[0]
|
||||
self.assertEqual(executor._calls,
|
||||
['start', 'execute', 'stop', 'wait'])
|
||||
|
||||
def test_wait_for_running_task(self):
|
||||
# Test that if 2 threads call a method simultaneously, both will wait,
|
||||
# but only 1 will call the underlying executor method.
|
||||
|
||||
start_event = threading.Event()
|
||||
finish_event = threading.Event()
|
||||
|
||||
running_event = threading.Event()
|
||||
done_event = threading.Event()
|
||||
|
||||
runner = [None]
|
||||
class SteppingFakeExecutor(self.server._executor_cls):
|
||||
def start(self):
|
||||
# Tell the test which thread won the race
|
||||
runner[0] = eventlet.getcurrent()
|
||||
running_event.set()
|
||||
|
||||
start_event.wait()
|
||||
super(SteppingFakeExecutor, self).start()
|
||||
done_event.set()
|
||||
|
||||
finish_event.wait()
|
||||
self.server._executor_cls = SteppingFakeExecutor
|
||||
|
||||
start1 = eventlet.spawn(self.server.start)
|
||||
start2 = eventlet.spawn(self.server.start)
|
||||
|
||||
# Wait until one of the threads starts running
|
||||
running_event.wait()
|
||||
runner = runner[0]
|
||||
waiter = start2 if runner == start1 else start2
|
||||
|
||||
waiter_finished = threading.Event()
|
||||
waiter.link(lambda _: waiter_finished.set())
|
||||
|
||||
# At this point, runner is running start(), and waiter() is waiting for
|
||||
# it to complete. runner has not yet logged anything.
|
||||
self.assertEqual(1, len(self.executors))
|
||||
executor = self.executors[0]
|
||||
|
||||
self.assertEqual(executor._calls, [])
|
||||
self.assertFalse(waiter_finished.is_set())
|
||||
|
||||
# Let the runner log the call
|
||||
start_event.set()
|
||||
done_event.wait()
|
||||
|
||||
# We haven't signalled completion yet, so execute shouldn't have run
|
||||
self.assertEqual(executor._calls, ['start'])
|
||||
self.assertFalse(waiter_finished.is_set())
|
||||
|
||||
# Let the runner complete
|
||||
finish_event.set()
|
||||
waiter.wait()
|
||||
runner.wait()
|
||||
|
||||
# Check that both threads have finished, start was only called once,
|
||||
# and execute ran
|
||||
self.assertTrue(waiter_finished.is_set())
|
||||
self.assertEqual(executor._calls, ['start', 'execute'])
|
||||
|
||||
def test_state_wrapping(self):
|
||||
# Test that we behave correctly if a thread waits, and the server state
|
||||
# has wrapped when it it next scheduled
|
||||
|
||||
# Ensure that if 2 threads wait for the completion of 'start', the
|
||||
# first will wait until complete_event is signalled, but the second
|
||||
# will continue
|
||||
complete_event = threading.Event()
|
||||
complete_waiting_callback = threading.Event()
|
||||
|
||||
start_state = self.server._states['start']
|
||||
old_wait_for_completion = start_state.wait_for_completion
|
||||
waited = [False]
|
||||
def new_wait_for_completion(*args, **kwargs):
|
||||
if not waited[0]:
|
||||
waited[0] = True
|
||||
complete_waiting_callback.set()
|
||||
complete_event.wait()
|
||||
old_wait_for_completion(*args, **kwargs)
|
||||
start_state.wait_for_completion = new_wait_for_completion
|
||||
|
||||
# thread1 will wait for start to complete until we signal it
|
||||
thread1 = eventlet.spawn(self.server.stop)
|
||||
thread1_finished = threading.Event()
|
||||
thread1.link(lambda _: thread1_finished.set())
|
||||
|
||||
self.server.start()
|
||||
complete_waiting_callback.wait()
|
||||
|
||||
# The server should have started, but stop should not have been called
|
||||
self.assertEqual(1, len(self.executors))
|
||||
self.assertEqual(self.executors[0]._calls, ['start', 'execute'])
|
||||
self.assertFalse(thread1_finished.is_set())
|
||||
|
||||
self.server.stop()
|
||||
self.server.wait()
|
||||
|
||||
# We should have gone through all the states, and thread1 should still
|
||||
# be waiting
|
||||
self.assertEqual(1, len(self.executors))
|
||||
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
|
||||
'stop', 'wait'])
|
||||
self.assertFalse(thread1_finished.is_set())
|
||||
|
||||
# Start again
|
||||
self.server.start()
|
||||
|
||||
# We should now record 2 executors
|
||||
self.assertEqual(2, len(self.executors))
|
||||
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
|
||||
'stop', 'wait'])
|
||||
self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
|
||||
self.assertFalse(thread1_finished.is_set())
|
||||
|
||||
# Allow thread1 to complete
|
||||
complete_event.set()
|
||||
thread1_finished.wait()
|
||||
|
||||
# thread1 should now have finished, and stop should not have been
|
||||
# called again on either the first or second executor
|
||||
self.assertEqual(2, len(self.executors))
|
||||
self.assertEqual(self.executors[0]._calls, ['start', 'execute',
|
||||
'stop', 'wait'])
|
||||
self.assertEqual(self.executors[1]._calls, ['start', 'execute'])
|
||||
self.assertTrue(thread1_finished.is_set())
|
||||
|
||||
@mock.patch.object(server_module._OrderedTask,
|
||||
'LOG_AFTER_WAIT_SECS', 1)
|
||||
@mock.patch.object(server_module, 'LOG')
|
||||
def test_timeout_logging(self, mock_log):
|
||||
# Test that we generate a log message if we wait longer than
|
||||
# LOG_AFTER_WAIT_SECS
|
||||
|
||||
log_event = threading.Event()
|
||||
mock_log.warn.side_effect = lambda _: log_event.set()
|
||||
|
||||
# Call stop without calling start. We should log a wait after 1 second
|
||||
thread = eventlet.spawn(self.server.stop)
|
||||
log_event.wait()
|
||||
|
||||
# Redundant given that we already waited, but it's nice to assert
|
||||
self.assertTrue(mock_log.warn.called)
|
||||
thread.kill()
|
||||
|
Loading…
Reference in New Issue
Block a user