From f4da21353956def06cb40e790b3a6f5275a68814 Mon Sep 17 00:00:00 2001 From: Mehdi Abaakouk Date: Thu, 24 Apr 2014 12:04:20 +0200 Subject: [PATCH] Remove amqp default exchange hack This change remove the hack to set the default exchange of a transport in the amqp driver, by removing the usage of the configuration object to get the default exchange in rabbit and qpid driver, and instead use the value passed to the driver constructor into all amqp publishers and consumers class/method that needs it. Closes-bug: #1256345 Change-Id: Iba54ca79a49f8545854205c1451b2403735c1006 --- oslo/messaging/_drivers/amqp.py | 4 -- oslo/messaging/_drivers/amqpdriver.py | 30 ++++---- oslo/messaging/_drivers/impl_qpid.py | 40 +++++------ oslo/messaging/_drivers/impl_rabbit.py | 34 ++++----- tests/test_qpid.py | 97 ++++++++++++++++++++++++-- 5 files changed, 145 insertions(+), 60 deletions(-) diff --git a/oslo/messaging/_drivers/amqp.py b/oslo/messaging/_drivers/amqp.py index 325c3609a..e0d33b393 100644 --- a/oslo/messaging/_drivers/amqp.py +++ b/oslo/messaging/_drivers/amqp.py @@ -249,7 +249,3 @@ def _add_unique_id(msg): unique_id = uuid.uuid4().hex msg.update({UNIQUE_ID: unique_id}) LOG.debug('UNIQUE_ID is %s.' % (unique_id)) - - -def get_control_exchange(conf): - return conf.control_exchange diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 16626d0b2..d990e90a7 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -297,10 +297,6 @@ class AMQPDriverBase(base.BaseDriver): self._default_exchange = default_exchange - # FIXME(markmc): temp hack - if self._default_exchange: - self.conf.set_override('control_exchange', self._default_exchange) - self._connection_pool = connection_pool self._reply_q_lock = threading.Lock() @@ -308,6 +304,9 @@ class AMQPDriverBase(base.BaseDriver): self._reply_q_conn = None self._waiter = None + def _get_exchange(self, target): + return target.exchange or self._default_exchange + def _get_connection(self, pooled=True): return rpc_amqp.ConnectionContext(self.conf, self._url, @@ -364,14 +363,16 @@ class AMQPDriverBase(base.BaseDriver): try: with self._get_connection() as conn: if notify: - conn.notify_send(target.topic, msg) + conn.notify_send(self._get_exchange(target), + target.topic, msg) elif target.fanout: conn.fanout_send(target.topic, msg) else: topic = target.topic if target.server: topic = '%s.%s' % (target.topic, target.server) - conn.topic_send(topic, msg, timeout=timeout) + conn.topic_send(exchange_name=self._get_exchange(target), + topic=topic, msg=msg, timeout=timeout) if wait_for_reply: result = self._waiter.wait(msg_id, timeout) @@ -394,9 +395,13 @@ class AMQPDriverBase(base.BaseDriver): listener = AMQPListener(self, conn) - conn.declare_topic_consumer(target.topic, listener) - conn.declare_topic_consumer('%s.%s' % (target.topic, target.server), - listener) + conn.declare_topic_consumer(exchange_name=self._get_exchange(target), + topic=target.topic, + callback=listener) + conn.declare_topic_consumer(exchange_name=self._get_exchange(target), + topic='%s.%s' % (target.topic, + target.server), + callback=listener) conn.declare_fanout_consumer(target.topic, listener) return listener @@ -406,9 +411,10 @@ class AMQPDriverBase(base.BaseDriver): listener = AMQPListener(self, conn) for target, priority in targets_and_priorities: - conn.declare_topic_consumer('%s.%s' % (target.topic, priority), - callback=listener, - exchange_name=target.exchange) + conn.declare_topic_consumer( + exchange_name=self._get_exchange(target), + topic='%s.%s' % (target.topic, priority), + callback=listener) return listener def cleanup(self): diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index 10fd7207c..def074baf 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -248,8 +248,8 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'.""" - def __init__(self, conf, session, topic, callback, name=None, - exchange_name=None): + def __init__(self, conf, session, topic, callback, exchange_name, + name=None): """Init a 'topic' queue. :param session: the amqp session to use @@ -259,7 +259,6 @@ class TopicConsumer(ConsumerBase): :param name: optional queue name, defaults to topic """ - exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) link_opts = { "auto-delete": conf.amqp_auto_delete, "durable": conf.amqp_durable_queues, @@ -376,14 +375,14 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'.""" - def __init__(self, conf, session, msg_id): + def __init__(self, conf, session, topic): """Init a 'direct' publisher.""" if conf.qpid_topology_version == 1: - node_name = "%s/%s" % (msg_id, msg_id) + node_name = "%s/%s" % (topic, topic) node_opts = {"type": "direct"} elif conf.qpid_topology_version == 2: - node_name = "amq.direct/%s" % msg_id + node_name = "amq.direct/%s" % topic node_opts = {} else: raise_invalid_topology_version(conf) @@ -394,11 +393,9 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): """Publisher class for 'topic'.""" - def __init__(self, conf, session, topic): + def __init__(self, conf, session, exchange_name, topic): """Init a 'topic' publisher. """ - exchange_name = rpc_amqp.get_control_exchange(conf) - if conf.qpid_topology_version == 1: node_name = "%s/%s" % (exchange_name, topic) elif conf.qpid_topology_version == 2: @@ -430,10 +427,9 @@ class FanoutPublisher(Publisher): class NotifyPublisher(Publisher): """Publisher class for notifications.""" - def __init__(self, conf, session, topic): + def __init__(self, conf, session, exchange_name, topic): """Init a 'topic' publisher. """ - exchange_name = rpc_amqp.get_control_exchange(conf) node_opts = {"durable": True} if conf.qpid_topology_version == 1: @@ -618,7 +614,7 @@ class Connection(object): raise StopIteration yield self.ensure(_error_callback, _consume) - def publisher_send(self, cls, topic, msg): + def publisher_send(self, cls, topic, msg, **kwargs): """Send to a publisher based on the publisher class.""" def _connect_error(exc): @@ -627,7 +623,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publisher_send(): - publisher = cls(self.conf, self.session, topic) + publisher = cls(self.conf, self.session, topic=topic, **kwargs) publisher.send(msg) return self.ensure(_connect_error, _publisher_send) @@ -639,8 +635,8 @@ class Connection(object): """ self.declare_consumer(DirectConsumer, topic, callback) - def declare_topic_consumer(self, topic, callback=None, queue_name=None, - exchange_name=None): + def declare_topic_consumer(self, exchange_name, topic, callback=None, + queue_name=None): """Create a 'topic' consumer.""" self.declare_consumer(functools.partial(TopicConsumer, name=queue_name, @@ -654,9 +650,9 @@ class Connection(object): def direct_send(self, msg_id, msg): """Send a 'direct' message.""" - self.publisher_send(DirectPublisher, msg_id, msg) + self.publisher_send(DirectPublisher, topic=msg_id, msg=msg) - def topic_send(self, topic, msg, timeout=None): + def topic_send(self, exchange_name, topic, msg, timeout=None): """Send a 'topic' message.""" # # We want to create a message with attributes, e.g. a TTL. We @@ -669,15 +665,17 @@ class Connection(object): # will need to be altered accordingly. # qpid_message = qpid_messaging.Message(content=msg, ttl=timeout) - self.publisher_send(TopicPublisher, topic, qpid_message) + self.publisher_send(TopicPublisher, topic=topic, msg=qpid_message, + exchange_name=exchange_name) def fanout_send(self, topic, msg): """Send a 'fanout' message.""" - self.publisher_send(FanoutPublisher, topic, msg) + self.publisher_send(FanoutPublisher, topic=topic, msg=msg) - def notify_send(self, topic, msg, **kwargs): + def notify_send(self, exchange_name, topic, msg, **kwargs): """Send a notify message on a topic.""" - self.publisher_send(NotifyPublisher, topic, msg) + self.publisher_send(NotifyPublisher, topic=topic, msg=msg, + exchange_name=exchange_name) def consume(self, limit=None, timeout=None): """Consume from all queues/consumers.""" diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index 29d101a07..f7ca5e41c 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -247,8 +247,8 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'.""" - def __init__(self, conf, channel, topic, callback, tag, name=None, - exchange_name=None, **kwargs): + def __init__(self, conf, channel, topic, callback, tag, exchange_name, + name=None, **kwargs): """Init a 'topic' queue. :param channel: the amqp channel to use @@ -256,6 +256,7 @@ class TopicConsumer(ConsumerBase): :paramtype topic: str :param callback: the callback to call when messages are received :param tag: a unique ID for the consumer on the channel + :param exchange_name: the exchange name to use :param name: optional queue name, defaults to topic :paramtype name: str @@ -267,7 +268,6 @@ class TopicConsumer(ConsumerBase): 'auto_delete': conf.amqp_auto_delete, 'exclusive': False} options.update(kwargs) - exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) exchange = kombu.entity.Exchange(name=exchange_name, type='topic', durable=options['durable'], @@ -347,7 +347,7 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'.""" - def __init__(self, conf, channel, msg_id, **kwargs): + def __init__(self, conf, channel, topic, **kwargs): """Init a 'direct' publisher. Kombu options may be passed as keyword args to override defaults @@ -357,13 +357,13 @@ class DirectPublisher(Publisher): 'auto_delete': True, 'exclusive': False} options.update(kwargs) - super(DirectPublisher, self).__init__(channel, msg_id, msg_id, + super(DirectPublisher, self).__init__(channel, topic, topic, type='direct', **options) class TopicPublisher(Publisher): """Publisher class for 'topic'.""" - def __init__(self, conf, channel, topic, **kwargs): + def __init__(self, conf, channel, exchange_name, topic, **kwargs): """Init a 'topic' publisher. Kombu options may be passed as keyword args to override defaults @@ -372,7 +372,6 @@ class TopicPublisher(Publisher): 'auto_delete': conf.amqp_auto_delete, 'exclusive': False} options.update(kwargs) - exchange_name = rpc_amqp.get_control_exchange(conf) super(TopicPublisher, self).__init__(channel, exchange_name, topic, @@ -398,10 +397,11 @@ class FanoutPublisher(Publisher): class NotifyPublisher(TopicPublisher): """Publisher class for 'notify'.""" - def __init__(self, conf, channel, topic, **kwargs): + def __init__(self, conf, channel, exchange_name, topic, **kwargs): self.durable = kwargs.pop('durable', conf.amqp_durable_queues) self.queue_arguments = _get_queue_arguments(conf) - super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) + super(NotifyPublisher, self).__init__(conf, channel, exchange_name, + topic, **kwargs) def reconnect(self, channel): super(NotifyPublisher, self).reconnect(channel) @@ -731,7 +731,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publish(): - publisher = cls(self.conf, self.channel, topic, **kwargs) + publisher = cls(self.conf, self.channel, topic=topic, **kwargs) publisher.send(msg, timeout) self.ensure(_error_callback, _publish) @@ -743,8 +743,8 @@ class Connection(object): """ self.declare_consumer(DirectConsumer, topic, callback) - def declare_topic_consumer(self, topic, callback=None, queue_name=None, - exchange_name=None): + def declare_topic_consumer(self, exchange_name, topic, callback=None, + queue_name=None): """Create a 'topic' consumer.""" self.declare_consumer(functools.partial(TopicConsumer, name=queue_name, @@ -760,17 +760,19 @@ class Connection(object): """Send a 'direct' message.""" self.publisher_send(DirectPublisher, msg_id, msg) - def topic_send(self, topic, msg, timeout=None): + def topic_send(self, exchange_name, topic, msg, timeout=None): """Send a 'topic' message.""" - self.publisher_send(TopicPublisher, topic, msg, timeout) + self.publisher_send(TopicPublisher, topic, msg, timeout, + exchange_name=exchange_name) def fanout_send(self, topic, msg): """Send a 'fanout' message.""" self.publisher_send(FanoutPublisher, topic, msg) - def notify_send(self, topic, msg, **kwargs): + def notify_send(self, exchange_name, topic, msg, **kwargs): """Send a notify message on a topic.""" - self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) + self.publisher_send(NotifyPublisher, topic, msg, timeout=None, + exchange_name=exchange_name, **kwargs) def consume(self, limit=None, timeout=None): """Consume from all queues/consumers.""" diff --git a/tests/test_qpid.py b/tests/test_qpid.py index 23145518a..976d5eb37 100644 --- a/tests/test_qpid.py +++ b/tests/test_qpid.py @@ -167,11 +167,17 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): scenarios = [ ('direct', dict(consumer_cls=qpid_driver.DirectConsumer, - publisher_cls=qpid_driver.DirectPublisher)), + consumer_kwargs={}, + publisher_cls=qpid_driver.DirectPublisher, + publisher_kwargs={})), ('topic', dict(consumer_cls=qpid_driver.TopicConsumer, - publisher_cls=qpid_driver.TopicPublisher)), + consumer_kwargs={'exchange_name': 'openstack'}, + publisher_cls=qpid_driver.TopicPublisher, + publisher_kwargs={'exchange_name': 'openstack'})), ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer, - publisher_cls=qpid_driver.FanoutPublisher)), + consumer_kwargs={}, + publisher_cls=qpid_driver.FanoutPublisher, + publisher_kwargs={})), ] def setUp(self): @@ -195,7 +201,8 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): self.consumer_cls(self.conf, self.session_receive, msgid_or_topic, - consumer_callback) + consumer_callback, + **self.consumer_kwargs) except Exception as e: recvd_exc_msg = e.message @@ -205,7 +212,8 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): try: self.publisher_cls(self.conf, self.session_send, - msgid_or_topic) + topic=msgid_or_topic, + **self.publisher_kwargs) except Exception as e: recvd_exc_msg = e.message @@ -307,11 +315,15 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): ] _exchange_class = [ ('topic', dict(consumer_cls=qpid_driver.TopicConsumer, + consumer_kwargs={'exchange_name': 'openstack'}, publisher_cls=qpid_driver.TopicPublisher, + publisher_kwargs={'exchange_name': 'openstack'}, topic='topictest.test', receive_topic='topictest.test')), ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer, + consumer_kwargs={}, publisher_cls=qpid_driver.FanoutPublisher, + publisher_kwargs={}, topic='fanouttest', receive_topic='fanouttest')), ] @@ -404,7 +416,8 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): consumer = self.consumer_cls(self.conf, self.session_receive, self.receive_topic, - self.consumer_callback) + self.consumer_callback, + **self.consumer_kwargs) self._receivers.append(consumer) # create receivers threads @@ -415,7 +428,8 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): for sender_id in range(self.no_senders): publisher = self.publisher_cls(self.conf, self.session_send, - self.topic) + topic=self.topic, + **self.publisher_kwargs) self._senders.append(publisher) # create sender threads @@ -450,6 +464,75 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): TestQpidTopicAndFanout.generate_scenarios() +class AddressNodeMatcher(object): + def __init__(self, node): + self.node = node + + def __eq__(self, address): + return address.split(';')[0].strip() == self.node + + +class TestDriverInterface(_QpidBaseTestCase): + """Unit Test cases to test the amqpdriver with qpid + """ + + def setUp(self): + super(TestDriverInterface, self).setUp() + self.config(qpid_topology_version=2) + transport = messaging.get_transport(self.conf) + self.driver = transport._driver + + def test_listen_and_direct_send(self): + target = messaging.Target(exchange="exchange_test", + topic="topic_test", + server="server_test") + + with mock.patch('qpid.messaging.Connection') as conn_cls: + conn = conn_cls.return_value + session = conn.session.return_value + session.receiver.side_effect = [mock.Mock(), mock.Mock(), + mock.Mock()] + + listener = self.driver.listen(target) + listener.conn.direct_send("msg_id", {}) + + self.assertEqual(3, len(listener.conn.consumers)) + + expected_calls = [ + mock.call(AddressNodeMatcher( + 'amq.topic/topic/exchange_test/topic_test')), + mock.call(AddressNodeMatcher( + 'amq.topic/topic/exchange_test/topic_test.server_test')), + mock.call(AddressNodeMatcher('amq.topic/fanout/topic_test')), + ] + session.receiver.assert_has_calls(expected_calls) + session.sender.assert_called_with( + AddressNodeMatcher("amq.direct/msg_id")) + + def test_send(self): + target = messaging.Target(exchange="exchange_test", + topic="topic_test", + server="server_test") + with mock.patch('qpid.messaging.Connection') as conn_cls: + conn = conn_cls.return_value + session = conn.session.return_value + + self.driver.send(target, {}, {}) + session.sender.assert_called_with(AddressNodeMatcher( + "amq.topic/topic/exchange_test/topic_test.server_test")) + + def test_send_notification(self): + target = messaging.Target(exchange="exchange_test", + topic="topic_test.info") + with mock.patch('qpid.messaging.Connection') as conn_cls: + conn = conn_cls.return_value + session = conn.session.return_value + + self.driver.send_notification(target, {}, {}, "2.0") + session.sender.assert_called_with(AddressNodeMatcher( + "amq.topic/topic/exchange_test/topic_test.info")) + + class TestQpidReconnectOrder(test_utils.BaseTestCase): """Unit Test cases to test reconnection """