diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py index cec94bb48..1bb20b089 100644 --- a/oslo_messaging/_utils.py +++ b/oslo_messaging/_utils.py @@ -116,6 +116,29 @@ def fetch_current_thread_functor(): return lambda: threading.current_thread() +class DummyCondition(object): + def acquire(self): + pass + + def notify(self): + pass + + def notify_all(self): + pass + + def wait(self, timeout=None): + pass + + def release(self): + pass + + def __enter__(self): + self.acquire() + + def __exit__(self, type, value, traceback): + self.release() + + class DummyLock(object): def acquire(self): pass diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py index f1739ad90..491ccbf52 100644 --- a/oslo_messaging/server.py +++ b/oslo_messaging/server.py @@ -23,17 +23,16 @@ __all__ = [ 'ServerListenError', ] -import functools -import inspect import logging import threading -import traceback from oslo_service import service from oslo_utils import timeutils from stevedore import driver from oslo_messaging._drivers import base as driver_base +from oslo_messaging._i18n import _LW +from oslo_messaging import _utils from oslo_messaging import exceptions LOG = logging.getLogger(__name__) @@ -63,170 +62,7 @@ class ServerListenError(MessagingServerError): self.ex = ex -class _OrderedTask(object): - """A task which must be executed in a particular order. - - A caller may wait for this task to complete by calling - `wait_for_completion`. - - A caller may run this task with `run_once`, which will ensure that however - many times the task is called it only runs once. Simultaneous callers will - block until the running task completes, which means that any caller can be - sure that the task has completed after run_once returns. - """ - - INIT = 0 # The task has not yet started - RUNNING = 1 # The task is running somewhere - COMPLETE = 2 # The task has run somewhere - - # We generate a log message if we wait for a lock longer than - # LOG_AFTER_WAIT_SECS seconds - LOG_AFTER_WAIT_SECS = 30 - - def __init__(self, name): - """Create a new _OrderedTask. - - :param name: The name of this task. Used in log messages. - """ - - super(_OrderedTask, self).__init__() - - self._name = name - self._cond = threading.Condition() - self._state = self.INIT - - def _wait(self, condition, warn_msg): - """Wait while condition() is true. Write a log message if condition() - has not become false within LOG_AFTER_WAIT_SECS. - """ - with timeutils.StopWatch(duration=self.LOG_AFTER_WAIT_SECS) as w: - logged = False - while condition(): - wait = None if logged else w.leftover() - self._cond.wait(wait) - - if not logged and w.expired(): - LOG.warn(warn_msg) - LOG.debug(''.join(traceback.format_stack())) - # Only log once. After than we wait indefinitely without - # logging. - logged = True - - def wait_for_completion(self, caller): - """Wait until this task has completed. - - :param caller: The name of the task which is waiting. - """ - with self._cond: - self._wait(lambda: self._state != self.COMPLETE, - '%s has been waiting for %s to complete for longer ' - 'than %i seconds' - % (caller, self._name, self.LOG_AFTER_WAIT_SECS)) - - def run_once(self, fn): - """Run a task exactly once. If it is currently running in another - thread, wait for it to complete. If it has already run, return - immediately without running it again. - - :param fn: The task to run. It must be a callable taking no arguments. - It may optionally return another callable, which also takes - no arguments, which will be executed after completion has - been signaled to other threads. - """ - with self._cond: - if self._state == self.INIT: - self._state = self.RUNNING - # Note that nothing waits on RUNNING, so no need to notify - - # We need to release the condition lock before calling out to - # prevent deadlocks. Reacquire it immediately afterwards. - self._cond.release() - try: - post_fn = fn() - finally: - self._cond.acquire() - self._state = self.COMPLETE - self._cond.notify_all() - - if post_fn is not None: - # Release the condition lock before calling out to prevent - # deadlocks. Reacquire it immediately afterwards. - self._cond.release() - try: - post_fn() - finally: - self._cond.acquire() - elif self._state == self.RUNNING: - self._wait(lambda: self._state == self.RUNNING, - '%s has been waiting on another thread to complete ' - 'for longer than %i seconds' - % (self._name, self.LOG_AFTER_WAIT_SECS)) - - -class _OrderedTaskRunner(object): - """Mixin for a class which executes ordered tasks.""" - - def __init__(self, *args, **kwargs): - super(_OrderedTaskRunner, self).__init__(*args, **kwargs) - - # Get a list of methods on this object which have the _ordered - # attribute - self._tasks = [name - for (name, member) in inspect.getmembers(self) - if inspect.ismethod(member) and - getattr(member, '_ordered', False)] - self.init_task_states() - - def init_task_states(self): - # Note that we don't need to lock this. Once created, the _states dict - # is immutable. Get and set are (individually) atomic operations in - # Python, and we only set after the dict is fully created. - self._states = {task: _OrderedTask(task) for task in self._tasks} - - @staticmethod - def decorate_ordered(fn, state, after): - - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - # Store the states we started with in case the state wraps on us - # while we're sleeping. We must wait and run_once in the same - # epoch. If the epoch ended while we were sleeping, run_once will - # safely do nothing. - states = self._states - - # Wait for the given preceding state to complete - if after is not None: - states[after].wait_for_completion(state) - - # Run this state - states[state].run_once(lambda: fn(self, *args, **kwargs)) - return wrapper - - -def ordered(after=None): - """A method which will be executed as an ordered task. The method will be - called exactly once, however many times it is called. If it is called - multiple times simultaneously it will only be called once, but all callers - will wait until execution is complete. - - If `after` is given, this method will not run until `after` has completed. - - :param after: Optionally, another method decorated with `ordered`. Wait for - the completion of `after` before executing this method. - """ - if after is not None: - after = after.__name__ - - def _ordered(fn): - # Set an attribute on the method so we can find it later - setattr(fn, '_ordered', True) - state = fn.__name__ - - return _OrderedTaskRunner.decorate_ordered(fn, state, after) - return _ordered - - -class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): +class MessageHandlingServer(service.ServiceBase): """Server for handling messages. Connect a transport to a dispatcher that knows how to process the @@ -258,18 +94,29 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): self.dispatcher = dispatcher self.executor = executor + # NOTE(sileht): we use a lock to protect the state change of the + # server, we don't want to call stop until the transport driver + # is fully started. Except for the blocking executor that have + # start() that doesn't return + if self.executor != "blocking": + self._state_cond = threading.Condition() + self._dummy_cond = False + else: + self._state_cond = _utils.DummyCondition() + self._dummy_cond = True + try: mgr = driver.DriverManager('oslo.messaging.executors', self.executor) except RuntimeError as ex: raise ExecutorLoadFailure(self.executor, ex) - - self._executor_cls = mgr.driver - self._executor_obj = None + else: + self._executor_cls = mgr.driver + self._executor_obj = None + self._running = False super(MessageHandlingServer, self).__init__() - @ordered() def start(self): """Start handling incoming messages. @@ -284,21 +131,24 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): choose to dispatch messages in a new thread, coroutine or simply the current thread. """ - try: - listener = self.dispatcher._listen(self.transport) - except driver_base.TransportDriverError as ex: - raise ServerListenError(self.target, ex) - executor = self._executor_cls(self.conf, listener, self.dispatcher) - executor.start() - self._executor_obj = executor + if self._executor_obj is not None: + return + with self._state_cond: + if self._executor_obj is not None: + return + try: + listener = self.dispatcher._listen(self.transport) + except driver_base.TransportDriverError as ex: + raise ServerListenError(self.target, ex) + self._executor_obj = self._executor_cls(self.conf, listener, + self.dispatcher) + self._executor_obj.start() + self._running = True + self._state_cond.notify_all() if self.executor == 'blocking': - # N.B. This will be executed unlocked and unordered, so - # we can't rely on the value of self._executor_obj when this runs. - # We explicitly pass the local variable. - return lambda: executor.execute() + self._executor_obj.execute() - @ordered(after=start) def stop(self): """Stop handling incoming messages. @@ -307,9 +157,12 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): some messages, and underlying driver resources associated to this server are still in use. See 'wait' for more details. """ - self._executor_obj.stop() + with self._state_cond: + if self._executor_obj is not None: + self._running = False + self._executor_obj.stop() + self._state_cond.notify_all() - @ordered(after=stop) def wait(self): """Wait for message processing to complete. @@ -320,14 +173,37 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): Once it's finished, the underlying driver resources associated to this server are released (like closing useless network connections). """ - try: - self._executor_obj.wait() - finally: - # Close listener connection after processing all messages - self._executor_obj.listener.cleanup() + with self._state_cond: + if self._running: + LOG.warn(_LW("wait() should be called after stop() as it " + "waits for existing messages to finish " + "processing")) + w = timeutils.StopWatch() + w.start() + while self._running: + # NOTE(harlowja): 1.0 seconds was mostly chosen at + # random, but it seems like a reasonable value to + # use to avoid spamming the logs with to much + # information. + self._state_cond.wait(1.0) + if self._running and not self._dummy_cond: + LOG.warn( + _LW("wait() should have been called" + " after stop() as wait() waits for existing" + " messages to finish processing, it has" + " been %0.2f seconds and stop() still has" + " not been called"), w.elapsed()) + executor = self._executor_obj self._executor_obj = None - - self.init_task_states() + if executor is not None: + # We are the lucky calling thread to wait on the executor to + # actually finish. + try: + executor.wait() + finally: + # Close listener connection after processing all messages + executor.listener.cleanup() + executor = None def reset(self): """Reset service. diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py index 1a2d2aa63..258dacb24 100644 --- a/oslo_messaging/tests/rpc/test_server.py +++ b/oslo_messaging/tests/rpc/test_server.py @@ -13,8 +13,6 @@ # License for the specific language governing permissions and limitations # under the License. -import eventlet -import time import threading from oslo_config import cfg @@ -22,7 +20,6 @@ import testscenarios import mock import oslo_messaging -from oslo_messaging import server as server_module from oslo_messaging.tests import utils as test_utils load_tests = testscenarios.load_tests_apply_scenarios @@ -531,210 +528,3 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin): TestMultipleServers.generate_scenarios() - -class TestServerLocking(test_utils.BaseTestCase): - def setUp(self): - super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts()) - - def _logmethod(name): - def method(self): - with self._lock: - self._calls.append(name) - return method - - executors = [] - class FakeExecutor(object): - def __init__(self, *args, **kwargs): - self._lock = threading.Lock() - self._calls = [] - self.listener = mock.MagicMock() - executors.append(self) - - start = _logmethod('start') - stop = _logmethod('stop') - wait = _logmethod('wait') - execute = _logmethod('execute') - self.executors = executors - - self.server = oslo_messaging.MessageHandlingServer(mock.Mock(), - mock.Mock()) - self.server._executor_cls = FakeExecutor - - def test_start_stop_wait(self): - # Test a simple execution of start, stop, wait in order - - thread = eventlet.spawn(self.server.start) - self.server.stop() - self.server.wait() - - self.assertEqual(len(self.executors), 1) - executor = self.executors[0] - self.assertEqual(executor._calls, - ['start', 'execute', 'stop', 'wait']) - self.assertTrue(executor.listener.cleanup.called) - - def test_reversed_order(self): - # Test that if we call wait, stop, start, these will be correctly - # reordered - - wait = eventlet.spawn(self.server.wait) - # This is non-deterministic, but there's not a great deal we can do - # about that - eventlet.sleep(0) - - stop = eventlet.spawn(self.server.stop) - eventlet.sleep(0) - - start = eventlet.spawn(self.server.start) - - self.server.wait() - - self.assertEqual(len(self.executors), 1) - executor = self.executors[0] - self.assertEqual(executor._calls, - ['start', 'execute', 'stop', 'wait']) - - def test_wait_for_running_task(self): - # Test that if 2 threads call a method simultaneously, both will wait, - # but only 1 will call the underlying executor method. - - start_event = threading.Event() - finish_event = threading.Event() - - running_event = threading.Event() - done_event = threading.Event() - - runner = [None] - class SteppingFakeExecutor(self.server._executor_cls): - def start(self): - # Tell the test which thread won the race - runner[0] = eventlet.getcurrent() - running_event.set() - - start_event.wait() - super(SteppingFakeExecutor, self).start() - done_event.set() - - finish_event.wait() - self.server._executor_cls = SteppingFakeExecutor - - start1 = eventlet.spawn(self.server.start) - start2 = eventlet.spawn(self.server.start) - - # Wait until one of the threads starts running - running_event.wait() - runner = runner[0] - waiter = start2 if runner == start1 else start2 - - waiter_finished = threading.Event() - waiter.link(lambda _: waiter_finished.set()) - - # At this point, runner is running start(), and waiter() is waiting for - # it to complete. runner has not yet logged anything. - self.assertEqual(1, len(self.executors)) - executor = self.executors[0] - - self.assertEqual(executor._calls, []) - self.assertFalse(waiter_finished.is_set()) - - # Let the runner log the call - start_event.set() - done_event.wait() - - # We haven't signalled completion yet, so execute shouldn't have run - self.assertEqual(executor._calls, ['start']) - self.assertFalse(waiter_finished.is_set()) - - # Let the runner complete - finish_event.set() - waiter.wait() - runner.wait() - - # Check that both threads have finished, start was only called once, - # and execute ran - self.assertTrue(waiter_finished.is_set()) - self.assertEqual(executor._calls, ['start', 'execute']) - - def test_state_wrapping(self): - # Test that we behave correctly if a thread waits, and the server state - # has wrapped when it it next scheduled - - # Ensure that if 2 threads wait for the completion of 'start', the - # first will wait until complete_event is signalled, but the second - # will continue - complete_event = threading.Event() - complete_waiting_callback = threading.Event() - - start_state = self.server._states['start'] - old_wait_for_completion = start_state.wait_for_completion - waited = [False] - def new_wait_for_completion(*args, **kwargs): - if not waited[0]: - waited[0] = True - complete_waiting_callback.set() - complete_event.wait() - old_wait_for_completion(*args, **kwargs) - start_state.wait_for_completion = new_wait_for_completion - - # thread1 will wait for start to complete until we signal it - thread1 = eventlet.spawn(self.server.stop) - thread1_finished = threading.Event() - thread1.link(lambda _: thread1_finished.set()) - - self.server.start() - complete_waiting_callback.wait() - - # The server should have started, but stop should not have been called - self.assertEqual(1, len(self.executors)) - self.assertEqual(self.executors[0]._calls, ['start', 'execute']) - self.assertFalse(thread1_finished.is_set()) - - self.server.stop() - self.server.wait() - - # We should have gone through all the states, and thread1 should still - # be waiting - self.assertEqual(1, len(self.executors)) - self.assertEqual(self.executors[0]._calls, ['start', 'execute', - 'stop', 'wait']) - self.assertFalse(thread1_finished.is_set()) - - # Start again - self.server.start() - - # We should now record 2 executors - self.assertEqual(2, len(self.executors)) - self.assertEqual(self.executors[0]._calls, ['start', 'execute', - 'stop', 'wait']) - self.assertEqual(self.executors[1]._calls, ['start', 'execute']) - self.assertFalse(thread1_finished.is_set()) - - # Allow thread1 to complete - complete_event.set() - thread1_finished.wait() - - # thread1 should now have finished, and stop should not have been - # called again on either the first or second executor - self.assertEqual(2, len(self.executors)) - self.assertEqual(self.executors[0]._calls, ['start', 'execute', - 'stop', 'wait']) - self.assertEqual(self.executors[1]._calls, ['start', 'execute']) - self.assertTrue(thread1_finished.is_set()) - - @mock.patch.object(server_module._OrderedTask, - 'LOG_AFTER_WAIT_SECS', 1) - @mock.patch.object(server_module, 'LOG') - def test_timeout_logging(self, mock_log): - # Test that we generate a log message if we wait longer than - # LOG_AFTER_WAIT_SECS - - log_event = threading.Event() - mock_log.warn.side_effect = lambda _: log_event.set() - - # Call stop without calling start. We should log a wait after 1 second - thread = eventlet.spawn(self.server.stop) - log_event.wait() - - # Redundant given that we already waited, but it's nice to assert - self.assertTrue(mock_log.warn.called) - thread.kill()