Adds unit tests for pika_poll module

Change-Id: I69cc0e0302382ab45ba464bb5993300d44679106
This commit is contained in:
Dmitriy Ukhlov 2015-12-14 11:36:28 +02:00 committed by dukhlov
parent 5149461fd2
commit 83a08d4b7e
5 changed files with 605 additions and 74 deletions

View File

@ -200,6 +200,9 @@ class PikaEngine(object):
self._connection_host_param_list = [] self._connection_host_param_list = []
self._connection_host_status_list = [] self._connection_host_status_list = []
if not url.hosts:
raise ValueError("You should provide at least one RabbitMQ host")
for transport_host in url.hosts: for transport_host in url.hosts:
pika_params = common_pika_params.copy() pika_params = common_pika_params.copy()
pika_params.update( pika_params.update(

View File

@ -72,18 +72,17 @@ class PikaIncomingMessage(object):
information from RabbitMQ message and provide access to it information from RabbitMQ message and provide access to it
""" """
def __init__(self, pika_engine, channel, method, properties, body, no_ack): def __init__(self, pika_engine, channel, method, properties, body):
"""Parse RabbitMQ message """Parse RabbitMQ message
:param pika_engine: PikaEngine, shared object with configuration and :param pika_engine: PikaEngine, shared object with configuration and
shared driver functionality shared driver functionality
:param channel: Channel, RabbitMQ channel which was used for :param channel: Channel, RabbitMQ channel which was used for
this message delivery this message delivery, used for sending ack back.
If None - ack is not required
:param method: Method, RabbitMQ message method :param method: Method, RabbitMQ message method
:param properties: Properties, RabbitMQ message properties :param properties: Properties, RabbitMQ message properties
:param body: Bytes, RabbitMQ message body :param body: Bytes, RabbitMQ message body
:param no_ack: Boolean, defines should this message be acked by
consumer or not
""" """
headers = getattr(properties, "headers", {}) headers = getattr(properties, "headers", {})
version = headers.get(_VERSION_HEADER, None) version = headers.get(_VERSION_HEADER, None)
@ -93,7 +92,6 @@ class PikaIncomingMessage(object):
"{}".format(version, _VERSION)) "{}".format(version, _VERSION))
self._pika_engine = pika_engine self._pika_engine = pika_engine
self._no_ack = no_ack
self._channel = channel self._channel = channel
self._delivery_tag = method.delivery_tag self._delivery_tag = method.delivery_tag
@ -128,12 +126,15 @@ class PikaIncomingMessage(object):
self.message = message_dict self.message = message_dict
self.ctxt = context_dict self.ctxt = context_dict
def need_ack(self):
return self._channel is not None
def acknowledge(self): def acknowledge(self):
"""Ack the message. Should be called by message processing logic when """Ack the message. Should be called by message processing logic when
it considered as consumed (means that we don't need redelivery of this it considered as consumed (means that we don't need redelivery of this
message anymore) message anymore)
""" """
if not self._no_ack: if self.need_ack():
self._channel.basic_ack(delivery_tag=self._delivery_tag) self._channel.basic_ack(delivery_tag=self._delivery_tag)
def requeue(self): def requeue(self):
@ -141,7 +142,7 @@ class PikaIncomingMessage(object):
when it can not process the message right now and should be redelivered when it can not process the message right now and should be redelivered
later if it is possible later if it is possible
""" """
if not self._no_ack: if self.need_ack():
return self._channel.basic_nack(delivery_tag=self._delivery_tag, return self._channel.basic_nack(delivery_tag=self._delivery_tag,
requeue=True) requeue=True)
@ -152,22 +153,21 @@ class RpcPikaIncomingMessage(PikaIncomingMessage):
method added to allow consumer to send RPC reply back to the RPC client method added to allow consumer to send RPC reply back to the RPC client
""" """
def __init__(self, pika_engine, channel, method, properties, body, no_ack): def __init__(self, pika_engine, channel, method, properties, body):
"""Defines default values of msg_id and reply_q fields and just call """Defines default values of msg_id and reply_q fields and just call
super.__init__ method super.__init__ method
:param pika_engine: PikaEngine, shared object with configuration and :param pika_engine: PikaEngine, shared object with configuration and
shared driver functionality shared driver functionality
:param channel: Channel, RabbitMQ channel which was used for :param channel: Channel, RabbitMQ channel which was used for
this message delivery this message delivery, used for sending ack back.
If None - ack is not required
:param method: Method, RabbitMQ message method :param method: Method, RabbitMQ message method
:param properties: Properties, RabbitMQ message properties :param properties: Properties, RabbitMQ message properties
:param body: Bytes, RabbitMQ message body :param body: Bytes, RabbitMQ message body
:param no_ack: Boolean, defines should this message be acked by
consumer or not
""" """
super(RpcPikaIncomingMessage, self).__init__( super(RpcPikaIncomingMessage, self).__init__(
pika_engine, channel, method, properties, body, no_ack pika_engine, channel, method, properties, body
) )
self.reply_q = properties.reply_to self.reply_q = properties.reply_to
self.msg_id = properties.correlation_id self.msg_id = properties.correlation_id
@ -231,7 +231,7 @@ class RpcReplyPikaIncomingMessage(PikaIncomingMessage):
"""PikaIncomingMessage implementation for RPC reply messages. It expects """PikaIncomingMessage implementation for RPC reply messages. It expects
extra RPC reply related fields in message body (result and failure). extra RPC reply related fields in message body (result and failure).
""" """
def __init__(self, pika_engine, channel, method, properties, body, no_ack): def __init__(self, pika_engine, channel, method, properties, body):
"""Defines default values of result and failure fields, call """Defines default values of result and failure fields, call
super.__init__ method and then construct Exception object if failure is super.__init__ method and then construct Exception object if failure is
not None not None
@ -239,15 +239,14 @@ class RpcReplyPikaIncomingMessage(PikaIncomingMessage):
:param pika_engine: PikaEngine, shared object with configuration and :param pika_engine: PikaEngine, shared object with configuration and
shared driver functionality shared driver functionality
:param channel: Channel, RabbitMQ channel which was used for :param channel: Channel, RabbitMQ channel which was used for
this message delivery this message delivery, used for sending ack back.
If None - ack is not required
:param method: Method, RabbitMQ message method :param method: Method, RabbitMQ message method
:param properties: Properties, RabbitMQ message properties :param properties: Properties, RabbitMQ message properties
:param body: Bytes, RabbitMQ message body :param body: Bytes, RabbitMQ message body
:param no_ack: Boolean, defines should this message be acked by
consumer or not
""" """
super(RpcReplyPikaIncomingMessage, self).__init__( super(RpcReplyPikaIncomingMessage, self).__init__(
pika_engine, channel, method, properties, body, no_ack pika_engine, channel, method, properties, body
) )
self.msg_id = properties.correlation_id self.msg_id = properties.correlation_id

View File

@ -31,8 +31,7 @@ class PikaPoller(object):
connectivity related problem detected connectivity related problem detected
""" """
def __init__(self, pika_engine, prefetch_count, def __init__(self, pika_engine, prefetch_count, incoming_message_class):
incoming_message_class=pika_drv_msg.PikaIncomingMessage):
"""Initialize required fields """Initialize required fields
:param pika_engine: PikaEngine, shared object with configuration and :param pika_engine: PikaEngine, shared object with configuration and
@ -110,8 +109,7 @@ class PikaPoller(object):
""" """
self._message_queue.append( self._message_queue.append(
self._incoming_message_class( self._incoming_message_class(
self._pika_engine, self._channel, method, properties, body, self._pika_engine, None, method, properties, body
True
) )
) )
@ -121,8 +119,7 @@ class PikaPoller(object):
""" """
self._message_queue.append( self._message_queue.append(
self._incoming_message_class( self._incoming_message_class(
self._pika_engine, self._channel, method, properties, body, self._pika_engine, self._channel, method, properties, body
False
) )
) )
@ -146,6 +143,11 @@ class PikaPoller(object):
LOG.exception("Unexpected error during closing connection") LOG.exception("Unexpected error during closing connection")
self._connection = None self._connection = None
for i in xrange(len(self._message_queue) - 1, -1, -1):
message = self._message_queue[i]
if message.need_ack():
del self._message_queue[i]
def poll(self, timeout=None, prefetch_size=1): def poll(self, timeout=None, prefetch_size=1):
"""Main method of this class - consumes message from RabbitMQ """Main method of this class - consumes message from RabbitMQ
@ -158,32 +160,29 @@ class PikaPoller(object):
""" """
expiration_time = time.time() + timeout if timeout else None expiration_time = time.time() + timeout if timeout else None
while len(self._message_queue) < prefetch_size: while True:
with self._lock: with self._lock:
if not self._started: if timeout is not None:
return None timeout = expiration_time - time.time()
if (len(self._message_queue) < prefetch_size and
try: self._started and ((timeout is None) or timeout > 0)):
if self._channel is None: try:
self._reconnect() if self._channel is None:
# we need some time_limit here, not too small to avoid a self._reconnect()
# lot of not needed iterations but not too large to release # we need some time_limit here, not too small to avoid
# lock time to time and give a chance to perform another # a lot of not needed iterations but not too large to
# method waiting this lock # release lock time to time and give a chance to
self._connection.process_data_events( # perform another method waiting this lock
time_limit=0.25 self._connection.process_data_events(
) time_limit=0.25
except Exception as e: )
LOG.warn("Exception during consuming message. " + str(e)) except pika_pool.Connection.connectivity_errors:
self._cleanup() self._cleanup()
if timeout is not None: raise
timeout = expiration_time - time.time() else:
if timeout <= 0: result = self._message_queue[:prefetch_size]
break del self._message_queue[:prefetch_size]
return result
result = self._message_queue[:prefetch_size]
self._message_queue = self._message_queue[prefetch_size:]
return result
def start(self): def start(self):
"""Starts poller. Should be called before polling to allow message """Starts poller. Should be called before polling to allow message
@ -201,7 +200,6 @@ class PikaPoller(object):
return return
self._started = False self._started = False
self._cleanup()
def reconnect(self): def reconnect(self):
"""Safe version of _reconnect. Performs reconnection to the broker.""" """Safe version of _reconnect. Performs reconnection to the broker."""
@ -249,9 +247,7 @@ class RpcServicePikaPoller(PikaPoller):
:return Dictionary, declared_queue_name -> no_ack_mode :return Dictionary, declared_queue_name -> no_ack_mode
""" """
queue_expiration = ( queue_expiration = self._pika_engine.rpc_queue_expiration
self._pika_engine.conf.oslo_messaging_pika.rpc_queue_expiration
)
queues_to_consume = {} queues_to_consume = {}
@ -319,15 +315,11 @@ class RpcReplyPikaPoller(PikaPoller):
:return Dictionary, declared_queue_name -> no_ack_mode :return Dictionary, declared_queue_name -> no_ack_mode
""" """
queue_expiration = (
self._pika_engine.conf.oslo_messaging_pika.rpc_queue_expiration
)
self._pika_engine.declare_queue_binding_by_channel( self._pika_engine.declare_queue_binding_by_channel(
channel=self._channel, channel=self._channel,
exchange=self._exchange, queue=self._queue, exchange=self._exchange, queue=self._queue,
routing_key=self._queue, exchange_type='direct', routing_key=self._queue, exchange_type='direct',
queue_expiration=queue_expiration, queue_expiration=self._pika_engine.rpc_queue_expiration,
durable=False durable=False
) )
@ -363,8 +355,8 @@ class NotificationPikaPoller(PikaPoller):
""" """
def __init__(self, pika_engine, targets_and_priorities, def __init__(self, pika_engine, targets_and_priorities,
queue_name=None, prefetch_count=100): queue_name=None, prefetch_count=100):
"""Adds exchange and queue parameter for declaring exchange and queue """Adds targets_and_priorities and queue_name parameter
used for RPC reply delivery for declaring exchanges and queues used for notification delivery
:param pika_engine: PikaEngine, shared object with configuration and :param pika_engine: PikaEngine, shared object with configuration and
shared driver functionality shared driver functionality
@ -379,7 +371,8 @@ class NotificationPikaPoller(PikaPoller):
self._queue_name = queue_name self._queue_name = queue_name
super(NotificationPikaPoller, self).__init__( super(NotificationPikaPoller, self).__init__(
pika_engine, prefetch_count=prefetch_count pika_engine, prefetch_count=prefetch_count,
incoming_message_class=pika_drv_msg.PikaIncomingMessage
) )
def _declare_queue_binding(self): def _declare_queue_binding(self):

View File

@ -46,7 +46,7 @@ class PikaIncomingMessageTestCase(unittest.TestCase):
def test_message_body_parsing(self): def test_message_body_parsing(self):
message = pika_drv_msg.PikaIncomingMessage( message = pika_drv_msg.PikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, True self._body
) )
self.assertEqual(message.ctxt.get("key_context", None), self.assertEqual(message.ctxt.get("key_context", None),
@ -57,7 +57,7 @@ class PikaIncomingMessageTestCase(unittest.TestCase):
def test_message_acknowledge(self): def test_message_acknowledge(self):
message = pika_drv_msg.PikaIncomingMessage( message = pika_drv_msg.PikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, False self._body
) )
message.acknowledge() message.acknowledge()
@ -68,8 +68,8 @@ class PikaIncomingMessageTestCase(unittest.TestCase):
def test_message_acknowledge_no_ack(self): def test_message_acknowledge_no_ack(self):
message = pika_drv_msg.PikaIncomingMessage( message = pika_drv_msg.PikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, None, self._method, self._properties,
self._body, True self._body
) )
message.acknowledge() message.acknowledge()
@ -79,7 +79,7 @@ class PikaIncomingMessageTestCase(unittest.TestCase):
def test_message_requeue(self): def test_message_requeue(self):
message = pika_drv_msg.PikaIncomingMessage( message = pika_drv_msg.PikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, False self._body
) )
message.requeue() message.requeue()
@ -90,8 +90,8 @@ class PikaIncomingMessageTestCase(unittest.TestCase):
def test_message_requeue_no_ack(self): def test_message_requeue_no_ack(self):
message = pika_drv_msg.PikaIncomingMessage( message = pika_drv_msg.PikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, None, self._method, self._properties,
self._body, True self._body
) )
message.requeue() message.requeue()
@ -126,7 +126,7 @@ class RpcPikaIncomingMessageTestCase(unittest.TestCase):
message = pika_drv_msg.RpcPikaIncomingMessage( message = pika_drv_msg.RpcPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, True self._body
) )
self.assertEqual(message.ctxt.get("key_context", None), self.assertEqual(message.ctxt.get("key_context", None),
@ -140,7 +140,7 @@ class RpcPikaIncomingMessageTestCase(unittest.TestCase):
def test_cast_message_body_parsing(self): def test_cast_message_body_parsing(self):
message = pika_drv_msg.RpcPikaIncomingMessage( message = pika_drv_msg.RpcPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, True self._body
) )
self.assertEqual(message.ctxt.get("key_context", None), self.assertEqual(message.ctxt.get("key_context", None),
@ -156,7 +156,7 @@ class RpcPikaIncomingMessageTestCase(unittest.TestCase):
def test_reply_for_cast_message(self, send_reply_mock): def test_reply_for_cast_message(self, send_reply_mock):
message = pika_drv_msg.RpcPikaIncomingMessage( message = pika_drv_msg.RpcPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, True self._body
) )
self.assertEqual(message.ctxt.get("key_context", None), self.assertEqual(message.ctxt.get("key_context", None),
@ -182,7 +182,7 @@ class RpcPikaIncomingMessageTestCase(unittest.TestCase):
message = pika_drv_msg.RpcPikaIncomingMessage( message = pika_drv_msg.RpcPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, True self._body
) )
self.assertEqual(message.ctxt.get("key_context", None), self.assertEqual(message.ctxt.get("key_context", None),
@ -218,7 +218,7 @@ class RpcPikaIncomingMessageTestCase(unittest.TestCase):
message = pika_drv_msg.RpcPikaIncomingMessage( message = pika_drv_msg.RpcPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
self._body, True self._body
) )
self.assertEqual(message.ctxt.get("key_context", None), self.assertEqual(message.ctxt.get("key_context", None),
@ -274,7 +274,7 @@ class RpcReplyPikaIncomingMessageTestCase(unittest.TestCase):
message = pika_drv_msg.RpcReplyPikaIncomingMessage( message = pika_drv_msg.RpcReplyPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
body, True body
) )
self.assertEqual(message.msg_id, 123456789) self.assertEqual(message.msg_id, 123456789)
@ -294,7 +294,7 @@ class RpcReplyPikaIncomingMessageTestCase(unittest.TestCase):
message = pika_drv_msg.RpcReplyPikaIncomingMessage( message = pika_drv_msg.RpcReplyPikaIncomingMessage(
self._pika_engine, self._channel, self._method, self._properties, self._pika_engine, self._channel, self._method, self._properties,
body, True body
) )
self.assertEqual(message.msg_id, 123456789) self.assertEqual(message.msg_id, 123456789)

View File

@ -0,0 +1,536 @@
# Copyright 2015 Mirantis, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import time
import unittest
import mock
from oslo_messaging._drivers.pika_driver import pika_poller
class PikaPollerTestCase(unittest.TestCase):
def setUp(self):
self._pika_engine = mock.Mock()
self._poller_connection_mock = mock.Mock()
self._poller_channel_mock = mock.Mock()
self._poller_connection_mock.channel.return_value = (
self._poller_channel_mock
)
self._pika_engine.create_connection.return_value = (
self._poller_connection_mock
)
self._prefetch_count = 123
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
unused = object()
method = object()
properties = object()
body = object()
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
unused, method, properties, body
)
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value)
incoming_message_class_mock.assert_called_once_with(
self._pika_engine, self._poller_channel_mock, method, properties,
body
)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_after_stop(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10
params = []
for i in range(n):
params.append((object(), object(), object(), object()))
index = [0]
def f(time_limit):
for i in range(10):
poller._on_message_no_ack_callback(
*params[index[0]]
)
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(prefetch_size=1)
self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[0][0],
(self._pika_engine, None) + params[0][1:]
)
poller.stop()
res2 = poller.poll(prefetch_size=n)
self.assertEqual(len(res2), n-1)
self.assertEqual(incoming_message_class_mock.call_count, n)
self.assertEqual(
self._poller_connection_mock.process_data_events.call_count, 1)
for i in range(n-1):
self.assertEqual(res2[i], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i+1][0],
(self._pika_engine, None) + params[i+1][1:]
)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_batch(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10
params = []
for i in range(n):
params.append((object(), object(), object(), object()))
index = [0]
def f(time_limit):
poller._on_message_with_ack_callback(
*params[index[0]]
)
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(prefetch_size=n)
self.assertEqual(len(res), n)
self.assertEqual(incoming_message_class_mock.call_count, n)
for i in range(n):
self.assertEqual(res[i], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i][0],
(self._pika_engine, self._poller_channel_mock) + params[i][1:]
)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
@mock.patch("oslo_messaging._drivers.pika_driver.pika_poller.PikaPoller."
"_declare_queue_binding")
def test_poll_batch_with_timeout(self, declare_queue_binding_mock):
incoming_message_class_mock = mock.Mock()
poller = pika_poller.PikaPoller(
self._pika_engine, self._prefetch_count,
incoming_message_class=incoming_message_class_mock
)
n = 10
timeout = 1
sleep_time = 0.2
params = []
success_count = 5
for i in range(n):
params.append((object(), object(), object(), object()))
index = [0]
def f(time_limit):
time.sleep(sleep_time)
poller._on_message_with_ack_callback(
*params[index[0]]
)
index[0] += 1
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(prefetch_size=n, timeout=timeout)
self.assertEqual(len(res), success_count)
self.assertEqual(incoming_message_class_mock.call_count, success_count)
for i in range(success_count):
self.assertEqual(res[i], incoming_message_class_mock.return_value)
self.assertEqual(
incoming_message_class_mock.call_args_list[i][0],
(self._pika_engine, self._poller_channel_mock) + params[i][1:]
)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
self.assertTrue(declare_queue_binding_mock.called)
class RpcServicePikaPollerTestCase(unittest.TestCase):
def setUp(self):
self._pika_engine = mock.Mock()
self._poller_connection_mock = mock.Mock()
self._poller_channel_mock = mock.Mock()
self._poller_connection_mock.channel.return_value = (
self._poller_channel_mock
)
self._pika_engine.create_connection.return_value = (
self._poller_connection_mock
)
self._pika_engine.get_rpc_queue_name.side_effect = (
lambda topic, server, no_ack: "_".join(
[topic, str(server), str(no_ack)]
)
)
self._pika_engine.get_rpc_exchange_name.side_effect = (
lambda exchange, topic, fanout, no_ack: "_".join(
[exchange, topic, str(fanout), str(no_ack)]
)
)
self._prefetch_count = 123
self._target = mock.Mock(exchange="exchange", topic="topic",
server="server")
self._pika_engine.rpc_queue_expiration = 12345
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message."
"RpcPikaIncomingMessage")
def test_declare_rpc_queue_bindings(self, rpc_pika_incoming_message_mock):
poller = pika_poller.RpcServicePikaPoller(
self._pika_engine, self._target, self._prefetch_count,
)
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], rpc_pika_incoming_message_mock.return_value)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
declare_queue_binding_by_channel_mock = (
self._pika_engine.declare_queue_binding_by_channel
)
self.assertEqual(
declare_queue_binding_by_channel_mock.call_count, 6
)
declare_queue_binding_by_channel_mock.assert_has_calls((
mock.call(
channel=self._poller_channel_mock, durable=False,
exchange="exchange_topic_False_True",
exchange_type='direct',
queue="topic_None_True",
queue_expiration=12345,
routing_key="topic_None_True"
),
mock.call(
channel=self._poller_channel_mock, durable=False,
exchange="exchange_topic_False_True",
exchange_type='direct',
queue="topic_server_True",
queue_expiration=12345,
routing_key="topic_server_True"
),
mock.call(
channel=self._poller_channel_mock, durable=False,
exchange="exchange_topic_True_True",
exchange_type='fanout',
queue="topic_server_True",
queue_expiration=12345,
routing_key=''
),
mock.call(
channel=self._poller_channel_mock, durable=False,
exchange="exchange_topic_False_False",
exchange_type='direct',
queue="topic_None_False",
queue_expiration=12345,
routing_key="topic_None_False"
),
mock.call(
channel=self._poller_channel_mock, durable=False,
exchange="exchange_topic_False_False",
exchange_type='direct',
queue="topic_server_False",
queue_expiration=12345,
routing_key="topic_server_False"
),
mock.call(
channel=self._poller_channel_mock, durable=False,
exchange="exchange_topic_True_False",
exchange_type='fanout',
queue="topic_server_False",
queue_expiration=12345,
routing_key=''
),
))
class RpcReplyServicePikaPollerTestCase(unittest.TestCase):
def setUp(self):
self._pika_engine = mock.Mock()
self._poller_connection_mock = mock.Mock()
self._poller_channel_mock = mock.Mock()
self._poller_connection_mock.channel.return_value = (
self._poller_channel_mock
)
self._pika_engine.create_connection.return_value = (
self._poller_connection_mock
)
self._prefetch_count = 123
self._exchange = "rpc_reply_exchange"
self._queue = "rpc_reply_queue"
self._pika_engine.rpc_reply_retry_delay = 12132543456
self._pika_engine.rpc_queue_expiration = 12345
self._pika_engine.rpc_reply_retry_attempts = 3
def test_start(self):
poller = pika_poller.RpcReplyPikaPoller(
self._pika_engine, self._exchange, self._queue,
self._prefetch_count,
)
poller.start()
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
def test_declare_rpc_reply_queue_binding(self):
poller = pika_poller.RpcReplyPikaPoller(
self._pika_engine, self._exchange, self._queue,
self._prefetch_count,
)
poller.start()
declare_queue_binding_by_channel_mock = (
self._pika_engine.declare_queue_binding_by_channel
)
self.assertEqual(
declare_queue_binding_by_channel_mock.call_count, 1
)
declare_queue_binding_by_channel_mock.assert_called_once_with(
channel=self._poller_channel_mock, durable=False,
exchange='rpc_reply_exchange', exchange_type='direct',
queue='rpc_reply_queue', queue_expiration=12345,
routing_key='rpc_reply_queue'
)
class NotificationPikaPollerTestCase(unittest.TestCase):
def setUp(self):
self._pika_engine = mock.Mock()
self._poller_connection_mock = mock.Mock()
self._poller_channel_mock = mock.Mock()
self._poller_connection_mock.channel.return_value = (
self._poller_channel_mock
)
self._pika_engine.create_connection.return_value = (
self._poller_connection_mock
)
self._prefetch_count = 123
self._target_and_priorities = (
(
mock.Mock(exchange="exchange1", topic="topic1",
server="server1"), 1
),
(
mock.Mock(exchange="exchange1", topic="topic1"), 2
),
(
mock.Mock(exchange="exchange2", topic="topic2",), 1
),
)
self._pika_engine.notification_persistence = object()
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message."
"PikaIncomingMessage")
def test_declare_notification_queue_bindings_default_queue(
self, pika_incoming_message_mock):
poller = pika_poller.NotificationPikaPoller(
self._pika_engine, self._target_and_priorities, None,
self._prefetch_count,
)
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], pika_incoming_message_mock.return_value)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
declare_queue_binding_by_channel_mock = (
self._pika_engine.declare_queue_binding_by_channel
)
self.assertEqual(
declare_queue_binding_by_channel_mock.call_count, 3
)
declare_queue_binding_by_channel_mock.assert_has_calls((
mock.call(
channel=self._poller_channel_mock,
durable=self._pika_engine.notification_persistence,
exchange="exchange1",
exchange_type='direct',
queue="topic1.1",
queue_expiration=None,
routing_key="topic1.1"
),
mock.call(
channel=self._poller_channel_mock,
durable=self._pika_engine.notification_persistence,
exchange="exchange1",
exchange_type='direct',
queue="topic1.2",
queue_expiration=None,
routing_key="topic1.2"
),
mock.call(
channel=self._poller_channel_mock,
durable=self._pika_engine.notification_persistence,
exchange="exchange2",
exchange_type='direct',
queue="topic2.1",
queue_expiration=None,
routing_key="topic2.1"
)
))
@mock.patch("oslo_messaging._drivers.pika_driver.pika_message."
"PikaIncomingMessage")
def test_declare_notification_queue_bindings_custom_queue(
self, pika_incoming_message_mock):
poller = pika_poller.NotificationPikaPoller(
self._pika_engine, self._target_and_priorities,
"custom_queue_name", self._prefetch_count
)
self._poller_connection_mock.process_data_events.side_effect = (
lambda time_limit: poller._on_message_with_ack_callback(
None, None, None, None
)
)
poller.start()
res = poller.poll()
self.assertEqual(len(res), 1)
self.assertEqual(res[0], pika_incoming_message_mock.return_value)
self.assertTrue(self._pika_engine.create_connection.called)
self.assertTrue(self._poller_connection_mock.channel.called)
declare_queue_binding_by_channel_mock = (
self._pika_engine.declare_queue_binding_by_channel
)
self.assertEqual(
declare_queue_binding_by_channel_mock.call_count, 3
)
declare_queue_binding_by_channel_mock.assert_has_calls((
mock.call(
channel=self._poller_channel_mock,
durable=self._pika_engine.notification_persistence,
exchange="exchange1",
exchange_type='direct',
queue="custom_queue_name",
queue_expiration=None,
routing_key="topic1.1"
),
mock.call(
channel=self._poller_channel_mock,
durable=self._pika_engine.notification_persistence,
exchange="exchange1",
exchange_type='direct',
queue="custom_queue_name",
queue_expiration=None,
routing_key="topic1.2"
),
mock.call(
channel=self._poller_channel_mock,
durable=self._pika_engine.notification_persistence,
exchange="exchange2",
exchange_type='direct',
queue="custom_queue_name",
queue_expiration=None,
routing_key="topic2.1"
)
))