From e5d9d1b74fec1bcf0d93361670a1ffcc22d2bdaa Mon Sep 17 00:00:00 2001 From: Erik Olof Gunnar Andersson Date: Mon, 18 Jan 2021 21:04:45 -0800 Subject: [PATCH] Use common rpc pattern for all services This patch introduces a common rpc pattern to ensure that the rpc transport is shared where possible. This helps prevent rpc connection leaks and should ensure that we are making the best possible use of all available rpc connections. Change-Id: Ib42e368cfda2b148a07df0bd74046739f40f7018 --- murano/common/engine.py | 10 +-- murano/common/rpc.py | 80 ++++++++++++++++++----- murano/common/server.py | 37 ++++------- murano/engine/system/instance_reporter.py | 10 ++- murano/engine/system/status_reporter.py | 19 ++---- murano/tests/unit/common/test_engine.py | 14 ++-- murano/tests/unit/common/test_server.py | 4 +- 7 files changed, 100 insertions(+), 74 deletions(-) diff --git a/murano/common/engine.py b/murano/common/engine.py index 3c5a97a95..9344f2880 100644 --- a/murano/common/engine.py +++ b/murano/common/engine.py @@ -20,8 +20,6 @@ import uuid import eventlet.debug from oslo_config import cfg from oslo_log import log as logging -import oslo_messaging as messaging -from oslo_messaging.rpc import dispatcher from oslo_messaging import target from oslo_serialization import jsonutils from oslo_service import service @@ -58,18 +56,16 @@ class EngineService(service.Service): self.server = None def start(self): + if not rpc.initialized(): + rpc.init() endpoints = [ TaskProcessingEndpoint(), StaticActionEndpoint(), SchemaEndpoint() ] - transport = messaging.get_rpc_transport(CONF) s_target = target.Target('murano', 'tasks', server=str(uuid.uuid4())) - access_policy = dispatcher.DefaultRPCAccessPolicy - self.server = messaging.get_rpc_server( - transport, s_target, endpoints, 'eventlet', - access_policy=access_policy) + self.server = rpc.get_server(s_target, endpoints, executor='eventlet') self.server.start() super(EngineService, self).start() diff --git a/murano/common/rpc.py b/murano/common/rpc.py index bbca38ff2..d54bcca17 100644 --- a/murano/common/rpc.py +++ b/murano/common/rpc.py @@ -14,18 +14,72 @@ from oslo_config import cfg import oslo_messaging as messaging -from oslo_messaging import rpc +from oslo_messaging.rpc import dispatcher from oslo_messaging import target CONF = cfg.CONF +NOTIFICATION_TRANSPORT = None TRANSPORT = None +def init(): + global TRANSPORT, NOTIFICATION_TRANSPORT + TRANSPORT = messaging.get_rpc_transport(CONF) + NOTIFICATION_TRANSPORT = messaging.get_notification_transport(CONF) + + +def initialized(): + return None not in [TRANSPORT, NOTIFICATION_TRANSPORT] + + +def cleanup(): + global TRANSPORT, NOTIFICATION_TRANSPORT + if TRANSPORT is not None: + TRANSPORT.cleanup() + if NOTIFICATION_TRANSPORT is not None: + NOTIFICATION_TRANSPORT.cleanup() + TRANSPORT = NOTIFICATION_TRANSPORT = None + + +def get_client(target, timeout=None): + if TRANSPORT is None: + init() + return messaging.RPCClient( + TRANSPORT, + target, + timeout=timeout + ) + + +def get_server(target, endpoints, executor): + if TRANSPORT is None: + init() + access_policy = dispatcher.DefaultRPCAccessPolicy + return messaging.get_rpc_server( + TRANSPORT, + target, + endpoints, + executor=executor, + access_policy=access_policy + ) + + +def get_notification_listener(targets, endpoints, executor): + if NOTIFICATION_TRANSPORT is None: + init() + return messaging.get_notification_listener( + NOTIFICATION_TRANSPORT, + targets, + endpoints, + executor=executor + ) + + class ApiClient(object): - def __init__(self, transport): + def __init__(self): client_target = target.Target('murano', 'results') - self._client = rpc.RPCClient(transport, client_target, timeout=15) + self._client = get_client(client_target, timeout=15) def process_result(self, result, environment_id): return self._client.call({}, 'process_result', result=result, @@ -33,9 +87,9 @@ class ApiClient(object): class EngineClient(object): - def __init__(self, transport): + def __init__(self): client_target = target.Target('murano', 'tasks') - self._client = rpc.RPCClient(transport, client_target, timeout=15) + self._client = get_client(client_target, timeout=15) def handle_task(self, task): return self._client.cast({}, 'handle_task', task=task) @@ -55,16 +109,12 @@ class EngineClient(object): def api(): - global TRANSPORT - if TRANSPORT is None: - TRANSPORT = messaging.get_rpc_transport(CONF) - - return ApiClient(TRANSPORT) + if not initialized(): + init() + return ApiClient() def engine(): - global TRANSPORT - if TRANSPORT is None: - TRANSPORT = messaging.get_rpc_transport(CONF) - - return EngineClient(TRANSPORT) + if not initialized(): + init() + return EngineClient() diff --git a/murano/common/server.py b/murano/common/server.py index 9410a4af0..af156050d 100644 --- a/murano/common/server.py +++ b/murano/common/server.py @@ -16,8 +16,6 @@ import uuid from oslo_config import cfg from oslo_log import log as logging -import oslo_messaging as messaging -from oslo_messaging.rpc import dispatcher from oslo_messaging import target from oslo_service import service from oslo_utils import timeutils @@ -25,6 +23,7 @@ import pytz from sqlalchemy import desc from murano.common.helpers import token_sanitizer +from murano.common import rpc from murano.db import models from murano.db.services import environments from murano.db.services import instances @@ -226,25 +225,21 @@ class Service(service.Service): def get_notification_listener(): - endpoints = [report_notification, track_instance, untrack_instance] - transport = messaging.get_notification_transport(CONF) s_target = target.Target(topic='murano', server=str(uuid.uuid4())) - listener = messaging.get_notification_listener( - transport, [s_target], endpoints, executor='threading') + listener = rpc.get_notification_listener( + [s_target], endpoints, executor='threading' + ) return listener def get_rpc_server(): - endpoints = [ResultEndpoint()] - transport = messaging.get_rpc_transport(CONF) s_target = target.Target('murano', 'results', server=str(uuid.uuid4())) - access_policy = dispatcher.DefaultRPCAccessPolicy - server = messaging.get_rpc_server( - transport, s_target, endpoints, 'threading', - access_policy=access_policy) + server = rpc.get_server( + s_target, endpoints, executor='threading' + ) return server @@ -256,13 +251,10 @@ class NotificationService(Service): def start(self): endpoints = [report_notification, track_instance, untrack_instance] - - transport = messaging.get_notification_transport(CONF) s_target = target.Target(topic='murano', server=str(uuid.uuid4())) - - self.server = messaging.get_notification_listener( - transport, [s_target], endpoints, executor='eventlet') - + self.server = rpc.get_notification_listener( + [s_target], endpoints, executor='eventlet' + ) self.server.start() super(NotificationService, self).start() @@ -271,12 +263,9 @@ class ApiService(Service): def start(self): endpoints = [ResultEndpoint()] - - transport = messaging.get_rpc_transport(CONF) s_target = target.Target('murano', 'results', server=str(uuid.uuid4())) - access_policy = dispatcher.DefaultRPCAccessPolicy - self.server = messaging.get_rpc_server( - transport, s_target, endpoints, 'eventlet', - access_policy=access_policy) + self.server = rpc.get_server( + s_target, endpoints, executor='eventlet' + ) self.server.start() super(ApiService, self).start() diff --git a/murano/engine/system/instance_reporter.py b/murano/engine/system/instance_reporter.py index aa05879b7..c028a105d 100644 --- a/murano/engine/system/instance_reporter.py +++ b/murano/engine/system/instance_reporter.py @@ -16,6 +16,7 @@ from oslo_config import cfg import oslo_messaging as messaging +from murano.common import rpc from murano.common import uuidutils from murano.dsl import dsl @@ -28,14 +29,11 @@ OS_INSTANCE = 200 @dsl.name('io.murano.system.InstanceNotifier') class InstanceReportNotifier(object): - transport = None - def __init__(self, environment): - if InstanceReportNotifier.transport is None: - InstanceReportNotifier.transport = \ - messaging.get_notification_transport(CONF) + if not rpc.initialized(): + rpc.init() self._notifier = messaging.Notifier( - InstanceReportNotifier.transport, + rpc.NOTIFICATION_TRANSPORT, publisher_id=uuidutils.generate_uuid(), topics=['murano']) self._environment_id = environment.id diff --git a/murano/engine/system/status_reporter.py b/murano/engine/system/status_reporter.py index 80298231e..92aa87fa3 100644 --- a/murano/engine/system/status_reporter.py +++ b/murano/engine/system/status_reporter.py @@ -20,6 +20,7 @@ from oslo_config import cfg from oslo_log import log as logging import oslo_messaging as messaging +from murano.common import rpc from murano.common import uuidutils from murano.dsl import dsl @@ -29,14 +30,11 @@ LOG = logging.getLogger(__name__) @dsl.name('io.murano.system.StatusReporter') class StatusReporter(object): - transport = None - def __init__(self, environment): - if StatusReporter.transport is None: - StatusReporter.transport = messaging.get_notification_transport( - CONF) + if not rpc.initialized(): + rpc.init() self._notifier = messaging.Notifier( - StatusReporter.transport, + rpc.NOTIFICATION_TRANSPORT, publisher_id=uuidutils.generate_uuid(), topics=['murano']) if isinstance(environment, str): @@ -68,16 +66,13 @@ class StatusReporter(object): class Notification(object): - transport = None - def __init__(self): if not CONF.stats.env_audit_enabled: return - - if Notification.transport is None: - Notification.transport = messaging.get_notification_transport(CONF) + if not rpc.initialized(): + rpc.init() self._notifier = messaging.Notifier( - Notification.transport, + rpc.NOTIFICATION_TRANSPORT, publisher_id=('murano.%s' % socket.gethostname()), driver='messaging') diff --git a/murano/tests/unit/common/test_engine.py b/murano/tests/unit/common/test_engine.py index 3687bfef4..3f183c7cb 100644 --- a/murano/tests/unit/common/test_engine.py +++ b/murano/tests/unit/common/test_engine.py @@ -35,12 +35,11 @@ class TestEngineService(base.MuranoTestCase): @mock.patch.object(service.Service, 'reset') @mock.patch.object(service.Service, 'stop') @mock.patch.object(service.Service, 'start') - @mock.patch('murano.common.engine.messaging') - def test_start_stop_reset(self, mock_messaging, mock_start, + @mock.patch('murano.common.rpc.get_server') + def test_start_stop_reset(self, mock_get_server, mock_start, mock_stop, mock_reset): self.engine.start() - self.assertTrue(mock_messaging.get_rpc_transport.called) - self.assertTrue(mock_messaging.get_rpc_server.called) + self.assertTrue(mock_get_server.called) self.assertTrue(mock_start.called) self.engine.stop() self.assertTrue(mock_stop.called) @@ -49,11 +48,10 @@ class TestEngineService(base.MuranoTestCase): @mock.patch.object(service.Service, 'stop') @mock.patch.object(service.Service, 'start') - @mock.patch('murano.common.engine.messaging') - def test_stop_graceful(self, mock_messaging, mock_start, mock_stop): + @mock.patch('murano.common.rpc.get_server') + def test_stop_graceful(self, mock_get_server, mock_start, mock_stop): self.engine.start() - self.assertTrue(mock_messaging.get_rpc_transport.called) - self.assertTrue(mock_messaging.get_rpc_server.called) + self.assertTrue(mock_get_server.called) self.assertTrue(mock_start.called) self.engine.stop(graceful=True) self.assertTrue(mock_stop.called) diff --git a/murano/tests/unit/common/test_server.py b/murano/tests/unit/common/test_server.py index 05602c2ea..b56f4cc6d 100644 --- a/murano/tests/unit/common/test_server.py +++ b/murano/tests/unit/common/test_server.py @@ -280,7 +280,7 @@ class ServerTest(base.MuranoTestCase): service.reset() service.server.reset.assert_called_once_with() - @mock.patch('murano.common.server.messaging') + @mock.patch('murano.common.rpc.messaging') def test_notification_service_class(self, mock_messaging): mock_server = mock.MagicMock() mock_messaging.get_notification_listener.return_value = mock_server @@ -292,7 +292,7 @@ class ServerTest(base.MuranoTestCase): mock_messaging.get_notification_listener.call_count) mock_server.start.assert_called_once_with() - @mock.patch('murano.common.server.messaging') + @mock.patch('murano.common.rpc.messaging') def test_api_service_class(self, mock_messaging): mock_server = mock.MagicMock() mock_messaging.get_rpc_server.return_value = mock_server