diff --git a/nova/rpc/amqp.py b/nova/rpc/amqp.py index 5387eff1..95fe9041 100644 --- a/nova/rpc/amqp.py +++ b/nova/rpc/amqp.py @@ -27,7 +27,6 @@ AMQP, but is deprecated and predates this code. import inspect import sys -import traceback import uuid from eventlet import greenpool @@ -141,11 +140,7 @@ def msg_reply(msg_id, connection_pool, reply=None, failure=None, ending=False): """ with ConnectionContext(connection_pool) as conn: if failure: - message = str(failure[1]) - tb = traceback.format_exception(*failure) - LOG.error(_("Returning exception %s to caller"), message) - LOG.error(tb) - failure = (failure[0].__name__, str(failure[1]), tb) + failure = rpc_common.serialize_remote_exception(failure) try: msg = {'result': reply, 'failure': failure} @@ -285,7 +280,9 @@ class MulticallWaiter(object): def __call__(self, data): """The consume() callback will call this. Store the result.""" if data['failure']: - self._result = rpc_common.RemoteError(*data['failure']) + failure = data['failure'] + self._result = rpc_common.deserialize_remote_exception(failure) + elif data.get('ending', False): self._got_ending = True else: diff --git a/nova/rpc/common.py b/nova/rpc/common.py index 95c24581..51bf2fd2 100644 --- a/nova/rpc/common.py +++ b/nova/rpc/common.py @@ -18,11 +18,14 @@ # under the License. import copy +import sys +import traceback from nova import exception from nova import flags from nova import log as logging from nova.openstack.common import cfg +from nova import utils LOG = logging.getLogger(__name__) @@ -37,9 +40,14 @@ rpc_opts = [ cfg.IntOpt('rpc_response_timeout', default=60, help='Seconds to wait for a response from call or multicall'), + cfg.IntOpt('allowed_rpc_exception_modules', + default=['nova.exception'], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), ] flags.FLAGS.register_opts(rpc_opts) +FLAGS = flags.FLAGS class RemoteError(exception.NovaException): @@ -158,3 +166,74 @@ def _safe_log(log_func, msg, msg_data): msg_data['auth_token'] = '' return log_func(msg, msg_data) + + +def serialize_remote_exception(failure_info): + """Prepares exception data to be sent over rpc. + + Failure_info should be a sys.exc_info() tuple. + + """ + tb = traceback.format_exception(*failure_info) + failure = failure_info[1] + LOG.error(_("Returning exception %s to caller"), unicode(failure)) + LOG.error(tb) + + kwargs = {} + if hasattr(failure, 'kwargs'): + kwargs = failure.kwargs + + data = { + 'class': str(failure.__class__.__name__), + 'module': str(failure.__class__.__module__), + 'message': unicode(failure), + 'tb': tb, + 'args': failure.args, + 'kwargs': kwargs + } + + json_data = utils.dumps(data) + + return json_data + + +def deserialize_remote_exception(data): + failure = utils.loads(str(data)) + + trace = failure.get('tb', []) + message = failure.get('message', "") + "\n" + "\n".join(trace) + name = failure.get('class') + module = failure.get('module') + + # NOTE(ameade): We DO NOT want to allow just any module to be imported, in + # order to prevent arbitrary code execution. + if not module in FLAGS.allowed_rpc_exception_modules: + return RemoteError(name, failure.get('message'), trace) + + try: + __import__(module) + mod = sys.modules[module] + klass = getattr(mod, name) + if not issubclass(klass, Exception): + raise TypeError("Can only deserialize Exceptions") + + failure = klass(**failure.get('kwargs', {})) + except (AttributeError, TypeError, ImportError): + return RemoteError(name, failure.get('message'), trace) + + ex_type = type(failure) + str_override = lambda self: message + new_ex_type = type(ex_type.__name__ + "_Remote", (ex_type,), + {'__str__': str_override}) + try: + # NOTE(ameade): Dynamically create a new exception type and swap it in + # as the new type for the exception. This only works on user defined + # Exceptions and not core python exceptions. This is important because + # we cannot necessarily change an exception message so we must override + # the __str__ method. + failure.__class__ = new_ex_type + except TypeError as e: + # NOTE(ameade): If a core exception then just add the traceback to the + # first exception argument. + failure.args = (message,) + failure.args[1:] + return failure diff --git a/nova/rpc/impl_fake.py b/nova/rpc/impl_fake.py index 42ed7907..43aed15c 100644 --- a/nova/rpc/impl_fake.py +++ b/nova/rpc/impl_fake.py @@ -77,12 +77,8 @@ class Consumer(object): else: res.append(rval) done.send(res) - except Exception: - exc_info = sys.exc_info() - done.send_exception( - rpc_common.RemoteError(exc_info[0].__name__, - str(exc_info[1]), - ''.join(traceback.format_exception(*exc_info)))) + except Exception as e: + done.send_exception(e) thread = eventlet.greenthread.spawn(_inner) @@ -161,7 +157,7 @@ def call(context, topic, msg, timeout=None): def cast(context, topic, msg): try: call(context, topic, msg) - except rpc_common.RemoteError: + except Exception: pass @@ -184,5 +180,5 @@ def fanout_cast(context, topic, msg): for consumer in CONSUMERS.get(topic, []): try: consumer.call(context, method, args, None) - except rpc_common.RemoteError: + except Exception: pass diff --git a/nova/tests/rpc/common.py b/nova/tests/rpc/common.py index 87cb522c..3524e568 100644 --- a/nova/tests/rpc/common.py +++ b/nova/tests/rpc/common.py @@ -25,6 +25,7 @@ from eventlet import greenthread import nose from nova import context +from nova import exception from nova import log as logging from nova.rpc import amqp as rpc_amqp from nova.rpc import common as rpc_common @@ -100,30 +101,6 @@ class BaseRpcTestCase(test.TestCase): "args": {"value": value}}) self.assertEqual(self.context.to_dict(), result) - def test_call_exception(self): - """Test that exception gets passed back properly. - - rpc.call returns a RemoteError object. The value of the - exception is converted to a string, so we convert it back - to an int in the test. - - """ - value = 42 - self.assertRaises(rpc_common.RemoteError, - self.rpc.call, - self.context, - 'test', - {"method": "fail", - "args": {"value": value}}) - try: - self.rpc.call(self.context, - 'test', - {"method": "fail", - "args": {"value": value}}) - self.fail("should have thrown RemoteError") - except rpc_common.RemoteError as exc: - self.assertEqual(int(exc.value), value) - def test_nested_calls(self): """Test that we can do an rpc.call inside another call.""" class Nested(object): @@ -248,7 +225,12 @@ class TestReceiver(object): @staticmethod def fail(context, value): """Raises an exception with the value sent in.""" - raise Exception(value) + raise NotImplementedError(value) + + @staticmethod + def fail_converted(context, value): + """Raises an exception with the value sent in.""" + raise exception.ConvertedException(explanation=value) @staticmethod def block(context, value): diff --git a/nova/tests/rpc/test_common.py b/nova/tests/rpc/test_common.py new file mode 100644 index 00000000..6220bd01 --- /dev/null +++ b/nova/tests/rpc/test_common.py @@ -0,0 +1,147 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for 'common' functons used through rpc code. +""" + +import json +import sys + +from nova import context +from nova import exception +from nova import flags +from nova import log as logging +from nova import test +from nova.rpc import amqp as rpc_amqp +from nova.rpc import common as rpc_common +from nova.tests.rpc import common + +FLAGS = flags.FLAGS +LOG = logging.getLogger(__name__) + + +def raise_exception(): + raise Exception("test") + + +class FakeUserDefinedException(Exception): + def __init__(self): + Exception.__init__(self, "Test Message") + + +class RpcCommonTestCase(test.TestCase): + def test_serialize_remote_exception(self): + expected = { + 'class': 'Exception', + 'module': 'exceptions', + 'message': 'test', + } + + try: + raise_exception() + except Exception as exc: + failure = rpc_common.serialize_remote_exception(sys.exc_info()) + + failure = json.loads(failure) + #assure the traceback was added + self.assertEqual(expected['class'], failure['class']) + self.assertEqual(expected['module'], failure['module']) + self.assertEqual(expected['message'], failure['message']) + + def test_serialize_remote_nova_exception(self): + def raise_nova_exception(): + raise exception.NovaException("test", code=500) + + expected = { + 'class': 'NovaException', + 'module': 'nova.exception', + 'kwargs': {'code': 500}, + 'message': 'test' + } + + try: + raise_nova_exception() + except Exception as exc: + failure = rpc_common.serialize_remote_exception(sys.exc_info()) + + failure = json.loads(failure) + #assure the traceback was added + self.assertEqual(expected['class'], failure['class']) + self.assertEqual(expected['module'], failure['module']) + self.assertEqual(expected['kwargs'], failure['kwargs']) + self.assertEqual(expected['message'], failure['message']) + + def test_deserialize_remote_exception(self): + failure = { + 'class': 'NovaException', + 'module': 'nova.exception', + 'message': 'test message', + 'tb': ['raise NovaException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, exception.NovaException)) + self.assertTrue('test message' in unicode(after_exc)) + #assure the traceback was added + self.assertTrue('raise NovaException' in unicode(after_exc)) + + def test_deserialize_remote_exception_bad_module(self): + failure = { + 'class': 'popen2', + 'module': 'os', + 'kwargs': {'cmd': '/bin/echo failed'}, + 'message': 'foo', + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) + + def test_deserialize_remote_exception_user_defined_exception(self): + """Ensure a user defined exception can be deserialized.""" + self.flags(allowed_rpc_exception_modules=[self.__class__.__module__]) + failure = { + 'class': 'FakeUserDefinedException', + 'module': self.__class__.__module__, + 'tb': ['raise FakeUserDefinedException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) + #assure the traceback was added + self.assertTrue('raise FakeUserDefinedException' in unicode(after_exc)) + + def test_deserialize_remote_exception_cannot_recreate(self): + """Ensure a RemoteError is returned on initialization failure. + + If an exception cannot be recreated with it's original class then a + RemoteError with the exception informations should still be returned. + + """ + self.flags(allowed_rpc_exception_modules=[self.__class__.__module__]) + failure = { + 'class': 'FakeIDontExistException', + 'module': self.__class__.__module__, + 'tb': ['raise FakeIDontExistException'], + } + serialized = json.dumps(failure) + + after_exc = rpc_common.deserialize_remote_exception(serialized) + self.assertTrue(isinstance(after_exc, rpc_common.RemoteError)) + #assure the traceback was added + self.assertTrue('raise FakeIDontExistException' in unicode(after_exc)) diff --git a/nova/tests/rpc/test_kombu.py b/nova/tests/rpc/test_kombu.py index aa49b5d5..966cb3a6 100644 --- a/nova/tests/rpc/test_kombu.py +++ b/nova/tests/rpc/test_kombu.py @@ -20,6 +20,7 @@ Unit Tests for remote procedure calls using kombu """ from nova import context +from nova import exception from nova import flags from nova import log as logging from nova import test @@ -292,4 +293,54 @@ class RpcKombuTestCase(common.BaseRpcAMQPTestCase): self.assertEqual(self.received_message, message) # Only called once, because our stub goes away during reconnection - self.assertEqual(info['called'], 1) + + def test_call_exception(self): + """Test that exception gets passed back properly. + + rpc.call returns an Exception object. The value of the + exception is converted to a string. + + """ + self.flags(allowed_rpc_exception_modules=['exceptions']) + value = "This is the exception message" + self.assertRaises(NotImplementedError, + self.rpc.call, + self.context, + 'test', + {"method": "fail", + "args": {"value": value}}) + try: + self.rpc.call(self.context, + 'test', + {"method": "fail", + "args": {"value": value}}) + self.fail("should have thrown Exception") + except NotImplementedError as exc: + self.assertTrue(value in unicode(exc)) + #Traceback should be included in exception message + self.assertTrue('raise NotImplementedError(value)' in unicode(exc)) + + def test_call_converted_exception(self): + """Test that exception gets passed back properly. + + rpc.call returns an Exception object. The value of the + exception is converted to a string. + + """ + value = "This is the exception message" + self.assertRaises(exception.ConvertedException, + self.rpc.call, + self.context, + 'test', + {"method": "fail_converted", + "args": {"value": value}}) + try: + self.rpc.call(self.context, + 'test', + {"method": "fail_converted", + "args": {"value": value}}) + self.fail("should have thrown Exception") + except exception.ConvertedException as exc: + self.assertTrue(value in unicode(exc)) + #Traceback should be included in exception message + self.assertTrue('exception.ConvertedException' in unicode(exc))