diff --git a/oslo_messaging/_drivers/impl_rabbit.py b/oslo_messaging/_drivers/impl_rabbit.py index aa7cffa7..7d180760 100644 --- a/oslo_messaging/_drivers/impl_rabbit.py +++ b/oslo_messaging/_drivers/impl_rabbit.py @@ -271,22 +271,10 @@ class Consumer(object): message.ack() -class DummyConnectionLock(object): - def acquire(self): - pass - - def release(self): - pass - +class DummyConnectionLock(_utils.DummyLock): def heartbeat_acquire(self): pass - def __enter__(self): - self.acquire() - - def __exit__(self, type, value, traceback): - self.release() - class ConnectionLock(DummyConnectionLock): """Lock object to protect access the the kombu connection diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py index 1c816de5..eec0d917 100644 --- a/oslo_messaging/_utils.py +++ b/oslo_messaging/_utils.py @@ -114,3 +114,17 @@ def fetch_current_thread_functor(): return lambda: eventlet.getcurrent() else: return lambda: threading.current_thread() + + +class DummyLock(object): + def acquire(self): + pass + + def release(self): + pass + + def __enter__(self): + self.acquire() + + def __exit__(self, type, value, traceback): + self.release() diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py index a2b49338..07c52839 100644 --- a/oslo_messaging/server.py +++ b/oslo_messaging/server.py @@ -24,6 +24,7 @@ __all__ = [ ] import logging +import threading from oslo_service import service from stevedore import driver @@ -92,7 +93,14 @@ class MessageHandlingServer(service.ServiceBase): self.dispatcher = dispatcher self.executor = executor - self._get_thread_id = _utils.fetch_current_thread_functor() + # NOTE(sileht): we use a lock to protect the state change of the + # server, we don't want to call stop until the transport driver + # is fully started. Except for the blocking executor that have + # start() that doesn't return + if self.executor != "blocking": + self._state_lock = threading.Lock() + else: + self._state_lock = _utils.DummyLock() try: mgr = driver.DriverManager('oslo.messaging.executors', @@ -103,7 +111,6 @@ class MessageHandlingServer(service.ServiceBase): self._executor_cls = mgr.driver self._executor = None self._running = False - self._thread_id = None super(MessageHandlingServer, self).__init__() @@ -121,8 +128,6 @@ class MessageHandlingServer(service.ServiceBase): choose to dispatch messages in a new thread, coroutine or simply the current thread. """ - self._check_same_thread_id() - if self._executor is not None: return try: @@ -130,20 +135,11 @@ class MessageHandlingServer(service.ServiceBase): except driver_base.TransportDriverError as ex: raise ServerListenError(self.target, ex) - self._running = True - self._executor = self._executor_cls(self.conf, listener, - self.dispatcher) - self._executor.start() - - def _check_same_thread_id(self): - if self._thread_id is None: - self._thread_id = self._get_thread_id() - elif self._thread_id != self._get_thread_id(): - # NOTE(dims): Need to change this to raise RuntimeError after - # verifying/fixing other openstack projects (like Neutron) - # work ok with this change - LOG.warn(_LW("start/stop/wait must be called in the " - "same thread")) + with self._state_lock: + self._running = True + self._executor = self._executor_cls(self.conf, listener, + self.dispatcher) + self._executor.start() def stop(self): """Stop handling incoming messages. @@ -153,11 +149,10 @@ class MessageHandlingServer(service.ServiceBase): some messages, and underlying driver resources associated to this server are still in use. See 'wait' for more details. """ - self._check_same_thread_id() - - if self._executor is not None: - self._running = False - self._executor.stop() + with self._state_lock: + if self._executor is not None: + self._running = False + self._executor.stop() def wait(self): """Wait for message processing to complete. @@ -169,25 +164,21 @@ class MessageHandlingServer(service.ServiceBase): Once it's finished, the underlying driver resources associated to this server are released (like closing useless network connections). """ - self._check_same_thread_id() + with self._state_lock: + if self._running: + # NOTE(dims): Need to change this to raise RuntimeError after + # verifying/fixing other openstack projects (like Neutron) + # work ok with this change + LOG.warn(_LW("wait() should be called after stop() as it " + "waits for existing messages to finish " + "processing")) - if self._running: - # NOTE(dims): Need to change this to raise RuntimeError after - # verifying/fixing other openstack projects (like Neutron) - # work ok with this change - LOG.warn(_LW("wait() should be called after stop() as it " - "waits for existing messages to finish " - "processing")) + if self._executor is not None: + self._executor.wait() + # Close listener connection after processing all messages + self._executor.listener.cleanup() - if self._executor is not None: - self._executor.wait() - # Close listener connection after processing all messages - self._executor.listener.cleanup() - - self._executor = None - # NOTE(sileht): executor/listener have been properly stopped - # allow to restart it into another thread - self._thread_id = None + self._executor = None def reset(self): """Reset service. diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py index dfa900cd..a0730341 100644 --- a/oslo_messaging/tests/rpc/test_server.py +++ b/oslo_messaging/tests/rpc/test_server.py @@ -21,7 +21,6 @@ import testscenarios import mock import oslo_messaging from oslo_messaging.tests import utils as test_utils -from six.moves import mock load_tests = testscenarios.load_tests_apply_scenarios @@ -151,31 +150,6 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin): 'stop() as it waits for existing ' 'messages to finish processing') - @mock.patch('oslo_messaging._executors.impl_pooledexecutor.' - 'PooledExecutor.wait') - def test_server_invalid_stop_from_other_thread(self, mock_wait): - transport = oslo_messaging.get_transport(self.conf, url='fake:') - target = oslo_messaging.Target(topic='foo', server='bar') - endpoints = [object()] - serializer = object() - - server = oslo_messaging.get_rpc_server(transport, target, endpoints, - serializer=serializer, - executor='eventlet') - - t = test_utils.ServerThreadHelper(server) - t.start() - self.addCleanup(t.join) - self.addCleanup(t.stop) - with mock.patch('logging.Logger.warn') as warn: - server.stop() - warn.assert_called_with('start/stop/wait must be called ' - 'in the same thread') - with mock.patch('logging.Logger.warn') as warn: - server.wait() - warn.assert_called_with('start/stop/wait must be called ' - 'in the same thread') - def test_no_target_server(self): transport = oslo_messaging.get_transport(self.conf, url='fake:')