From ac2176cde3630b5953cf368c6c1511fb203c13d7 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Wed, 7 Aug 2013 13:07:05 +0100 Subject: [PATCH] Add a per-transport allow_remote_exmods API Currently we have a allowed_rpc_exception_modules configuration variable which we use to configure a per-project list of modules which we will allow exceptions to be instantiated from when deserializing remote errors. It makes no sense for this to be user configurable, instead the list of modules should be set when you create a transport. Closes-Bug: #1031719 Change-Id: Ib40e92cb920996ec5e8f63d6f2cbd88fd01a90f2 --- oslo/messaging/_drivers/amqpdriver.py | 16 ++++++---- oslo/messaging/_drivers/base.py | 4 ++- oslo/messaging/_drivers/common.py | 6 ++-- oslo/messaging/_drivers/impl_fake.py | 6 ++-- oslo/messaging/_drivers/impl_qpid.py | 6 ++-- oslo/messaging/_drivers/impl_rabbit.py | 6 ++-- oslo/messaging/transport.py | 8 +++-- tests/test_exception_serialization.py | 26 ++++++---------- tests/test_transport.py | 43 ++++++++++++++++++-------- 9 files changed, 73 insertions(+), 48 deletions(-) diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 3486a53a6..40c1fcc82 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -135,10 +135,11 @@ class ReplyWaiters(object): class ReplyWaiter(object): - def __init__(self, conf, reply_q, conn): + def __init__(self, conf, reply_q, conn, allowed_remote_exmods): self.conf = conf self.conn = conn self.reply_q = reply_q + self.allowed_remote_exmods = allowed_remote_exmods self.conn_lock = threading.Lock() self.incoming = [] @@ -163,8 +164,8 @@ class ReplyWaiter(object): self.msg_id_cache.check_duplicate_message(data) if data['failure']: failure = data['failure'] - result = rpc_common.deserialize_remote_exception(self.conf, - failure) + result = rpc_common.deserialize_remote_exception( + failure, self.allowed_remote_exmods) elif data.get('ending', False): ending = True else: @@ -241,8 +242,10 @@ class ReplyWaiter(object): class AMQPDriverBase(base.BaseDriver): - def __init__(self, conf, connection_pool, url=None, default_exchange=None): - super(AMQPDriverBase, self).__init__(conf, url, default_exchange) + def __init__(self, conf, connection_pool, url=None, default_exchange=None, + allowed_remote_exmods=[]): + super(AMQPDriverBase, self).__init__(conf, url, default_exchange, + allowed_remote_exmods) self._default_exchange = urls.exchange_from_url(url, default_exchange) @@ -271,7 +274,8 @@ class AMQPDriverBase(base.BaseDriver): conn = self._get_connection(pooled=False) - self._waiter = ReplyWaiter(self.conf, reply_q, conn) + self._waiter = ReplyWaiter(self.conf, reply_q, conn, + self._allowed_remote_exmods) self._reply_q = reply_q self._reply_q_conn = conn diff --git a/oslo/messaging/_drivers/base.py b/oslo/messaging/_drivers/base.py index 0868d6d1d..085b28dd6 100644 --- a/oslo/messaging/_drivers/base.py +++ b/oslo/messaging/_drivers/base.py @@ -55,10 +55,12 @@ class BaseDriver(object): __metaclass__ = abc.ABCMeta - def __init__(self, conf, url=None, default_exchange=None): + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): self.conf = conf self._url = url self._default_exchange = default_exchange + self._allowed_remote_exmods = allowed_remote_exmods @abc.abstractmethod def send(self, target, ctxt, message, diff --git a/oslo/messaging/_drivers/common.py b/oslo/messaging/_drivers/common.py index b9b68d6e6..9fd169cb0 100644 --- a/oslo/messaging/_drivers/common.py +++ b/oslo/messaging/_drivers/common.py @@ -73,7 +73,6 @@ _MESSAGE_KEY = 'oslo.message' _REMOTE_POSTFIX = '_Remote' -# FIXME(markmc): add an API to replace this option _exception_opts = [ cfg.ListOpt('allowed_rpc_exception_modules', default=['oslo.messaging.exceptions', @@ -330,7 +329,7 @@ def serialize_remote_exception(failure_info, log_failure=True): return json_data -def deserialize_remote_exception(conf, data): +def deserialize_remote_exception(data, allowed_remote_exmods): failure = jsonutils.loads(str(data)) trace = failure.get('tb', []) @@ -340,8 +339,7 @@ def deserialize_remote_exception(conf, data): # NOTE(ameade): We DO NOT want to allow just any module to be imported, in # order to prevent arbitrary code execution. - conf.register_opts(_exception_opts) - if module not in conf.allowed_rpc_exception_modules: + if module != 'exceptions' and module not in allowed_remote_exmods: return messaging.RemoteError(name, failure.get('message'), trace) try: diff --git a/oslo/messaging/_drivers/impl_fake.py b/oslo/messaging/_drivers/impl_fake.py index ffce9677a..42ddd86b4 100644 --- a/oslo/messaging/_drivers/impl_fake.py +++ b/oslo/messaging/_drivers/impl_fake.py @@ -87,8 +87,10 @@ class FakeExchange(object): class FakeDriver(base.BaseDriver): - def __init__(self, conf, url=None, default_exchange=None): - super(FakeDriver, self).__init__(conf, url, default_exchange) + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): + super(FakeDriver, self).__init__(conf, url, default_exchange, + allowed_remote_exmods=[]) self._default_exchange = urls.exchange_from_url(url, default_exchange) diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index cdb77924a..e159277ce 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -742,11 +742,13 @@ def cleanup(): class QpidDriver(amqpdriver.AMQPDriverBase): - def __init__(self, conf, url=None, default_exchange=None): + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): conf.register_opts(qpid_opts) conf.register_opts(rpc_amqp.amqp_opts) connection_pool = rpc_amqp.get_connection_pool(conf, Connection) super(QpidDriver, self).__init__(conf, connection_pool, - url, default_exchange) + url, default_exchange, + allowed_remote_exmods) diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index d997d78ec..71f65b49f 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -873,11 +873,13 @@ def cleanup(): class RabbitDriver(amqpdriver.AMQPDriverBase): - def __init__(self, conf, url=None, default_exchange=None): + def __init__(self, conf, url=None, default_exchange=None, + allowed_remote_exmods=[]): conf.register_opts(rabbit_opts) conf.register_opts(rpc_amqp.amqp_opts) connection_pool = rpc_amqp.get_connection_pool(conf, Connection) super(RabbitDriver, self).__init__(conf, connection_pool, - url, default_exchange) + url, default_exchange, + allowed_remote_exmods) diff --git a/oslo/messaging/transport.py b/oslo/messaging/transport.py index 20083688f..87572f5ff 100644 --- a/oslo/messaging/transport.py +++ b/oslo/messaging/transport.py @@ -119,7 +119,7 @@ class DriverLoadFailure(exceptions.MessagingException): self.ex = ex -def get_transport(conf, url=None): +def get_transport(conf, url=None, allowed_remote_exmods=[]): """A factory method for Transport objects. This method will construct a Transport object from transport configuration @@ -140,6 +140,9 @@ def get_transport(conf, url=None): :type conf: cfg.ConfigOpts :param url: a transport URL :type url: str + :param allowed_remote_exmods: a list of modules which a client using this + transport will deserialize remote exceptions from + :type allowed_remote_exmods: list """ conf.register_opts(_transport_opts) @@ -151,7 +154,8 @@ def get_transport(conf, url=None): else: rpc_backend = conf.rpc_backend - kwargs = dict(default_exchange=conf.control_exchange) + kwargs = dict(default_exchange=conf.control_exchange, + allowed_remote_exmods=allowed_remote_exmods) if url is not None: kwargs['url'] = url diff --git a/tests/test_exception_serialization.py b/tests/test_exception_serialization.py index 71195b806..8884c60d3 100644 --- a/tests/test_exception_serialization.py +++ b/tests/test_exception_serialization.py @@ -150,7 +150,7 @@ SerializeRemoteExceptionTestCase.generate_scenarios() class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): - _standard_allowed = [__name__, 'exceptions'] + _standard_allowed = [__name__] scenarios = [ ('bog_standard', @@ -203,18 +203,18 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): remote_kwargs={})), ('not_allowed', dict(allowed=[], - clsname='Exception', - modname='exceptions', + clsname='NovaStyleException', + modname=__name__, cls=messaging.RemoteError, args=[], kwargs={}, - str=("Remote error: Exception test\n" + str=("Remote error: NovaStyleException test\n" "[u'traceback\\ntraceback\\n']."), - msg=("Remote error: Exception test\n" + msg=("Remote error: NovaStyleException test\n" "[u'traceback\\ntraceback\\n']."), remote_name='RemoteError', remote_args=(), - remote_kwargs={'exc_type': 'Exception', + remote_kwargs={'exc_type': 'NovaStyleException', 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('unknown_module', @@ -234,7 +234,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('unknown_exception', - dict(allowed=['exceptions'], + dict(allowed=[], clsname='FarcicalError', modname='exceptions', cls=messaging.RemoteError, @@ -250,7 +250,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('unknown_kwarg', - dict(allowed=['exceptions'], + dict(allowed=[], clsname='Exception', modname='exceptions', cls=messaging.RemoteError, @@ -266,7 +266,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'value': 'test', 'traceback': 'traceback\ntraceback\n'})), ('system_exit', - dict(allowed=['exceptions'], + dict(allowed=[], clsname='SystemExit', modname='exceptions', cls=messaging.RemoteError, @@ -283,13 +283,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): 'traceback': 'traceback\ntraceback\n'})), ] - def setUp(self): - super(DeserializeRemoteExceptionTestCase, self).setUp() - self.conf.register_opts(exceptions._exception_opts) - def test_deserialize_remote_exception(self): - self.config(allowed_rpc_exception_modules=self.allowed) - failure = { 'class': self.clsname, 'module': self.modname, @@ -301,7 +295,7 @@ class DeserializeRemoteExceptionTestCase(test_utils.BaseTestCase): serialized = jsonutils.dumps(failure) - ex = exceptions.deserialize_remote_exception(self.conf, serialized) + ex = exceptions.deserialize_remote_exception(serialized, self.allowed) self.assertIsInstance(ex, self.cls) self.assertEqual(ex.__class__.__name__, self.remote_name) diff --git a/tests/test_transport.py b/tests/test_transport.py index 0722ac554..bbdc0b796 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -54,34 +54,46 @@ class GetTransportTestCase(test_utils.BaseTestCase): scenarios = [ ('all_none', dict(url=None, transport_url=None, rpc_backend=None, - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend=None, exchange=None, - url=None))), + url=None, + allowed=[]))), ('rpc_backend', dict(url=None, transport_url=None, rpc_backend='testbackend', - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend='testbackend', exchange=None, - url=None))), + url=None, + allowed=[]))), ('control_exchange', dict(url=None, transport_url=None, rpc_backend=None, - control_exchange='testexchange', + control_exchange='testexchange', allowed=None, expect=dict(backend=None, exchange='testexchange', - url=None))), + url=None, + allowed=[]))), ('transport_url', dict(url=None, transport_url='testtransport:', rpc_backend=None, - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend='testtransport', exchange=None, - url='testtransport:'))), + url='testtransport:', + allowed=[]))), ('url_param', dict(url='testtransport:', transport_url=None, rpc_backend=None, - control_exchange=None, + control_exchange=None, allowed=None, expect=dict(backend='testtransport', exchange=None, - url='testtransport:'))), + url='testtransport:', + allowed=[]))), + ('allowed_remote_exmods', + dict(url=None, transport_url=None, rpc_backend=None, + control_exchange=None, allowed=['foo', 'bar'], + expect=dict(backend=None, + exchange=None, + url=None, + allowed=['foo', 'bar']))), ] def setUp(self): @@ -96,7 +108,8 @@ class GetTransportTestCase(test_utils.BaseTestCase): self.mox.StubOutWithMock(driver, 'DriverManager') invoke_args = [self.conf] - invoke_kwds = dict(default_exchange=self.expect['exchange']) + invoke_kwds = dict(default_exchange=self.expect['exchange'], + allowed_remote_exmods=self.expect['allowed']) if self.expect['url']: invoke_kwds['url'] = self.expect['url'] @@ -110,7 +123,10 @@ class GetTransportTestCase(test_utils.BaseTestCase): self.mox.ReplayAll() - transport = messaging.get_transport(self.conf, url=self.url) + kwargs = dict(url=self.url) + if self.allowed is not None: + kwargs['allowed_remote_exmods'] = self.allowed + transport = messaging.get_transport(self.conf, **kwargs) self.assertIsNotNone(transport) self.assertIs(transport.conf, self.conf) @@ -149,7 +165,8 @@ class GetTransportSadPathTestCase(test_utils.BaseTestCase): self.mox.StubOutWithMock(driver, 'DriverManager') invoke_args = [self.conf] - invoke_kwds = dict(default_exchange='openstack') + invoke_kwds = dict(default_exchange='openstack', + allowed_remote_exmods=[]) driver.DriverManager('oslo.messaging.drivers', self.rpc_backend,