From 49a104079d93b73f9a2caec77312ee173b1d459c Mon Sep 17 00:00:00 2001 From: Alex Meade Date: Wed, 21 Mar 2012 19:13:16 +0000 Subject: [PATCH] Add the serialization of exceptions for RPC calls. This change uses json to serialize an exception so that it can be sent through RPC calls to be reconstructed on the other side. The traceback is added to the exception message. If recreating the exception fails for whatever reason then a RemoteError is created containing all of the exception information. Adds flag 'allowed_rpc_exception_modules' to prevent dangerous modules from being accessed and allowing arbitrary code to be run. Fixes bug 920705 Fixes bug 940500 Change-Id: Ife3b64b19fe8abbc730184d4ee7d9fcabfd29db3 --- nova/rpc/amqp.py | 11 +-- nova/rpc/common.py | 79 ++++++++++++++++++ nova/rpc/impl_fake.py | 12 +-- nova/tests/rpc/common.py | 32 ++------ nova/tests/rpc/test_common.py | 147 ++++++++++++++++++++++++++++++++++ nova/tests/rpc/test_kombu.py | 53 +++++++++++- 6 files changed, 293 insertions(+), 41 deletions(-) create mode 100644 nova/tests/rpc/test_common.py diff --git a/nova/rpc/amqp.py b/nova/rpc/amqp.py index 444ade48..df6e1b97 100644 --- a/nova/rpc/amqp.py +++ b/nova/rpc/amqp.py @@ -27,7 +27,6 @@ AMQP, but is deprecated and predates this code. import inspect import sys -import traceback import uuid from eventlet import greenpool @@ -141,11 +140,7 @@ def msg_reply(msg_id, connection_pool, reply=None, failure=None, ending=False): """ with ConnectionContext(connection_pool) as conn: if failure: - message = str(failure[1]) - tb = traceback.format_exception(*failure) - LOG.error(_("Returning exception %s to caller"), message) - LOG.error(tb) - failure = (failure[0].__name__, str(failure[1]), tb) + failure = rpc_common.serialize_remote_exception(failure) try: msg = {'result': reply, 'failure': failure} @@ -285,7 +280,9 @@ class MulticallWaiter(object): def __call__(self, data): """The consume() callback will call this. Store the result.""" if data['failure']: - self._result = rpc_common.RemoteError(*data['failure']) + failure = data['failure'] + self._result = rpc_common.deserialize_remote_exception(failure) + elif data.get('ending', False): self._got_ending = True else: diff --git a/nova/rpc/common.py b/nova/rpc/common.py index 95c24581..51bf2fd2 100644 --- a/nova/rpc/common.py +++ b/nova/rpc/common.py @@ -18,11 +18,14 @@ # under the License. import copy +import sys +import traceback from nova import exception from nova import flags from nova import log as logging from nova.openstack.common import cfg +from nova import utils LOG = logging.getLogger(__name__) @@ -37,9 +40,14 @@ rpc_opts = [ cfg.IntOpt('rpc_response_timeout', default=60, help='Seconds to wait for a response from call or multicall'), + cfg.IntOpt('allowed_rpc_exception_modules', + default=['nova.exception'], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), ] flags.FLAGS.register_opts(rpc_opts) +FLAGS = flags.FLAGS class RemoteError(exception.NovaException): @@ -158,3 +166,74 @@ def _safe_log(log_func, msg, msg_data): msg_data['auth_token'] = '' return log_func(msg, msg_data) + + +def serialize_remote_exception(failure_info): + """Prepares exception data to be sent over rpc. + + Failure_info should be a sys.exc_info() tuple. + + """ + tb = traceback.format_exception(*failure_info) + failure = failure_info[1] + LOG.error(_("Returning exception %s to caller"), unicode(failure)) + LOG.error(tb) + + kwargs = {} + if hasattr(failure, 'kwargs'): + kwargs = failure.kwargs + + data = { + 'class': str(failure.__class__.__name__), + 'module': str(failure.__class__.__module__), + 'message': unicode(failure), + 'tb': tb, + 'args': failure.args, + 'kwargs': kwargs + } + + json_data = utils.dumps(data) + + return json_data + + +def deserialize_remote_exception(data): + failure = utils.loads(str(data)) + + trace = failure.get('tb', []) + message = failure.get('message', "") + "\n" + "\n".join(trace) + name = failure.get('class') + module = failure.get('module') + + # NOTE(ameade): We DO NOT want to allow just any module to be imported, in + # order to prevent arbitrary code execution. + if not module in FLAGS.allowed_rpc_exception_modules: + return RemoteError(name, failure.get('message'), trace) + + try: + __import__(module) + mod = sys.modules[module] + klass = getattr(mod, name) + if not issubclass(klass, Exception): + raise TypeError("Can only deserialize Exceptions") + + failure = klass(**failure.get('kwargs', {})) + except (AttributeError, TypeError, ImportError): + return RemoteError(name, failure.get('message'), trace) + + ex_type = type(failure) + str_override = lambda self: message + new_ex_type = type(ex_type.__name__ + "_Remote", (ex_type,), + {'__str__': str_override}) + try: + # NOTE(ameade): Dynamically create a new exception type and swap it in + # as the new type for the exception. This only works on user defined + # Exceptions and not core python exceptions. This is important because + # we cannot necessarily change an exception message so we must override + # the __str__ method. + failure.__class__ = new_ex_type + except TypeError as e: + # NOTE(ameade): If a core exception then just add the traceback to the + # first exception argument. + failure.args = (message,) + failure.args[1:] + return failure diff --git a/nova/rpc/impl_fake.py b/nova/rpc/impl_fake.py index 42ed7907..43aed15c 100644 --- a/nova/rpc/impl_fake.py +++ b/nova/rpc/impl_fake.py @@ -77,12 +77,8 @@ class Consumer(object): else: res.append(rval) done.send(res) - except Exception: - exc_info = sys.exc_info() - done.send_exception( - rpc_common.RemoteError(exc_info[0].__name__, - str(exc_info[1]), - ''.join(traceback.format_exception(*exc_info)))) + except Exception as e: + done.send_exception(e) thread = eventlet.greenthread.spawn(_inner) @@ -161,7 +157,7 @@ def call(context, topic, msg, timeout=None): def cast(context, topic, msg): try: call(context, topic, msg) - except rpc_common.RemoteError: + except Exception: pass @@ -184,5 +180,5 @@ def fanout_cast(context, topic, msg): for consumer in CONSUMERS.get(topic, []): try: consumer.call(context, method, args, None) - except rpc_common.RemoteError: + except Exception: pass diff --git a/nova/tests/rpc/common.py b/nova/tests/rpc/common.py index 87cb522c..3524e568 100644 --- a/nova/tests/rpc/common.py +++ b/nova/tests/rpc/common.py @@ -25,6 +25,7 @@ from eventlet import greenthread import nose from nova import context +from nova import exception from nova import log as logging from nova.rpc import amqp as rpc_amqp from nova.rpc import common as rpc_common @@ -100,30 +101,6 @@ class BaseRpcTestCase(test.TestCase): "args": {"value": value}}) self.assertEqual(self.context.to_dict(), result) - def test_call_exception(self): - """Test that exception gets passed back properly. - - rpc.call returns a RemoteError object. The value of the - exception is converted to a string, so we convert it back - to an int in the test. - - """ - value = 42 - self.assertRaises(rpc_common.RemoteError, - self.rpc.call, - self.context, - 'test', - {"method": "fail", - "args": {"value": value}}) - try: - self.rpc.call(self.context, - 'test', - {"method": "fail", - "args": {"value": value}}) - self.fail("should have thrown RemoteError") - except rpc_common.RemoteError as exc: - self.assertEqual(int(exc.value), value) - def test_nested_calls(self): """Test that we can do an rpc.call inside another call.""" class Nested(object): @@ -248,7 +225,12 @@ class TestReceiver(object): @staticmethod def fail(context, value): """Raises an exception with the value sent in.""" - raise Exception(value) + raise NotImplementedError(value) + + @staticmethod + def fail_converted(context, value): + """Raises an exception with the value sent in.""" + raise exception.ConvertedException(explanation=value) @staticmethod def block(context, value): diff --git a/nova/tests/rpc/test_common.py b/nova/tests/rpc/test_common.py new file mode 100644 index 00000000..6220bd01 --- /dev/null +++ b/nova/tests/rpc/test_common.py @@ -0,0 +1,147 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for 'common' functons used through rpc code. +""" + +import json +import sys + +from nova import context +from nova import exception +from nova import flags +from nova import log as logging +from nova import test +from nova.rpc import amqp as rpc_amqp +from nova.rpc import common as rpc_common +from nova.tests.rpc import common + +FLAGS = flags.FLAGS +LOG = logging.getLogger(__name__) + + +def raise_exception(): + raise Exception("test") + + +class FakeUserDefinedException(Exception): + def __init__(self): + Exception.__init__(self, "Test Message") + + +class RpcCommonTestCase(test.TestCase): + def test_serialize_remote_exception(self): + expected = { + 'class': 'Exception', + 'module': 'exceptions', + 'message': 'test', + } + + try: + raise_exception() + except Exception as exc: + failure = rpc_common.serialize_remote_exception(sys.exc_info()) + + failure = json.loads(failure) + #assure the traceback was added + self.assertEqual(expected['class'], failure['class']) + self.assertEqual(expected['module'], failure['module']) + self.assertEqual(expected['message'], failure['message']) + + def test_serialize_remote_nova_exception(self): + def raise_nova_exception(): + raise exception.NovaException("test", code=500) + + expected = { + 'class': 'NovaException', + 'module': 'nova.exception', + 'kwargs': {'code': 500}, + 'message': 'test' + } + + try: + raise_nova_exception() + except Exception as exc: + failure = rpc_common.serialize_remote_exception(sys.exc_info()) + + failure = json.loads(failure) + #assure the traceback was added + self.assertEqual(expected['class'], failure['class']) + self.assertEqual(expected['module'], failure['module']) + self.assertEqual(expected['kwargs'], failure['kwargs']) + self.assertEqual(expected['message'], failure['message']) + + def test_deserialize_remote_exception(self): + failure = { + 'class': 'NovaException', + 'module': 'nova.exception', + 'message': 'test message', + 'tb': ['raise NovaException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, exception.NovaException)) + self.assertTrue('test message' in unicode(after_exc)) + #assure the traceback was added + self.assertTrue('raise NovaException' in unicode(after_exc)) + + def test_deserialize_remote_exception_bad_module(self): + failure = { + 'class': 'popen2', + 'module': 'os', + 'kwargs': {'cmd': '/bin/echo failed'}, + 'message': 'foo', + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) + + def test_deserialize_remote_exception_user_defined_exception(self): + """Ensure a user defined exception can be deserialized.""" + self.flags(allowed_rpc_exception_modules=[self.__class__.__module__]) + failure = { + 'class': 'FakeUserDefinedException', + 'module': self.__class__.__module__, + 'tb': ['raise FakeUserDefinedException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) + #assure the traceback was added + self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc)) + + def test_deserialize_remote_exception_cannot_recreate(self): + """Ensure a RemoteError is returned on initialization failure. + + If an exception cannot be recreated with it's original class then a + RemoteError with the exception informations should still be returned. + + """ + self.flags(allowed_rpc_exception_modules=[self.__class__.__module__]) + failure = { + 'class': 'FakeIDontExistException', + 'module': self.__class__.__module__, + 'tb': ['raise FakeIDontExistException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) + #assure the traceback was added + self.assertTrue('raise FakeIDontExistException' in unicode(after_exc)) diff --git a/nova/tests/rpc/test_kombu.py b/nova/tests/rpc/test_kombu.py index aa49b5d5..966cb3a6 100644 --- a/nova/tests/rpc/test_kombu.py +++ b/nova/tests/rpc/test_kombu.py @@ -20,6 +20,7 @@ Unit Tests for remote procedure calls using kombu """ from nova import context +from nova import exception from nova import flags from nova import log as logging from nova import test @@ -292,4 +293,54 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): self.assertEqual(self.received_message, message) # Only called once, because our stub goes away during reconnection - self.assertEqual(info['called'], 1) + + def test_call_exception(self): + """Test that exception gets passed back properly. + + rpc.call returns an Exception object. The value of the + exception is converted to a string. + + """ + self.flags(allowed_rpc_exception_modules=['exceptions']) + value = "This is the exception message" + self.assertRaises(NotImplementedError, + self.rpc.call, + self.context, + 'test', + {"method": "fail", + "args": {"value": value}}) + try: + self.rpc.call(self.context, + 'test', + {"method": "fail", + "args": {"value": value}}) + self.fail("should have thrown Exception") + except NotImplementedError as exc: + self.assertTrue(value in unicode(exc)) + #Traceback should be included in exception message + self.assertTrue('raise NotImplementedError(value)' in unicode(exc)) + + def test_call_converted_exception(self): + """Test that exception gets passed back properly. + + rpc.call returns an Exception object. The value of the + exception is converted to a string. + + """ + value = "This is the exception message" + self.assertRaises(exception.ConvertedException, + self.rpc.call, + self.context, + 'test', + {"method": "fail_converted", + "args": {"value": value}}) + try: + self.rpc.call(self.context, + 'test', + {"method": "fail_converted", + "args": {"value": value}}) + self.fail("should have thrown Exception") + except exception.ConvertedException as exc: + self.assertTrue(value in unicode(exc)) + #Traceback should be included in exception message + self.assertTrue('exception.ConvertedException' in unicode(exc))