diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 3142aca1e..dd1247fbe 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -37,6 +37,9 @@ class AMQPIncomingMessage(base.IncomingMessage): self.reply_q = reply_q def _send_reply(self, conn, reply=None, failure=None, ending=False): + if failure: + failure = rpc_common.serialize_remote_exception(failure) + # FIXME(markmc): is the reply format really driver specific? msg = {'result': reply, 'failure': failure} @@ -306,7 +309,10 @@ class AMQPDriverBase(base.BaseDriver): if wait_for_reply: # FIXME(markmc): timeout? - return self._waiter.wait(msg_id) + result = self._waiter.wait(msg_id) + if isinstance(result, Exception): + raise result + return result finally: if wait_for_reply: self._waiter.unlisten(msg_id) diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index bb2880e15..5f14182f1 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -15,6 +15,7 @@ # under the License. import datetime +import sys import threading import uuid @@ -31,10 +32,10 @@ from tests import utils as test_utils load_tests = testscenarios.load_tests_apply_scenarios -class TestRabbitDriver(test_utils.BaseTestCase): +class TestRabbitDriverLoad(test_utils.BaseTestCase): def setUp(self): - super(TestRabbitDriver, self).setUp() + super(TestRabbitDriverLoad, self).setUp() self.conf.register_opts(msg_transport._transport_opts) self.conf.register_opts(rabbit_driver.rabbit_opts) self.config(rpc_backend='rabbit') @@ -45,6 +46,37 @@ class TestRabbitDriver(test_utils.BaseTestCase): self.assertTrue(isinstance(transport._driver, rabbit_driver.RabbitDriver)) + +class TestSendReceive(test_utils.BaseTestCase): + + _n_senders = [ + ('single_sender', dict(n_senders=1)), + ('multiple_senders', dict(n_senders=10)), + ] + + _context = [ + ('empty_context', dict(ctxt={})), + ('with_context', dict(ctxt={'user': 'mark'})), + ] + + _failure = [ + ('success', dict(failure=False)), + ('failure', dict(failure=True)), + ] + + @classmethod + def generate_scenarios(cls): + cls.scenarios = testscenarios.multiply_scenarios(cls._n_senders, + cls._context, + cls._failure) + + def setUp(self): + super(TestSendReceive, self).setUp() + self.conf.register_opts(msg_transport._transport_opts) + self.conf.register_opts(rabbit_driver.rabbit_opts) + self.config(rpc_backend='rabbit') + self.config(fake_rabbit=True) + def test_send_receive(self): transport = messaging.get_transport(self.conf) self.addCleanup(transport.cleanup) @@ -60,12 +92,17 @@ class TestRabbitDriver(test_utils.BaseTestCase): msgs = [] def send_and_wait_for_reply(i): - replies.append(driver.send(target, - {}, - {'foo': i}, - wait_for_reply=True)) + try: + replies.append(driver.send(target, + self.ctxt, + {'foo': i}, + wait_for_reply=True)) + self.assertFalse(self.failure) + except ZeroDivisionError as e: + replies.append(e) + self.assertTrue(self.failure) - while len(senders) < 10: + while len(senders) < self.n_senders: senders.append(threading.Thread(target=send_and_wait_for_reply, args=(len(senders), ))) @@ -74,21 +111,35 @@ class TestRabbitDriver(test_utils.BaseTestCase): received = listener.poll() self.assertTrue(received is not None) - self.assertEqual(received.ctxt, {}) + self.assertEqual(received.ctxt, self.ctxt) self.assertEqual(received.message, {'foo': i}) msgs.append(received) # reply in reverse, except reply to the first guy second from last order = range(len(senders)-1, -1, -1) - order[-1], order[-2] = order[-2], order[-1] + if len(order) > 1: + order[-1], order[-2] = order[-2], order[-1] for i in order: - msgs[i].reply({'bar': msgs[i].message['foo']}) + if self.failure: + try: + raise ZeroDivisionError + except Exception: + failure = sys.exc_info() + msgs[i].reply(failure=failure) + else: + msgs[i].reply({'bar': msgs[i].message['foo']}) senders[i].join() self.assertEqual(len(replies), len(senders)) for i, reply in enumerate(replies): - self.assertEqual(reply, {'bar': order[i]}) + if not self.failure: + self.assertEqual(reply, {'bar': order[i]}) + else: + self.assertTrue(isinstance(reply, ZeroDivisionError)) + + +TestSendReceive.generate_scenarios() def _declare_queue(target):