Merge "Implements pika thread safe connection"

This commit is contained in:
Jenkins 2016-04-27 20:00:25 +00:00 committed by Gerrit Code Review
commit bedd400b6a
6 changed files with 919 additions and 454 deletions

View File

@ -335,26 +335,18 @@ class PikaDriver(base.BaseDriver):
)
def listen(self, target, batch_size, batch_timeout):
listener = pika_drv_poller.RpcServicePikaPoller(
self._pika_engine, target,
prefetch_count=self._pika_engine.rpc_listener_prefetch_count
return pika_drv_poller.RpcServicePikaPoller(
self._pika_engine, target, batch_size, batch_timeout,
self._pika_engine.rpc_listener_prefetch_count
)
listener.start()
return base.PollStyleListenerAdapter(listener, batch_size,
batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool,
batch_size, batch_timeout):
listener = pika_drv_poller.NotificationPikaPoller(
self._pika_engine, targets_and_priorities,
prefetch_count=(
self._pika_engine.notification_listener_prefetch_count
),
queue_name=pool
return pika_drv_poller.NotificationPikaPoller(
self._pika_engine, targets_and_priorities, batch_size,
batch_timeout,
self._pika_engine.notification_listener_prefetch_count, pool
)
listener.start()
return base.PollStyleListenerAdapter(listener, batch_size,
batch_timeout)
def cleanup(self):
self._reply_listener.cleanup()

View File

@ -0,0 +1,497 @@
# Copyright 2016 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import logging
import os
import threading
import futurist
import pika
from pika.adapters import select_connection
from pika import exceptions as pika_exceptions
from pika import spec as pika_spec
from oslo_utils import eventletutils
current_thread = eventletutils.fetch_current_thread_functor()
LOG = logging.getLogger(__name__)
class ThreadSafePikaConnection(object):
def __init__(self, params=None):
self.params = params
self._connection_lock = threading.Lock()
self._evt_closed = threading.Event()
self._task_queue = collections.deque()
self._pending_connection_futures = set()
create_connection_future = self._register_pending_future()
def on_open_error(conn, err):
create_connection_future.set_exception(
pika_exceptions.AMQPConnectionError(err)
)
self._impl = pika.SelectConnection(
parameters=params,
on_open_callback=create_connection_future.set_result,
on_open_error_callback=on_open_error,
on_close_callback=self._on_connection_close,
stop_ioloop_on_close=False,
)
self._interrupt_pipein, self._interrupt_pipeout = os.pipe()
self._impl.ioloop.add_handler(self._interrupt_pipein,
self._impl.ioloop.read_interrupt,
select_connection.READ)
self._thread = threading.Thread(target=self._process_io)
self._thread.daemon = True
self._thread_id = None
self._thread.start()
create_connection_future.result()
def _execute_task(self, func, *args, **kwargs):
if current_thread() == self._thread_id:
return func(*args, **kwargs)
future = futurist.Future()
self._task_queue.append((func, args, kwargs, future))
if self._evt_closed.is_set():
self._notify_all_futures_connection_close()
elif self._interrupt_pipeout is not None:
os.write(self._interrupt_pipeout, b'X')
return future.result()
def _register_pending_future(self):
future = futurist.Future()
self._pending_connection_futures.add(future)
def on_done_callback(fut):
try:
self._pending_connection_futures.remove(fut)
except KeyError:
pass
future.add_done_callback(on_done_callback)
if self._evt_closed.is_set():
self._notify_all_futures_connection_close()
return future
def _notify_all_futures_connection_close(self):
while self._task_queue:
try:
method_res_future = self._task_queue.pop()[3]
except KeyError:
break
else:
method_res_future.set_exception(
pika_exceptions.ConnectionClosed()
)
while self._pending_connection_futures:
try:
pending_connection_future = (
self._pending_connection_futures.pop()
)
except KeyError:
break
else:
pending_connection_future.set_exception(
pika_exceptions.ConnectionClosed()
)
def _on_connection_close(self, conn, reply_code, reply_text):
self._evt_closed.set()
self._notify_all_futures_connection_close()
if self._interrupt_pipeout:
os.close(self._interrupt_pipeout)
os.close(self._interrupt_pipein)
def add_on_close_callback(self, callback):
return self._execute_task(self._impl.add_on_close_callback, callback)
def _do_process_io(self):
while self._task_queue:
func, args, kwargs, future = self._task_queue.pop()
try:
res = func(*args, **kwargs)
except BaseException as e:
LOG.exception(e)
future.set_exception(e)
else:
future.set_result(res)
self._impl.ioloop.poll()
self._impl.ioloop.process_timeouts()
def _process_io(self):
self._thread_id = current_thread()
while not self._evt_closed.is_set():
try:
self._do_process_io()
except BaseException:
LOG.exception("Error during processing connection's IO")
def close(self, *args, **kwargs):
res = self._execute_task(self._impl.close, *args, **kwargs)
self._evt_closed.wait()
self._thread.join()
return res
def channel(self, channel_number=None):
channel_opened_future = self._register_pending_future()
impl_channel = self._execute_task(
self._impl.channel,
on_open_callback=channel_opened_future.set_result,
channel_number=channel_number
)
# Create our proxy channel
channel = ThreadSafePikaChannel(impl_channel, self)
# Link implementation channel with our proxy channel
impl_channel._set_cookie(channel)
channel_opened_future.result()
return channel
def add_timeout(self, timeout, callback):
return self._execute_task(self._impl.add_timeout, timeout, callback)
def remove_timeout(self, timeout_id):
return self._execute_task(self._impl.remove_timeout, timeout_id)
@property
def is_closed(self):
return self._impl.is_closed
@property
def is_closing(self):
return self._impl.is_closing
@property
def is_open(self):
return self._impl.is_open
class ThreadSafePikaChannel(object): # pylint: disable=R0904,R0902
def __init__(self, channel_impl, connection):
self._impl = channel_impl
self._connection = connection
self._delivery_confirmation = False
self._message_returned = False
self._current_future = None
self._evt_closed = threading.Event()
self.add_on_close_callback(self._on_channel_close)
def _execute_task(self, func, *args, **kwargs):
return self._connection._execute_task(func, *args, **kwargs)
def _on_channel_close(self, channel, reply_code, reply_text):
self._evt_closed.set()
if self._current_future:
self._current_future.set_exception(
pika_exceptions.ChannelClosed(reply_code, reply_text))
def _on_message_confirmation(self, frame):
self._current_future.set_result(frame)
def add_on_close_callback(self, callback):
self._execute_task(self._impl.add_on_close_callback, callback)
def add_on_cancel_callback(self, callback):
self._execute_task(self._impl.add_on_cancel_callback, callback)
def __int__(self):
return self.channel_number
@property
def channel_number(self):
return self._impl.channel_number
@property
def is_closed(self):
return self._impl.is_closed
@property
def is_closing(self):
return self._impl.is_closing
@property
def is_open(self):
return self._impl.is_open
def close(self, reply_code=0, reply_text="Normal Shutdown"):
self._impl.close(reply_code=reply_code, reply_text=reply_text)
self._evt_closed.wait()
def flow(self, active):
self._current_future = futurist.Future()
self._execute_task(
self._impl.flow, callback=self._current_future.set_result,
active=active
)
return self._current_future.result()
def basic_consume(self, # pylint: disable=R0913
consumer_callback,
queue,
no_ack=False,
exclusive=False,
consumer_tag=None,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(
self._impl.add_callback, self._current_future.set_result,
replies=[pika_spec.Basic.ConsumeOk], one_shot=True
)
self._impl.add_callback(self._current_future.set_result,
replies=[pika_spec.Basic.ConsumeOk],
one_shot=True)
tag = self._execute_task(
self._impl.basic_consume,
consumer_callback=consumer_callback,
queue=queue,
no_ack=no_ack,
exclusive=exclusive,
consumer_tag=consumer_tag,
arguments=arguments
)
self._current_future.result()
return tag
def basic_cancel(self, consumer_tag):
self._current_future = futurist.Future()
self._execute_task(
self._impl.basic_cancel,
callback=self._current_future.set_result,
consumer_tag=consumer_tag,
nowait=False)
self._current_future.result()
def basic_ack(self, delivery_tag=0, multiple=False):
return self._execute_task(
self._impl.basic_ack, delivery_tag=delivery_tag, multiple=multiple)
def basic_nack(self, delivery_tag=None, multiple=False, requeue=True):
return self._execute_task(
self._impl.basic_nack, delivery_tag=delivery_tag,
multiple=multiple, requeue=requeue
)
def publish(self, exchange, routing_key, body, # pylint: disable=R0913
properties=None, mandatory=False, immediate=False):
if self._delivery_confirmation:
# In publisher-acknowledgments mode
self._message_returned = False
self._current_future = futurist.Future()
self._execute_task(self._impl.basic_publish,
exchange=exchange,
routing_key=routing_key,
body=body,
properties=properties,
mandatory=mandatory,
immediate=immediate)
conf_method = self._current_future.result().method
if isinstance(conf_method, pika_spec.Basic.Nack):
raise pika_exceptions.NackError((None,))
else:
assert isinstance(conf_method, pika_spec.Basic.Ack), (
conf_method)
if self._message_returned:
raise pika_exceptions.UnroutableError((None,))
else:
# In non-publisher-acknowledgments mode
self._execute_task(self._impl.basic_publish,
exchange=exchange,
routing_key=routing_key,
body=body,
properties=properties,
mandatory=mandatory,
immediate=immediate)
def basic_qos(self, prefetch_size=0, prefetch_count=0, all_channels=False):
self._current_future = futurist.Future()
self._execute_task(self._impl.basic_qos,
callback=self._current_future.set_result,
prefetch_size=prefetch_size,
prefetch_count=prefetch_count,
all_channels=all_channels)
self._current_future.result()
def basic_recover(self, requeue=False):
self._current_future = futurist.Future()
self._execute_task(
self._impl.basic_recover,
callback=lambda: self._current_future.set_result(None),
requeue=requeue
)
self._current_future.result()
def basic_reject(self, delivery_tag=None, requeue=True):
self._execute_task(self._impl.basic_reject,
delivery_tag=delivery_tag,
requeue=requeue)
def _on_message_returned(self, *args, **kwargs):
self._message_returned = True
def confirm_delivery(self):
self._current_future = futurist.Future()
self._execute_task(self._impl.add_callback,
callback=self._current_future.set_result,
replies=[pika_spec.Confirm.SelectOk],
one_shot=True)
self._execute_task(self._impl.confirm_delivery,
callback=self._on_message_confirmation,
nowait=False)
self._current_future.result()
self._delivery_confirmation = True
self._execute_task(self._impl.add_on_return_callback,
self._on_message_returned)
def exchange_declare(self, exchange=None, # pylint: disable=R0913
exchange_type='direct', passive=False, durable=False,
auto_delete=False, internal=False,
arguments=None, **kwargs):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_declare,
callback=self._current_future.set_result,
exchange=exchange,
exchange_type=exchange_type,
passive=passive,
durable=durable,
auto_delete=auto_delete,
internal=internal,
nowait=False,
arguments=arguments,
type=kwargs["type"] if kwargs else None)
return self._current_future.result()
def exchange_delete(self, exchange=None, if_unused=False):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_delete,
callback=self._current_future.set_result,
exchange=exchange,
if_unused=if_unused,
nowait=False)
return self._current_future.result()
def exchange_bind(self, destination=None, source=None, routing_key='',
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_bind,
callback=self._current_future.set_result,
destination=destination,
source=source,
routing_key=routing_key,
nowait=False,
arguments=arguments)
return self._current_future.result()
def exchange_unbind(self, destination=None, source=None, routing_key='',
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.exchange_unbind,
callback=self._current_future.set_result,
destination=destination,
source=source,
routing_key=routing_key,
nowait=False,
arguments=arguments)
return self._current_future.result()
def queue_declare(self, queue='', passive=False, durable=False,
exclusive=False, auto_delete=False,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_declare,
callback=self._current_future.set_result,
queue=queue,
passive=passive,
durable=durable,
exclusive=exclusive,
auto_delete=auto_delete,
nowait=False,
arguments=arguments)
return self._current_future.result()
def queue_delete(self, queue='', if_unused=False, if_empty=False):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_delete,
callback=self._current_future.set_result,
queue=queue,
if_unused=if_unused,
if_empty=if_empty,
nowait=False)
return self._current_future.result()
def queue_purge(self, queue=''):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_purge,
callback=self._current_future.set_result,
queue=queue,
nowait=False)
return self._current_future.result()
def queue_bind(self, queue, exchange, routing_key=None,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_bind,
callback=self._current_future.set_result,
queue=queue,
exchange=exchange,
routing_key=routing_key,
nowait=False,
arguments=arguments)
return self._current_future.result()
def queue_unbind(self, queue='', exchange=None, routing_key=None,
arguments=None):
self._current_future = futurist.Future()
self._execute_task(self._impl.queue_unbind,
callback=self._current_future.set_result,
queue=queue,
exchange=exchange,
routing_key=routing_key,
arguments=arguments)
return self._current_future.result()

View File

@ -11,7 +11,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import random
import socket
import threading
@ -19,38 +19,18 @@ import time
from oslo_log import log as logging
import pika
from pika.adapters import select_connection
from pika import credentials as pika_credentials
import pika_pool
import uuid
from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns
from oslo_messaging._drivers.pika_driver import pika_connection
from oslo_messaging._drivers.pika_driver import pika_exceptions as pika_drv_exc
LOG = logging.getLogger(__name__)
def _create_select_poller_connection_impl(
parameters, on_open_callback, on_open_error_callback,
on_close_callback, stop_ioloop_on_close):
"""Used for disabling autochoise of poller ('select', 'poll', 'epool', etc)
inside default 'SelectConnection.__init__(...)' logic. It is necessary to
force 'select' poller usage if eventlet is monkeypatched because eventlet
patches only 'select' system call
Method signature is copied form 'SelectConnection.__init__(...)', because
it is used as replacement of 'SelectConnection' class to create instances
"""
return select_connection.SelectConnection(
parameters=parameters,
on_open_callback=on_open_callback,
on_open_error_callback=on_open_error_callback,
on_close_callback=on_close_callback,
stop_ioloop_on_close=stop_ioloop_on_close,
custom_ioloop=select_connection.SelectPoller()
)
_PID = None
class _PooledConnectionWithConfirmations(pika_pool.Connection):
@ -84,10 +64,6 @@ class PikaEngine(object):
allowed_remote_exmods=None):
self.conf = conf
self._force_select_poller_use = (
pika_drv_cmns.is_eventlet_monkey_patched('select')
)
# processing rpc options
self.default_rpc_exchange = (
conf.oslo_messaging_pika.default_rpc_exchange
@ -168,7 +144,7 @@ class PikaEngine(object):
)
# initializing connection parameters for configured RabbitMQ hosts
common_pika_params = {
self._common_pika_params = {
'virtual_host': url.virtual_host,
'channel_max': self.conf.oslo_messaging_pika.channel_max,
'frame_max': self.conf.oslo_messaging_pika.frame_max,
@ -177,31 +153,18 @@ class PikaEngine(object):
'socket_timeout': self.conf.oslo_messaging_pika.socket_timeout,
}
self._connection_lock = threading.Lock()
self._connection_lock = threading.RLock()
self._pid = None
self._connection_host_param_list = []
self._connection_host_status_list = []
self._connection_host_status = {}
if not url.hosts:
raise ValueError("You should provide at least one RabbitMQ host")
for transport_host in url.hosts:
pika_params = common_pika_params.copy()
pika_params.update(
host=transport_host.hostname,
port=transport_host.port,
credentials=pika_credentials.PlainCredentials(
transport_host.username, transport_host.password
),
)
self._connection_host_param_list.append(pika_params)
self._connection_host_status_list.append({
self.HOST_CONNECTION_LAST_TRY_TIME: 0,
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0
})
self._host_list = url.hosts
self._next_connection_host_num = random.randint(
0, len(self._connection_host_param_list) - 1
self._cur_connection_host_num = random.randint(
0, len(self._host_list) - 1
)
# initializing 2 connection pools: 1st for connections without
@ -228,49 +191,37 @@ class PikaEngine(object):
_PooledConnectionWithConfirmations
)
def _next_connection_num(self):
"""Used for creating connections to different RabbitMQ nodes in
round robin order
:return: next host number to create connection to
"""
with self._connection_lock:
cur_num = self._next_connection_host_num
self._next_connection_host_num += 1
self._next_connection_host_num %= len(
self._connection_host_param_list
)
return cur_num
def create_connection(self, for_listening=False):
"""Create and return connection to any available host.
:return: created connection
:raise: ConnectionException if all hosts are not reachable
"""
host_count = len(self._connection_host_param_list)
connection_attempts = host_count
pika_next_connection_num = self._next_connection_num()
with self._connection_lock:
self._init_if_needed()
while connection_attempts > 0:
try:
return self.create_host_connection(
pika_next_connection_num, for_listening
)
except pika_pool.Connection.connectivity_errors as e:
LOG.warning("Can't establish connection to host. %s", e)
except pika_drv_exc.HostConnectionNotAllowedException as e:
LOG.warning("Connection to host is not allowed. %s", e)
host_count = len(self._host_list)
connection_attempts = host_count
connection_attempts -= 1
pika_next_connection_num += 1
pika_next_connection_num %= host_count
while connection_attempts > 0:
self._cur_connection_host_num += 1
self._cur_connection_host_num %= host_count
try:
return self.create_host_connection(
self._cur_connection_host_num, for_listening
)
except pika_pool.Connection.connectivity_errors as e:
LOG.warning("Can't establish connection to host. %s", e)
except pika_drv_exc.HostConnectionNotAllowedException as e:
LOG.warning("Connection to host is not allowed. %s", e)
raise pika_drv_exc.EstablishConnectionException(
"Can not establish connection to any configured RabbitMQ host: " +
str(self._connection_host_param_list)
)
connection_attempts -= 1
raise pika_drv_exc.EstablishConnectionException(
"Can not establish connection to any configured RabbitMQ "
"host: " + str(self._host_list)
)
def _set_tcp_user_timeout(self, s):
if not self._tcp_user_timeout:
@ -285,28 +236,64 @@ class PikaEngine(object):
"Whoops, this kernel doesn't seem to support TCP_USER_TIMEOUT."
)
def _init_if_needed(self):
global _PID
cur_pid = os.getpid()
if _PID != cur_pid:
if _PID:
LOG.warning("New pid is detected. Old: %s, new: %s. "
"Cleaning up...", _PID, cur_pid)
# Note(dukhlov): we need to force select poller usage in case when
# 'thread' module is monkey patched becase current eventlet
# implementation does not support patching of poll/epoll/kqueue
if pika_drv_cmns.is_eventlet_monkey_patched("thread"):
from pika.adapters import select_connection
select_connection.SELECT_TYPE = "select"
_PID = cur_pid
def create_host_connection(self, host_index, for_listening=False):
"""Create new connection to host #host_index
:param host_index: Integer, number of host for connection establishing
:param for_listening: Boolean, creates connection for listening
(enable heartbeats) if True
if True
:return: New connection
"""
connection_params = pika.ConnectionParameters(
heartbeat_interval=(
self._heartbeat_interval if for_listening else None
),
**self._connection_host_param_list[host_index]
)
with self._connection_lock:
self._init_if_needed()
host = self._host_list[host_index]
connection_params = pika.ConnectionParameters(
host=host.hostname,
port=host.port,
credentials=pika_credentials.PlainCredentials(
host.username, host.password
),
heartbeat_interval=(
self._heartbeat_interval if for_listening else None
),
**self._common_pika_params
)
cur_time = time.time()
last_success_time = self._connection_host_status_list[host_index][
host_connection_status = self._connection_host_status.get(host)
if host_connection_status is None:
host_connection_status = {
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME: 0,
self.HOST_CONNECTION_LAST_TRY_TIME: 0
}
self._connection_host_status[host] = host_connection_status
last_success_time = host_connection_status[
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME
]
last_time = self._connection_host_status_list[host_index][
last_time = host_connection_status[
self.HOST_CONNECTION_LAST_TRY_TIME
]
@ -322,27 +309,25 @@ class PikaEngine(object):
)
try:
connection = pika.BlockingConnection(
parameters=connection_params,
_impl_class=(_create_select_poller_connection_impl
if self._force_select_poller_use else None)
)
# It is needed for pika_pool library which expects that
# connections has params attribute defined in BaseConnection
# but BlockingConnection is not derived from BaseConnection
# and doesn't have it
connection.params = connection_params
if for_listening:
connection = pika_connection.ThreadSafePikaConnection(
params=connection_params
)
else:
connection = pika.BlockingConnection(
parameters=connection_params
)
connection.params = connection_params
self._set_tcp_user_timeout(connection._impl.socket)
self._connection_host_status_list[host_index][
self._connection_host_status[host][
self.HOST_CONNECTION_LAST_SUCCESS_TRY_TIME
] = cur_time
return connection
finally:
self._connection_host_status_list[host_index][
self._connection_host_status[host][
self.HOST_CONNECTION_LAST_TRY_TIME
] = cur_time

View File

@ -29,6 +29,7 @@ class RpcReplyPikaListener(object):
"""
def __init__(self, pika_engine):
super(RpcReplyPikaListener, self).__init__()
self._pika_engine = pika_engine
# preparing poller for listening replies
@ -39,7 +40,6 @@ class RpcReplyPikaListener(object):
self._reply_consumer_initialized = False
self._reply_consumer_initialization_lock = threading.Lock()
self._poller_thread = None
self._shutdown = False
def get_reply_qname(self):
@ -66,51 +66,31 @@ class RpcReplyPikaListener(object):
# initialize reply poller if needed
if self._reply_poller is None:
self._reply_poller = pika_drv_poller.RpcReplyPikaPoller(
pika_engine=self._pika_engine,
exchange=self._pika_engine.rpc_reply_exchange,
queue=self._reply_queue,
prefetch_count=(
self._pika_engine.rpc_reply_listener_prefetch_count
)
self._pika_engine, self._pika_engine.rpc_reply_exchange,
self._reply_queue, 1, None,
self._pika_engine.rpc_reply_listener_prefetch_count
)
self._reply_poller.start()
# start reply poller job thread if needed
if self._poller_thread is None:
self._poller_thread = threading.Thread(target=self._poller)
self._poller_thread.daemon = True
if not self._poller_thread.is_alive():
self._poller_thread.start()
self._reply_poller.start(self._on_incoming)
self._reply_consumer_initialized = True
return self._reply_queue
def _poller(self):
def _on_incoming(self, incoming):
"""Reply polling job. Poll replies in infinite loop and notify
registered features
"""
while True:
for message in incoming:
try:
messages = self._reply_poller.poll()
if not messages and self._shutdown:
break
for message in messages:
try:
message.acknowledge()
future = self._reply_waiting_futures.pop(
message.msg_id, None
)
if future is not None:
future.set_result(message)
except Exception:
LOG.exception("Unexpected exception during processing"
"reply message")
except BaseException:
LOG.exception("Unexpected exception during reply polling")
message.acknowledge()
future = self._reply_waiting_futures.pop(
message.msg_id, None
)
if future is not None:
future.set_result(message)
except Exception:
LOG.exception("Unexpected exception during processing"
"reply message")
def register_reply_waiter(self, msg_id):
"""Register reply waiter. Should be called before message sending to
@ -140,9 +120,4 @@ class RpcReplyPikaListener(object):
self._reply_poller.cleanup()
self._reply_poller = None
if self._poller_thread:
if self._poller_thread.is_alive():
self._poller_thread.join()
self._poller_thread = None
self._reply_queue = None

View File

@ -13,11 +13,9 @@
# under the License.
import threading
import time
from oslo_log import log as logging
from oslo_utils import timeutils
import six
from oslo_service import loopingcall
from oslo_messaging._drivers import base
from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns
@ -27,45 +25,188 @@ from oslo_messaging._drivers.pika_driver import pika_message as pika_drv_msg
LOG = logging.getLogger(__name__)
class PikaPoller(base.PollStyleListener):
class PikaPoller(base.Listener):
"""Provides user friendly functionality for RabbitMQ message consuming,
handles low level connectivity problems and restore connection if some
connectivity related problem detected
"""
def __init__(self, pika_engine, prefetch_count, incoming_message_class):
def __init__(self, pika_engine, batch_size, batch_timeout, prefetch_count,
incoming_message_class):
"""Initialize required fields
:param pika_engine: PikaEngine, shared object with configuration and
shared driver functionality
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer
:param incoming_message_class: PikaIncomingMessage, wrapper for
consumed RabbitMQ message
"""
super(PikaPoller, self).__init__(prefetch_count)
super(PikaPoller, self).__init__(batch_size, batch_timeout,
prefetch_count)
self._pika_engine = pika_engine
self._incoming_message_class = incoming_message_class
self._connection = None
self._channel = None
self._lock = threading.Lock()
self._recover_loopingcall = None
self._lock = threading.RLock()
self._cur_batch_buffer = None
self._cur_batch_timeout_id = None
self._started = False
self._closing_connection_by_poller = False
self._queues_to_consume = None
self._message_queue = []
def _on_connection_close(self, connection, reply_code, reply_text):
self._deliver_cur_batch()
if self._closing_connection_by_poller:
return
with self._lock:
self._connection = None
self._start_recover_consuming_task()
def _reconnect(self):
def _on_channel_close(self, channel, reply_code, reply_text):
if self._cur_batch_buffer:
self._cur_batch_buffer = [
message for message in self._cur_batch_buffer
if not message.need_ack()
]
if self._closing_connection_by_poller:
return
with self._lock:
self._channel = None
self._start_recover_consuming_task()
def _on_consumer_cancel(self, method_frame):
with self._lock:
if self._queues_to_consume:
consumer_tag = method_frame.method.consumer_tag
for queue_info in self._queues_to_consume:
if queue_info["consumer_tag"] == consumer_tag:
queue_info["consumer_tag"] = None
self._start_recover_consuming_task()
def _on_message_no_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=True mode
"""
incoming_message = self._incoming_message_class(
self._pika_engine, None, method, properties, body
)
self._on_incoming_message(incoming_message)
def _on_message_with_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=False mode
"""
incoming_message = self._incoming_message_class(
self._pika_engine, self._channel, method, properties, body
)
self._on_incoming_message(incoming_message)
def _deliver_cur_batch(self):
if self._cur_batch_timeout_id is not None:
self._connection.remove_timeout(self._cur_batch_timeout_id)
self._cur_batch_timeout_id = None
if self._cur_batch_buffer:
buf_to_send = self._cur_batch_buffer
self._cur_batch_buffer = None
try:
self.on_incoming_callback(buf_to_send)
except Exception:
LOG.exception("Unexpected exception during incoming delivery")
def _on_incoming_message(self, incoming_message):
if self._cur_batch_buffer is None:
self._cur_batch_buffer = [incoming_message]
else:
self._cur_batch_buffer.append(incoming_message)
if len(self._cur_batch_buffer) >= self.batch_size:
self._deliver_cur_batch()
return
if self._cur_batch_timeout_id is None:
self._cur_batch_timeout_id = self._connection.add_timeout(
self.batch_timeout, self._deliver_cur_batch)
def _start_recover_consuming_task(self):
"""Start async job for checking connection to the broker."""
if self._recover_loopingcall is None and self._started:
self._recover_loopingcall = (
loopingcall.DynamicLoopingCall(
self._try_recover_consuming
)
)
LOG.info("Starting recover consuming job for listener: %s", self)
self._recover_loopingcall.start()
def _try_recover_consuming(self):
with self._lock:
try:
if self._started:
self._start_or_recover_consuming()
except pika_drv_exc.EstablishConnectionException as e:
LOG.warning(
"Problem during establishing connection for pika "
"poller %s", e, exc_info=True
)
return self._pika_engine.host_connection_reconnect_delay
except pika_drv_exc.ConnectionException as e:
LOG.warning(
"Connectivity exception during starting/recovering pika "
"poller %s", e, exc_info=True
)
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as e:
LOG.warning(
"Connectivity exception during starting/recovering pika "
"poller %s", e, exc_info=True
)
except BaseException:
# NOTE (dukhlov): I preffer to use here BaseException because
# if this method raise such exception LoopingCall stops
# execution Probably it should never happen and Exception
# should be enough but in case of programmer mistake it could
# be and it is potentially hard to catch problem if we will
# stop background task. It is better when it continue to work
# and write a lot of LOG with this error
LOG.exception("Unexpected exception during "
"starting/recovering pika poller")
else:
self._recover_loopingcall = None
LOG.info("Recover consuming job was finished for listener: %s",
self)
raise loopingcall.LoopingCallDone(True)
return 0
def _start_or_recover_consuming(self):
"""Performs reconnection to the broker. It is unsafe method for
internal use only
"""
self._connection = self._pika_engine.create_connection(
for_listening=True
)
self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=self.prefetch_size)
if self._connection is None or not self._connection.is_open:
self._connection = self._pika_engine.create_connection(
for_listening=True
)
self._connection.add_on_close_callback(self._on_connection_close)
self._channel = None
if self._channel is None or not self._channel.is_open:
if self._queues_to_consume:
for queue_info in self._queues_to_consume:
queue_info["consumer_tag"] = None
self._channel = self._connection.channel()
self._channel.add_on_close_callback(self._on_channel_close)
self._channel.add_on_cancel_callback(self._on_consumer_cancel)
self._channel.basic_qos(prefetch_count=self.prefetch_size)
if self._queues_to_consume is None:
self._queues_to_consume = self._declare_queue_binding()
@ -92,6 +233,8 @@ class PikaPoller(base.PollStyleListener):
try:
for queue_info in self._queues_to_consume:
if queue_info["consumer_tag"] is not None:
continue
no_ack = queue_info["no_ack"]
on_message_callback = (
@ -120,168 +263,95 @@ class PikaPoller(base.PollStyleListener):
self._channel.basic_cancel(consumer_tag)
queue_info["consumer_tag"] = None
def _on_message_no_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=True mode
"""
self._message_queue.append(
self._incoming_message_class(
self._pika_engine, None, method, properties, body
)
)
def _on_message_with_ack_callback(self, unused, method, properties, body):
"""Is called by Pika when message was received from queue listened with
no_ack=False mode
"""
self._message_queue.append(
self._incoming_message_class(
self._pika_engine, self._channel, method, properties, body
)
)
def _cleanup(self):
"""Cleanup allocated resources (channel, connection, etc). It is unsafe
method for internal use only
"""
if self._connection:
try:
self._connection.close()
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS:
# expected errors
pass
except Exception:
LOG.exception("Unexpected error during closing connection")
finally:
self._channel = None
self._connection = None
for i in six.moves.range(len(self._message_queue) - 1, -1, -1):
message = self._message_queue[i]
if message.need_ack():
del self._message_queue[i]
@base.batch_poll_helper
def poll(self, timeout=None):
"""Main method of this class - consumes message from RabbitMQ
:param: timeout: float, seconds, timeout for waiting new incoming
message, None means wait forever
:return: list of PikaIncomingMessage, RabbitMQ messages
"""
with timeutils.StopWatch(timeout) as stop_watch:
while True:
with self._lock:
if self._message_queue:
return self._message_queue.pop(0)
if stop_watch.expired():
return None
try:
if self._started:
if self._channel is None:
self._reconnect()
# we need some time_limit here, not too small to
# avoid a lot of not needed iterations but not too
# large to release lock time to time and give a
# chance to perform another method waiting this
# lock
self._connection.process_data_events(
time_limit=0.25
)
else:
# consumer is stopped so we don't expect new
# messages, just process already sent events
if self._channel is not None:
self._connection.process_data_events(
time_limit=0
)
# and return if we don't see new messages
if not self._message_queue:
return None
except pika_drv_exc.EstablishConnectionException as e:
LOG.warning(
"Problem during establishing connection for pika "
"poller %s", e, exc_info=True
)
time.sleep(
self._pika_engine.host_connection_reconnect_delay
)
except pika_drv_exc.ConnectionException:
self._cleanup()
raise
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS:
self._cleanup()
raise
def start(self):
def start(self, on_incoming_callback):
"""Starts poller. Should be called before polling to allow message
consuming
:param on_incoming_callback: callback function to be executed when
listener received messages. Messages should be processed and
acked/nacked by callback
"""
super(PikaPoller, self).start(on_incoming_callback)
with self._lock:
if self._started:
return
connected = False
try:
self._reconnect()
self._start_or_recover_consuming()
except pika_drv_exc.EstablishConnectionException as exc:
LOG.warning(
"Can not establish connection during pika poller's "
"start(). Connecting is required during first poll() "
"call. %s", exc, exc_info=True
"start(). %s", exc, exc_info=True
)
except pika_drv_exc.ConnectionException as exc:
self._cleanup()
LOG.warning(
"Connectivity problem during pika poller's start(). "
"Reconnecting required during first poll() call. %s",
"Connectivity problem during pika poller's start(). %s",
exc, exc_info=True
)
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc:
self._cleanup()
LOG.warning(
"Connectivity problem during pika poller's start(). "
"Reconnecting required during first poll() call. %s",
"Connectivity problem during pika poller's start(). %s",
exc, exc_info=True
)
else:
connected = True
self._started = True
if not connected:
self._start_recover_consuming_task()
def stop(self):
"""Stops poller. Should be called when polling is not needed anymore to
stop new message consuming. After that it is necessary to poll already
prefetched messages
"""
super(PikaPoller, self).stop()
with self._lock:
if not self._started:
return
if self._queues_to_consume and self._channel:
if self._recover_loopingcall is not None:
self._recover_loopingcall.stop()
self._recover_loopingcall = None
if (self._queues_to_consume and self._channel and
self._channel.is_open):
try:
self._stop_consuming()
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS as exc:
self._cleanup()
LOG.warning(
"Connectivity problem detected during consumer "
"cancellation. %s", exc, exc_info=True
)
self._deliver_cur_batch()
self._started = False
def cleanup(self):
"""Safe version of _cleanup. Cleans up allocated resources (channel,
connection, etc).
"""
"""Cleanup allocated resources (channel, connection, etc)."""
with self._lock:
self._cleanup()
if self._connection and self._connection.is_open:
try:
self._closing_connection_by_poller = True
self._connection.close()
self._closing_connection_by_poller = False
except pika_drv_cmns.PIKA_CONNECTIVITY_ERRORS:
# expected errors
pass
except Exception:
LOG.exception("Unexpected error during closing connection")
finally:
self._channel = None
self._connection = None
class RpcServicePikaPoller(PikaPoller):
"""PikaPoller implementation for polling RPC messages. Overrides base
functionality according to RPC specific
"""
def __init__(self, pika_engine, target, prefetch_count):
def __init__(self, pika_engine, target, batch_size, batch_timeout,
prefetch_count):
"""Adds target parameter for declaring RPC specific exchanges and
queues
@ -289,14 +359,18 @@ class RpcServicePikaPoller(PikaPoller):
shared driver functionality
:param target: Target, oslo.messaging Target object which defines RPC
endpoint
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer
"""
self._target = target
super(RpcServicePikaPoller, self).__init__(
pika_engine, prefetch_count=prefetch_count,
incoming_message_class=pika_drv_msg.RpcPikaIncomingMessage
pika_engine, batch_size, batch_timeout, prefetch_count,
pika_drv_msg.RpcPikaIncomingMessage
)
def _declare_queue_binding(self):
@ -363,7 +437,8 @@ class RpcReplyPikaPoller(PikaPoller):
"""PikaPoller implementation for polling RPC reply messages. Overrides
base functionality according to RPC reply specific
"""
def __init__(self, pika_engine, exchange, queue, prefetch_count):
def __init__(self, pika_engine, exchange, queue, batch_size, batch_timeout,
prefetch_count):
"""Adds exchange and queue parameter for declaring exchange and queue
used for RPC reply delivery
@ -371,6 +446,10 @@ class RpcReplyPikaPoller(PikaPoller):
shared driver functionality
:param exchange: String, exchange name used for RPC reply delivery
:param queue: String, queue name used for RPC reply delivery
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer
"""
@ -378,8 +457,8 @@ class RpcReplyPikaPoller(PikaPoller):
self._queue = queue
super(RpcReplyPikaPoller, self).__init__(
pika_engine=pika_engine, prefetch_count=prefetch_count,
incoming_message_class=pika_drv_msg.RpcReplyPikaIncomingMessage
pika_engine, batch_size, batch_timeout, prefetch_count,
pika_drv_msg.RpcReplyPikaIncomingMessage
)
def _declare_queue_binding(self):
@ -404,8 +483,8 @@ class NotificationPikaPoller(PikaPoller):
"""PikaPoller implementation for polling Notification messages. Overrides
base functionality according to Notification specific
"""
def __init__(self, pika_engine, targets_and_priorities, prefetch_count,
queue_name=None):
def __init__(self, pika_engine, targets_and_priorities,
batch_size, batch_timeout, prefetch_count, queue_name=None):
"""Adds targets_and_priorities and queue_name parameter
for declaring exchanges and queues used for notification delivery
@ -413,6 +492,10 @@ class NotificationPikaPoller(PikaPoller):
shared driver functionality
:param targets_and_priorities: list of (target, priority), defines
default queue names for corresponding notification types
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_count: Integer, maximum count of unacknowledged
messages which RabbitMQ broker sends to this consumer
:param queue: String, alternative queue name used for this poller
@ -422,8 +505,8 @@ class NotificationPikaPoller(PikaPoller):
self._queue_name = queue_name
super(NotificationPikaPoller, self).__init__(
pika_engine, prefetch_count=prefetch_count,
incoming_message_class=pika_drv_msg.PikaIncomingMessage
pika_engine, batch_size, batch_timeout, prefetch_count,
pika_drv_msg.PikaIncomingMessage
)
def _declare_queue_binding(self):
@ -442,7 +525,7 @@ class NotificationPikaPoller(PikaPoller):
target.exchange or
self._pika_engine.default_notification_exchange
),
queue = queue,
queue=queue,
routing_key=routing_key,
exchange_type='direct',
queue_expiration=None,

View File

@ -13,9 +13,11 @@
# under the License.
import socket
import threading
import time
import unittest
from concurrent import futures
import mock
from oslo_messaging._drivers.pika_driver import pika_poller
@ -32,26 +34,53 @@ class PikaPollerTestCase(unittest.TestCase):
self._pika_engine.create_connection.return_value = (
self._poller_connection_mock
)
self._executor = futures.ThreadPoolExecutor(1)
def timer_task(timeout, callback):
time.sleep(timeout)
callback()
self._poller_connection_mock.add_timeout.side_effect = (
lambda *args: self._executor.submit(timer_task, *args)
)
self._prefetch_count = 123
def test_start_when_connection_unavailable(self):
incoming_message_class_mock = mock.Mock()
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_start(self, declare_queue_binding_mock):
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
self._pika_engine, 1, None, self._prefetch_count, None
)
poller.start(None)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
def test_start_when_connection_unavailable(self):
poller = pika_poller.PikaPoller(
self._pika_engine, 1, None, self._prefetch_count, None
)
self._pika_engine.create_connection.side_effect = socket.timeout()
# start() should not raise socket.timeout exception
poller.start()
poller.start(None)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll(self, declare_queue_binding_mock):
def test_message_processing(self, declare_queue_binding_mock):
res = []
def on_incoming_callback(incoming):
res.append(incoming)
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
self._pika_engine, 1, None, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
unused = object()
@ -59,18 +88,14 @@ class PikaPollerTestCase(unittest.TestCase):
properties = object()
body = object()
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
unused, method, properties, body
)
poller.start(on_incoming_callback)
poller._on_message_with_ack_callback(
unused, method, properties, body
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value)
self.assertEqual(res[0], [incoming_message_class_mock.return_value])
incoming_message_class_mock.assert_called_once_with(
self._pika_engine, self._poller_channel_mock, method, properties,
body
@ -83,92 +108,39 @@ class PikaPollerTestCase(unittest.TestCase):
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_after_stop(self, declare_queue_binding_mock):
def test_message_processing_batch(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10
params = []
for i in range(n):
params.append((object(), object(), object(), object()))
res = []
def f(time_limit):
if poller._started:
for k in range(n):
poller._on_message_no_ack_callback(
*params[k]
)
def on_incoming_callback(incoming):
res.append(incoming)
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(batch_size=1)
self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[0][0],
(self._pika_engine, None) + params[0][1:]
)
poller.stop()
res2 = poller.poll(batch_size=n)
self.assertEqual(len(res2), n - 1)
self.assertEqual(incoming_message_class_mock.call_count, n)
self.assertEqual(
self._poller_connection_mock.process_data_events.call_count, 2)
for i in range(n - 1):
self.assertEqual(res2[i], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i + 1][0],
(self._pika_engine, None) + params[i + 1][1:]
)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_batch(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
self._pika_engine, n, None, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10
params = []
for i in range(n):
params.append((object(), object(), object(), object()))
index = [0]
poller.start(on_incoming_callback)
def f(time_limit):
for i in range(n):
poller._on_message_with_ack_callback(
*params[index[0]]
*params[i]
)
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(batch_size=n)
self.assertEqual(len(res), n)
self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), 10)
self.assertEqual(incoming_message_class_mock.call_count, n)
for i in range(n):
self.assertEqual(res[i], incoming_message_class_mock.return_value)
self.assertEqual(res[0][i],
incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i][0],
(self._pika_engine, self._poller_channel_mock) + params[i][1:]
@ -181,42 +153,48 @@ class PikaPollerTestCase(unittest.TestCase):
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_batch_with_timeout(self, declare_queue_binding_mock):
def test_message_processing_batch_with_timeout(self,
declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10
timeout = 1
sleep_time = 0.2
res = []
evt = threading.Event()
def on_incoming_callback(incoming):
res.append(incoming)
evt.set()
poller = pika_poller.PikaPoller(
self._pika_engine, n, timeout, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
params = []
success_count = 5
poller.start(on_incoming_callback)
for i in range(n):
params.append((object(), object(), object(), object()))
index = [0]
def f(time_limit):
time.sleep(sleep_time)
for i in range(success_count):
poller._on_message_with_ack_callback(
*params[index[0]]
*params[i]
)
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f
self.assertTrue(evt.wait(timeout * 2))
poller.start()
res = poller.poll(batch_size=n, timeout=timeout)
self.assertEqual(len(res), success_count)
self.assertEqual(len(res), 1)
self.assertEqual(len(res[0]), success_count)
self.assertEqual(incoming_message_class_mock.call_count, success_count)
for i in range(success_count):
self.assertEqual(res[i], incoming_message_class_mock.return_value)
self.assertEqual(res[0][i],
incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i][0],
(self._pika_engine, self._poller_channel_mock) + params[i][1:]
@ -258,20 +236,11 @@ class RpcServicePikaPollerTestCase(unittest.TestCase):
"RpcPikaIncomingMessage")
def test_declare_rpc_queue_bindings(self, rpc_pika_incoming_message_mock):
poller = pika_poller.RpcServicePikaPoller(
self._pika_engine, self._target, self._prefetch_count,
)
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
self._pika_engine, self._target, 1, None,
self._prefetch_count
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], rpc_pika_incoming_message_mock.return_value)
poller.start(None)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
@ -357,24 +326,14 @@ class RpcReplyServicePikaPollerTestCase(unittest.TestCase):
self._pika_engine.rpc_queue_expiration = 12345
self._pika_engine.rpc_reply_retry_attempts = 3
def test_start(self):
poller = pika_poller.RpcReplyPikaPoller(
self._pika_engine, self._exchange, self._queue,
self._prefetch_count,
)
poller.start()
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
def test_declare_rpc_reply_queue_binding(self):
poller = pika_poller.RpcReplyPikaPoller(
self._pika_engine, self._exchange, self._queue,
self._pika_engine, self._exchange, self._queue, 1, None,
self._prefetch_count,
)
poller.start()
poller.start(None)
poller.stop()
declare_queue_binding_by_channel_mock = (
self._pika_engine.declare_queue_binding_by_channel
@ -419,26 +378,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase):
)
self._pika_engine.notification_persistence = object()
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message."
"PikaIncomingMessage")
def test_declare_notification_queue_bindings_default_queue(
self, pika_incoming_message_mock):
def test_declare_notification_queue_bindings_default_queue(self):
poller = pika_poller.NotificationPikaPoller(
self._pika_engine, self._target_and_priorities,
self._pika_engine, self._target_and_priorities, 1, None,
self._prefetch_count, None
)
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], pika_incoming_message_mock.return_value)
poller.start(None)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
@ -481,26 +427,13 @@ class NotificationPikaPollerTestCase(unittest.TestCase):
)
))
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message."
"PikaIncomingMessage")
def test_declare_notification_queue_bindings_custom_queue(
self, pika_incoming_message_mock):
def test_declare_notification_queue_bindings_custom_queue(self):
poller = pika_poller.NotificationPikaPoller(
self._pika_engine, self._target_and_priorities,
self._pika_engine, self._target_and_priorities, 1, None,
self._prefetch_count, "custom_queue_name"
)
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], pika_incoming_message_mock.return_value)
poller.start(None)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)