diff --git a/oslo_messaging/_drivers/impl_pika.py b/oslo_messaging/_drivers/impl_pika.py index 42686350d..12590abe9 100644 --- a/oslo_messaging/_drivers/impl_pika.py +++ b/oslo_messaging/_drivers/impl_pika.py @@ -335,26 +335,18 @@ class PikaDriver(base.BaseDriver): ) def listen(self, target, batch_size, batch_timeout): - listener = pika_drv_poller.RpcServicePikaPoller( - self._pika_engine, target, - prefetch_count=self._pika_engine.rpc_listener_prefetch_count + return pika_drv_poller.RpcServicePikaPoller( + self._pika_engine, target, batch_size, batch_timeout, + self._pika_engine.rpc_listener_prefetch_count ) - listener.start() - return base.PollStyleListenerAdapter(listener, batch_size, - batch_timeout) def listen_for_notifications(self, targets_and_priorities, pool, batch_size, batch_timeout): - listener = pika_drv_poller.NotificationPikaPoller( - self._pika_engine, targets_and_priorities, - prefetch_count=( - self._pika_engine.notification_listener_prefetch_count - ), - queue_name=pool + return pika_drv_poller.NotificationPikaPoller( + self._pika_engine, targets_and_priorities, batch_size, + batch_timeout, + self._pika_engine.notification_listener_prefetch_count, pool ) - listener.start() - return base.PollStyleListenerAdapter(listener, batch_size, - batch_timeout) def cleanup(self): self._reply_listener.cleanup() diff --git a/oslo_messaging/_drivers/pika_driver/pika_connection.py b/oslo_messaging/_drivers/pika_driver/pika_connection.py new file mode 100644 index 000000000..a2b1d00f8 --- /dev/null +++ b/oslo_messaging/_drivers/pika_driver/pika_connection.py @@ -0,0 +1,497 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import logging +import os +import threading + +import futurist +import pika +from pika.adapters import select_connection +from pika import exceptions as pika_exceptions +from pika import spec as pika_spec + +from oslo_utils import eventletutils + +current_thread = eventletutils.fetch_current_thread_functor() + +LOG = logging.getLogger(__name__) + + +class ThreadSafePikaConnection(object): + def __init__(self, params=None): + self.params = params + self._connection_lock = threading.Lock() + self._evt_closed = threading.Event() + self._task_queue = collections.deque() + self._pending_connection_futures = set() + + create_connection_future = self._register_pending_future() + + def on_open_error(conn, err): + create_connection_future.set_exception( + pika_exceptions.AMQPConnectionError(err) + ) + + self._impl = pika.SelectConnection( + parameters=params, + on_open_callback=create_connection_future.set_result, + on_open_error_callback=on_open_error, + on_close_callback=self._on_connection_close, + stop_ioloop_on_close=False, + ) + self._interrupt_pipein, self._interrupt_pipeout = os.pipe() + self._impl.ioloop.add_handler(self._interrupt_pipein, + self._impl.ioloop.read_interrupt, + select_connection.READ) + + self._thread = threading.Thread(target=self._process_io) + self._thread.daemon = True + self._thread_id = None + self._thread.start() + + create_connection_future.result() + + def _execute_task(self, func, *args, **kwargs): + if current_thread() == self._thread_id: + return func(*args, **kwargs) + + future = futurist.Future() + self._task_queue.append((func, args, kwargs, future)) + + if self._evt_closed.is_set(): + self._notify_all_futures_connection_close() + elif self._interrupt_pipeout is not None: + os.write(self._interrupt_pipeout, b'X') + + return future.result() + + def _register_pending_future(self): + future = futurist.Future() + self._pending_connection_futures.add(future) + + def on_done_callback(fut): + try: + self._pending_connection_futures.remove(fut) + except KeyError: + pass + + future.add_done_callback(on_done_callback) + + if self._evt_closed.is_set(): + self._notify_all_futures_connection_close() + return future + + def _notify_all_futures_connection_close(self): + while self._task_queue: + try: + method_res_future = self._task_queue.pop()[3] + except KeyError: + break + else: + method_res_future.set_exception( + pika_exceptions.ConnectionClosed() + ) + + while self._pending_connection_futures: + try: + pending_connection_future = ( + self._pending_connection_futures.pop() + ) + except KeyError: + break + else: + pending_connection_future.set_exception( + pika_exceptions.ConnectionClosed() + ) + + def _on_connection_close(self, conn, reply_code, reply_text): + self._evt_closed.set() + self._notify_all_futures_connection_close() + if self._interrupt_pipeout: + os.close(self._interrupt_pipeout) + os.close(self._interrupt_pipein) + + def add_on_close_callback(self, callback): + return self._execute_task(self._impl.add_on_close_callback, callback) + + def _do_process_io(self): + while self._task_queue: + func, args, kwargs, future = self._task_queue.pop() + try: + res = func(*args, **kwargs) + except BaseException as e: + LOG.exception(e) + future.set_exception(e) + else: + future.set_result(res) + + self._impl.ioloop.poll() + self._impl.ioloop.process_timeouts() + + def _process_io(self): + self._thread_id = current_thread() + while not self._evt_closed.is_set(): + try: + self._do_process_io() + except BaseException: + LOG.exception("Error during processing connection's IO") + + def close(self, *args, **kwargs): + res = self._execute_task(self._impl.close, *args, **kwargs) + + self._evt_closed.wait() + self._thread.join() + return res + + def channel(self, channel_number=None): + channel_opened_future = self._register_pending_future() + + impl_channel = self._execute_task( + self._impl.channel, + on_open_callback=channel_opened_future.set_result, + channel_number=channel_number + ) + + # Create our proxy channel + channel = ThreadSafePikaChannel(impl_channel, self) + + # Link implementation channel with our proxy channel + impl_channel._set_cookie(channel) + + channel_opened_future.result() + return channel + + def add_timeout(self, timeout, callback): + return self._execute_task(self._impl.add_timeout, timeout, callback) + + def remove_timeout(self, timeout_id): + return self._execute_task(self._impl.remove_timeout, timeout_id) + + @property + def is_closed(self): + return self._impl.is_closed + + @property + def is_closing(self): + return self._impl.is_closing + + @property + def is_open(self): + return self._impl.is_open + + +class ThreadSafePikaChannel(object): # pylint: disable=R0904,R0902 + + def __init__(self, channel_impl, connection): + self._impl = channel_impl + self._connection = connection + + self._delivery_confirmation = False + + self._message_returned = False + self._current_future = None + + self._evt_closed = threading.Event() + + self.add_on_close_callback(self._on_channel_close) + + def _execute_task(self, func, *args, **kwargs): + return self._connection._execute_task(func, *args, **kwargs) + + def _on_channel_close(self, channel, reply_code, reply_text): + self._evt_closed.set() + + if self._current_future: + self._current_future.set_exception( + pika_exceptions.ChannelClosed(reply_code, reply_text)) + + def _on_message_confirmation(self, frame): + self._current_future.set_result(frame) + + def add_on_close_callback(self, callback): + self._execute_task(self._impl.add_on_close_callback, callback) + + def add_on_cancel_callback(self, callback): + self._execute_task(self._impl.add_on_cancel_callback, callback) + + def __int__(self): + return self.channel_number + + @property + def channel_number(self): + return self._impl.channel_number + + @property + def is_closed(self): + return self._impl.is_closed + + @property + def is_closing(self): + return self._impl.is_closing + + @property + def is_open(self): + return self._impl.is_open + + def close(self, reply_code=0, reply_text="Normal Shutdown"): + self._impl.close(reply_code=reply_code, reply_text=reply_text) + self._evt_closed.wait() + + def flow(self, active): + self._current_future = futurist.Future() + self._execute_task( + self._impl.flow, callback=self._current_future.set_result, + active=active + ) + return self._current_future.result() + + def basic_consume(self, # pylint: disable=R0913 + consumer_callback, + queue, + no_ack=False, + exclusive=False, + consumer_tag=None, + arguments=None): + self._current_future = futurist.Future() + self._execute_task( + self._impl.add_callback, self._current_future.set_result, + replies=[pika_spec.Basic.ConsumeOk], one_shot=True + ) + + self._impl.add_callback(self._current_future.set_result, + replies=[pika_spec.Basic.ConsumeOk], + one_shot=True) + tag = self._execute_task( + self._impl.basic_consume, + consumer_callback=consumer_callback, + queue=queue, + no_ack=no_ack, + exclusive=exclusive, + consumer_tag=consumer_tag, + arguments=arguments + ) + + self._current_future.result() + return tag + + def basic_cancel(self, consumer_tag): + self._current_future = futurist.Future() + self._execute_task( + self._impl.basic_cancel, + callback=self._current_future.set_result, + consumer_tag=consumer_tag, + nowait=False) + self._current_future.result() + + def basic_ack(self, delivery_tag=0, multiple=False): + return self._execute_task( + self._impl.basic_ack, delivery_tag=delivery_tag, multiple=multiple) + + def basic_nack(self, delivery_tag=None, multiple=False, requeue=True): + return self._execute_task( + self._impl.basic_nack, delivery_tag=delivery_tag, + multiple=multiple, requeue=requeue + ) + + def publish(self, exchange, routing_key, body, # pylint: disable=R0913 + properties=None, mandatory=False, immediate=False): + + if self._delivery_confirmation: + # In publisher-acknowledgments mode + self._message_returned = False + self._current_future = futurist.Future() + + self._execute_task(self._impl.basic_publish, + exchange=exchange, + routing_key=routing_key, + body=body, + properties=properties, + mandatory=mandatory, + immediate=immediate) + + conf_method = self._current_future.result().method + + if isinstance(conf_method, pika_spec.Basic.Nack): + raise pika_exceptions.NackError((None,)) + else: + assert isinstance(conf_method, pika_spec.Basic.Ack), ( + conf_method) + + if self._message_returned: + raise pika_exceptions.UnroutableError((None,)) + else: + # In non-publisher-acknowledgments mode + self._execute_task(self._impl.basic_publish, + exchange=exchange, + routing_key=routing_key, + body=body, + properties=properties, + mandatory=mandatory, + immediate=immediate) + + def basic_qos(self, prefetch_size=0, prefetch_count=0, all_channels=False): + self._current_future = futurist.Future() + self._execute_task(self._impl.basic_qos, + callback=self._current_future.set_result, + prefetch_size=prefetch_size, + prefetch_count=prefetch_count, + all_channels=all_channels) + self._current_future.result() + + def basic_recover(self, requeue=False): + self._current_future = futurist.Future() + self._execute_task( + self._impl.basic_recover, + callback=lambda: self._current_future.set_result(None), + requeue=requeue + ) + self._current_future.result() + + def basic_reject(self, delivery_tag=None, requeue=True): + self._execute_task(self._impl.basic_reject, + delivery_tag=delivery_tag, + requeue=requeue) + + def _on_message_returned(self, *args, **kwargs): + self._message_returned = True + + def confirm_delivery(self): + self._current_future = futurist.Future() + self._execute_task(self._impl.add_callback, + callback=self._current_future.set_result, + replies=[pika_spec.Confirm.SelectOk], + one_shot=True) + self._execute_task(self._impl.confirm_delivery, + callback=self._on_message_confirmation, + nowait=False) + self._current_future.result() + + self._delivery_confirmation = True + self._execute_task(self._impl.add_on_return_callback, + self._on_message_returned) + + def exchange_declare(self, exchange=None, # pylint: disable=R0913 + exchange_type='direct', passive=False, durable=False, + auto_delete=False, internal=False, + arguments=None, **kwargs): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_declare, + callback=self._current_future.set_result, + exchange=exchange, + exchange_type=exchange_type, + passive=passive, + durable=durable, + auto_delete=auto_delete, + internal=internal, + nowait=False, + arguments=arguments, + type=kwargs["type"] if kwargs else None) + + return self._current_future.result() + + def exchange_delete(self, exchange=None, if_unused=False): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_delete, + callback=self._current_future.set_result, + exchange=exchange, + if_unused=if_unused, + nowait=False) + + return self._current_future.result() + + def exchange_bind(self, destination=None, source=None, routing_key='', + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_bind, + callback=self._current_future.set_result, + destination=destination, + source=source, + routing_key=routing_key, + nowait=False, + arguments=arguments) + + return self._current_future.result() + + def exchange_unbind(self, destination=None, source=None, routing_key='', + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.exchange_unbind, + callback=self._current_future.set_result, + destination=destination, + source=source, + routing_key=routing_key, + nowait=False, + arguments=arguments) + + return self._current_future.result() + + def queue_declare(self, queue='', passive=False, durable=False, + exclusive=False, auto_delete=False, + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_declare, + callback=self._current_future.set_result, + queue=queue, + passive=passive, + durable=durable, + exclusive=exclusive, + auto_delete=auto_delete, + nowait=False, + arguments=arguments) + + return self._current_future.result() + + def queue_delete(self, queue='', if_unused=False, if_empty=False): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_delete, + callback=self._current_future.set_result, + queue=queue, + if_unused=if_unused, + if_empty=if_empty, + nowait=False) + + return self._current_future.result() + + def queue_purge(self, queue=''): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_purge, + callback=self._current_future.set_result, + queue=queue, + nowait=False) + return self._current_future.result() + + def queue_bind(self, queue, exchange, routing_key=None, + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_bind, + callback=self._current_future.set_result, + queue=queue, + exchange=exchange, + routing_key=routing_key, + nowait=False, + arguments=arguments) + return self._current_future.result() + + def queue_unbind(self, queue='', exchange=None, routing_key=None, + arguments=None): + self._current_future = futurist.Future() + self._execute_task(self._impl.queue_unbind, + callback=self._current_future.set_result, + queue=queue, + exchange=exchange, + routing_key=routing_key, + arguments=arguments) + return self._current_future.result() diff --git a/oslo_messaging/_drivers/pika_driver/pika_engine.py b/oslo_messaging/_drivers/pika_driver/pika_engine.py index 3b762c327..65e87fc99 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_engine.py +++ b/oslo_messaging/_drivers/pika_driver/pika_engine.py @@ -11,7 +11,7 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. - +import os import random import socket import threading @@ -19,38 +19,18 @@ import time from oslo_log import log as logging import pika -from pika.adapters import select_connection from pika import credentials as pika_credentials import pika_pool - import uuid from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns +from oslo_messaging._drivers.pika_driver import pika_connection from oslo_messaging._drivers.pika_driver import pika_exceptions as pika_drv_exc LOG = logging.getLogger(__name__) - -def _create_select_poller_connection_impl( - parameters, on_open_callback, on_open_error_callback, - on_close_callback, stop_ioloop_on_close): - """Used for disabling autochoise of poller ('select', 'poll', 'epool', etc) - inside default 'SelectConnection.__init__(...)' logic. It is necessary to - force 'select' poller usage if eventlet is monkeypatched because eventlet - patches only 'select' system call - - Method signature is copied form 'SelectConnection.__init__(...)', because - it is used as replacement of 'SelectConnection' class to create instances - """ - return select_connection.SelectConnection( - parameters=parameters, - on_open_callback=on_open_callback, - on_open_error_callback=on_open_error_callback, - on_close_callback=on_close_callback, - stop_ioloop_on_close=stop_ioloop_on_close, - custom_ioloop=select_connection.SelectPoller() - ) +_PID = None class _PooledConnectionWithConfirmations(pika_pool.Connection): @@ -84,10 +64,6 @@ class PikaEngine(object): allowed_remote_exmods=None): self.conf = conf - self._force_select_poller_use = ( - pika_drv_cmns.is_eventlet_monkey_patched('select') - ) - # processing rpc options self.default_rpc_exchange = ( conf.oslo_messaging_pika.default_rpc_exchange @@ -168,7 +144,7 @@ class PikaEngine(object): ) # initializing connection parameters for configured RabbitMQ hosts - common_pika_params = { + self._common_pika_params = { 'virtual_host': url.virtual_host, 'channel_max': self.conf.oslo_messaging_pika.channel_max, 'frame_max': self.conf.oslo_messaging_pika.frame_max, @@ -177,31 +153,18 @@ class PikaEngine(object): 'socket_timeout': self.conf.oslo_messaging_pika.socket_timeout, } - self._connection_lock = threading.Lock() + self._connection_lock = threading.RLock() + self._pid = None - self._connection_host_param_list = [] - self._connection_host_status_list = [] + self._connection_host_status = {} if not url.hosts: raise ValueError("You should provide at least one RabbitMQ host") - for transport_host in url.hosts: - pika_params = common_pika_params.copy() - pika_params.update( - host=transport_host.hostname, - port=transport_host.port, - credentials=pika_credentials.PlainCredentials( - transport_host.username, transport_host.password - ), - ) - self._connection_host_param_list.append(pika_params) - self._connection_host_status_list.append({ - self.HOST_CONNECTION_LAST_TRY_TIME: 0, - self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0 - }) + self._host_list = url.hosts - self._next_connection_host_num = random.randint( - 0, len(self._connection_host_param_list) - 1 + self._cur_connection_host_num = random.randint( + 0, len(self._host_list) - 1 ) # initializing 2 connection pools: 1st for connections without @@ -228,49 +191,37 @@ class PikaEngine(object): _PooledConnectionWithConfirmations ) - def _next_connection_num(self): - """Used for creating connections to different RabbitMQ nodes in - round robin order - - :return: next host number to create connection to - """ - with self._connection_lock: - cur_num = self._next_connection_host_num - self._next_connection_host_num += 1 - self._next_connection_host_num %= len( - self._connection_host_param_list - ) - return cur_num - def create_connection(self, for_listening=False): """Create and return connection to any available host. :return: created connection :raise: ConnectionException if all hosts are not reachable """ - host_count = len(self._connection_host_param_list) - connection_attempts = host_count - pika_next_connection_num = self._next_connection_num() + with self._connection_lock: + self._init_if_needed() - while connection_attempts > 0: - try: - return self.create_host_connection( - pika_next_connection_num, for_listening - ) - except pika_pool.Connection.connectivity_errors as e: - LOG.warning("Can't establish connection to host. %s", e) - except pika_drv_exc.HostConnectionNotAllowedException as e: - LOG.warning("Connection to host is not allowed. %s", e) + host_count = len(self._host_list) + connection_attempts = host_count - connection_attempts -= 1 - pika_next_connection_num += 1 - pika_next_connection_num %= host_count + while connection_attempts > 0: + self._cur_connection_host_num += 1 + self._cur_connection_host_num %= host_count + try: + return self.create_host_connection( + self._cur_connection_host_num, for_listening + ) + except pika_pool.Connection.connectivity_errors as e: + LOG.warning("Can't establish connection to host. %s", e) + except pika_drv_exc.HostConnectionNotAllowedException as e: + LOG.warning("Connection to host is not allowed. %s", e) - raise pika_drv_exc.EstablishConnectionException( - "Can not establish connection to any configured RabbitMQ host: " + - str(self._connection_host_param_list) - ) + connection_attempts -= 1 + + raise pika_drv_exc.EstablishConnectionException( + "Can not establish connection to any configured RabbitMQ " + "host: " + str(self._host_list) + ) def _set_tcp_user_timeout(self, s): if not self._tcp_user_timeout: @@ -285,28 +236,64 @@ class PikaEngine(object): "Whoops, this kernel doesn't seem to support TCP_USER_TIMEOUT." ) + def _init_if_needed(self): + global _PID + + cur_pid = os.getpid() + + if _PID != cur_pid: + if _PID: + LOG.warning("New pid is detected. Old: %s, new: %s. " + "Cleaning up...", _PID, cur_pid) + # Note(dukhlov): we need to force select poller usage in case when + # 'thread' module is monkey patched becase current eventlet + # implementation does not support patching of poll/epoll/kqueue + if pika_drv_cmns.is_eventlet_monkey_patched("thread"): + from pika.adapters import select_connection + select_connection.SELECT_TYPE = "select" + + _PID = cur_pid + def create_host_connection(self, host_index, for_listening=False): """Create new connection to host #host_index + :param host_index: Integer, number of host for connection establishing :param for_listening: Boolean, creates connection for listening - (enable heartbeats) if True + if True :return: New connection """ - - connection_params = pika.ConnectionParameters( - heartbeat_interval=( - self._heartbeat_interval if for_listening else None - ), - **self._connection_host_param_list[host_index] - ) - with self._connection_lock: + self._init_if_needed() + + host = self._host_list[host_index] + + connection_params = pika.ConnectionParameters( + host=host.hostname, + port=host.port, + credentials=pika_credentials.PlainCredentials( + host.username, host.password + ), + heartbeat_interval=( + self._heartbeat_interval if for_listening else None + ), + **self._common_pika_params + ) + cur_time = time.time() - last_success_time = self._connection_host_status_list[host_index][ + host_connection_status = self._connection_host_status.get(host) + + if host_connection_status is None: + host_connection_status = { + self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0, + self.HOST_CONNECTION_LAST_TRY_TIME: 0 + } + self._connection_host_status[host] = host_connection_status + + last_success_time = host_connection_status[ self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME ] - last_time = self._connection_host_status_list[host_index][ + last_time = host_connection_status[ self.HOST_CONNECTION_LAST_TRY_TIME ] @@ -322,27 +309,25 @@ class PikaEngine(object): ) try: - connection = pika.BlockingConnection( - parameters=connection_params, - _impl_class=(_create_select_poller_connection_impl - if self._force_select_poller_use else None) - ) - - # It is needed for pika_pool library which expects that - # connections has params attribute defined in BaseConnection - # but BlockingConnection is not derived from BaseConnection - # and doesn't have it - connection.params = connection_params + if for_listening: + connection = pika_connection.ThreadSafePikaConnection( + params=connection_params + ) + else: + connection = pika.BlockingConnection( + parameters=connection_params + ) + connection.params = connection_params self._set_tcp_user_timeout(connection._impl.socket) - self._connection_host_status_list[host_index][ + self._connection_host_status[host][ self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME ] = cur_time return connection finally: - self._connection_host_status_list[host_index][ + self._connection_host_status[host][ self.HOST_CONNECTION_LAST_TRY_TIME ] = cur_time diff --git a/oslo_messaging/_drivers/pika_driver/pika_listener.py b/oslo_messaging/_drivers/pika_driver/pika_listener.py index 1739c7932..1942fcc1c 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_listener.py +++ b/oslo_messaging/_drivers/pika_driver/pika_listener.py @@ -29,6 +29,7 @@ class RpcReplyPikaListener(object): """ def __init__(self, pika_engine): + super(RpcReplyPikaListener, self).__init__() self._pika_engine = pika_engine # preparing poller for listening replies @@ -39,7 +40,6 @@ class RpcReplyPikaListener(object): self._reply_consumer_initialized = False self._reply_consumer_initialization_lock = threading.Lock() - self._poller_thread = None self._shutdown = False def get_reply_qname(self): @@ -66,51 +66,31 @@ class RpcReplyPikaListener(object): # initialize reply poller if needed if self._reply_poller is None: self._reply_poller = pika_drv_poller.RpcReplyPikaPoller( - pika_engine=self._pika_engine, - exchange=self._pika_engine.rpc_reply_exchange, - queue=self._reply_queue, - prefetch_count=( - self._pika_engine.rpc_reply_listener_prefetch_count - ) + self._pika_engine, self._pika_engine.rpc_reply_exchange, + self._reply_queue, 1, None, + self._pika_engine.rpc_reply_listener_prefetch_count ) - self._reply_poller.start() - - # start reply poller job thread if needed - if self._poller_thread is None: - self._poller_thread = threading.Thread(target=self._poller) - self._poller_thread.daemon = True - - if not self._poller_thread.is_alive(): - self._poller_thread.start() - + self._reply_poller.start(self._on_incoming) self._reply_consumer_initialized = True return self._reply_queue - def _poller(self): + def _on_incoming(self, incoming): """Reply polling job. Poll replies in infinite loop and notify registered features """ - while True: + for message in incoming: try: - messages = self._reply_poller.poll() - if not messages and self._shutdown: - break - - for message in messages: - try: - message.acknowledge() - future = self._reply_waiting_futures.pop( - message.msg_id, None - ) - if future is not None: - future.set_result(message) - except Exception: - LOG.exception("Unexpected exception during processing" - "reply message") - except BaseException: - LOG.exception("Unexpected exception during reply polling") + message.acknowledge() + future = self._reply_waiting_futures.pop( + message.msg_id, None + ) + if future is not None: + future.set_result(message) + except Exception: + LOG.exception("Unexpected exception during processing" + "reply message") def register_reply_waiter(self, msg_id): """Register reply waiter. Should be called before message sending to @@ -140,9 +120,4 @@ class RpcReplyPikaListener(object): self._reply_poller.cleanup() self._reply_poller = None - if self._poller_thread: - if self._poller_thread.is_alive(): - self._poller_thread.join() - self._poller_thread = None - self._reply_queue = None diff --git a/oslo_messaging/_drivers/pika_driver/pika_poller.py b/oslo_messaging/_drivers/pika_driver/pika_poller.py index 35bf41191..d34ddfe77 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_poller.py +++ b/oslo_messaging/_drivers/pika_driver/pika_poller.py @@ -13,11 +13,9 @@ # under the License. import threading -import time from oslo_log import log as logging -from oslo_utils import timeutils -import six +from oslo_service import loopingcall from oslo_messaging._drivers import base from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns @@ -27,45 +25,188 @@ from oslo_messaging._drivers.pika_driver import pika_message as pika_drv_msg LOG = logging.getLogger(__name__) -class PikaPoller(base.PollStyleListener): +class PikaPoller(base.Listener): """Provides user friendly functionality for RabbitMQ message consuming, handles low level connectivity problems and restore connection if some connectivity related problem detected """ - def __init__(self, pika_engine, prefetch_count, incoming_message_class): + def __init__(self, pika_engine, batch_size, batch_timeout, prefetch_count, + incoming_message_class): """Initialize required fields :param pika_engine: PikaEngine, shared object with configuration and shared driver functionality + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer :param incoming_message_class: PikaIncomingMessage, wrapper for consumed RabbitMQ message """ - super(PikaPoller, self).__init__(prefetch_count) + super(PikaPoller, self).__init__(batch_size, batch_timeout, + prefetch_count) self._pika_engine = pika_engine self._incoming_message_class = incoming_message_class self._connection = None self._channel = None - self._lock = threading.Lock() + self._recover_loopingcall = None + self._lock = threading.RLock() + + self._cur_batch_buffer = None + self._cur_batch_timeout_id = None self._started = False + self._closing_connection_by_poller = False self._queues_to_consume = None - self._message_queue = [] + def _on_connection_close(self, connection, reply_code, reply_text): + self._deliver_cur_batch() + if self._closing_connection_by_poller: + return + with self._lock: + self._connection = None + self._start_recover_consuming_task() - def _reconnect(self): + def _on_channel_close(self, channel, reply_code, reply_text): + if self._cur_batch_buffer: + self._cur_batch_buffer = [ + message for message in self._cur_batch_buffer + if not message.need_ack() + ] + if self._closing_connection_by_poller: + return + with self._lock: + self._channel = None + self._start_recover_consuming_task() + + def _on_consumer_cancel(self, method_frame): + with self._lock: + if self._queues_to_consume: + consumer_tag = method_frame.method.consumer_tag + for queue_info in self._queues_to_consume: + if queue_info["consumer_tag"] == consumer_tag: + queue_info["consumer_tag"] = None + + self._start_recover_consuming_task() + + def _on_message_no_ack_callback(self, unused, method, properties, body): + """Is called by Pika when message was received from queue listened with + no_ack=True mode + """ + incoming_message = self._incoming_message_class( + self._pika_engine, None, method, properties, body + ) + self._on_incoming_message(incoming_message) + + def _on_message_with_ack_callback(self, unused, method, properties, body): + """Is called by Pika when message was received from queue listened with + no_ack=False mode + """ + incoming_message = self._incoming_message_class( + self._pika_engine, self._channel, method, properties, body + ) + self._on_incoming_message(incoming_message) + + def _deliver_cur_batch(self): + if self._cur_batch_timeout_id is not None: + self._connection.remove_timeout(self._cur_batch_timeout_id) + self._cur_batch_timeout_id = None + if self._cur_batch_buffer: + buf_to_send = self._cur_batch_buffer + self._cur_batch_buffer = None + try: + self.on_incoming_callback(buf_to_send) + except Exception: + LOG.exception("Unexpected exception during incoming delivery") + + def _on_incoming_message(self, incoming_message): + if self._cur_batch_buffer is None: + self._cur_batch_buffer = [incoming_message] + else: + self._cur_batch_buffer.append(incoming_message) + + if len(self._cur_batch_buffer) >= self.batch_size: + self._deliver_cur_batch() + return + + if self._cur_batch_timeout_id is None: + self._cur_batch_timeout_id = self._connection.add_timeout( + self.batch_timeout, self._deliver_cur_batch) + + def _start_recover_consuming_task(self): + """Start async job for checking connection to the broker.""" + if self._recover_loopingcall is None and self._started: + self._recover_loopingcall = ( + loopingcall.DynamicLoopingCall( + self._try_recover_consuming + ) + ) + LOG.info("Starting recover consuming job for listener: %s", self) + self._recover_loopingcall.start() + + def _try_recover_consuming(self): + with self._lock: + try: + if self._started: + self._start_or_recover_consuming() + except pika_drv_exc.EstablishConnectionException as e: + LOG.warning( + "Problem during establishing connection for pika " + "poller %s", e, exc_info=True + ) + return self._pika_engine.host_connection_reconnect_delay + except pika_drv_exc.ConnectionException as e: + LOG.warning( + "Connectivity exception during starting/recovering pika " + "poller %s", e, exc_info=True + ) + except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as e: + LOG.warning( + "Connectivity exception during starting/recovering pika " + "poller %s", e, exc_info=True + ) + except BaseException: + # NOTE (dukhlov): I preffer to use here BaseException because + # if this method raise such exception LoopingCall stops + # execution Probably it should never happen and Exception + # should be enough but in case of programmer mistake it could + # be and it is potentially hard to catch problem if we will + # stop background task. It is better when it continue to work + # and write a lot of LOG with this error + LOG.exception("Unexpected exception during " + "starting/recovering pika poller") + else: + self._recover_loopingcall = None + LOG.info("Recover consuming job was finished for listener: %s", + self) + raise loopingcall.LoopingCallDone(True) + return 0 + + def _start_or_recover_consuming(self): """Performs reconnection to the broker. It is unsafe method for internal use only """ - self._connection = self._pika_engine.create_connection( - for_listening=True - ) - self._channel = self._connection.channel() - self._channel.basic_qos(prefetch_count=self.prefetch_size) + if self._connection is None or not self._connection.is_open: + self._connection = self._pika_engine.create_connection( + for_listening=True + ) + self._connection.add_on_close_callback(self._on_connection_close) + self._channel = None + + if self._channel is None or not self._channel.is_open: + if self._queues_to_consume: + for queue_info in self._queues_to_consume: + queue_info["consumer_tag"] = None + + self._channel = self._connection.channel() + self._channel.add_on_close_callback(self._on_channel_close) + self._channel.add_on_cancel_callback(self._on_consumer_cancel) + self._channel.basic_qos(prefetch_count=self.prefetch_size) if self._queues_to_consume is None: self._queues_to_consume = self._declare_queue_binding() @@ -92,6 +233,8 @@ class PikaPoller(base.PollStyleListener): try: for queue_info in self._queues_to_consume: + if queue_info["consumer_tag"] is not None: + continue no_ack = queue_info["no_ack"] on_message_callback = ( @@ -120,168 +263,95 @@ class PikaPoller(base.PollStyleListener): self._channel.basic_cancel(consumer_tag) queue_info["consumer_tag"] = None - def _on_message_no_ack_callback(self, unused, method, properties, body): - """Is called by Pika when message was received from queue listened with - no_ack=True mode - """ - self._message_queue.append( - self._incoming_message_class( - self._pika_engine, None, method, properties, body - ) - ) - - def _on_message_with_ack_callback(self, unused, method, properties, body): - """Is called by Pika when message was received from queue listened with - no_ack=False mode - """ - self._message_queue.append( - self._incoming_message_class( - self._pika_engine, self._channel, method, properties, body - ) - ) - - def _cleanup(self): - """Cleanup allocated resources (channel, connection, etc). It is unsafe - method for internal use only - """ - if self._connection: - try: - self._connection.close() - except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS: - # expected errors - pass - except Exception: - LOG.exception("Unexpected error during closing connection") - finally: - self._channel = None - self._connection = None - - for i in six.moves.range(len(self._message_queue) - 1, -1, -1): - message = self._message_queue[i] - if message.need_ack(): - del self._message_queue[i] - - @base.batch_poll_helper - def poll(self, timeout=None): - """Main method of this class - consumes message from RabbitMQ - - :param: timeout: float, seconds, timeout for waiting new incoming - message, None means wait forever - :return: list of PikaIncomingMessage, RabbitMQ messages - """ - - with timeutils.StopWatch(timeout) as stop_watch: - while True: - with self._lock: - if self._message_queue: - return self._message_queue.pop(0) - - if stop_watch.expired(): - return None - - try: - if self._started: - if self._channel is None: - self._reconnect() - # we need some time_limit here, not too small to - # avoid a lot of not needed iterations but not too - # large to release lock time to time and give a - # chance to perform another method waiting this - # lock - self._connection.process_data_events( - time_limit=0.25 - ) - else: - # consumer is stopped so we don't expect new - # messages, just process already sent events - if self._channel is not None: - self._connection.process_data_events( - time_limit=0 - ) - - # and return if we don't see new messages - if not self._message_queue: - return None - except pika_drv_exc.EstablishConnectionException as e: - LOG.warning( - "Problem during establishing connection for pika " - "poller %s", e, exc_info=True - ) - time.sleep( - self._pika_engine.host_connection_reconnect_delay - ) - except pika_drv_exc.ConnectionException: - self._cleanup() - raise - except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS: - self._cleanup() - raise - - def start(self): + def start(self, on_incoming_callback): """Starts poller. Should be called before polling to allow message consuming + + :param on_incoming_callback: callback function to be executed when + listener received messages. Messages should be processed and + acked/nacked by callback """ + super(PikaPoller, self).start(on_incoming_callback) + with self._lock: if self._started: return - + connected = False try: - self._reconnect() + self._start_or_recover_consuming() except pika_drv_exc.EstablishConnectionException as exc: LOG.warning( "Can not establish connection during pika poller's " - "start(). Connecting is required during first poll() " - "call. %s", exc, exc_info=True + "start(). %s", exc, exc_info=True ) except pika_drv_exc.ConnectionException as exc: - self._cleanup() LOG.warning( - "Connectivity problem during pika poller's start(). " - "Reconnecting required during first poll() call. %s", + "Connectivity problem during pika poller's start(). %s", exc, exc_info=True ) except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc: - self._cleanup() LOG.warning( - "Connectivity problem during pika poller's start(). " - "Reconnecting required during first poll() call. %s", + "Connectivity problem during pika poller's start(). %s", exc, exc_info=True ) + else: + connected = True + self._started = True + if not connected: + self._start_recover_consuming_task() def stop(self): """Stops poller. Should be called when polling is not needed anymore to stop new message consuming. After that it is necessary to poll already prefetched messages """ + super(PikaPoller, self).stop() + with self._lock: if not self._started: return - if self._queues_to_consume and self._channel: + if self._recover_loopingcall is not None: + self._recover_loopingcall.stop() + self._recover_loopingcall = None + + if (self._queues_to_consume and self._channel and + self._channel.is_open): try: self._stop_consuming() except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc: - self._cleanup() LOG.warning( "Connectivity problem detected during consumer " "cancellation. %s", exc, exc_info=True ) + self._deliver_cur_batch() self._started = False def cleanup(self): - """Safe version of _cleanup. Cleans up allocated resources (channel, - connection, etc). - """ + """Cleanup allocated resources (channel, connection, etc).""" with self._lock: - self._cleanup() + if self._connection and self._connection.is_open: + try: + self._closing_connection_by_poller = True + self._connection.close() + self._closing_connection_by_poller = False + except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS: + # expected errors + pass + except Exception: + LOG.exception("Unexpected error during closing connection") + finally: + self._channel = None + self._connection = None class RpcServicePikaPoller(PikaPoller): """PikaPoller implementation for polling RPC messages. Overrides base functionality according to RPC specific """ - def __init__(self, pika_engine, target, prefetch_count): + def __init__(self, pika_engine, target, batch_size, batch_timeout, + prefetch_count): """Adds target parameter for declaring RPC specific exchanges and queues @@ -289,14 +359,18 @@ class RpcServicePikaPoller(PikaPoller): shared driver functionality :param target: Target, oslo.messaging Target object which defines RPC endpoint + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer """ self._target = target super(RpcServicePikaPoller, self).__init__( - pika_engine, prefetch_count=prefetch_count, - incoming_message_class=pika_drv_msg.RpcPikaIncomingMessage + pika_engine, batch_size, batch_timeout, prefetch_count, + pika_drv_msg.RpcPikaIncomingMessage ) def _declare_queue_binding(self): @@ -363,7 +437,8 @@ class RpcReplyPikaPoller(PikaPoller): """PikaPoller implementation for polling RPC reply messages. Overrides base functionality according to RPC reply specific """ - def __init__(self, pika_engine, exchange, queue, prefetch_count): + def __init__(self, pika_engine, exchange, queue, batch_size, batch_timeout, + prefetch_count): """Adds exchange and queue parameter for declaring exchange and queue used for RPC reply delivery @@ -371,6 +446,10 @@ class RpcReplyPikaPoller(PikaPoller): shared driver functionality :param exchange: String, exchange name used for RPC reply delivery :param queue: String, queue name used for RPC reply delivery + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer """ @@ -378,8 +457,8 @@ class RpcReplyPikaPoller(PikaPoller): self._queue = queue super(RpcReplyPikaPoller, self).__init__( - pika_engine=pika_engine, prefetch_count=prefetch_count, - incoming_message_class=pika_drv_msg.RpcReplyPikaIncomingMessage + pika_engine, batch_size, batch_timeout, prefetch_count, + pika_drv_msg.RpcReplyPikaIncomingMessage ) def _declare_queue_binding(self): @@ -404,8 +483,8 @@ class NotificationPikaPoller(PikaPoller): """PikaPoller implementation for polling Notification messages. Overrides base functionality according to Notification specific """ - def __init__(self, pika_engine, targets_and_priorities, prefetch_count, - queue_name=None): + def __init__(self, pika_engine, targets_and_priorities, + batch_size, batch_timeout, prefetch_count, queue_name=None): """Adds targets_and_priorities and queue_name parameter for declaring exchanges and queues used for notification delivery @@ -413,6 +492,10 @@ class NotificationPikaPoller(PikaPoller): shared driver functionality :param targets_and_priorities: list of (target, priority), defines default queue names for corresponding notification types + :param batch_size: desired number of messages passed to + single on_incoming_callback call + :param batch_timeout: defines how long should we wait for batch_size + messages if we already have some messages waiting for processing :param prefetch_count: Integer, maximum count of unacknowledged messages which RabbitMQ broker sends to this consumer :param queue: String, alternative queue name used for this poller @@ -422,8 +505,8 @@ class NotificationPikaPoller(PikaPoller): self._queue_name = queue_name super(NotificationPikaPoller, self).__init__( - pika_engine, prefetch_count=prefetch_count, - incoming_message_class=pika_drv_msg.PikaIncomingMessage + pika_engine, batch_size, batch_timeout, prefetch_count, + pika_drv_msg.PikaIncomingMessage ) def _declare_queue_binding(self): @@ -442,7 +525,7 @@ class NotificationPikaPoller(PikaPoller): target.exchange or self._pika_engine.default_notification_exchange ), - queue = queue, + queue=queue, routing_key=routing_key, exchange_type='direct', queue_expiration=None, diff --git a/oslo_messaging/tests/drivers/pika/test_poller.py b/oslo_messaging/tests/drivers/pika/test_poller.py index 65492b4af..cfa69e720 100644 --- a/oslo_messaging/tests/drivers/pika/test_poller.py +++ b/oslo_messaging/tests/drivers/pika/test_poller.py @@ -13,9 +13,11 @@ # under the License. import socket +import threading import time import unittest +from concurrent import futures import mock from oslo_messaging._drivers.pika_driver import pika_poller @@ -32,26 +34,53 @@ class PikaPollerTestCase(unittest.TestCase): self._pika_engine.create_connection.return_value = ( self._poller_connection_mock ) + + self._executor = futures.ThreadPoolExecutor(1) + + def timer_task(timeout, callback): + time.sleep(timeout) + callback() + + self._poller_connection_mock.add_timeout.side_effect = ( + lambda *args: self._executor.submit(timer_task, *args) + ) + self._prefetch_count = 123 - def test_start_when_connection_unavailable(self): - incoming_message_class_mock = mock.Mock() + @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." + "_declare_queue_binding") + def test_start(self, declare_queue_binding_mock): poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, - incoming_message_class=incoming_message_class_mock + self._pika_engine, 1, None, self._prefetch_count, None + ) + + poller.start(None) + + self.assertTrue(self._pika_engine.create_connection.called) + self.assertTrue(self._poller_connection_mock.channel.called) + self.assertTrue(declare_queue_binding_mock.called) + + def test_start_when_connection_unavailable(self): + poller = pika_poller.PikaPoller( + self._pika_engine, 1, None, self._prefetch_count, None ) self._pika_engine.create_connection.side_effect = socket.timeout() # start() should not raise socket.timeout exception - poller.start() + poller.start(None) @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." "_declare_queue_binding") - def test_poll(self, declare_queue_binding_mock): + def test_message_processing(self, declare_queue_binding_mock): + res = [] + + def on_incoming_callback(incoming): + res.append(incoming) + incoming_message_class_mock = mock.Mock() poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, + self._pika_engine, 1, None, self._prefetch_count, incoming_message_class=incoming_message_class_mock ) unused = object() @@ -59,18 +88,14 @@ class PikaPollerTestCase(unittest.TestCase): properties = object() body = object() - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - unused, method, properties, body - ) + poller.start(on_incoming_callback) + poller._on_message_with_ack_callback( + unused, method, properties, body ) - poller.start() - res = poller.poll() - self.assertEqual(len(res), 1) - self.assertEqual(res[0], incoming_message_class_mock.return_value) + self.assertEqual(res[0], [incoming_message_class_mock.return_value]) incoming_message_class_mock.assert_called_once_with( self._pika_engine, self._poller_channel_mock, method, properties, body @@ -83,92 +108,39 @@ class PikaPollerTestCase(unittest.TestCase): @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." "_declare_queue_binding") - def test_poll_after_stop(self, declare_queue_binding_mock): + def test_message_processing_batch(self, declare_queue_binding_mock): incoming_message_class_mock = mock.Mock() - poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, - incoming_message_class=incoming_message_class_mock - ) n = 10 params = [] - for i in range(n): - params.append((object(), object(), object(), object())) + res = [] - def f(time_limit): - if poller._started: - for k in range(n): - poller._on_message_no_ack_callback( - *params[k] - ) + def on_incoming_callback(incoming): + res.append(incoming) - self._poller_connection_mock.process_data_events.side_effect = f - - poller.start() - res = poller.poll(batch_size=1) - self.assertEqual(len(res), 1) - self.assertEqual(res[0], incoming_message_class_mock.return_value) - self.assertEqual( - incoming_message_class_mock.call_args_list[0][0], - (self._pika_engine, None) + params[0][1:] - ) - - poller.stop() - - res2 = poller.poll(batch_size=n) - - self.assertEqual(len(res2), n - 1) - self.assertEqual(incoming_message_class_mock.call_count, n) - - self.assertEqual( - self._poller_connection_mock.process_data_events.call_count, 2) - - for i in range(n - 1): - self.assertEqual(res2[i], incoming_message_class_mock.return_value) - self.assertEqual( - incoming_message_class_mock.call_args_list[i + 1][0], - (self._pika_engine, None) + params[i + 1][1:] - ) - - self.assertTrue(self._pika_engine.create_connection.called) - self.assertTrue(self._poller_connection_mock.channel.called) - - self.assertTrue(declare_queue_binding_mock.called) - - @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." - "_declare_queue_binding") - def test_poll_batch(self, declare_queue_binding_mock): - incoming_message_class_mock = mock.Mock() poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, + self._pika_engine, n, None, self._prefetch_count, incoming_message_class=incoming_message_class_mock ) - n = 10 - params = [] - for i in range(n): params.append((object(), object(), object(), object())) - index = [0] + poller.start(on_incoming_callback) - def f(time_limit): + for i in range(n): poller._on_message_with_ack_callback( - *params[index[0]] + *params[i] ) - index[0] += 1 - self._poller_connection_mock.process_data_events.side_effect = f - - poller.start() - res = poller.poll(batch_size=n) - - self.assertEqual(len(res), n) + self.assertEqual(len(res), 1) + self.assertEqual(len(res[0]), 10) self.assertEqual(incoming_message_class_mock.call_count, n) for i in range(n): - self.assertEqual(res[i], incoming_message_class_mock.return_value) + self.assertEqual(res[0][i], + incoming_message_class_mock.return_value) self.assertEqual( incoming_message_class_mock.call_args_list[i][0], (self._pika_engine, self._poller_channel_mock) + params[i][1:] @@ -181,42 +153,48 @@ class PikaPollerTestCase(unittest.TestCase): @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." "_declare_queue_binding") - def test_poll_batch_with_timeout(self, declare_queue_binding_mock): + def test_message_processing_batch_with_timeout(self, + declare_queue_binding_mock): incoming_message_class_mock = mock.Mock() - poller = pika_poller.PikaPoller( - self._pika_engine, self._prefetch_count, - incoming_message_class=incoming_message_class_mock - ) n = 10 timeout = 1 - sleep_time = 0.2 + + res = [] + evt = threading.Event() + + def on_incoming_callback(incoming): + res.append(incoming) + evt.set() + + poller = pika_poller.PikaPoller( + self._pika_engine, n, timeout, self._prefetch_count, + incoming_message_class=incoming_message_class_mock + ) + params = [] success_count = 5 + poller.start(on_incoming_callback) + for i in range(n): params.append((object(), object(), object(), object())) - index = [0] - - def f(time_limit): - time.sleep(sleep_time) + for i in range(success_count): poller._on_message_with_ack_callback( - *params[index[0]] + *params[i] ) - index[0] += 1 - self._poller_connection_mock.process_data_events.side_effect = f + self.assertTrue(evt.wait(timeout * 2)) - poller.start() - res = poller.poll(batch_size=n, timeout=timeout) - - self.assertEqual(len(res), success_count) + self.assertEqual(len(res), 1) + self.assertEqual(len(res[0]), success_count) self.assertEqual(incoming_message_class_mock.call_count, success_count) for i in range(success_count): - self.assertEqual(res[i], incoming_message_class_mock.return_value) + self.assertEqual(res[0][i], + incoming_message_class_mock.return_value) self.assertEqual( incoming_message_class_mock.call_args_list[i][0], (self._pika_engine, self._poller_channel_mock) + params[i][1:] @@ -258,20 +236,11 @@ class RpcServicePikaPollerTestCase(unittest.TestCase): "RpcPikaIncomingMessage") def test_declare_rpc_queue_bindings(self, rpc_pika_incoming_message_mock): poller = pika_poller.RpcServicePikaPoller( - self._pika_engine, self._target, self._prefetch_count, - ) - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - None, None, None, None - ) + self._pika_engine, self._target, 1, None, + self._prefetch_count ) - poller.start() - res = poller.poll() - - self.assertEqual(len(res), 1) - - self.assertEqual(res[0], rpc_pika_incoming_message_mock.return_value) + poller.start(None) self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._poller_connection_mock.channel.called) @@ -357,24 +326,14 @@ class RpcReplyServicePikaPollerTestCase(unittest.TestCase): self._pika_engine.rpc_queue_expiration = 12345 self._pika_engine.rpc_reply_retry_attempts = 3 - def test_start(self): - poller = pika_poller.RpcReplyPikaPoller( - self._pika_engine, self._exchange, self._queue, - self._prefetch_count, - ) - - poller.start() - - self.assertTrue(self._pika_engine.create_connection.called) - self.assertTrue(self._poller_connection_mock.channel.called) - def test_declare_rpc_reply_queue_binding(self): poller = pika_poller.RpcReplyPikaPoller( - self._pika_engine, self._exchange, self._queue, + self._pika_engine, self._exchange, self._queue, 1, None, self._prefetch_count, ) - poller.start() + poller.start(None) + poller.stop() declare_queue_binding_by_channel_mock = ( self._pika_engine.declare_queue_binding_by_channel @@ -419,26 +378,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase): ) self._pika_engine.notification_persistence = object() - @mock.patch("oslo_messaging._drivers.pika_driver.pika_message." - "PikaIncomingMessage") - def test_declare_notification_queue_bindings_default_queue( - self, pika_incoming_message_mock): + def test_declare_notification_queue_bindings_default_queue(self): poller = pika_poller.NotificationPikaPoller( - self._pika_engine, self._target_and_priorities, + self._pika_engine, self._target_and_priorities, 1, None, self._prefetch_count, None ) - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - None, None, None, None - ) - ) - poller.start() - res = poller.poll() - - self.assertEqual(len(res), 1) - - self.assertEqual(res[0], pika_incoming_message_mock.return_value) + poller.start(None) self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._poller_connection_mock.channel.called) @@ -481,26 +427,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase): ) )) - @mock.patch("oslo_messaging._drivers.pika_driver.pika_message." - "PikaIncomingMessage") - def test_declare_notification_queue_bindings_custom_queue( - self, pika_incoming_message_mock): + def test_declare_notification_queue_bindings_custom_queue(self): poller = pika_poller.NotificationPikaPoller( - self._pika_engine, self._target_and_priorities, + self._pika_engine, self._target_and_priorities, 1, None, self._prefetch_count, "custom_queue_name" ) - self._poller_connection_mock.process_data_events.side_effect = ( - lambda time_limit: poller._on_message_with_ack_callback( - None, None, None, None - ) - ) - poller.start() - res = poller.poll() - - self.assertEqual(len(res), 1) - - self.assertEqual(res[0], pika_incoming_message_mock.return_value) + poller.start(None) self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._poller_connection_mock.channel.called)