diff --git a/oslo_messaging/_drivers/protocols/amqp/controller.py b/oslo_messaging/_drivers/protocols/amqp/controller.py index d89eb0ae7..1dd91f7ec 100644 --- a/oslo_messaging/_drivers/protocols/amqp/controller.py +++ b/oslo_messaging/_drivers/protocols/amqp/controller.py @@ -158,11 +158,12 @@ class Server(pyngus.ReceiverEventHandler): from a given target. Messages arriving on the links are placed on the 'incoming' queue. """ - def __init__(self, addresses, incoming): + def __init__(self, addresses, incoming, subscription_id): self._incoming = incoming self._addresses = addresses self._capacity = 500 # credit per link self._receivers = None + self._id = subscription_id def attach(self, connection): """Create receiver links over the given connection for all the @@ -267,7 +268,8 @@ class Controller(pyngus.ConnectionEventHandler): self._max_task_batch = 50 # cache of sending links indexed by address: self._senders = {} - # Servers (set of receiving links), indexed by target: + # Servers indexed by target. Each entry is a map indexed by the + # specific ProtonListener's identifier: self._servers = {} opt_group = cfg.OptGroup(name='oslo_messaging_amqp', @@ -329,8 +331,9 @@ class Controller(pyngus.ConnectionEventHandler): self.processor = None self._tasks = None self._senders = None - for server in self._servers.values(): - server.destroy() + for servers in self._servers.values(): + for server in servers.values(): + server.destroy() self._servers.clear() self._socket_connection = None if self._replies: @@ -382,7 +385,7 @@ class Controller(pyngus.ConnectionEventHandler): LOG.debug("Sending response to %s", address) self._send(address, response) - def subscribe(self, target, in_queue): + def subscribe(self, target, in_queue, subscription_id): """Subscribe to messages sent to 'target', place received messages on 'in_queue'. """ @@ -391,20 +394,25 @@ class Controller(pyngus.ConnectionEventHandler): self._broadcast_address(target), self._group_request_address(target) ] - self._subscribe(target, addresses, in_queue) + self._subscribe(target, addresses, in_queue, subscription_id) - def subscribe_notifications(self, target, in_queue): + def subscribe_notifications(self, target, in_queue, subscription_id): """Subscribe for notifications on 'target', place received messages on 'in_queue'. """ addresses = [self._group_request_address(target)] - self._subscribe(target, addresses, in_queue) + self._subscribe(target, addresses, in_queue, subscription_id) - def _subscribe(self, target, addresses, in_queue): + def _subscribe(self, target, addresses, in_queue, subscription_id): LOG.debug("Subscribing to %(target)s (%(addresses)s)", {'target': target, 'addresses': addresses}) - self._servers[target] = Server(addresses, in_queue) - self._servers[target].attach(self._socket_connection.connection) + server = Server(addresses, in_queue, subscription_id) + servers = self._servers.get(target) + if servers is None: + servers = {} + self._servers[target] = servers + servers[subscription_id] = server + server.attach(self._socket_connection.connection) def _resolve(self, target): """Return a link address for a given target.""" @@ -583,8 +591,9 @@ class Controller(pyngus.ConnectionEventHandler): LOG.debug("Connection active (%(hostname)s:%(port)s), subscribing...", {'hostname': self.hosts.current.hostname, 'port': self.hosts.current.port}) - for s in self._servers.values(): - s.attach(self._socket_connection.connection) + for servers in self._servers.values(): + for server in servers.values(): + server.attach(self._socket_connection.connection) self._replies = Replies(self._socket_connection.connection, lambda: self._reply_link_ready()) self._delay = 0 diff --git a/oslo_messaging/_drivers/protocols/amqp/driver.py b/oslo_messaging/_drivers/protocols/amqp/driver.py index 04feb2de1..68fbbf4d8 100644 --- a/oslo_messaging/_drivers/protocols/amqp/driver.py +++ b/oslo_messaging/_drivers/protocols/amqp/driver.py @@ -25,6 +25,7 @@ import logging import os import threading import time +import uuid from oslo_serialization import jsonutils from oslo_utils import importutils @@ -149,6 +150,7 @@ class ProtonListener(base.Listener): super(ProtonListener, self).__init__(driver.prefetch_size) self.driver = driver self.incoming = Queue() + self.id = uuid.uuid4().hex def stop(self): self.incoming.stop() diff --git a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py index 04943961d..0addc0758 100644 --- a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py +++ b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py @@ -83,9 +83,12 @@ class ListenTask(controller.Task): """ if self._notifications: controller.subscribe_notifications(self._target, - self._listener.incoming) + self._listener.incoming, + self._listener.id) else: - controller.subscribe(self._target, self._listener.incoming) + controller.subscribe(self._target, + self._listener.incoming, + self._listener.id) class ReplyTask(controller.Task): diff --git a/oslo_messaging/tests/test_amqp_driver.py b/oslo_messaging/tests/test_amqp_driver.py index 909bc599d..b011a6383 100644 --- a/oslo_messaging/tests/test_amqp_driver.py +++ b/oslo_messaging/tests/test_amqp_driver.py @@ -47,6 +47,12 @@ CYRUS_ENABLED = (pyngus and pyngus.VERSION >= (2, 0, 0) and _proton LOG = logging.getLogger(__name__) +def _wait_until(predicate, timeout): + deadline = timeout + time.time() + while not predicate() and deadline > time.time(): + time.sleep(0.1) + + class _ListenerThread(threading.Thread): """Run a blocking listener in a thread.""" def __init__(self, listener, msg_count): @@ -55,10 +61,13 @@ class _ListenerThread(threading.Thread): self.msg_count = msg_count self.messages = moves.queue.Queue() self.daemon = True + self.started = threading.Event() self.start() + self.started.wait() def run(self): LOG.debug("Listener started") + self.started.set() while self.msg_count > 0: in_msg = self.listener.poll()[0] self.messages.put(in_msg) @@ -515,12 +524,19 @@ class TestFailover(test_utils.BaseTestCase): target = oslo_messaging.Target(topic="my-topic") listener = _ListenerThread(driver.listen(target), 2) + # wait for listener links to come up + # 4 == 3 links per listener + 1 for the global reply queue + predicate = lambda: self._brokers[0].sender_link_count == 4 + _wait_until(predicate, 30) + self.assertTrue(predicate()) + rc = driver.send(target, {"context": "whatever"}, {"method": "echo", "id": "echo-1"}, wait_for_reply=True, timeout=30) self.assertIsNotNone(rc) self.assertEqual(rc.get('correlation-id'), 'echo-1') + # 1 request msg, 1 response: self.assertEqual(self._brokers[0].topic_count, 1) self.assertEqual(self._brokers[0].direct_count, 1) @@ -528,28 +544,25 @@ class TestFailover(test_utils.BaseTestCase): # fail broker 0 and start broker 1: self._brokers[0].stop() self._brokers[1].start() - deadline = time.time() + 30 - responded = False - sequence = 2 - while deadline > time.time() and not responded: - if not listener.isAlive(): - # listener may have exited after replying to an old correlation - # id: restart new listener - listener = _ListenerThread(driver.listen(target), 1) - try: - rc = driver.send(target, {"context": "whatever"}, - {"method": "echo", - "id": "echo-%d" % sequence}, - wait_for_reply=True, - timeout=2) - self.assertIsNotNone(rc) - self.assertEqual(rc.get('correlation-id'), - 'echo-%d' % sequence) - responded = True - except oslo_messaging.MessagingTimeout: - sequence += 1 - self.assertTrue(responded) + # wait for listener links to re-establish + # 4 = 3 links per listener + 1 for the global reply queue + predicate = lambda: self._brokers[1].sender_link_count == 4 + _wait_until(predicate, 30) + self.assertTrue(predicate()) + + rc = driver.send(target, + {"context": "whatever"}, + {"method": "echo", "id": "echo-2"}, + wait_for_reply=True, + timeout=2) + self.assertIsNotNone(rc) + self.assertEqual(rc.get('correlation-id'), 'echo-2') + + # 1 request msg, 1 response: + self.assertEqual(self._brokers[1].topic_count, 1) + self.assertEqual(self._brokers[1].direct_count, 1) + listener.join(timeout=30) self.assertFalse(listener.isAlive()) @@ -558,6 +571,55 @@ class TestFailover(test_utils.BaseTestCase): self._brokers[1].stop() driver.cleanup() + def test_listener_failover(self): + """Verify that Listeners are re-established after failover. + """ + self._brokers[0].start() + driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) + + target = oslo_messaging.Target(topic="my-topic") + bcast = oslo_messaging.Target(topic="my-topic", fanout=True) + listener1 = _ListenerThread(driver.listen(target), 2) + listener2 = _ListenerThread(driver.listen(target), 2) + + # wait for 7 sending links to become active on the broker. + # 7 = 3 links per Listener + 1 global reply link + predicate = lambda: self._brokers[0].sender_link_count == 7 + _wait_until(predicate, 30) + self.assertTrue(predicate()) + + driver.send(bcast, {"context": "whatever"}, + {"method": "ignore", "id": "echo-1"}) + + # 1 message per listener + predicate = lambda: self._brokers[0].fanout_sent_count == 2 + _wait_until(predicate, 30) + self.assertTrue(predicate()) + + # fail broker 0 and start broker 1: + self._brokers[0].stop() + self._brokers[1].start() + + # wait again for 7 sending links to re-establish + predicate = lambda: self._brokers[1].sender_link_count == 7 + _wait_until(predicate, 30) + self.assertTrue(predicate()) + + driver.send(bcast, {"context": "whatever"}, + {"method": "ignore", "id": "echo-2"}) + + # 1 message per listener + predicate = lambda: self._brokers[1].fanout_sent_count == 2 + _wait_until(predicate, 30) + self.assertTrue(predicate()) + + listener1.join(timeout=30) + listener2.join(timeout=30) + self.assertFalse(listener1.isAlive() or listener2.isAlive()) + + self._brokers[1].stop() + driver.cleanup() + class FakeBroker(threading.Thread): """A test AMQP message 'broker'.""" @@ -638,12 +700,16 @@ class FakeBroker(threading.Thread): # Pyngus ConnectionEventHandler callbacks: + def connection_active(self, connection): + self.server.connection_count += 1 + def connection_remote_closed(self, connection, reason): """Peer has closed the connection.""" self.connection.close() def connection_closed(self, connection): """Connection close completed.""" + self.server.connection_count -= 1 self.closed = True # main loop will destroy def connection_failed(self, connection, error): @@ -712,6 +778,7 @@ class FakeBroker(threading.Thread): # Pyngus SenderEventHandler callbacks: def sender_active(self, sender_link): + self.server.sender_link_count += 1 self.server.add_route(self.link.source_address, self) self.routed = True @@ -720,6 +787,7 @@ class FakeBroker(threading.Thread): self.link.close() def sender_closed(self, sender_link): + self.server.sender_link_count -= 1 self.destroy() class ReceiverLink(pyngus.ReceiverEventHandler): @@ -746,10 +814,14 @@ class FakeBroker(threading.Thread): # ReceiverEventHandler callbacks: + def receiver_active(self, receiver_link): + self.server.receiver_link_count += 1 + def receiver_remote_closed(self, receiver_link, error): self.link.close() def receiver_closed(self, receiver_link): + self.server.receiver_link_count -= 1 self.destroy() def message_received(self, receiver_link, message, handle): @@ -795,7 +867,12 @@ class FakeBroker(threading.Thread): self.direct_count = 0 self.topic_count = 0 self.fanout_count = 0 + self.fanout_sent_count = 0 self.dropped_count = 0 + # counts for active links and connections: + self.connection_count = 0 + self.sender_link_count = 0 + self.receiver_link_count = 0 def start(self): """Start the server.""" @@ -907,6 +984,7 @@ class FakeBroker(threading.Thread): if dest.startswith(self._broadcast_prefix): self.fanout_count += 1 for link in self._sources[dest]: + self.fanout_sent_count += 1 LOG.debug("Broadcast to %s", dest) link.send_message(message) elif dest.startswith(self._group_prefix):