diff --git a/oslo_messaging/_drivers/amqpdriver.py b/oslo_messaging/_drivers/amqpdriver.py index 46e91c918..c5613b07a 100644 --- a/oslo_messaging/_drivers/amqpdriver.py +++ b/oslo_messaging/_drivers/amqpdriver.py @@ -42,20 +42,58 @@ ACK_REQUEUE_EVERY_SECONDS_MIN = 0.001 ACK_REQUEUE_EVERY_SECONDS_MAX = 1.0 -def do_pending_tasks(tasks): - while True: - try: - task = tasks.get(block=False) - except moves.queue.Empty: - break - else: - task() +class MessageOperationsHandler(object): + """Queue used by message operations to ensure that all tasks are + serialized and run in the same thread, since underlying drivers like kombu + are not thread safe. + """ + def __init__(self, name): + self.name = "%s (%s)" % (name, hex(id(self))) + self._tasks = moves.queue.Queue() + + self._shutdown = threading.Event() + self._shutdown_thread = threading.Thread( + target=self._process_in_background) + self._shutdown_thread.daemon = True + + def stop(self): + self._shutdown.set() + + def process_in_background(self): + """Run all pending tasks queued by do() in an thread during the + shutdown process. + """ + self._shutdown_thread.start() + + def _process_in_background(self): + while not self._shutdown.is_set(): + self.process() + time.sleep(ACK_REQUEUE_EVERY_SECONDS_MIN) + + def process(self): + "Run all pending tasks queued by do()." + + while True: + try: + task, event = self._tasks.get(block=False) + except moves.queue.Empty: + break + try: + task() + finally: + event.set() + + def do(self, task): + "Put the task in the queue and waits until the task is completed." + event = threading.Event() + self._tasks.put((task, event)) + event.wait() class AMQPIncomingMessage(base.RpcIncomingMessage): def __init__(self, listener, ctxt, message, unique_id, msg_id, reply_q, - obsolete_reply_queues, pending_message_actions): + obsolete_reply_queues, message_operations_handler): super(AMQPIncomingMessage, self).__init__(ctxt, message) self.listener = listener @@ -63,7 +101,7 @@ class AMQPIncomingMessage(base.RpcIncomingMessage): self.msg_id = msg_id self.reply_q = reply_q self._obsolete_reply_queues = obsolete_reply_queues - self._pending_tasks = pending_message_actions + self._message_operations_handler = message_operations_handler self.stopwatch = timeutils.StopWatch() self.stopwatch.start() @@ -133,7 +171,7 @@ class AMQPIncomingMessage(base.RpcIncomingMessage): return def acknowledge(self): - self._pending_tasks.put(self.message.acknowledge) + self._message_operations_handler.do(self.message.acknowledge) self.listener.msg_id_cache.add(self.unique_id) def requeue(self): @@ -143,7 +181,7 @@ class AMQPIncomingMessage(base.RpcIncomingMessage): # msg_id_cache, the message will be reconsumed, the only difference is # the message stay at the beginning of the queue instead of moving to # the end. - self._pending_tasks.put(self.message.requeue) + self._message_operations_handler.do(self.message.requeue) class ObsoleteReplyQueuesCache(object): @@ -199,9 +237,11 @@ class AMQPListener(base.PollStyleListener): self.conn = conn self.msg_id_cache = rpc_amqp._MsgIdCache() self.incoming = [] - self._stopped = threading.Event() + self._shutdown = threading.Event() + self._shutoff = threading.Event() self._obsolete_reply_queues = ObsoleteReplyQueuesCache() - self._pending_tasks = moves.queue.Queue() + self._message_operations_handler = MessageOperationsHandler( + "AMQPListener") self._current_timeout = ACK_REQUEUE_EVERY_SECONDS_MIN def __call__(self, message): @@ -222,14 +262,14 @@ class AMQPListener(base.PollStyleListener): ctxt.msg_id, ctxt.reply_q, self._obsolete_reply_queues, - self._pending_tasks)) + self._message_operations_handler)) @base.batch_poll_helper def poll(self, timeout=None): stopwatch = timeutils.StopWatch(duration=timeout).start() - while not self._stopped.is_set(): - do_pending_tasks(self._pending_tasks) + while not self._shutdown.is_set(): + self._message_operations_handler.process() if self.incoming: return self.incoming.pop(0) @@ -248,12 +288,30 @@ class AMQPListener(base.PollStyleListener): else: self._current_timeout = ACK_REQUEUE_EVERY_SECONDS_MIN + # NOTE(sileht): listener is stopped, just processes remaining messages + # and operations + self._message_operations_handler.process() + if self.incoming: + return self.incoming.pop(0) + + self._shutoff.set() + def stop(self): - self._stopped.set() + self._shutdown.set() self.conn.stop_consuming() - do_pending_tasks(self._pending_tasks) + self._shutoff.wait() + + # NOTE(sileht): Here, the listener is stopped, but some incoming + # messages may still live on server side, because callback is still + # running and message is not yet ack/requeue. It's safe to do the ack + # into another thread, side the polling thread is now terminated. + self._message_operations_handler.process_in_background() def cleanup(self): + # NOTE(sileht): server executor is now stopped, we are sure that no + # more incoming messages in live, we can acknowledge + # remaining messages and stop the thread + self._message_operations_handler.stop() # Closes listener connection self.conn.close() @@ -306,7 +364,6 @@ class ReplyWaiter(object): self.allowed_remote_exmods = allowed_remote_exmods self.msg_id_cache = rpc_amqp._MsgIdCache() self.waiters = ReplyWaiters() - self._pending_tasks = moves.queue.Queue() self.conn.declare_direct_consumer(reply_q, self) @@ -321,12 +378,10 @@ class ReplyWaiter(object): self.conn.stop_consuming() self._thread.join() self._thread = None - do_pending_tasks(self._pending_tasks) def poll(self): current_timeout = ACK_REQUEUE_EVERY_SECONDS_MIN while not self._thread_exit_event.is_set(): - do_pending_tasks(self._pending_tasks) try: # ack every ACK_REQUEUE_EVERY_SECONDS_MAX seconds self.conn.consume(timeout=current_timeout) @@ -340,7 +395,11 @@ class ReplyWaiter(object): current_timeout = ACK_REQUEUE_EVERY_SECONDS_MIN def __call__(self, message): - self._pending_tasks.put(message.acknowledge) + # NOTE(sileht): __call__ is running within the polling thread, + # (conn.consume -> conn.conn.drain_events() -> __call__ callback) + # it's threadsafe to acknowledge the message here, no need to wait + # the next polling + message.acknowledge() incoming_msg_id = message.pop('_msg_id', None) if message.get('ending'): LOG.debug("received reply msg_id: %s", incoming_msg_id)