Merge "Implements pika thread safe connection"

This commit is contained in:
Jenkins 2016-04-27 20:00:25 +00:00 committed by Gerrit Code Review
commit bedd400b6a
6 changed files with 919 additions and 454 deletions

View File

@ -335,26 +335,18 @@ class PikaDriver(base.BaseDriver):
) )
def listen(self, target, batch_size, batch_timeout): def listen(self, target, batch_size, batch_timeout):
listener = pika_drv_poller.RpcServicePikaPoller( return pika_drv_poller.RpcServicePikaPoller(
self._pika_engine, target, self._pika_engine, target, batch_size, batch_timeout,
prefetch_count=self._pika_engine.rpc_listener_prefetch_count self._pika_engine.rpc_listener_prefetch_count
) )
listener.start()
return base.PollStyleListenerAdapter(listener, batch_size,
batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool, def listen_for_notifications(self, targets_and_priorities, pool,
batch_size, batch_timeout): batch_size, batch_timeout):
listener = pika_drv_poller.NotificationPikaPoller( return pika_drv_poller.NotificationPikaPoller(
self._pika_engine, targets_and_priorities, self._pika_engine, targets_and_priorities, batch_size,
prefetch_count=( batch_timeout,
self._pika_engine.notification_listener_prefetch_count self._pika_engine.notification_listener_prefetch_count, pool
),
queue_name=pool
) )
listener.start()
return base.PollStyleListenerAdapter(listener, batch_size,
batch_timeout)
def cleanup(self): def cleanup(self):
self._reply_listener.cleanup() self._reply_listener.cleanup()

View File

@ -0,0 +1,497 @@
# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import logging
import os
import threading
import futurist
import pika
from pika.adapters import select_connection
from pika import exceptions as pika_exceptions
from pika import spec as pika_spec
from oslo_utils import eventletutils
current_thread = eventletutils.fetch_current_thread_functor()
LOG = logging.getLogger(__name__)
class ThreadSafePikaConnection(object):
def __init__(self, params=None):
self.params = params
self._connection_lock = threading.Lock()
self._evt_closed = threading.Event()
self._task_queue = collections.deque()
self._pending_connection_futures = set()
create_connection_future = self._register_pending_future()
def on_open_error(conn, err):
create_connection_future.set_exception(
pika_exceptions.AMQPConnectionError(err)
)
self._impl = pika.SelectConnection(
parameters=params,
on_open_callback=create_connection_future.set_result,
on_open_error_callback=on_open_error,
on_close_callback=self._on_connection_close,
stop_ioloop_on_close=False,
)
self._interrupt_pipein, self._interrupt_pipeout = os.pipe()
self._impl.ioloop.add_handler(self._interrupt_pipein,
self._impl.ioloop.read_interrupt,
select_connection.READ)
self._thread = threading.Thread(target=self._process_io)
self._thread.daemon = True
self._thread_id = None
self._thread.start()
create_connection_future.result()
def _execute_task(self, func, *args, **kwargs):
if current_thread() == self._thread_id:
return func(*args, **kwargs)
future = futurist.Future()
self._task_queue.append((func, args, kwargs, future))
if self._evt_closed.is_set():
self._notify_all_futures_connection_close()
elif self._interrupt_pipeout is not None:
os.write(self._interrupt_pipeout, b'X')
return future.result()
def _register_pending_future(self):
future = futurist.Future()
self._pending_connection_futures.add(future)
def on_done_callback(fut):
try:
self._pending_connection_futures.remove(fut)
except KeyError:
pass
future.add_done_callback(on_done_callback)
if self._evt_closed.is_set():
self._notify_all_futures_connection_close()
return future
def _notify_all_futures_connection_close(self):
while self._task_queue:
try:
method_res_future = self._task_queue.pop()[3]
except KeyError:
break
else:
method_res_future.set_exception(
pika_exceptions.ConnectionClosed()
)
while self._pending_connection_futures:
try:
pending_connection_future = (
self._pending_connection_futures.pop()
)
except KeyError:
break
else:
pending_connection_future.set_exception(
pika_exceptions.ConnectionClosed()
)
def _on_connection_close(self, conn, reply_code, reply_text):
self._evt_closed.set()
self._notify_all_futures_connection_close()
if self._interrupt_pipeout:
os.close(self._interrupt_pipeout)
os.close(self._interrupt_pipein)
def add_on_close_callback(self, callback):
return self._execute_task(self._impl.add_on_close_callback, callback)
def _do_process_io(self):
while self._task_queue:
func, args, kwargs, future = self._task_queue.pop()
try:
res = func(*args, **kwargs)
except BaseException as e:
LOG.exception(e)
future.set_exception(e)
else:
future.set_result(res)
self._impl.ioloop.poll()
self._impl.ioloop.process_timeouts()
def _process_io(self):
self._thread_id = current_thread()
while not self._evt_closed.is_set():
try:
self._do_process_io()
except BaseException:
LOG.exception("Error during processing connection's IO")
def close(self, *args, **kwargs):
res = self._execute_task(self._impl.close, *args, **kwargs)
self._evt_closed.wait()
self._thread.join()
return res
def channel(self, channel_number=None):
channel_opened_future = self._register_pending_future()
impl_channel = self._execute_task(
self._impl.channel,
on_open_callback=channel_opened_future.set_result,
channel_number=channel_number
)
# Create our proxy channel
channel = ThreadSafePikaChannel(impl_channel, self)
# Link implementation channel with our proxy channel
impl_channel._set_cookie(channel)
channel_opened_future.result()
return channel
def add_timeout(self, timeout, callback):
return self._execute_task(self._impl.add_timeout, timeout, callback)
def remove_timeout(self, timeout_id):
return self._execute_task(self._impl.remove_timeout, timeout_id)
@property
def is_closed(self):
return self._impl.is_closed
@property
def is_closing(self):
return self._impl.is_closing
@property
def is_open(self):
return self._impl.is_open
class ThreadSafePikaChannel(object): # pylint: disable=R0904,R0902
def __init__(self, channel_impl, connection):
self._impl = channel_impl
self._connection = connection
self._delivery_confirmation = False
self._message_returned = False
self._current_future = None
self._evt_closed = threading.Event()
self.add_on_close_callback(self._on_channel_close)
def _execute_task(self, func, *args, **kwargs):
return self._connection._execute_task(func, *args, **kwargs)
def _on_channel_close(self, channel, reply_code, reply_text):
self._evt_closed.set()
if self._current_future:
self._current_future.set_exception(
pika_exceptions.ChannelClosed(reply_code, reply_text))
def _on_message_confirmation(self, frame):
self._current_future.set_result(frame)
def add_on_close_callback(self, callback):
self._execute_task(self._impl.add_on_close_callback, callback)
def add_on_cancel_callback(self, callback):
self._execute_task(self._impl.add_on_cancel_callback, callback)
def __int__(self):
return self.channel_number
@property
def channel_number(self):
return self._impl.channel_number
@property
def is_closed(self):
return self._impl.is_closed
@property
def is_closing(self):
return self._impl.is_closing
@property
def is_open(self):
return self._impl.is_open
def close(self, reply_code=0, reply_text="Normal Shutdown"):
self._impl.close(reply_code=reply_code, reply_text=reply_text)
self._evt_closed.wait()
def flow(self, active):
self._current_future = futurist.Future()
self._execute_task(
self._impl.flow, callback=self._current_future.set_result,
active=active
)
return self._current_future.result()
def basic_consume(self, # pylint: disable=R0913
consumer_callback,
queue,
no_ack=False,
exclusive=False,
consumer_tag=None,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(
self._impl.add_callback, self._current_future.set_result,
replies=[pika_spec.Basic.ConsumeOk], one_shot=True
)
self._impl.add_callback(self._current_future.set_result,
replies=[pika_spec.Basic.ConsumeOk],
one_shot=True)
tag = self._execute_task(
self._impl.basic_consume,
consumer_callback=consumer_callback,
queue=queue,
no_ack=no_ack,
exclusive=exclusive,
consumer_tag=consumer_tag,
arguments=arguments
)
self._current_future.result()
return tag
def basic_cancel(self, consumer_tag):
self._current_future = futurist.Future()
self._execute_task(
self._impl.basic_cancel,
callback=self._current_future.set_result,
consumer_tag=consumer_tag,
nowait=False)
self._current_future.result()
def basic_ack(self, delivery_tag=0, multiple=False):
return self._execute_task(
self._impl.basic_ack, delivery_tag=delivery_tag, multiple=multiple)
def basic_nack(self, delivery_tag=None, multiple=False, requeue=True):
return self._execute_task(
self._impl.basic_nack, delivery_tag=delivery_tag,
multiple=multiple, requeue=requeue
)
def publish(self, exchange, routing_key, body, # pylint: disable=R0913
properties=None, mandatory=False, immediate=False):
if self._delivery_confirmation:
# In publisher-acknowledgments mode
self._message_returned = False
self._current_future = futurist.Future()
self._execute_task(self._impl.basic_publish,
exchange=exchange,
routing_key=routing_key,
body=body,
properties=properties,
mandatory=mandatory,
immediate=immediate)
conf_method = self._current_future.result().method
if isinstance(conf_method, pika_spec.Basic.Nack):
raise pika_exceptions.NackError((None,))
else:
assert isinstance(conf_method, pika_spec.Basic.Ack), (
conf_method)
if self._message_returned:
raise pika_exceptions.UnroutableError((None,))
else:
# In non-publisher-acknowledgments mode
self._execute_task(self._impl.basic_publish,
exchange=exchange,
routing_key=routing_key,
body=body,
properties=properties,
mandatory=mandatory,
immediate=immediate)
def basic_qos(self, prefetch_size=0, prefetch_count=0, all_channels=False):
self._current_future = futurist.Future()
self._execute_task(self._impl.basic_qos,
callback=self._current_future.set_result,
prefetch_size=prefetch_size,
prefetch_count=prefetch_count,
all_channels=all_channels)
self._current_future.result()
def basic_recover(self, requeue=False):
self._current_future = futurist.Future()
self._execute_task(
self._impl.basic_recover,
callback=lambda: self._current_future.set_result(None),
requeue=requeue
)
self._current_future.result()
def basic_reject(self, delivery_tag=None, requeue=True):
self._execute_task(self._impl.basic_reject,
delivery_tag=delivery_tag,
requeue=requeue)
def _on_message_returned(self, *args, **kwargs):
self._message_returned = True
def confirm_delivery(self):
self._current_future = futurist.Future()
self._execute_task(self._impl.add_callback,
callback=self._current_future.set_result,
replies=[pika_spec.Confirm.SelectOk],
one_shot=True)
self._execute_task(self._impl.confirm_delivery,
callback=self._on_message_confirmation,
nowait=False)
self._current_future.result()
self._delivery_confirmation = True
self._execute_task(self._impl.add_on_return_callback,
self._on_message_returned)
def exchange_declare(self, exchange=None, # pylint: disable=R0913
exchange_type='direct', passive=False, durable=False,
auto_delete=False, internal=False,
arguments=None, **kwargs):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_declare,
callback=self._current_future.set_result,
exchange=exchange,
exchange_type=exchange_type,
passive=passive,
durable=durable,
auto_delete=auto_delete,
internal=internal,
nowait=False,
arguments=arguments,
type=kwargs["type"] if kwargs else None)
return self._current_future.result()
def exchange_delete(self, exchange=None, if_unused=False):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_delete,
callback=self._current_future.set_result,
exchange=exchange,
if_unused=if_unused,
nowait=False)
return self._current_future.result()
def exchange_bind(self, destination=None, source=None, routing_key='',
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_bind,
callback=self._current_future.set_result,
destination=destination,
source=source,
routing_key=routing_key,
nowait=False,
arguments=arguments)
return self._current_future.result()
def exchange_unbind(self, destination=None, source=None, routing_key='',
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_unbind,
callback=self._current_future.set_result,
destination=destination,
source=source,
routing_key=routing_key,
nowait=False,
arguments=arguments)
return self._current_future.result()
def queue_declare(self, queue='', passive=False, durable=False,
exclusive=False, auto_delete=False,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_declare,
callback=self._current_future.set_result,
queue=queue,
passive=passive,
durable=durable,
exclusive=exclusive,
auto_delete=auto_delete,
nowait=False,
arguments=arguments)
return self._current_future.result()
def queue_delete(self, queue='', if_unused=False, if_empty=False):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_delete,
callback=self._current_future.set_result,
queue=queue,
if_unused=if_unused,
if_empty=if_empty,
nowait=False)
return self._current_future.result()
def queue_purge(self, queue=''):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_purge,
callback=self._current_future.set_result,
queue=queue,
nowait=False)
return self._current_future.result()
def queue_bind(self, queue, exchange, routing_key=None,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_bind,
callback=self._current_future.set_result,
queue=queue,
exchange=exchange,
routing_key=routing_key,
nowait=False,
arguments=arguments)
return self._current_future.result()
def queue_unbind(self, queue='', exchange=None, routing_key=None,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_unbind,
callback=self._current_future.set_result,
queue=queue,
exchange=exchange,
routing_key=routing_key,
arguments=arguments)
return self._current_future.result()

View File

@ -11,7 +11,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import os
import random import random
import socket import socket
import threading import threading
@ -19,38 +19,18 @@ import time
from oslo_log import log as logging from oslo_log import log as logging
import pika import pika
from pika.adapters import select_connection
from pika import credentials as pika_credentials from pika import credentials as pika_credentials
import pika_pool import pika_pool
import uuid import uuid
from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns
from oslo_messaging._drivers.pika_driver import pika_connection
from oslo_messaging._drivers.pika_driver import pika_exceptions as pika_drv_exc from oslo_messaging._drivers.pika_driver import pika_exceptions as pika_drv_exc
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
_PID = None
def _create_select_poller_connection_impl(
parameters, on_open_callback, on_open_error_callback,
on_close_callback, stop_ioloop_on_close):
"""Used for disabling autochoise of poller ('select', 'poll', 'epool', etc)
inside default 'SelectConnection.__init__(...)' logic. It is necessary to
force 'select' poller usage if eventlet is monkeypatched because eventlet
patches only 'select' system call
Method signature is copied form 'SelectConnection.__init__(...)', because
it is used as replacement of 'SelectConnection' class to create instances
"""
return select_connection.SelectConnection(
parameters=parameters,
on_open_callback=on_open_callback,
on_open_error_callback=on_open_error_callback,
on_close_callback=on_close_callback,
stop_ioloop_on_close=stop_ioloop_on_close,
custom_ioloop=select_connection.SelectPoller()
)
class _PooledConnectionWithConfirmations(pika_pool.Connection): class _PooledConnectionWithConfirmations(pika_pool.Connection):
@ -84,10 +64,6 @@ class PikaEngine(object):
allowed_remote_exmods=None): allowed_remote_exmods=None):
self.conf = conf self.conf = conf
self._force_select_poller_use = (
pika_drv_cmns.is_eventlet_monkey_patched('select')
)
# processing rpc options # processing rpc options
self.default_rpc_exchange = ( self.default_rpc_exchange = (
conf.oslo_messaging_pika.default_rpc_exchange conf.oslo_messaging_pika.default_rpc_exchange
@ -168,7 +144,7 @@ class PikaEngine(object):
) )
# initializing connection parameters for configured RabbitMQ hosts # initializing connection parameters for configured RabbitMQ hosts
common_pika_params = { self._common_pika_params = {
'virtual_host': url.virtual_host, 'virtual_host': url.virtual_host,
'channel_max': self.conf.oslo_messaging_pika.channel_max, 'channel_max': self.conf.oslo_messaging_pika.channel_max,
'frame_max': self.conf.oslo_messaging_pika.frame_max, 'frame_max': self.conf.oslo_messaging_pika.frame_max,
@ -177,31 +153,18 @@ class PikaEngine(object):
'socket_timeout': self.conf.oslo_messaging_pika.socket_timeout, 'socket_timeout': self.conf.oslo_messaging_pika.socket_timeout,
} }
self._connection_lock = threading.Lock() self._connection_lock = threading.RLock()
self._pid = None
self._connection_host_param_list = [] self._connection_host_status = {}
self._connection_host_status_list = []
if not url.hosts: if not url.hosts:
raise ValueError("You should provide at least one RabbitMQ host") raise ValueError("You should provide at least one RabbitMQ host")
for transport_host in url.hosts: self._host_list = url.hosts
pika_params = common_pika_params.copy()
pika_params.update(
host=transport_host.hostname,
port=transport_host.port,
credentials=pika_credentials.PlainCredentials(
transport_host.username, transport_host.password
),
)
self._connection_host_param_list.append(pika_params)
self._connection_host_status_list.append({
self.HOST_CONNECTION_LAST_TRY_TIME: 0,
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0
})
self._next_connection_host_num = random.randint( self._cur_connection_host_num = random.randint(
0, len(self._connection_host_param_list) - 1 0, len(self._host_list) - 1
) )
# initializing 2 connection pools: 1st for connections without # initializing 2 connection pools: 1st for connections without
@ -228,49 +191,37 @@ class PikaEngine(object):
_PooledConnectionWithConfirmations _PooledConnectionWithConfirmations
) )
def _next_connection_num(self):
"""Used for creating connections to different RabbitMQ nodes in
round robin order
:return: next host number to create connection to
"""
with self._connection_lock:
cur_num = self._next_connection_host_num
self._next_connection_host_num += 1
self._next_connection_host_num %= len(
self._connection_host_param_list
)
return cur_num
def create_connection(self, for_listening=False): def create_connection(self, for_listening=False):
"""Create and return connection to any available host. """Create and return connection to any available host.
:return: created connection :return: created connection
:raise: ConnectionException if all hosts are not reachable :raise: ConnectionException if all hosts are not reachable
""" """
host_count = len(self._connection_host_param_list)
connection_attempts = host_count
pika_next_connection_num = self._next_connection_num() with self._connection_lock:
self._init_if_needed()
while connection_attempts > 0: host_count = len(self._host_list)
try: connection_attempts = host_count
return self.create_host_connection(
pika_next_connection_num, for_listening
)
except pika_pool.Connection.connectivity_errors as e:
LOG.warning("Can't establish connection to host. %s", e)
except pika_drv_exc.HostConnectionNotAllowedException as e:
LOG.warning("Connection to host is not allowed. %s", e)
connection_attempts -= 1 while connection_attempts > 0:
pika_next_connection_num += 1 self._cur_connection_host_num += 1
pika_next_connection_num %= host_count self._cur_connection_host_num %= host_count
try:
return self.create_host_connection(
self._cur_connection_host_num, for_listening
)
except pika_pool.Connection.connectivity_errors as e:
LOG.warning("Can't establish connection to host. %s", e)
except pika_drv_exc.HostConnectionNotAllowedException as e:
LOG.warning("Connection to host is not allowed. %s", e)
raise pika_drv_exc.EstablishConnectionException( connection_attempts -= 1
"Can not establish connection to any configured RabbitMQ host: " +
str(self._connection_host_param_list) raise pika_drv_exc.EstablishConnectionException(
) "Can not establish connection to any configured RabbitMQ "
"host: " + str(self._host_list)
)
def _set_tcp_user_timeout(self, s): def _set_tcp_user_timeout(self, s):
if not self._tcp_user_timeout: if not self._tcp_user_timeout:
@ -285,28 +236,64 @@ class PikaEngine(object):
"Whoops, this kernel doesn't seem to support TCP_USER_TIMEOUT." "Whoops, this kernel doesn't seem to support TCP_USER_TIMEOUT."
) )
def _init_if_needed(self):
global _PID
cur_pid = os.getpid()
if _PID != cur_pid:
if _PID:
LOG.warning("New pid is detected. Old: %s, new: %s. "
"Cleaning up...", _PID, cur_pid)
# Note(dukhlov): we need to force select poller usage in case when
# 'thread' module is monkey patched becase current eventlet
# implementation does not support patching of poll/epoll/kqueue
if pika_drv_cmns.is_eventlet_monkey_patched("thread"):
from pika.adapters import select_connection
select_connection.SELECT_TYPE = "select"
_PID = cur_pid
def create_host_connection(self, host_index, for_listening=False): def create_host_connection(self, host_index, for_listening=False):
"""Create new connection to host #host_index """Create new connection to host #host_index
:param host_index: Integer, number of host for connection establishing :param host_index: Integer, number of host for connection establishing
:param for_listening: Boolean, creates connection for listening :param for_listening: Boolean, creates connection for listening
(enable heartbeats) if True if True
:return: New connection :return: New connection
""" """
connection_params = pika.ConnectionParameters(
heartbeat_interval=(
self._heartbeat_interval if for_listening else None
),
**self._connection_host_param_list[host_index]
)
with self._connection_lock: with self._connection_lock:
self._init_if_needed()
host = self._host_list[host_index]
connection_params = pika.ConnectionParameters(
host=host.hostname,
port=host.port,
credentials=pika_credentials.PlainCredentials(
host.username, host.password
),
heartbeat_interval=(
self._heartbeat_interval if for_listening else None
),
**self._common_pika_params
)
cur_time = time.time() cur_time = time.time()
last_success_time = self._connection_host_status_list[host_index][ host_connection_status = self._connection_host_status.get(host)
if host_connection_status is None:
host_connection_status = {
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0,
self.HOST_CONNECTION_LAST_TRY_TIME: 0
}
self._connection_host_status[host] = host_connection_status
last_success_time = host_connection_status[
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME
] ]
last_time = self._connection_host_status_list[host_index][ last_time = host_connection_status[
self.HOST_CONNECTION_LAST_TRY_TIME self.HOST_CONNECTION_LAST_TRY_TIME
] ]
@ -322,27 +309,25 @@ class PikaEngine(object):
) )
try: try:
connection = pika.BlockingConnection( if for_listening:
parameters=connection_params, connection = pika_connection.ThreadSafePikaConnection(
_impl_class=(_create_select_poller_connection_impl params=connection_params
if self._force_select_poller_use else None) )
) else:
connection = pika.BlockingConnection(
# It is needed for pika_pool library which expects that parameters=connection_params
# connections has params attribute defined in BaseConnection )
# but BlockingConnection is not derived from BaseConnection connection.params = connection_params
# and doesn't have it
connection.params = connection_params
self._set_tcp_user_timeout(connection._impl.socket) self._set_tcp_user_timeout(connection._impl.socket)
self._connection_host_status_list[host_index][ self._connection_host_status[host][
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME
] = cur_time ] = cur_time
return connection return connection
finally: finally:
self._connection_host_status_list[host_index][ self._connection_host_status[host][
self.HOST_CONNECTION_LAST_TRY_TIME self.HOST_CONNECTION_LAST_TRY_TIME
] = cur_time ] = cur_time

View File

@ -29,6 +29,7 @@ class RpcReplyPikaListener(object):
""" """
def __init__(self, pika_engine): def __init__(self, pika_engine):
super(RpcReplyPikaListener, self).__init__()
self._pika_engine = pika_engine self._pika_engine = pika_engine
# preparing poller for listening replies # preparing poller for listening replies
@ -39,7 +40,6 @@ class RpcReplyPikaListener(object):
self._reply_consumer_initialized = False self._reply_consumer_initialized = False
self._reply_consumer_initialization_lock = threading.Lock() self._reply_consumer_initialization_lock = threading.Lock()
self._poller_thread = None
self._shutdown = False self._shutdown = False
def get_reply_qname(self): def get_reply_qname(self):
@ -66,51 +66,31 @@ class RpcReplyPikaListener(object):
# initialize reply poller if needed # initialize reply poller if needed
if self._reply_poller is None: if self._reply_poller is None:
self._reply_poller = pika_drv_poller.RpcReplyPikaPoller( self._reply_poller = pika_drv_poller.RpcReplyPikaPoller(
pika_engine=self._pika_engine, self._pika_engine, self._pika_engine.rpc_reply_exchange,
exchange=self._pika_engine.rpc_reply_exchange, self._reply_queue, 1, None,
queue=self._reply_queue, self._pika_engine.rpc_reply_listener_prefetch_count
prefetch_count=(
self._pika_engine.rpc_reply_listener_prefetch_count
)
) )
self._reply_poller.start() self._reply_poller.start(self._on_incoming)
# start reply poller job thread if needed
if self._poller_thread is None:
self._poller_thread = threading.Thread(target=self._poller)
self._poller_thread.daemon = True
if not self._poller_thread.is_alive():
self._poller_thread.start()
self._reply_consumer_initialized = True self._reply_consumer_initialized = True
return self._reply_queue return self._reply_queue
def _poller(self): def _on_incoming(self, incoming):
"""Reply polling job. Poll replies in infinite loop and notify """Reply polling job. Poll replies in infinite loop and notify
registered features registered features
""" """
while True: for message in incoming:
try: try:
messages = self._reply_poller.poll() message.acknowledge()
if not messages and self._shutdown: future = self._reply_waiting_futures.pop(
break message.msg_id, None
)
for message in messages: if future is not None:
try: future.set_result(message)
message.acknowledge() except Exception:
future = self._reply_waiting_futures.pop( LOG.exception("Unexpected exception during processing"
message.msg_id, None "reply message")
)
if future is not None:
future.set_result(message)
except Exception:
LOG.exception("Unexpected exception during processing"
"reply message")
except BaseException:
LOG.exception("Unexpected exception during reply polling")
def register_reply_waiter(self, msg_id): def register_reply_waiter(self, msg_id):
"""Register reply waiter. Should be called before message sending to """Register reply waiter. Should be called before message sending to
@ -140,9 +120,4 @@ class RpcReplyPikaListener(object):
self._reply_poller.cleanup() self._reply_poller.cleanup()
self._reply_poller = None self._reply_poller = None
if self._poller_thread:
if self._poller_thread.is_alive():
self._poller_thread.join()
self._poller_thread = None
self._reply_queue = None self._reply_queue = None

View File

@ -13,11 +13,9 @@
# under the License. # under the License.
import threading import threading
import time
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils import timeutils from oslo_service import loopingcall
import six
from oslo_messaging._drivers import base from oslo_messaging._drivers import base
from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns
@ -27,45 +25,188 @@ from oslo_messaging._drivers.pika_driver import pika_message as pika_drv_msg
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class PikaPoller(base.PollStyleListener): class PikaPoller(base.Listener):
"""Provides user friendly functionality for RabbitMQ message consuming, """Provides user friendly functionality for RabbitMQ message consuming,
handles low level connectivity problems and restore connection if some handles low level connectivity problems and restore connection if some
connectivity related problem detected connectivity related problem detected
""" """
def __init__(self, pika_engine, prefetch_count, incoming_message_class): def __init__(self, pika_engine, batch_size, batch_timeout, prefetch_count,
incoming_message_class):
"""Initialize required fields """Initialize required fields
:param pika_engine: PikaEngine, shared object with configuration and :param pika_engine: PikaEngine, shared object with configuration and
shared driver functionality shared driver functionality
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged :param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer messages which RabbitMQ broker sends to this consumer
:param incoming_message_class: PikaIncomingMessage, wrapper for :param incoming_message_class: PikaIncomingMessage, wrapper for
consumed RabbitMQ message consumed RabbitMQ message
""" """
super(PikaPoller, self).__init__(prefetch_count) super(PikaPoller, self).__init__(batch_size, batch_timeout,
prefetch_count)
self._pika_engine = pika_engine self._pika_engine = pika_engine
self._incoming_message_class = incoming_message_class self._incoming_message_class = incoming_message_class
self._connection = None self._connection = None
self._channel = None self._channel = None
self._lock = threading.Lock() self._recover_loopingcall = None
self._lock = threading.RLock()
self._cur_batch_buffer = None
self._cur_batch_timeout_id = None
self._started = False self._started = False
self._closing_connection_by_poller = False
self._queues_to_consume = None self._queues_to_consume = None
self._message_queue = [] def _on_connection_close(self, connection, reply_code, reply_text):
self._deliver_cur_batch()
if self._closing_connection_by_poller:
return
with self._lock:
self._connection = None
self._start_recover_consuming_task()
def _reconnect(self): def _on_channel_close(self, channel, reply_code, reply_text):
if self._cur_batch_buffer:
self._cur_batch_buffer = [
message for message in self._cur_batch_buffer
if not message.need_ack()
]
if self._closing_connection_by_poller:
return
with self._lock:
self._channel = None
self._start_recover_consuming_task()
def _on_consumer_cancel(self, method_frame):
with self._lock:
if self._queues_to_consume:
consumer_tag = method_frame.method.consumer_tag
for queue_info in self._queues_to_consume:
if queue_info["consumer_tag"] == consumer_tag:
queue_info["consumer_tag"] = None
self._start_recover_consuming_task()
def _on_message_no_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=True mode
"""
incoming_message = self._incoming_message_class(
self._pika_engine, None, method, properties, body
)
self._on_incoming_message(incoming_message)
def _on_message_with_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=False mode
"""
incoming_message = self._incoming_message_class(
self._pika_engine, self._channel, method, properties, body
)
self._on_incoming_message(incoming_message)
def _deliver_cur_batch(self):
if self._cur_batch_timeout_id is not None:
self._connection.remove_timeout(self._cur_batch_timeout_id)
self._cur_batch_timeout_id = None
if self._cur_batch_buffer:
buf_to_send = self._cur_batch_buffer
self._cur_batch_buffer = None
try:
self.on_incoming_callback(buf_to_send)
except Exception:
LOG.exception("Unexpected exception during incoming delivery")
def _on_incoming_message(self, incoming_message):
if self._cur_batch_buffer is None:
self._cur_batch_buffer = [incoming_message]
else:
self._cur_batch_buffer.append(incoming_message)
if len(self._cur_batch_buffer) >= self.batch_size:
self._deliver_cur_batch()
return
if self._cur_batch_timeout_id is None:
self._cur_batch_timeout_id = self._connection.add_timeout(
self.batch_timeout, self._deliver_cur_batch)
def _start_recover_consuming_task(self):
"""Start async job for checking connection to the broker."""
if self._recover_loopingcall is None and self._started:
self._recover_loopingcall = (
loopingcall.DynamicLoopingCall(
self._try_recover_consuming
)
)
LOG.info("Starting recover consuming job for listener: %s", self)
self._recover_loopingcall.start()
def _try_recover_consuming(self):
with self._lock:
try:
if self._started:
self._start_or_recover_consuming()
except pika_drv_exc.EstablishConnectionException as e:
LOG.warning(
"Problem during establishing connection for pika "
"poller %s", e, exc_info=True
)
return self._pika_engine.host_connection_reconnect_delay
except pika_drv_exc.ConnectionException as e:
LOG.warning(
"Connectivity exception during starting/recovering pika "
"poller %s", e, exc_info=True
)
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as e:
LOG.warning(
"Connectivity exception during starting/recovering pika "
"poller %s", e, exc_info=True
)
except BaseException:
# NOTE (dukhlov): I preffer to use here BaseException because
# if this method raise such exception LoopingCall stops
# execution Probably it should never happen and Exception
# should be enough but in case of programmer mistake it could
# be and it is potentially hard to catch problem if we will
# stop background task. It is better when it continue to work
# and write a lot of LOG with this error
LOG.exception("Unexpected exception during "
"starting/recovering pika poller")
else:
self._recover_loopingcall = None
LOG.info("Recover consuming job was finished for listener: %s",
self)
raise loopingcall.LoopingCallDone(True)
return 0
def _start_or_recover_consuming(self):
"""Performs reconnection to the broker. It is unsafe method for """Performs reconnection to the broker. It is unsafe method for
internal use only internal use only
""" """
self._connection = self._pika_engine.create_connection( if self._connection is None or not self._connection.is_open:
for_listening=True self._connection = self._pika_engine.create_connection(
) for_listening=True
self._channel = self._connection.channel() )
self._channel.basic_qos(prefetch_count=self.prefetch_size) self._connection.add_on_close_callback(self._on_connection_close)
self._channel = None
if self._channel is None or not self._channel.is_open:
if self._queues_to_consume:
for queue_info in self._queues_to_consume:
queue_info["consumer_tag"] = None
self._channel = self._connection.channel()
self._channel.add_on_close_callback(self._on_channel_close)
self._channel.add_on_cancel_callback(self._on_consumer_cancel)
self._channel.basic_qos(prefetch_count=self.prefetch_size)
if self._queues_to_consume is None: if self._queues_to_consume is None:
self._queues_to_consume = self._declare_queue_binding() self._queues_to_consume = self._declare_queue_binding()
@ -92,6 +233,8 @@ class PikaPoller(base.PollStyleListener):
try: try:
for queue_info in self._queues_to_consume: for queue_info in self._queues_to_consume:
if queue_info["consumer_tag"] is not None:
continue
no_ack = queue_info["no_ack"] no_ack = queue_info["no_ack"]
on_message_callback = ( on_message_callback = (
@ -120,168 +263,95 @@ class PikaPoller(base.PollStyleListener):
self._channel.basic_cancel(consumer_tag) self._channel.basic_cancel(consumer_tag)
queue_info["consumer_tag"] = None queue_info["consumer_tag"] = None
def _on_message_no_ack_callback(self, unused, method, properties, body): def start(self, on_incoming_callback):
"""Is called by Pika when message was received from queue listened with
no_ack=True mode
"""
self._message_queue.append(
self._incoming_message_class(
self._pika_engine, None, method, properties, body
)
)
def _on_message_with_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=False mode
"""
self._message_queue.append(
self._incoming_message_class(
self._pika_engine, self._channel, method, properties, body
)
)
def _cleanup(self):
"""Cleanup allocated resources (channel, connection, etc). It is unsafe
method for internal use only
"""
if self._connection:
try:
self._connection.close()
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS:
# expected errors
pass
except Exception:
LOG.exception("Unexpected error during closing connection")
finally:
self._channel = None
self._connection = None
for i in six.moves.range(len(self._message_queue) - 1, -1, -1):
message = self._message_queue[i]
if message.need_ack():
del self._message_queue[i]
@base.batch_poll_helper
def poll(self, timeout=None):
"""Main method of this class - consumes message from RabbitMQ
:param: timeout: float, seconds, timeout for waiting new incoming
message, None means wait forever
:return: list of PikaIncomingMessage, RabbitMQ messages
"""
with timeutils.StopWatch(timeout) as stop_watch:
while True:
with self._lock:
if self._message_queue:
return self._message_queue.pop(0)
if stop_watch.expired():
return None
try:
if self._started:
if self._channel is None:
self._reconnect()
# we need some time_limit here, not too small to
# avoid a lot of not needed iterations but not too
# large to release lock time to time and give a
# chance to perform another method waiting this
# lock
self._connection.process_data_events(
time_limit=0.25
)
else:
# consumer is stopped so we don't expect new
# messages, just process already sent events
if self._channel is not None:
self._connection.process_data_events(
time_limit=0
)
# and return if we don't see new messages
if not self._message_queue:
return None
except pika_drv_exc.EstablishConnectionException as e:
LOG.warning(
"Problem during establishing connection for pika "
"poller %s", e, exc_info=True
)
time.sleep(
self._pika_engine.host_connection_reconnect_delay
)
except pika_drv_exc.ConnectionException:
self._cleanup()
raise
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS:
self._cleanup()
raise
def start(self):
"""Starts poller. Should be called before polling to allow message """Starts poller. Should be called before polling to allow message
consuming consuming
:param on_incoming_callback: callback function to be executed when
listener received messages. Messages should be processed and
acked/nacked by callback
""" """
super(PikaPoller, self).start(on_incoming_callback)
with self._lock: with self._lock:
if self._started: if self._started:
return return
connected = False
try: try:
self._reconnect() self._start_or_recover_consuming()
except pika_drv_exc.EstablishConnectionException as exc: except pika_drv_exc.EstablishConnectionException as exc:
LOG.warning( LOG.warning(
"Can not establish connection during pika poller's " "Can not establish connection during pika poller's "
"start(). Connecting is required during first poll() " "start(). %s", exc, exc_info=True
"call. %s", exc, exc_info=True
) )
except pika_drv_exc.ConnectionException as exc: except pika_drv_exc.ConnectionException as exc:
self._cleanup()
LOG.warning( LOG.warning(
"Connectivity problem during pika poller's start(). " "Connectivity problem during pika poller's start(). %s",
"Reconnecting required during first poll() call. %s",
exc, exc_info=True exc, exc_info=True
) )
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc: except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc:
self._cleanup()
LOG.warning( LOG.warning(
"Connectivity problem during pika poller's start(). " "Connectivity problem during pika poller's start(). %s",
"Reconnecting required during first poll() call. %s",
exc, exc_info=True exc, exc_info=True
) )
else:
connected = True
self._started = True self._started = True
if not connected:
self._start_recover_consuming_task()
def stop(self): def stop(self):
"""Stops poller. Should be called when polling is not needed anymore to """Stops poller. Should be called when polling is not needed anymore to
stop new message consuming. After that it is necessary to poll already stop new message consuming. After that it is necessary to poll already
prefetched messages prefetched messages
""" """
super(PikaPoller, self).stop()
with self._lock: with self._lock:
if not self._started: if not self._started:
return return
if self._queues_to_consume and self._channel: if self._recover_loopingcall is not None:
self._recover_loopingcall.stop()
self._recover_loopingcall = None
if (self._queues_to_consume and self._channel and
self._channel.is_open):
try: try:
self._stop_consuming() self._stop_consuming()
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc: except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc:
self._cleanup()
LOG.warning( LOG.warning(
"Connectivity problem detected during consumer " "Connectivity problem detected during consumer "
"cancellation. %s", exc, exc_info=True "cancellation. %s", exc, exc_info=True
) )
self._deliver_cur_batch()
self._started = False self._started = False
def cleanup(self): def cleanup(self):
"""Safe version of _cleanup. Cleans up allocated resources (channel, """Cleanup allocated resources (channel, connection, etc)."""
connection, etc).
"""
with self._lock: with self._lock:
self._cleanup() if self._connection and self._connection.is_open:
try:
self._closing_connection_by_poller = True
self._connection.close()
self._closing_connection_by_poller = False
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS:
# expected errors
pass
except Exception:
LOG.exception("Unexpected error during closing connection")
finally:
self._channel = None
self._connection = None
class RpcServicePikaPoller(PikaPoller): class RpcServicePikaPoller(PikaPoller):
"""PikaPoller implementation for polling RPC messages. Overrides base """PikaPoller implementation for polling RPC messages. Overrides base
functionality according to RPC specific functionality according to RPC specific
""" """
def __init__(self, pika_engine, target, prefetch_count): def __init__(self, pika_engine, target, batch_size, batch_timeout,
prefetch_count):
"""Adds target parameter for declaring RPC specific exchanges and """Adds target parameter for declaring RPC specific exchanges and
queues queues
@ -289,14 +359,18 @@ class RpcServicePikaPoller(PikaPoller):
shared driver functionality shared driver functionality
:param target: Target, oslo.messaging Target object which defines RPC :param target: Target, oslo.messaging Target object which defines RPC
endpoint endpoint
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged :param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer messages which RabbitMQ broker sends to this consumer
""" """
self._target = target self._target = target
super(RpcServicePikaPoller, self).__init__( super(RpcServicePikaPoller, self).__init__(
pika_engine, prefetch_count=prefetch_count, pika_engine, batch_size, batch_timeout, prefetch_count,
incoming_message_class=pika_drv_msg.RpcPikaIncomingMessage pika_drv_msg.RpcPikaIncomingMessage
) )
def _declare_queue_binding(self): def _declare_queue_binding(self):
@ -363,7 +437,8 @@ class RpcReplyPikaPoller(PikaPoller):
"""PikaPoller implementation for polling RPC reply messages. Overrides """PikaPoller implementation for polling RPC reply messages. Overrides
base functionality according to RPC reply specific base functionality according to RPC reply specific
""" """
def __init__(self, pika_engine, exchange, queue, prefetch_count): def __init__(self, pika_engine, exchange, queue, batch_size, batch_timeout,
prefetch_count):
"""Adds exchange and queue parameter for declaring exchange and queue """Adds exchange and queue parameter for declaring exchange and queue
used for RPC reply delivery used for RPC reply delivery
@ -371,6 +446,10 @@ class RpcReplyPikaPoller(PikaPoller):
shared driver functionality shared driver functionality
:param exchange: String, exchange name used for RPC reply delivery :param exchange: String, exchange name used for RPC reply delivery
:param queue: String, queue name used for RPC reply delivery :param queue: String, queue name used for RPC reply delivery
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged :param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer messages which RabbitMQ broker sends to this consumer
""" """
@ -378,8 +457,8 @@ class RpcReplyPikaPoller(PikaPoller):
self._queue = queue self._queue = queue
super(RpcReplyPikaPoller, self).__init__( super(RpcReplyPikaPoller, self).__init__(
pika_engine=pika_engine, prefetch_count=prefetch_count, pika_engine, batch_size, batch_timeout, prefetch_count,
incoming_message_class=pika_drv_msg.RpcReplyPikaIncomingMessage pika_drv_msg.RpcReplyPikaIncomingMessage
) )
def _declare_queue_binding(self): def _declare_queue_binding(self):
@ -404,8 +483,8 @@ class NotificationPikaPoller(PikaPoller):
"""PikaPoller implementation for polling Notification messages. Overrides """PikaPoller implementation for polling Notification messages. Overrides
base functionality according to Notification specific base functionality according to Notification specific
""" """
def __init__(self, pika_engine, targets_and_priorities, prefetch_count, def __init__(self, pika_engine, targets_and_priorities,
queue_name=None): batch_size, batch_timeout, prefetch_count, queue_name=None):
"""Adds targets_and_priorities and queue_name parameter """Adds targets_and_priorities and queue_name parameter
for declaring exchanges and queues used for notification delivery for declaring exchanges and queues used for notification delivery
@ -413,6 +492,10 @@ class NotificationPikaPoller(PikaPoller):
shared driver functionality shared driver functionality
:param targets_and_priorities: list of (target, priority), defines :param targets_and_priorities: list of (target, priority), defines
default queue names for corresponding notification types default queue names for corresponding notification types
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged :param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer messages which RabbitMQ broker sends to this consumer
:param queue: String, alternative queue name used for this poller :param queue: String, alternative queue name used for this poller
@ -422,8 +505,8 @@ class NotificationPikaPoller(PikaPoller):
self._queue_name = queue_name self._queue_name = queue_name
super(NotificationPikaPoller, self).__init__( super(NotificationPikaPoller, self).__init__(
pika_engine, prefetch_count=prefetch_count, pika_engine, batch_size, batch_timeout, prefetch_count,
incoming_message_class=pika_drv_msg.PikaIncomingMessage pika_drv_msg.PikaIncomingMessage
) )
def _declare_queue_binding(self): def _declare_queue_binding(self):
@ -442,7 +525,7 @@ class NotificationPikaPoller(PikaPoller):
target.exchange or target.exchange or
self._pika_engine.default_notification_exchange self._pika_engine.default_notification_exchange
), ),
queue = queue, queue=queue,
routing_key=routing_key, routing_key=routing_key,
exchange_type='direct', exchange_type='direct',
queue_expiration=None, queue_expiration=None,

View File

@ -13,9 +13,11 @@
# under the License. # under the License.
import socket import socket
import threading
import time import time
import unittest import unittest
from concurrent import futures
import mock import mock
from oslo_messaging._drivers.pika_driver import pika_poller from oslo_messaging._drivers.pika_driver import pika_poller
@ -32,26 +34,53 @@ class PikaPollerTestCase(unittest.TestCase):
self._pika_engine.create_connection.return_value = ( self._pika_engine.create_connection.return_value = (
self._poller_connection_mock self._poller_connection_mock
) )
self._executor = futures.ThreadPoolExecutor(1)
def timer_task(timeout, callback):
time.sleep(timeout)
callback()
self._poller_connection_mock.add_timeout.side_effect = (
lambda *args: self._executor.submit(timer_task, *args)
)
self._prefetch_count = 123 self._prefetch_count = 123
def test_start_when_connection_unavailable(self): @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
incoming_message_class_mock = mock.Mock() "_declare_queue_binding")
def test_start(self, declare_queue_binding_mock):
poller = pika_poller.PikaPoller( poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count, self._pika_engine, 1, None, self._prefetch_count, None
incoming_message_class=incoming_message_class_mock )
poller.start(None)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
def test_start_when_connection_unavailable(self):
poller = pika_poller.PikaPoller(
self._pika_engine, 1, None, self._prefetch_count, None
) )
self._pika_engine.create_connection.side_effect = socket.timeout() self._pika_engine.create_connection.side_effect = socket.timeout()
# start() should not raise socket.timeout exception # start() should not raise socket.timeout exception
poller.start() poller.start(None)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding") "_declare_queue_binding")
def test_poll(self, declare_queue_binding_mock): def test_message_processing(self, declare_queue_binding_mock):
res = []
def on_incoming_callback(incoming):
res.append(incoming)
incoming_message_class_mock = mock.Mock() incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller( poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count, self._pika_engine, 1, None, self._prefetch_count,
incoming_message_class=incoming_message_class_mock incoming_message_class=incoming_message_class_mock
) )
unused = object() unused = object()
@ -59,18 +88,14 @@ class PikaPollerTestCase(unittest.TestCase):
properties = object() properties = object()
body = object() body = object()
self._poller_connection_mock.process_data_events.side_effect = ( poller.start(on_incoming_callback)
lambda time_limit: poller._on_message_with_ack_callback( poller._on_message_with_ack_callback(
unused, method, properties, body unused, method, properties, body
)
) )
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value) self.assertEqual(res[0], [incoming_message_class_mock.return_value])
incoming_message_class_mock.assert_called_once_with( incoming_message_class_mock.assert_called_once_with(
self._pika_engine, self._poller_channel_mock, method, properties, self._pika_engine, self._poller_channel_mock, method, properties,
body body
@ -83,92 +108,39 @@ class PikaPollerTestCase(unittest.TestCase):
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding") "_declare_queue_binding")
def test_poll_after_stop(self, declare_queue_binding_mock): def test_message_processing_batch(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock() incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10 n = 10
params = [] params = []
for i in range(n): res = []
params.append((object(), object(), object(), object()))
def f(time_limit): def on_incoming_callback(incoming):
if poller._started: res.append(incoming)
for k in range(n):
poller._on_message_no_ack_callback(
*params[k]
)
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(batch_size=1)
self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[0][0],
(self._pika_engine, None) + params[0][1:]
)
poller.stop()
res2 = poller.poll(batch_size=n)
self.assertEqual(len(res2), n - 1)
self.assertEqual(incoming_message_class_mock.call_count, n)
self.assertEqual(
self._poller_connection_mock.process_data_events.call_count, 2)
for i in range(n - 1):
self.assertEqual(res2[i], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i + 1][0],
(self._pika_engine, None) + params[i + 1][1:]
)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_batch(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller( poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count, self._pika_engine, n, None, self._prefetch_count,
incoming_message_class=incoming_message_class_mock incoming_message_class=incoming_message_class_mock
) )
n = 10
params = []
for i in range(n): for i in range(n):
params.append((object(), object(), object(), object())) params.append((object(), object(), object(), object()))
index = [0] poller.start(on_incoming_callback)
def f(time_limit): for i in range(n):
poller._on_message_with_ack_callback( poller._on_message_with_ack_callback(
*params[index[0]] *params[i]
) )
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), 10)
poller.start()
res = poller.poll(batch_size=n)
self.assertEqual(len(res), n)
self.assertEqual(incoming_message_class_mock.call_count, n) self.assertEqual(incoming_message_class_mock.call_count, n)
for i in range(n): for i in range(n):
self.assertEqual(res[i], incoming_message_class_mock.return_value) self.assertEqual(res[0][i],
incoming_message_class_mock.return_value)
self.assertEqual( self.assertEqual(
incoming_message_class_mock.call_args_list[i][0], incoming_message_class_mock.call_args_list[i][0],
(self._pika_engine, self._poller_channel_mock) + params[i][1:] (self._pika_engine, self._poller_channel_mock) + params[i][1:]
@ -181,42 +153,48 @@ class PikaPollerTestCase(unittest.TestCase):
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller." @mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding") "_declare_queue_binding")
def test_poll_batch_with_timeout(self, declare_queue_binding_mock): def test_message_processing_batch_with_timeout(self,
declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock() incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10 n = 10
timeout = 1 timeout = 1
sleep_time = 0.2
res = []
evt = threading.Event()
def on_incoming_callback(incoming):
res.append(incoming)
evt.set()
poller = pika_poller.PikaPoller(
self._pika_engine, n, timeout, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
params = [] params = []
success_count = 5 success_count = 5
poller.start(on_incoming_callback)
for i in range(n): for i in range(n):
params.append((object(), object(), object(), object())) params.append((object(), object(), object(), object()))
index = [0] for i in range(success_count):
def f(time_limit):
time.sleep(sleep_time)
poller._on_message_with_ack_callback( poller._on_message_with_ack_callback(
*params[index[0]] *params[i]
) )
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f self.assertTrue(evt.wait(timeout * 2))
poller.start() self.assertEqual(len(res), 1)
res = poller.poll(batch_size=n, timeout=timeout) self.assertEqual(len(res[0]), success_count)
self.assertEqual(len(res), success_count)
self.assertEqual(incoming_message_class_mock.call_count, success_count) self.assertEqual(incoming_message_class_mock.call_count, success_count)
for i in range(success_count): for i in range(success_count):
self.assertEqual(res[i], incoming_message_class_mock.return_value) self.assertEqual(res[0][i],
incoming_message_class_mock.return_value)
self.assertEqual( self.assertEqual(
incoming_message_class_mock.call_args_list[i][0], incoming_message_class_mock.call_args_list[i][0],
(self._pika_engine, self._poller_channel_mock) + params[i][1:] (self._pika_engine, self._poller_channel_mock) + params[i][1:]
@ -258,20 +236,11 @@ class RpcServicePikaPollerTestCase(unittest.TestCase):
"RpcPikaIncomingMessage") "RpcPikaIncomingMessage")
def test_declare_rpc_queue_bindings(self, rpc_pika_incoming_message_mock): def test_declare_rpc_queue_bindings(self, rpc_pika_incoming_message_mock):
poller = pika_poller.RpcServicePikaPoller( poller = pika_poller.RpcServicePikaPoller(
self._pika_engine, self._target, self._prefetch_count, self._pika_engine, self._target, 1, None,
) self._prefetch_count
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
) )
poller.start() poller.start(None)
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], rpc_pika_incoming_message_mock.return_value)
self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called) self.assertTrue(self._poller_connection_mock.channel.called)
@ -357,24 +326,14 @@ class RpcReplyServicePikaPollerTestCase(unittest.TestCase):
self._pika_engine.rpc_queue_expiration = 12345 self._pika_engine.rpc_queue_expiration = 12345
self._pika_engine.rpc_reply_retry_attempts = 3 self._pika_engine.rpc_reply_retry_attempts = 3
def test_start(self):
poller = pika_poller.RpcReplyPikaPoller(
self._pika_engine, self._exchange, self._queue,
self._prefetch_count,
)
poller.start()
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
def test_declare_rpc_reply_queue_binding(self): def test_declare_rpc_reply_queue_binding(self):
poller = pika_poller.RpcReplyPikaPoller( poller = pika_poller.RpcReplyPikaPoller(
self._pika_engine, self._exchange, self._queue, self._pika_engine, self._exchange, self._queue, 1, None,
self._prefetch_count, self._prefetch_count,
) )
poller.start() poller.start(None)
poller.stop()
declare_queue_binding_by_channel_mock = ( declare_queue_binding_by_channel_mock = (
self._pika_engine.declare_queue_binding_by_channel self._pika_engine.declare_queue_binding_by_channel
@ -419,26 +378,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase):
) )
self._pika_engine.notification_persistence = object() self._pika_engine.notification_persistence = object()
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message." def test_declare_notification_queue_bindings_default_queue(self):
"PikaIncomingMessage")
def test_declare_notification_queue_bindings_default_queue(
self, pika_incoming_message_mock):
poller = pika_poller.NotificationPikaPoller( poller = pika_poller.NotificationPikaPoller(
self._pika_engine, self._target_and_priorities, self._pika_engine, self._target_and_priorities, 1, None,
self._prefetch_count, None self._prefetch_count, None
) )
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start() poller.start(None)
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], pika_incoming_message_mock.return_value)
self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called) self.assertTrue(self._poller_connection_mock.channel.called)
@ -481,26 +427,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase):
) )
)) ))
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message." def test_declare_notification_queue_bindings_custom_queue(self):
"PikaIncomingMessage")
def test_declare_notification_queue_bindings_custom_queue(
self, pika_incoming_message_mock):
poller = pika_poller.NotificationPikaPoller( poller = pika_poller.NotificationPikaPoller(
self._pika_engine, self._target_and_priorities, self._pika_engine, self._target_and_priorities, 1, None,
self._prefetch_count, "custom_queue_name" self._prefetch_count, "custom_queue_name"
) )
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start() poller.start(None)
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], pika_incoming_message_mock.return_value)
self.assertTrue(self._pika_engine.create_connection.called) self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called) self.assertTrue(self._poller_connection_mock.channel.called)