From 4275044b0d03197ec025ef9215ffd9b2bc30fa60 Mon Sep 17 00:00:00 2001 From: dukhlov Date: Fri, 26 Feb 2016 13:11:31 +0200 Subject: [PATCH] Implements pika thread safe connection ThreadSafePikaConnection works over SelectConnection which has asynchronous interface and provides synchronous interface It allows to use single connection in concurrent environment with nonblocking io approach. For this goal internal thread is added for processing ioloop. Also this approach allows to remove poller's thread and use ioloop's thread for calling on_incoming calbacks. It is also done by this patch This patch is important to use single connection for whole process in future Change-Id: I12c16715f6bf8a99e438bc054f9d0132a09cecf3 Depends-On: I41a768c5624fa2212257ce20bf9a67d09de0c4ab --- oslo_messaging/_drivers/impl_pika.py | 22 +- .../_drivers/pika_driver/pika_connection.py | 497 ++++++++++++++++++ .../_drivers/pika_driver/pika_engine.py | 193 ++++--- .../_drivers/pika_driver/pika_listener.py | 57 +- .../_drivers/pika_driver/pika_poller.py | 361 ++++++++----- .../tests/drivers/pika/test_poller.py | 243 ++++----- 6 files changed, 919 insertions(+), 454 deletions(-) create mode 100644 oslo_messaging/_drivers/pika_driver/pika_connection.py diff --git a/oslo_messaging/_drivers/impl_pika.py b/oslo_messaging/_drivers/impl_pika.py index 42686350d..12590abe9 100644 --- a/oslo_messaging/_drivers/impl_pika.py +++ b/oslo_messaging/_drivers/impl_pika.py @@ -335,26 +335,18 @@ class PikaDriver(base.BaseDriver): ) def listen(self, target, batch_size, batch_timeout): - listener = pika_drv_poller.RpcServicePikaPoller( - self._pika_engine, target, - prefetch_count=self._pika_engine.rpc_listener_prefetch_count + return pika_drv_poller.RpcServicePikaPoller( + self._pika_engine, target, batch_size, batch_timeout, + self._pika_engine.rpc_listener_prefetch_count ) - listener.start() - return base.PollStyleListenerAdapter(listener, batch_size, - batch_timeout) def listen_for_notifications(self, targets_and_priorities, pool, batch_size, batch_timeout): - listener = pika_drv_poller.NotificationPikaPoller( - self._pika_engine, targets_and_priorities, - prefetch_count=( - self._pika_engine.notification_listener_prefetch_count - ), - queue_name=pool + return pika_drv_poller.NotificationPikaPoller( + self._pika_engine, targets_and_priorities, batch_size, + batch_timeout, + self._pika_engine.notification_listener_prefetch_count, pool ) - listener.start() - return base.PollStyleListenerAdapter(listener, batch_size, - batch_timeout) def cleanup(self): self._reply_listener.cleanup() diff --git a/oslo_messaging/_drivers/pika_driver/pika_connection.py b/oslo_messaging/_drivers/pika_driver/pika_connection.py new file mode 100644 index 000000000..a2b1d00f8 --- /dev/null +++ b/oslo_messaging/_drivers/pika_driver/pika_connection.py @@ -0,0 +1,497 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import logging +import os +import threading + +import futurist +import pika +from pika.adapters import select_connection +from pika import exceptions as pika_exceptions +from pika import spec as pika_spec + +from oslo_utils import eventletutils + +current_thread = eventletutils.fetch_current_thread_functor() + +LOG = logging.getLogger(__name__) + + +class ThreadSafePikaConnection(object): + def __init__(self, params=None): + self.params = params + self._connection_lock = threading.Lock() + self._evt_closed = threading.Event() + self._task_queue = collections.deque() + self._pending_connection_futures = set() + + create_connection_future = self._register_pending_future() + + def on_open_error(conn, err): + create_connection_future.set_exception( + pika_exceptions.AMQPConnectionError(err) + ) + + self._impl = pika.SelectConnection( + parameters=params, + on_open_callback=create_connection_future.set_result, + on_open_error_callback=on_open_error, + on_close_callback=self._on_connection_close, + stop_ioloop_on_close=False, + ) + self._interrupt_pipein, self._interrupt_pipeout = os.pipe() + self._impl.ioloop.add_handler(self._interrupt_pipein, + self._impl.ioloop.read_interrupt, + select_connection.READ) + + self._thread = threading.Thread(target=self._process_io) + self._thread.daemon = True + self._thread_id = None + self._thread.start() + + create_connection_future.result() + + def _execute_task(self, func, *args, **kwargs): + if current_thread() == self._thread_id: + return func(*args, **kwargs) + + future = futurist.Future() + self._task_queue.append((func, args, kwargs, future)) + + if self._evt_closed.is_set(): + self._notify_all_futures_connection_close() + elif self._interrupt_pipeout is not None: + os.write(self._interrupt_pipeout, b'X') + + return future.result() + + def _register_pending_future(self): + future = futurist.Future() + self._pending_connection_futures.add(future) + + def on_done_callback(fut): + try: + self._pending_connection_futures.remove(fut) + except KeyError: + pass + + future.add_done_callback(on_done_callback) + + if self._evt_closed.is_set(): + self._notify_all_futures_connection_close() + return future + + def _notify_all_futures_connection_close(self): + while self._task_queue: + try: + method_res_future = self._task_queue.pop()[3] + except KeyError: + break + else: + method_res_future.set_exception( + pika_exceptions.ConnectionClosed() + ) + + while self._pending_connection_futures: + try: + pending_connection_future = ( + self._pending_connection_futures.pop() + ) + except KeyError: + break + else: + pending_connection_future.set_exception( + pika_exceptions.ConnectionClosed() + ) + + def _on_connection_close(self, conn, reply_code, reply_text): + self._evt_closed.set() + self._notify_all_futures_connection_close() + if self._interrupt_pipeout: + os.close(self._interrupt_pipeout) + os.close(self._interrupt_pipein) + + def add_on_close_callback(self, callback): + return self._execute_task(self._impl.add_on_close_callback, callback) + + def _do_process_io(self): + while self._task_queue: + func, args, kwargs, future = self._task_queue.pop() + try: + res = func(*args, **kwargs) + except BaseException as e: + LOG.exception(e) + future.set_exception(e) + else: + future.set_result(res) + + self._impl.ioloop.poll() + self._impl.ioloop.process_timeouts() + + def _process_io(self): + self._thread_id = current_thread() + while not self._evt_closed.is_set(): + try: + self._do_process_io() + except BaseException: + LOG.exception("Error during processing connection's IO") + + def close(self, *args, **kwargs): + res = self._execute_task(self._impl.close, *args, **kwargs) + + self._evt_closed.wait() + self._thread.join() + return res + + def channel(self, channel_number=None): + channel_opened_future = self._register_pending_future() + + impl_channel = self._execute_task( + self._impl.channel, + on_open_callback=channel_opened_future.set_result, + channel_number=channel_number + ) + + # Create our proxy channel + channel = ThreadSafePikaChannel(impl_channel, self) + + # Link implementation channel with our proxy channel + impl_channel._set_cookie(channel) + + channel_opened_future.result() + return channel + + def add_timeout(self, timeout, callback): + return self._execute_task(self._impl.add_timeout, timeout, callback) + + def remove_timeout(self, timeout_id): + return self._execute_task(self._impl.remove_timeout, timeout_id) + + @property + def is_closed(self): + return self._impl.is_closed + + @property + def is_closing(self): + return self._impl.is_closing + + @property + def is_open(self): + return self._impl.is_open + + +class ThreadSafePikaChannel(object): # pylint: disable=R0904,R0902 + + def __init__(self, channel_impl, connection): + self._impl = channel_impl + self._connection = connection + + self._delivery_confirmation = False + + self._message_returned = False + self._current_future = None + + self._evt_closed = threading.Event() + + self.add_on_close_callback(self._on_channel_close) + + def _execute_task(self, func, *args, **kwargs): + return self._connection._execute_task(func, *args, **kwargs) + + def _on_channel_close(self, channel, reply_code, reply_text): + self._evt_closed.set() + + if self._current_future: + self._current_future.set_exception( + pika_exceptions.ChannelClosed(reply_code, reply_text)) + + def _on_message_confirmation(self, frame): + self._current_future.set_result(frame) + + def add_on_close_callback(self, callback): + self._execute_task(self._impl.add_on_close_callback, callback) + + def add_on_cancel_callback(self, callback): + self._execute_task(self._impl.add_on_cancel_callback, callback) + + def __int__(self): + return self.channel_number + + @property + def channel_number(self): + return self._impl.channel_number + + @property + def is_closed(self): + return self._impl.is_closed + + @property + def is_closing(self): + return self._impl.is_closing + + @property + def is_open(self): + return self._impl.is_open + + def close(self, reply_code=0, reply_text="Normal Shutdown"): + self._impl.close(reply_code=reply_code, reply_text=reply_text) + self._evt_closed.wait() + + def flow(self, active): + self._current_future = futurist.Future() + self._execute_task( + self._impl.flow, callback=self._current_future.set_result, + active=active + ) + return self._current_future.result() + + def basic_consume(self, # pylint: disable=R0913 + consumer_callback, + queue, + no_ack=False, + exclusive=False, + consumer_tag=None, + arguments=None): + self._current_future = futurist.Future() + self._execute_task( + self._impl.add_callback, self._current_future.set_result, + replies=[pika_spec.Basic.ConsumeOk], one_shot=True + ) + + self._impl.add_callback(self._current_future.set_result, + replies=[pika_spec.Basic.ConsumeOk], + one_shot=True) + tag = self._execute_task( + self._impl.basic_consume, + consumer_callback=consumer_callback, + queue=queue, + no_ack=no_ack, + exclusive=exclusive, + consumer_tag=consumer_tag, + arguments=arguments + ) + + self._current_future.result() + return tag + + def basic_cancel(self, consumer_tag): + self._current_future = futurist.Future() + self._execute_task( + self._impl.basic_cancel, + callback=self._current_future.set_result, + consumer_tag=consumer_tag, + nowait=False) + self._current_future.result() + + def basic_ack(self, delivery_tag=0, multiple=False): + return self._execute_task( + self._impl.basic_ack, delivery_tag=delivery_tag, multiple=multiple) + + def basic_nack(self, delivery_tag=None, multiple=False, requeue=True): + return self._execute_task( + self._impl.basic_nack, delivery_tag=delivery_tag, + multiple=multiple, requeue=requeue + ) + + def publish(self, exchange, routing_key, body, # pylint: disable=R0913 + properties=None, mandatory=False, immediate=False): + + if self._delivery_confirmation: + # In publisher-acknowledgments mode + self._message_returned = False + self._current_future = futurist.Future() + + self._execute_task(self._impl.basic_publish, + exchange=exchange, + routing_key=routing_key, + body=body, + properties=properties, + mandatory=mandatory, + immediate=immediate) + + conf_method = self._current_future.result().method + + if isinstance(conf_method, pika_spec.Basic.Nack): + raise pika_exceptions.NackError((None,)) + else: + assert isinstance(conf_method, pika_spec.Basic.Ack), ( + conf_method) + + if self._message_returned: + raise pika_exceptions.UnroutableError((None,)) + else: + # In non-publisher-acknowledgments mode + self._execute_task(self._impl.basic_publish, + exchange=exchange, + routing_key=routing_key, + body=body, + properties=properties, + mandatory=mandatory, + immediate=immediate) + + def basic_qos(self, prefetch_size=0, prefetch_count=0, all_channels=False): + self._current_future = futurist.Future() + self._execute_task(self._impl.basic_qos, + callback=self._current_future.set_result, + prefetch_size=prefetch_size, + prefetch_count=prefetch_count, + all_channels=all_channels) + self._current_future.result() + + def basic_recover(self, requeue=False): + self._current_future = futurist.Future() + self._execute_task( + self._impl.basic_recover, + callback=lambda: self._current_future.set_result(None), + requeue=requeue + ) + self._current_future.result() + + def basic_reject(self, delivery_tag=None, requeue=True): + self._execute_task(self._impl.basic_reject, + delivery_tag=delivery_tag, + requeue=requeue) + + def _on_message_returned(self, *args, **kwargs): + self._message_returned = True + + def confirm_delivery(self): + self._current_future = futurist.Future() + self._execute_task(self._impl.add_callback, + callback=self._current_future.set_result, + replies=[pika_spec.Confirm.SelectOk], + one_shot=True) + self._execute_task(self._impl.confirm_delivery, + callback=self._on_message_confirmation, + nowait=False) + self._current_future.result() + + self._delivery_confirmation = True + self._execute_task(self._impl.add_on_return_callback, + self._on_message_returned) + + def exchange_declare(self, exchange=None, # pylint: disable=R0913 + exchange_type='direct', passive=False, durable=False, + auto_delete=False, internal=False, + arguments=None, **kwargs): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_declare, + callback=self._current_future.set_result, + exchange=exchange, + exchange_type=exchange_type, + passive=passive, + durable=durable, + auto_delete=auto_delete, + internal=internal, + nowait=False, + arguments=arguments, + type=kwargs["type"] if kwargs else None) + + return self._current_future.result() + + def exchange_delete(self, exchange=None, if_unused=False): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_delete, + callback=self._current_future.set_result, + exchange=exchange, + if_unused=if_unused, + nowait=False) + + return self._current_future.result() + + def exchange_bind(self, destination=None, source=None, routing_key='', + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_bind, + callback=self._current_future.set_result, + destination=destination, + source=source, + routing_key=routing_key, + nowait=False, + arguments=arguments) + + return self._current_future.result() + + def exchange_unbind(self, destination=None, source=None, routing_key='', + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_unbind, + callback=self._current_future.set_result, + destination=destination, + source=source, + routing_key=routing_key, + nowait=False, + arguments=arguments) + + return self._current_future.result() + + def queue_declare(self, queue='', passive=False, durable=False, + exclusive=False, auto_delete=False, + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_declare, + callback=self._current_future.set_result, + queue=queue, + passive=passive, + durable=durable, + exclusive=exclusive, + auto_delete=auto_delete, + nowait=False, + arguments=arguments) + + return self._current_future.result() + + def queue_delete(self, queue='', if_unused=False, if_empty=False): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_delete, + callback=self._current_future.set_result, + queue=queue, + if_unused=if_unused, + if_empty=if_empty, + nowait=False) + + return self._current_future.result() + + def queue_purge(self, queue=''): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_purge, + callback=self._current_future.set_result, + queue=queue, + nowait=False) + return self._current_future.result() + + def queue_bind(self, queue, exchange, routing_key=None, + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_bind, + callback=self._current_future.set_result, + queue=queue, + exchange=exchange, + routing_key=routing_key, + nowait=False, + arguments=arguments) + return self._current_future.result() + + def queue_unbind(self, queue='', exchange=None, routing_key=None, + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_unbind, + callback=self._current_future.set_result, + queue=queue, + exchange=exchange, + routing_key=routing_key, + arguments=arguments) + return self._current_future.result() diff --git a/oslo_messaging/_drivers/pika_driver/pika_engine.py b/oslo_messaging/_drivers/pika_driver/pika_engine.py index 3b762c327..65e87fc99 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_engine.py +++ b/oslo_messaging/_drivers/pika_driver/pika_engine.py @@ -11,7 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - +import os import random import socket import threading @@ -19,38 +19,18 @@ import time from oslo_log import log as logging import pika -from pika.adapters import select_connection from pika import credentials as pika_credentials import pika_pool - import uuid from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns +from oslo_messaging._drivers.pika_driver import pika_connection from oslo_messaging._drivers.pika_driver import pika_exceptions as pika_drv_exc LOG = logging.getLogger(__name__) - -def _create_select_poller_connection_impl( - parameters, on_open_callback, on_open_error_callback, - on_close_callback, stop_ioloop_on_close): - """Used for disabling autochoise of poller ('select', 'poll', 'epool', etc) - inside default 'SelectConnection.__init__(...)' logic. It is necessary to - force 'select' poller usage if eventlet is monkeypatched because eventlet - patches only 'select' system call - - Method signature is copied form 'SelectConnection.__init__(...)', because - it is used as replacement of 'SelectConnection' class to create instances - """ - return select_connection.SelectConnection( - parameters=parameters, - on_open_callback=on_open_callback, - on_open_error_callback=on_open_error_callback, - on_close_callback=on_close_callback, - stop_ioloop_on_close=stop_ioloop_on_close, - custom_ioloop=select_connection.SelectPoller() - ) +_PID = None class _PooledConnectionWithConfirmations(pika_pool.Connection): @@ -84,10 +64,6 @@ class PikaEngine(object): allowed_remote_exmods=None): self.conf = conf - self._force_select_poller_use = ( - pika_drv_cmns.is_eventlet_monkey_patched('select') - ) - # processing rpc options self.default_rpc_exchange = ( conf.oslo_messaging_pika.default_rpc_exchange @@ -168,7 +144,7 @@ class PikaEngine(object): ) # initializing connection parameters for configured RabbitMQ hosts - common_pika_params = { + self._common_pika_params = { 'virtual_host': url.virtual_host, 'channel_max': self.conf.oslo_messaging_pika.channel_max, 'frame_max': self.conf.oslo_messaging_pika.frame_max, @@ -177,31 +153,18 @@ class PikaEngine(object): 'socket_timeout': self.conf.oslo_messaging_pika.socket_timeout, } - self._connection_lock = threading.Lock() + self._connection_lock = threading.RLock() + self._pid = None - self._connection_host_param_list = [] - self._connection_host_status_list = [] + self._connection_host_status = {} if not url.hosts: raise ValueError("You should provide at least one RabbitMQ host") - for transport_host in url.hosts: - pika_params = common_pika_params.copy() - pika_params.update( - host=transport_host.hostname, - port=transport_host.port, - credentials=pika_credentials.PlainCredentials( - transport_host.username, transport_host.password - ), - ) - self._connection_host_param_list.append(pika_params) - self._connection_host_status_list.append({ - self.HOST_CONNECTION_LAST_TRY_TIME: 0, - self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0 - }) + self._host_list = url.hosts - self._next_connection_host_num = random.randint( - 0, len(self._connection_host_param_list) - 1 + self._cur_connection_host_num = random.randint( + 0, len(self._host_list) - 1 ) # initializing 2 connection pools: 1st for connections without @@ -228,49 +191,37 @@ class PikaEngine(object): _PooledConnectionWithConfirmations ) - def _next_connection_num(self): - """Used for creating connections to different RabbitMQ nodes in - round robin order - - :return: next host number to create connection to - """ - with self._connection_lock: - cur_num = self._next_connection_host_num - self._next_connection_host_num += 1 - self._next_connection_host_num %= len( - self._connection_host_param_list - ) - return cur_num - def create_connection(self, for_listening=False): """Create and return connection to any available host. :return: created connection :raise: ConnectionException if all hosts are not reachable """ - host_count = len(self._connection_host_param_list) - connection_attempts = host_count - pika_next_connection_num = self._next_connection_num() + with self._connection_lock: + self._init_if_needed() - while connection_attempts > 0: - try: - return self.create_host_connection( - pika_next_connection_num, for_listening - ) - except pika_pool.Connection.connectivity_errors as e: - LOG.warning("Can't establish connection to host. %s", e) - except pika_drv_exc.HostConnectionNotAllowedException as e: - LOG.warning("Connection to host is not allowed. %s", e) + host_count = len(self._host_list) + connection_attempts = host_count - connection_attempts -= 1 - pika_next_connection_num += 1 - pika_next_connection_num %= host_count + while connection_attempts > 0: + self._cur_connection_host_num += 1 + self._cur_connection_host_num %= host_count + try: + return self.create_host_connection( + self._cur_connection_host_num, for_listening + ) + except pika_pool.Connection.connectivity_errors as e: + LOG.warning("Can't establish connection to host. %s", e) + except pika_drv_exc.HostConnectionNotAllowedException as e: + LOG.warning("Connection to host is not allowed. %s", e) - raise pika_drv_exc.EstablishConnectionException( - "Can not establish connection to any configured RabbitMQ host: " + - str(self._connection_host_param_list) - ) + connection_attempts -= 1 + + raise pika_drv_exc.EstablishConnectionException( + "Can not establish connection to any configured RabbitMQ " + "host: " + str(self._host_list) + ) def _set_tcp_user_timeout(self, s): if not self._tcp_user_timeout: @@ -285,28 +236,64 @@ class PikaEngine(object): "Whoops, this kernel doesn't seem to support TCP_USER_TIMEOUT." ) + def _init_if_needed(self): + global _PID + + cur_pid = os.getpid() + + if _PID != cur_pid: + if _PID: + LOG.warning("New pid is detected. Old: %s, new: %s. " + "Cleaning up...", _PID, cur_pid) + # Note(dukhlov): we need to force select poller usage in case when + # 'thread' module is monkey patched becase current eventlet + # implementation does not support patching of poll/epoll/kqueue + if pika_drv_cmns.is_eventlet_monkey_patched("thread"): + from pika.adapters import select_connection + select_connection.SELECT_TYPE = "select" + + _PID = cur_pid + def create_host_connection(self, host_index, for_listening=False): """Create new connection to host #host_index + :param host_index: Integer, number of host for connection establishing :param for_listening: Boolean, creates connection for listening - (enable heartbeats) if True + if True :return: New connection """ - - connection_params = pika.ConnectionParameters( - heartbeat_interval=( - self._heartbeat_interval if for_listening else None - ), - **self._connection_host_param_list[host_index] - ) - with self._connection_lock: + self._init_if_needed() + + host = self._host_list[host_index] + + connection_params = pika.ConnectionParameters( + host=host.hostname, + port=host.port, + credentials=pika_credentials.PlainCredentials( + host.username, host.password + ), + heartbeat_interval=( + self._heartbeat_interval if for_listening else None + ), + **self._common_pika_params + ) + cur_time = time.time() - last_success_time = self._connection_host_status_list[host_index][ + host_connection_status = self._connection_host_status.get(host) + + if host_connection_status is None: + host_connection_status = { + self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0, + self.HOST_CONNECTION_LAST_TRY_TIME: 0 + } + self._connection_host_status[host] = host_connection_status + + last_success_time = host_connection_status[ self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME ] - last_time = self._connection_host_status_list[host_index][ + last_time = host_connection_status[ self.HOST_CONNECTION_LAST_TRY_TIME ] @@ -322,27 +309,25 @@ class PikaEngine(object): ) try: - connection = pika.BlockingConnection( - parameters=connection_params, - _impl_class=(_create_select_poller_connection_impl - if self._force_select_poller_use else None) - ) - - # It is needed for pika_pool library which expects that - # connections has params attribute defined in BaseConnection - # but BlockingConnection is not derived from BaseConnection - # and doesn't have it - connection.params = connection_params + if for_listening: + connection = pika_connection.ThreadSafePikaConnection( + params=connection_params + ) + else: + connection = pika.BlockingConnection( + parameters=connection_params + ) + connection.params = connection_params self._set_tcp_user_timeout(connection._impl.socket) - self._connection_host_status_list[host_index][ + self._connection_host_status[host][ self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME ] = cur_time return connection finally: - self._connection_host_status_list[host_index][ + self._connection_host_status[host][ self.HOST_CONNECTION_LAST_TRY_TIME ] = cur_time diff --git a/oslo_messaging/_drivers/pika_driver/pika_listener.py b/oslo_messaging/_drivers/pika_driver/pika_listener.py index 1739c7932..1942fcc1c 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_listener.py +++ b/oslo_messaging/_drivers/pika_driver/pika_listener.py @@ -29,6 +29,7 @@ class RpcReplyPikaListener(object): """ def __init__(self, pika_engine): + super(RpcReplyPikaListener, self).__init__() self._pika_engine = pika_engine # preparing poller for listening replies @@ -39,7 +40,6 @@ class RpcReplyPikaListener(object): self._reply_consumer_initialized = False self._reply_consumer_initialization_lock = threading.Lock() - self._poller_thread = None self._shutdown = False def get_reply_qname(self): @@ -66,51 +66,31 @@ class RpcReplyPikaListener(object): # initialize reply poller if needed if self._reply_poller is None: self._reply_poller = pika_drv_poller.RpcReplyPikaPoller( - pika_engine=self._pika_engine, - exchange=self._pika_engine.rpc_reply_exchange, - queue=self._reply_queue, - prefetch_count=( - self._pika_engine.rpc_reply_listener_prefetch_count - ) + self._pika_engine, self._pika_engine.rpc_reply_exchange, + self._reply_queue, 1, None, + self._pika_engine.rpc_reply_listener_prefetch_count ) - self._reply_poller.start() - - # start reply poller job thread if needed - if self._poller_thread is None: - self._poller_thread = threading.Thread(target=self._poller) - self._poller_thread.daemon = True - - if not self._poller_thread.is_alive(): - self._poller_thread.start() - + self._reply_poller.start(self._on_incoming) self._reply_consumer_initialized = True return self._reply_queue - def _poller(self): + def _on_incoming(self, incoming): """Reply polling job. Poll replies in infinite loop and notify registered features """ - while True: + for message in incoming: try: - messages = self._reply_poller.poll() - if not messages and self._shutdown: - break - - for message in messages: - try: - message.acknowledge() - future = self._reply_waiting_futures.pop( - message.msg_id, None - ) - if future is not None: - future.set_result(message) - except Exception: - LOG.exception("Unexpected exception during processing" - "reply message") - except BaseException: - LOG.exception("Unexpected exception during reply polling") + message.acknowledge() + future = self._reply_waiting_futures.pop( + message.msg_id, None + ) + if future is not None: + future.set_result(message) + except Exception: + LOG.exception("Unexpected exception during processing" + "reply message") def register_reply_waiter(self, msg_id): """Register reply waiter. Should be called before message sending to @@ -140,9 +120,4 @@ class RpcReplyPikaListener(object): self._reply_poller.cleanup() self._reply_poller = None - if self._poller_thread: - if self._poller_thread.is_alive(): - self._poller_thread.join() - self._poller_thread = None - self._reply_queue = None diff --git a/oslo_messaging/_drivers/pika_driver/pika_poller.py b/oslo_messaging/_drivers/pika_driver/pika_poller.py index 35bf41191..d34ddfe77 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_poller.py +++ b/oslo_messaging/_drivers/pika_driver/pika_poller.py @@ -13,11 +13,9 @@ # under the License. import threading -import time from oslo_log import log as logging -from oslo_utils import timeutils -import six +from oslo_service import loopingcall from oslo_messaging._drivers import base from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns @@ -27,45 +25,188 @@ from oslo_messaging._drivers.pika_driver import pika_message as pika_drv_msg LOG = logging.getLogger(__name__) -class PikaPoller(base.PollStyleListener): +class PikaPoller(base.Listener): """Provides user friendly functionality for RabbitMQ message consuming, handles low level connectivity problems and restore connection if some connectivity related problem detected """ - def __init__(self, pika_engine, prefetch_count, incoming_message_class): + def __init__(self, pika_engine, batch_size, batch_timeout, prefetch_count, + incoming_message_class): """Initialize required fields :param pika_engine: PikaEngine, shared object with configuration and shared driver functionality + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer :param incoming_message_class: PikaIncomingMessage, wrapper for consumed RabbitMQ message """ - super(PikaPoller, self).__init__(prefetch_count) + super(PikaPoller, self).__init__(batch_size, batch_timeout, + prefetch_count) self._pika_engine = pika_engine self._incoming_message_class = incoming_message_class self._connection = None self._channel = None - self._lock = threading.Lock() + self._recover_loopingcall = None + self._lock = threading.RLock() + + self._cur_batch_buffer = None + self._cur_batch_timeout_id = None self._started = False + self._closing_connection_by_poller = False self._queues_to_consume = None - self._message_queue = [] + def _on_connection_close(self, connection, reply_code, reply_text): + self._deliver_cur_batch() + if self._closing_connection_by_poller: + return + with self._lock: + self._connection = None + self._start_recover_consuming_task() - def _reconnect(self): + def _on_channel_close(self, channel, reply_code, reply_text): + if self._cur_batch_buffer: + self._cur_batch_buffer = [ + message for message in self._cur_batch_buffer + if not message.need_ack() + ] + if self._closing_connection_by_poller: + return + with self._lock: + self._channel = None + self._start_recover_consuming_task() + + def _on_consumer_cancel(self, method_frame): + with self._lock: + if self._queues_to_consume: + consumer_tag = method_frame.method.consumer_tag + for queue_info in self._queues_to_consume: + if queue_info["consumer_tag"] == consumer_tag: + queue_info["consumer_tag"] = None + + self._start_recover_consuming_task() + + def _on_message_no_ack_callback(self, unused, method, properties, body): + """Is called by Pika when message was received from queue listened with + no_ack=True mode + """ + incoming_message = self._incoming_message_class( + self._pika_engine, None, method, properties, body + ) + self._on_incoming_message(incoming_message) + + def _on_message_with_ack_callback(self, unused, method, properties, body): + """Is called by Pika when message was received from queue listened with + no_ack=False mode + """ + incoming_message = self._incoming_message_class( + self._pika_engine, self._channel, method, properties, body + ) + self._on_incoming_message(incoming_message) + + def _deliver_cur_batch(self): + if self._cur_batch_timeout_id is not None: + self._connection.remove_timeout(self._cur_batch_timeout_id) + self._cur_batch_timeout_id = None + if self._cur_batch_buffer: + buf_to_send = self._cur_batch_buffer + self._cur_batch_buffer = None + try: + self.on_incoming_callback(buf_to_send) + except Exception: + LOG.exception("Unexpected exception during incoming delivery") + + def _on_incoming_message(self, incoming_message): + if self._cur_batch_buffer is None: + self._cur_batch_buffer = [incoming_message] + else: + self._cur_batch_buffer.append(incoming_message) + + if len(self._cur_batch_buffer) >= self.batch_size: + self._deliver_cur_batch() + return + + if self._cur_batch_timeout_id is None: + self._cur_batch_timeout_id = self._connection.add_timeout( + self.batch_timeout, self._deliver_cur_batch) + + def _start_recover_consuming_task(self): + """Start async job for checking connection to the broker.""" + if self._recover_loopingcall is None and self._started: + self._recover_loopingcall = ( + loopingcall.DynamicLoopingCall( + self._try_recover_consuming + ) + ) + LOG.info("Starting recover consuming job for listener: %s", self) + self._recover_loopingcall.start() + + def _try_recover_consuming(self): + with self._lock: + try: + if self._started: + self._start_or_recover_consuming() + except pika_drv_exc.EstablishConnectionException as e: + LOG.warning( + "Problem during establishing connection for pika " + "poller %s", e, exc_info=True + ) + return self._pika_engine.host_connection_reconnect_delay + except pika_drv_exc.ConnectionException as e: + LOG.warning( + "Connectivity exception during starting/recovering pika " + "poller %s", e, exc_info=True + ) + except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as e: + LOG.warning( + "Connectivity exception during starting/recovering pika " + "poller %s", e, exc_info=True + ) + except BaseException: + # NOTE (dukhlov): I preffer to use here BaseException because + # if this method raise such exception LoopingCall stops + # execution Probably it should never happen and Exception + # should be enough but in case of programmer mistake it could + # be and it is potentially hard to catch problem if we will + # stop background task. It is better when it continue to work + # and write a lot of LOG with this error + LOG.exception("Unexpected exception during " + "starting/recovering pika poller") + else: + self._recover_loopingcall = None + LOG.info("Recover consuming job was finished for listener: %s", + self) + raise loopingcall.LoopingCallDone(True) + return 0 + + def _start_or_recover_consuming(self): """Performs reconnection to the broker. It is unsafe method for internal use only """ - self._connection = self._pika_engine.create_connection( - for_listening=True - ) - self._channel = self._connection.channel() - self._channel.basic_qos(prefetch_count=self.prefetch_size) + if self._connection is None or not self._connection.is_open: + self._connection = self._pika_engine.create_connection( + for_listening=True + ) + self._connection.add_on_close_callback(self._on_connection_close) + self._channel = None + + if self._channel is None or not self._channel.is_open: + if self._queues_to_consume: + for queue_info in self._queues_to_consume: + queue_info["consumer_tag"] = None + + self._channel = self._connection.channel() + self._channel.add_on_close_callback(self._on_channel_close) + self._channel.add_on_cancel_callback(self._on_consumer_cancel) + self._channel.basic_qos(prefetch_count=self.prefetch_size) if self._queues_to_consume is None: self._queues_to_consume = self._declare_queue_binding() @@ -92,6 +233,8 @@ class PikaPoller(base.PollStyleListener): try: for queue_info in self._queues_to_consume: + if queue_info["consumer_tag"] is not None: + continue no_ack = queue_info["no_ack"] on_message_callback = ( @@ -120,168 +263,95 @@ class PikaPoller(base.PollStyleListener): self._channel.basic_cancel(consumer_tag) queue_info["consumer_tag"] = None - def _on_message_no_ack_callback(self, unused, method, properties, body): - """Is called by Pika when message was received from queue listened with - no_ack=True mode - """ - self._message_queue.append( - self._incoming_message_class( - self._pika_engine, None, method, properties, body - ) - ) - - def _on_message_with_ack_callback(self, unused, method, properties, body): - """Is called by Pika when message was received from queue listened with - no_ack=False mode - """ - self._message_queue.append( - self._incoming_message_class( - self._pika_engine, self._channel, method, properties, body - ) - ) - - def _cleanup(self): - """Cleanup allocated resources (channel, connection, etc). It is unsafe - method for internal use only - """ - if self._connection: - try: - self._connection.close() - except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS: - # expected errors - pass - except Exception: - LOG.exception("Unexpected error during closing connection") - finally: - self._channel = None - self._connection = None - - for i in six.moves.range(len(self._message_queue) - 1, -1, -1): - message = self._message_queue[i] - if message.need_ack(): - del self._message_queue[i] - - @base.batch_poll_helper - def poll(self, timeout=None): - """Main method of this class - consumes message from RabbitMQ - - :param: timeout: float, seconds, timeout for waiting new incoming - message, None means wait forever - :return: list of PikaIncomingMessage, RabbitMQ messages - """ - - with timeutils.StopWatch(timeout) as stop_watch: - while True: - with self._lock: - if self._message_queue: - return self._message_queue.pop(0) - - if stop_watch.expired(): - return None - - try: - if self._started: - if self._channel is None: - self._reconnect() - # we need some time_limit here, not too small to - # avoid a lot of not needed iterations but not too - # large to release lock time to time and give a - # chance to perform another method waiting this - # lock - self._connection.process_data_events( - time_limit=0.25 - ) - else: - # consumer is stopped so we don't expect new - # messages, just process already sent events - if self._channel is not None: - self._connection.process_data_events( - time_limit=0 - ) - - # and return if we don't see new messages - if not self._message_queue: - return None - except pika_drv_exc.EstablishConnectionException as e: - LOG.warning( - "Problem during establishing connection for pika " - "poller %s", e, exc_info=True - ) - time.sleep( - self._pika_engine.host_connection_reconnect_delay - ) - except pika_drv_exc.ConnectionException: - self._cleanup() - raise - except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS: - self._cleanup() - raise - - def start(self): + def start(self, on_incoming_callback): """Starts poller. Should be called before polling to allow message consuming + + :param on_incoming_callback: callback function to be executed when + listener received messages. Messages should be processed and + acked/nacked by callback """ + super(PikaPoller, self).start(on_incoming_callback) + with self._lock: if self._started: return - + connected = False try: - self._reconnect() + self._start_or_recover_consuming() except pika_drv_exc.EstablishConnectionException as exc: LOG.warning( "Can not establish connection during pika poller's " - "start(). Connecting is required during first poll() " - "call. %s", exc, exc_info=True + "start(). %s", exc, exc_info=True ) except pika_drv_exc.ConnectionException as exc: - self._cleanup() LOG.warning( - "Connectivity problem during pika poller's start(). " - "Reconnecting required during first poll() call. %s", + "Connectivity problem during pika poller's start(). %s", exc, exc_info=True ) except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc: - self._cleanup() LOG.warning( - "Connectivity problem during pika poller's start(). " - "Reconnecting required during first poll() call. %s", + "Connectivity problem during pika poller's start(). %s", exc, exc_info=True ) + else: + connected = True + self._started = True + if not connected: + self._start_recover_consuming_task() def stop(self): """Stops poller. Should be called when polling is not needed anymore to stop new message consuming. After that it is necessary to poll already prefetched messages """ + super(PikaPoller, self).stop() + with self._lock: if not self._started: return - if self._queues_to_consume and self._channel: + if self._recover_loopingcall is not None: + self._recover_loopingcall.stop() + self._recover_loopingcall = None + + if (self._queues_to_consume and self._channel and + self._channel.is_open): try: self._stop_consuming() except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc: - self._cleanup() LOG.warning( "Connectivity problem detected during consumer " "cancellation. %s", exc, exc_info=True ) + self._deliver_cur_batch() self._started = False def cleanup(self): - """Safe version of _cleanup. Cleans up allocated resources (channel, - connection, etc). - """ + """Cleanup allocated resources (channel, connection, etc).""" with self._lock: - self._cleanup() + if self._connection and self._connection.is_open: + try: + self._closing_connection_by_poller = True + self._connection.close() + self._closing_connection_by_poller = False + except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS: + # expected errors + pass + except Exception: + LOG.exception("Unexpected error during closing connection") + finally: + self._channel = None + self._connection = None class RpcServicePikaPoller(PikaPoller): """PikaPoller implementation for polling RPC messages. Overrides base functionality according to RPC specific """ - def __init__(self, pika_engine, target, prefetch_count): + def __init__(self, pika_engine, target, batch_size, batch_timeout, + prefetch_count): """Adds target parameter for declaring RPC specific exchanges and queues @@ -289,14 +359,18 @@ class RpcServicePikaPoller(PikaPoller): shared driver functionality :param target: Target, oslo.messaging Target object which defines RPC endpoint + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer """ self._target = target super(RpcServicePikaPoller, self).__init__( - pika_engine, prefetch_count=prefetch_count, - incoming_message_class=pika_drv_msg.RpcPikaIncomingMessage + pika_engine, batch_size, batch_timeout, prefetch_count, + pika_drv_msg.RpcPikaIncomingMessage ) def _declare_queue_binding(self): @@ -363,7 +437,8 @@ class RpcReplyPikaPoller(PikaPoller): """PikaPoller implementation for polling RPC reply messages. Overrides base functionality according to RPC reply specific """ - def __init__(self, pika_engine, exchange, queue, prefetch_count): + def __init__(self, pika_engine, exchange, queue, batch_size, batch_timeout, + prefetch_count): """Adds exchange and queue parameter for declaring exchange and queue used for RPC reply delivery @@ -371,6 +446,10 @@ class RpcReplyPikaPoller(PikaPoller): shared driver functionality :param exchange: String, exchange name used for RPC reply delivery :param queue: String, queue name used for RPC reply delivery + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer """ @@ -378,8 +457,8 @@ class RpcReplyPikaPoller(PikaPoller): self._queue = queue super(RpcReplyPikaPoller, self).__init__( - pika_engine=pika_engine, prefetch_count=prefetch_count, - incoming_message_class=pika_drv_msg.RpcReplyPikaIncomingMessage + pika_engine, batch_size, batch_timeout, prefetch_count, + pika_drv_msg.RpcReplyPikaIncomingMessage ) def _declare_queue_binding(self): @@ -404,8 +483,8 @@ class NotificationPikaPoller(PikaPoller): """PikaPoller implementation for polling Notification messages. Overrides base functionality according to Notification specific """ - def __init__(self, pika_engine, targets_and_priorities, prefetch_count, - queue_name=None): + def __init__(self, pika_engine, targets_and_priorities, + batch_size, batch_timeout, prefetch_count, queue_name=None): """Adds targets_and_priorities and queue_name parameter for declaring exchanges and queues used for notification delivery @@ -413,6 +492,10 @@ class NotificationPikaPoller(PikaPoller): shared driver functionality :param targets_and_priorities: list of (target, priority), defines default queue names for corresponding notification types + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer :param queue: String, alternative queue name used for this poller @@ -422,8 +505,8 @@ class NotificationPikaPoller(PikaPoller): self._queue_name = queue_name super(NotificationPikaPoller, self).__init__( - pika_engine, prefetch_count=prefetch_count, - incoming_message_class=pika_drv_msg.PikaIncomingMessage + pika_engine, batch_size, batch_timeout, prefetch_count, + pika_drv_msg.PikaIncomingMessage ) def _declare_queue_binding(self): @@ -442,7 +525,7 @@ class NotificationPikaPoller(PikaPoller): target.exchange or self._pika_engine.default_notification_exchange ), - queue = queue, + queue=queue, routing_key=routing_key, exchange_type='direct', queue_expiration=None, diff --git a/oslo_messaging/tests/drivers/pika/test_poller.py b/oslo_messaging/tests/drivers/pika/test_poller.py index 65492b4af..cfa69e720 100644 --- a/oslo_messaging/tests/drivers/pika/test_poller.py +++ b/oslo_messaging/tests/drivers/pika/test_poller.py @@ -13,9 +13,11 @@ # under the License. import socket +import threading import time import unittest +from concurrent import futures import mock from oslo_messaging._drivers.pika_driver import pika_poller @@ -32,26 +34,53 @@ class PikaPollerTestCase(unittest.TestCase): self._pika_engine.create_connection.return_value = ( self._poller_connection_mock ) + + self._executor = futures.ThreadPoolExecutor(1) + + def timer_task(timeout, callback): + time.sleep(timeout) + callback() + + self._poller_connection_mock.add_timeout.side_effect = ( + lambda *args: self._executor.submit(timer_task, *args) + ) + self._prefetch_count = 123 - def test_start_when_connection_unavailable(self): - incoming_message_class_mock = mock.Mock() + @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." + "_declare_queue_binding") + def test_start(self, declare_queue_binding_mock): poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, - incoming_message_class=incoming_message_class_mock + self._pika_engine, 1, None, self._prefetch_count, None + ) + + poller.start(None) + + self.assertTrue(self._pika_engine.create_connection.called) + self.assertTrue(self._poller_connection_mock.channel.called) + self.assertTrue(declare_queue_binding_mock.called) + + def test_start_when_connection_unavailable(self): + poller = pika_poller.PikaPoller( + self._pika_engine, 1, None, self._prefetch_count, None ) self._pika_engine.create_connection.side_effect = socket.timeout() # start() should not raise socket.timeout exception - poller.start() + poller.start(None) @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." "_declare_queue_binding") - def test_poll(self, declare_queue_binding_mock): + def test_message_processing(self, declare_queue_binding_mock): + res = [] + + def on_incoming_callback(incoming): + res.append(incoming) + incoming_message_class_mock = mock.Mock() poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, + self._pika_engine, 1, None, self._prefetch_count, incoming_message_class=incoming_message_class_mock ) unused = object() @@ -59,18 +88,14 @@ class PikaPollerTestCase(unittest.TestCase): properties = object() body = object() - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - unused, method, properties, body - ) + poller.start(on_incoming_callback) + poller._on_message_with_ack_callback( + unused, method, properties, body ) - poller.start() - res = poller.poll() - self.assertEqual(len(res), 1) - self.assertEqual(res[0], incoming_message_class_mock.return_value) + self.assertEqual(res[0], [incoming_message_class_mock.return_value]) incoming_message_class_mock.assert_called_once_with( self._pika_engine, self._poller_channel_mock, method, properties, body @@ -83,92 +108,39 @@ class PikaPollerTestCase(unittest.TestCase): @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." "_declare_queue_binding") - def test_poll_after_stop(self, declare_queue_binding_mock): + def test_message_processing_batch(self, declare_queue_binding_mock): incoming_message_class_mock = mock.Mock() - poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, - incoming_message_class=incoming_message_class_mock - ) n = 10 params = [] - for i in range(n): - params.append((object(), object(), object(), object())) + res = [] - def f(time_limit): - if poller._started: - for k in range(n): - poller._on_message_no_ack_callback( - *params[k] - ) + def on_incoming_callback(incoming): + res.append(incoming) - self._poller_connection_mock.process_data_events.side_effect = f - - poller.start() - res = poller.poll(batch_size=1) - self.assertEqual(len(res), 1) - self.assertEqual(res[0], incoming_message_class_mock.return_value) - self.assertEqual( - incoming_message_class_mock.call_args_list[0][0], - (self._pika_engine, None) + params[0][1:] - ) - - poller.stop() - - res2 = poller.poll(batch_size=n) - - self.assertEqual(len(res2), n - 1) - self.assertEqual(incoming_message_class_mock.call_count, n) - - self.assertEqual( - self._poller_connection_mock.process_data_events.call_count, 2) - - for i in range(n - 1): - self.assertEqual(res2[i], incoming_message_class_mock.return_value) - self.assertEqual( - incoming_message_class_mock.call_args_list[i + 1][0], - (self._pika_engine, None) + params[i + 1][1:] - ) - - self.assertTrue(self._pika_engine.create_connection.called) - self.assertTrue(self._poller_connection_mock.channel.called) - - self.assertTrue(declare_queue_binding_mock.called) - - @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." - "_declare_queue_binding") - def test_poll_batch(self, declare_queue_binding_mock): - incoming_message_class_mock = mock.Mock() poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, + self._pika_engine, n, None, self._prefetch_count, incoming_message_class=incoming_message_class_mock ) - n = 10 - params = [] - for i in range(n): params.append((object(), object(), object(), object())) - index = [0] + poller.start(on_incoming_callback) - def f(time_limit): + for i in range(n): poller._on_message_with_ack_callback( - *params[index[0]] + *params[i] ) - index[0] += 1 - self._poller_connection_mock.process_data_events.side_effect = f - - poller.start() - res = poller.poll(batch_size=n) - - self.assertEqual(len(res), n) + self.assertEqual(len(res), 1) + self.assertEqual(len(res[0]), 10) self.assertEqual(incoming_message_class_mock.call_count, n) for i in range(n): - self.assertEqual(res[i], incoming_message_class_mock.return_value) + self.assertEqual(res[0][i], + incoming_message_class_mock.return_value) self.assertEqual( incoming_message_class_mock.call_args_list[i][0], (self._pika_engine, self._poller_channel_mock) + params[i][1:] @@ -181,42 +153,48 @@ class PikaPollerTestCase(unittest.TestCase): @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." "_declare_queue_binding") - def test_poll_batch_with_timeout(self, declare_queue_binding_mock): + def test_message_processing_batch_with_timeout(self, + declare_queue_binding_mock): incoming_message_class_mock = mock.Mock() - poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, - incoming_message_class=incoming_message_class_mock - ) n = 10 timeout = 1 - sleep_time = 0.2 + + res = [] + evt = threading.Event() + + def on_incoming_callback(incoming): + res.append(incoming) + evt.set() + + poller = pika_poller.PikaPoller( + self._pika_engine, n, timeout, self._prefetch_count, + incoming_message_class=incoming_message_class_mock + ) + params = [] success_count = 5 + poller.start(on_incoming_callback) + for i in range(n): params.append((object(), object(), object(), object())) - index = [0] - - def f(time_limit): - time.sleep(sleep_time) + for i in range(success_count): poller._on_message_with_ack_callback( - *params[index[0]] + *params[i] ) - index[0] += 1 - self._poller_connection_mock.process_data_events.side_effect = f + self.assertTrue(evt.wait(timeout * 2)) - poller.start() - res = poller.poll(batch_size=n, timeout=timeout) - - self.assertEqual(len(res), success_count) + self.assertEqual(len(res), 1) + self.assertEqual(len(res[0]), success_count) self.assertEqual(incoming_message_class_mock.call_count, success_count) for i in range(success_count): - self.assertEqual(res[i], incoming_message_class_mock.return_value) + self.assertEqual(res[0][i], + incoming_message_class_mock.return_value) self.assertEqual( incoming_message_class_mock.call_args_list[i][0], (self._pika_engine, self._poller_channel_mock) + params[i][1:] @@ -258,20 +236,11 @@ class RpcServicePikaPollerTestCase(unittest.TestCase): "RpcPikaIncomingMessage") def test_declare_rpc_queue_bindings(self, rpc_pika_incoming_message_mock): poller = pika_poller.RpcServicePikaPoller( - self._pika_engine, self._target, self._prefetch_count, - ) - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - None, None, None, None - ) + self._pika_engine, self._target, 1, None, + self._prefetch_count ) - poller.start() - res = poller.poll() - - self.assertEqual(len(res), 1) - - self.assertEqual(res[0], rpc_pika_incoming_message_mock.return_value) + poller.start(None) self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._poller_connection_mock.channel.called) @@ -357,24 +326,14 @@ class RpcReplyServicePikaPollerTestCase(unittest.TestCase): self._pika_engine.rpc_queue_expiration = 12345 self._pika_engine.rpc_reply_retry_attempts = 3 - def test_start(self): - poller = pika_poller.RpcReplyPikaPoller( - self._pika_engine, self._exchange, self._queue, - self._prefetch_count, - ) - - poller.start() - - self.assertTrue(self._pika_engine.create_connection.called) - self.assertTrue(self._poller_connection_mock.channel.called) - def test_declare_rpc_reply_queue_binding(self): poller = pika_poller.RpcReplyPikaPoller( - self._pika_engine, self._exchange, self._queue, + self._pika_engine, self._exchange, self._queue, 1, None, self._prefetch_count, ) - poller.start() + poller.start(None) + poller.stop() declare_queue_binding_by_channel_mock = ( self._pika_engine.declare_queue_binding_by_channel @@ -419,26 +378,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase): ) self._pika_engine.notification_persistence = object() - @mock.patch("oslo_messaging._drivers.pika_driver.pika_message." - "PikaIncomingMessage") - def test_declare_notification_queue_bindings_default_queue( - self, pika_incoming_message_mock): + def test_declare_notification_queue_bindings_default_queue(self): poller = pika_poller.NotificationPikaPoller( - self._pika_engine, self._target_and_priorities, + self._pika_engine, self._target_and_priorities, 1, None, self._prefetch_count, None ) - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - None, None, None, None - ) - ) - poller.start() - res = poller.poll() - - self.assertEqual(len(res), 1) - - self.assertEqual(res[0], pika_incoming_message_mock.return_value) + poller.start(None) self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._poller_connection_mock.channel.called) @@ -481,26 +427,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase): ) )) - @mock.patch("oslo_messaging._drivers.pika_driver.pika_message." - "PikaIncomingMessage") - def test_declare_notification_queue_bindings_custom_queue( - self, pika_incoming_message_mock): + def test_declare_notification_queue_bindings_custom_queue(self): poller = pika_poller.NotificationPikaPoller( - self._pika_engine, self._target_and_priorities, + self._pika_engine, self._target_and_priorities, 1, None, self._prefetch_count, "custom_queue_name" ) - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - None, None, None, None - ) - ) - poller.start() - res = poller.poll() - - self.assertEqual(len(res), 1) - - self.assertEqual(res[0], pika_incoming_message_mock.return_value) + poller.start(None) self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._poller_connection_mock.channel.called)