From b336bb34dc9fdf1cc9a84b9d94f578564360b18d Mon Sep 17 00:00:00 2001 From: Erik Olof Gunnar Andersson Date: Tue, 16 Feb 2021 20:21:08 -0800 Subject: [PATCH] Re-factored rpc serializer This patch cleans up the current rpc implementation by moving the seralizer back to the rpc module, this is more in line with other projects; such as Nova. - Moved _init_serializer back into rpc. - Added back unit-tests for profiler. Change-Id: Ia148b2d3bc352e96e7633f7af82ecd26b5f35e35 --- magnum/common/rpc.py | 8 +++++ magnum/common/rpc_service.py | 29 ++++++------------ magnum/tests/unit/common/test_rpc.py | 45 ++++++++++++++++++++++++++-- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/magnum/common/rpc.py b/magnum/common/rpc.py index c33f40219f..4bfe830372 100644 --- a/magnum/common/rpc.py +++ b/magnum/common/rpc.py @@ -142,6 +142,10 @@ def get_transport_url(url_str=None): def get_client(target, version_cap=None, serializer=None, timeout=None): assert TRANSPORT is not None + if profiler: + serializer = ProfilerRequestContextSerializer(serializer) + else: + serializer = RequestContextSerializer(serializer) return messaging.RPCClient(TRANSPORT, target, version_cap=version_cap, @@ -151,6 +155,10 @@ def get_client(target, version_cap=None, serializer=None, timeout=None): def get_server(target, endpoints, serializer=None): assert TRANSPORT is not None + if profiler: + serializer = ProfilerRequestContextSerializer(serializer) + else: + serializer = RequestContextSerializer(serializer) access_policy = dispatcher.DefaultRPCAccessPolicy return messaging.get_rpc_server(TRANSPORT, target, diff --git a/magnum/common/rpc_service.py b/magnum/common/rpc_service.py index 99aa75c43d..4345809a84 100644 --- a/magnum/common/rpc_service.py +++ b/magnum/common/rpc_service.py @@ -16,7 +16,6 @@ import oslo_messaging as messaging from oslo_service import service -from oslo_utils import importutils from magnum.common import profiler from magnum.common import rpc @@ -26,30 +25,19 @@ from magnum.service import periodic from magnum.servicegroup import magnum_service_periodic as servicegroup -osprofiler = importutils.try_import("osprofiler.profiler") - CONF = magnum.conf.CONF -def _init_serializer(): - serializer = rpc.RequestContextSerializer( - objects_base.MagnumObjectSerializer()) - if osprofiler: - serializer = rpc.ProfilerRequestContextSerializer(serializer) - else: - serializer = rpc.RequestContextSerializer(serializer) - return serializer - - class Service(service.Service): def __init__(self, topic, server, handlers, binary): super(Service, self).__init__() - serializer = _init_serializer() # TODO(asalkeld) add support for version='x.y' target = messaging.Target(topic=topic, server=server) - self._server = rpc.get_server(target, handlers, - serializer=serializer) + self._server = rpc.get_server( + target, handlers, + serializer=objects_base.MagnumObjectSerializer() + ) self.binary = binary profiler.setup(binary, CONF.host) @@ -77,14 +65,15 @@ class Service(service.Service): class API(object): def __init__(self, context=None, topic=None, server=None, timeout=None): - serializer = _init_serializer() self._context = context if topic is None: topic = '' target = messaging.Target(topic=topic, server=server) - self._client = rpc.get_client(target, - serializer=serializer, - timeout=timeout) + self._client = rpc.get_client( + target, + serializer=objects_base.MagnumObjectSerializer(), + timeout=timeout + ) def _call(self, method, *args, **kwargs): return self._client.call(self._context, method, *args, **kwargs) diff --git a/magnum/tests/unit/common/test_rpc.py b/magnum/tests/unit/common/test_rpc.py index 2ea39dd432..71a981103b 100644 --- a/magnum/tests/unit/common/test_rpc.py +++ b/magnum/tests/unit/common/test_rpc.py @@ -26,12 +26,32 @@ from magnum.tests import base class TestRpc(base.TestCase): @mock.patch.object(rpc, 'profiler', None) + @mock.patch.object(rpc, 'RequestContextSerializer') @mock.patch.object(messaging, 'RPCClient') - def test_get_client(self, mock_client): + def test_get_client(self, mock_client, mock_ser): rpc.TRANSPORT = mock.Mock() tgt = mock.Mock() ser = mock.Mock() mock_client.return_value = 'client' + mock_ser.return_value = ser + + client = rpc.get_client(tgt, version_cap='1.0', serializer=ser, + timeout=6969) + + mock_client.assert_called_once_with(rpc.TRANSPORT, + tgt, version_cap='1.0', + serializer=ser, timeout=6969) + self.assertEqual('client', client) + + @mock.patch.object(rpc, 'profiler', mock.Mock()) + @mock.patch.object(rpc, 'ProfilerRequestContextSerializer') + @mock.patch.object(messaging, 'RPCClient') + def test_get_client_profiler_enabled(self, mock_client, mock_ser): + rpc.TRANSPORT = mock.Mock() + tgt = mock.Mock() + ser = mock.Mock() + mock_client.return_value = 'client' + mock_ser.return_value = ser client = rpc.get_client(tgt, version_cap='1.0', serializer=ser, timeout=6969) @@ -42,13 +62,15 @@ class TestRpc(base.TestCase): self.assertEqual('client', client) @mock.patch.object(rpc, 'profiler', None) + @mock.patch.object(rpc, 'RequestContextSerializer') @mock.patch.object(messaging, 'get_rpc_server') - def test_get_server(self, mock_get): + def test_get_server(self, mock_get, mock_ser): rpc.TRANSPORT = mock.Mock() ser = mock.Mock() tgt = mock.Mock() ends = mock.Mock() mock_get.return_value = 'server' + mock_ser.return_value = ser access_policy = dispatcher.DefaultRPCAccessPolicy server = rpc.get_server(tgt, ends, serializer=ser) @@ -57,6 +79,25 @@ class TestRpc(base.TestCase): access_policy=access_policy) self.assertEqual('server', server) + @mock.patch.object(rpc, 'profiler', mock.Mock()) + @mock.patch.object(rpc, 'ProfilerRequestContextSerializer') + @mock.patch.object(messaging, 'get_rpc_server') + def test_get_server_profiler_enabled(self, mock_get, mock_ser): + rpc.TRANSPORT = mock.Mock() + ser = mock.Mock() + tgt = mock.Mock() + ends = mock.Mock() + mock_ser.return_value = ser + mock_get.return_value = 'server' + access_policy = dispatcher.DefaultRPCAccessPolicy + server = rpc.get_server(tgt, ends, serializer='foo') + + mock_ser.assert_called_once_with('foo') + mock_get.assert_called_once_with(rpc.TRANSPORT, tgt, ends, + executor='eventlet', serializer=ser, + access_policy=access_policy) + self.assertEqual('server', server) + @mock.patch.object(messaging, 'TransportURL') def test_get_transport_url(self, mock_url): conf = mock.Mock()