Refactor driver's listener interface

Current Listener interface has poll() method which return messages

To use it we need have poller thread which is located in MessageHandlerServer
But my investigations of existing driver's code shows that some implemetations have
its own thread inside for processing connection event loop. This event loop received
messages and store in queue object. And then our poller's thread reads this queue
This situation can be improved. we can remove poller's thread, remove queue object
and just call on_message server's callback from connection eventloop thread

This path provide posibility to do this for one of drivers and leave as is other drivers

Change-Id: I3e3d4369d8fdadcecf079d10af58b1e4f5616047
This commit is contained in:
Dmitriy Ukhlov 2016-04-02 14:58:29 +03:00
parent ee394d3c5b
commit 5d7d7253d1
21 changed files with 325 additions and 223 deletions

View File

@ -176,7 +176,7 @@ class ObsoleteReplyQueuesCache(object):
'msg_id': msg_id})
class AMQPListener(base.Listener):
class AMQPListener(base.PollStyleListener):
def __init__(self, driver, conn):
super(AMQPListener, self).__init__(driver.prefetch_size)
@ -473,7 +473,7 @@ class AMQPDriverBase(base.BaseDriver):
return self._send(target, ctxt, message,
envelope=(version == 2.0), notify=True, retry=retry)
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
conn = self._get_connection(rpc_common.PURPOSE_LISTEN)
listener = AMQPListener(self, conn)
@ -487,9 +487,12 @@ class AMQPDriverBase(base.BaseDriver):
callback=listener)
conn.declare_fanout_consumer(target.topic, listener)
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
conn = self._get_connection(rpc_common.PURPOSE_LISTEN)
listener = AMQPListener(self, conn)
@ -498,7 +501,8 @@ class AMQPDriverBase(base.BaseDriver):
exchange_name=self._get_exchange(target),
topic='%s.%s' % (target.topic, priority),
callback=listener, queue_name=pool)
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self):
if self._connection_pool:

View File

@ -14,12 +14,12 @@
# under the License.
import abc
import time
import threading
from oslo_config import cfg
from oslo_utils import excutils
from oslo_utils import timeutils
import six
from six.moves import range as compat_range
from oslo_messaging import exceptions
@ -38,21 +38,33 @@ def batch_poll_helper(func):
This decorator helps driver that polls message one by one,
to returns a list of message.
"""
def wrapper(in_self, timeout=None, prefetch_size=1):
def wrapper(in_self, timeout=None, batch_size=1, batch_timeout=None):
incomings = []
driver_prefetch = in_self.prefetch_size
if driver_prefetch > 0:
prefetch_size = min(prefetch_size, driver_prefetch)
watch = timeutils.StopWatch(duration=timeout)
with watch:
for __ in compat_range(prefetch_size):
msg = func(in_self, timeout=watch.leftover(return_none=True))
batch_size = min(batch_size, driver_prefetch)
with timeutils.StopWatch(timeout) as timeout_watch:
# poll first message
msg = func(in_self, timeout=timeout_watch.leftover(True))
if msg is not None:
incomings.append(msg)
if batch_size == 1 or msg is None:
return incomings
# update batch_timeout according to timeout for whole operation
timeout_left = timeout_watch.leftover(True)
if timeout_left is not None and (
batch_timeout is None or timeout_left < batch_timeout):
batch_timeout = timeout_left
with timeutils.StopWatch(batch_timeout) as batch_timeout_watch:
# poll remained batch messages
while len(incomings) < batch_size and msg is not None:
msg = func(in_self, timeout=batch_timeout_watch.leftover(True))
if msg is not None:
incomings.append(msg)
else:
# timeout reached or listener stopped
break
time.sleep(0)
return incomings
return wrapper
@ -81,20 +93,22 @@ class RpcIncomingMessage(IncomingMessage):
@abc.abstractmethod
def reply(self, reply=None, failure=None, log_failure=True):
"Send a reply or failure back to the client."
"""Send a reply or failure back to the client."""
@six.add_metaclass(abc.ABCMeta)
class Listener(object):
class PollStyleListener(object):
def __init__(self, prefetch_size=-1):
self.prefetch_size = prefetch_size
@abc.abstractmethod
def poll(self, timeout=None, prefetch_size=1):
"""Blocking until 'prefetch_size' message is pending and return
def poll(self, timeout=None, batch_size=1, batch_timeout=None):
"""Blocking until 'batch_size' message is pending and return
[IncomingMessage].
Return None after timeout seconds if timeout is set and no message is
ending or if the listener have been stopped.
Waits for first message. Then waits for next batch_size-1 messages
during batch window defined by batch_timeout
This method block current thread until message comes, stop() is
executed by another thread or timemout is elapsed.
"""
def stop(self):
@ -112,6 +126,113 @@ class Listener(object):
pass
@six.add_metaclass(abc.ABCMeta)
class Listener(object):
def __init__(self, on_incoming_callback, batch_size, batch_timeout,
prefetch_size=-1):
"""Init Listener
:param on_incoming_callback: callback function to be executed when
listener received messages. Messages should be processed and
acked/nacked by callback
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_size: defines how many massages we want to prefetch
from backend (depend on driver type) by single request
"""
self.on_incoming_callback = on_incoming_callback
self.batch_timeout = batch_timeout
self.prefetch_size = prefetch_size
if prefetch_size > 0:
batch_size = min(batch_size, prefetch_size)
self.batch_size = batch_size
@abc.abstractmethod
def start(self):
"""Stop listener.
Stop the listener message polling
"""
@abc.abstractmethod
def wait(self):
"""Wait listener.
Wait for processing remained input after listener Stop
"""
@abc.abstractmethod
def stop(self):
"""Stop listener.
Stop the listener message polling
"""
@abc.abstractmethod
def cleanup(self):
"""Cleanup listener.
Close connection (socket) used by listener if any.
As this is listener specific method, overwrite it in to derived class
if cleanup of listener required.
"""
class PollStyleListenerAdapter(Listener):
def __init__(self, poll_style_listener, on_incoming_callback, batch_size,
batch_timeout):
super(PollStyleListenerAdapter, self).__init__(
on_incoming_callback, batch_size, batch_timeout,
poll_style_listener.prefetch_size
)
self._poll_style_listener = poll_style_listener
self._listen_thread = threading.Thread(target=self._runner)
self._listen_thread.daemon = True
self._started = False
def start(self):
"""Start listener.
Start the listener message polling
"""
self._started = True
self._listen_thread.start()
@excutils.forever_retry_uncaught_exceptions
def _runner(self):
while self._started:
incoming = self._poll_style_listener.poll(
batch_size=self.batch_size, batch_timeout=self.batch_timeout)
if incoming:
self.on_incoming_callback(incoming)
# listener is stopped but we need to process all already consumed
# messages
while True:
incoming = self._poll_style_listener.poll(
batch_size=self.batch_size, batch_timeout=self.batch_timeout)
if not incoming:
return
self.on_incoming_callback(incoming)
def stop(self):
"""Stop listener.
Stop the listener message polling
"""
self._started = False
self._poll_style_listener.stop()
def wait(self):
self._listen_thread.join()
def cleanup(self):
"""Cleanup listener.
Close connection (socket) used by listener if any.
As this is listener specific method, overwrite it in to derived class
if cleanup of listener required.
"""
self._poll_style_listener.cleanup()
@six.add_metaclass(abc.ABCMeta)
class BaseDriver(object):
prefetch_size = 0
@ -138,11 +259,13 @@ class BaseDriver(object):
"""Send a notification message to the given target."""
@abc.abstractmethod
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
"""Construct a Listener for the given target."""
@abc.abstractmethod
def listen_for_notifications(self, targets_and_priorities, pool):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
"""Construct a notification Listener for the given list of
tuple of (target, priority).
"""

View File

@ -39,7 +39,7 @@ class FakeIncomingMessage(base.RpcIncomingMessage):
self.requeue_callback()
class FakeListener(base.Listener):
class FakeListener(base.PollStyleListener):
def __init__(self, exchange_manager, targets, pool=None):
super(FakeListener, self).__init__()
@ -222,7 +222,7 @@ class FakeDriver(base.BaseDriver):
# transport always works
self._send(target, ctxt, message)
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
exchange = target.exchange or self._default_exchange
listener = FakeListener(self._exchange_manager,
[oslo_messaging.Target(
@ -232,9 +232,12 @@ class FakeDriver(base.BaseDriver):
oslo_messaging.Target(
topic=target.topic,
exchange=exchange)])
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
targets = [
oslo_messaging.Target(
topic='%s.%s' % (target.topic, priority),
@ -242,7 +245,8 @@ class FakeDriver(base.BaseDriver):
for target, priority in targets_and_priorities]
listener = FakeListener(self._exchange_manager, targets, pool)
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self):
pass

View File

@ -247,7 +247,7 @@ class OsloKafkaMessage(base.RpcIncomingMessage):
LOG.warning(_LW("reply is not supported"))
class KafkaListener(base.Listener):
class KafkaListener(base.PollStyleListener):
def __init__(self, conn):
super(KafkaListener, self).__init__()
@ -342,7 +342,9 @@ class KafkaDriver(base.BaseDriver):
raise NotImplementedError(
'The RPC implementation for Kafka is not implemented')
def listen_for_notifications(self, targets_and_priorities, pool=None):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
"""Listen to a specified list of targets on Kafka brokers
:param targets_and_priorities: List of pairs (target, priority)
@ -361,7 +363,8 @@ class KafkaDriver(base.BaseDriver):
conn.declare_topic_consumer(topics, pool)
listener = KafkaListener(conn)
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def _get_connection(self, purpose):
return driver_common.ConnectionContext(self.connection_pool, purpose)

View File

@ -334,15 +334,18 @@ class PikaDriver(base.BaseDriver):
retrier=retrier
)
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
listener = pika_drv_poller.RpcServicePikaPoller(
self._pika_engine, target,
prefetch_count=self._pika_engine.rpc_listener_prefetch_count
)
listener.start()
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback,
batch_size, batch_timeout):
listener = pika_drv_poller.NotificationPikaPoller(
self._pika_engine, targets_and_priorities,
prefetch_count=(
@ -351,7 +354,8 @@ class PikaDriver(base.BaseDriver):
queue_name=pool
)
listener.start()
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self):
self._reply_listener.cleanup()

View File

@ -251,15 +251,20 @@ class ZmqDriver(base.BaseDriver):
client = self.notifier.get()
client.send_notify(target, ctxt, message, version, retry)
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
"""Listen to a specified target on a server side
:param target: Message destination target
:type target: oslo_messaging.Target
"""
return zmq_server.ZmqServer(self, self.conf, self.matchmaker, target)
listener = zmq_server.ZmqServer(self, self.conf, self.matchmaker,
target)
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
"""Listen to a specified list of targets on a server side
:param targets_and_priorities: List of pairs (target, priority)
@ -267,8 +272,10 @@ class ZmqDriver(base.BaseDriver):
:param pool: Not used for zmq implementation
:type pool: object
"""
return zmq_server.ZmqNotificationServer(
listener = zmq_server.ZmqNotificationServer(
self, self.conf, self.matchmaker, targets_and_priorities)
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self):
"""Cleanup all driver's connections finally

View File

@ -27,7 +27,7 @@ from oslo_messaging._drivers.pika_driver import pika_message as pika_drv_msg
LOG = logging.getLogger(__name__)
class PikaPoller(base.Listener):
class PikaPoller(base.PollStyleListener):
"""Provides user friendly functionality for RabbitMQ message consuming,
handles low level connectivity problems and restore connection if some
connectivity related problem detected
@ -43,8 +43,8 @@ class PikaPoller(base.Listener):
:param incoming_message_class: PikaIncomingMessage, wrapper for
consumed RabbitMQ message
"""
super(PikaPoller, self).__init__(prefetch_count)
self._pika_engine = pika_engine
self._prefetch_count = prefetch_count
self._incoming_message_class = incoming_message_class
self._connection = None
@ -65,7 +65,7 @@ class PikaPoller(base.Listener):
for_listening=True
)
self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=self._prefetch_count)
self._channel.basic_qos(prefetch_count=self.prefetch_size)
if self._queues_to_consume is None:
self._queues_to_consume = self._declare_queue_binding()
@ -161,27 +161,23 @@ class PikaPoller(base.Listener):
if message.need_ack():
del self._message_queue[i]
def poll(self, timeout=None, prefetch_size=1):
@base.batch_poll_helper
def poll(self, timeout=None):
"""Main method of this class - consumes message from RabbitMQ
:param: timeout: float, seconds, timeout for waiting new incoming
message, None means wait forever
:param: prefetch_size: Integer, count of messages which we are want to
poll. It blocks until prefetch_size messages are consumed or until
timeout gets expired
:return: list of PikaIncomingMessage, RabbitMQ messages
"""
with timeutils.StopWatch(timeout) as stop_watch:
while True:
with self._lock:
last_queue_size = len(self._message_queue)
if self._message_queue:
return self._message_queue.pop(0)
if (last_queue_size >= prefetch_size
or stop_watch.expired()):
result = self._message_queue[:prefetch_size]
del self._message_queue[:prefetch_size]
return result
if stop_watch.expired():
return None
try:
if self._started:
@ -202,11 +198,10 @@ class PikaPoller(base.Listener):
self._connection.process_data_events(
time_limit=0
)
# and return result if we don't see new messages
if last_queue_size == len(self._message_queue):
result = self._message_queue[:prefetch_size]
del self._message_queue[:prefetch_size]
return result
# and return if we don't see new messages
if not self._message_queue:
return None
except pika_drv_exc.EstablishConnectionException as e:
LOG.warning(
"Problem during establishing connection for pika "

View File

@ -145,7 +145,7 @@ class Queue(object):
self._pop_wake_condition.notify_all()
class ProtonListener(base.Listener):
class ProtonListener(base.PollStyleListener):
def __init__(self, driver):
super(ProtonListener, self).__init__(driver.prefetch_size)
self.driver = driver
@ -266,15 +266,19 @@ class ProtonDriver(base.BaseDriver):
return self.send(target, ctxt, message, envelope=(version == 2.0))
@_ensure_connect_called
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
"""Construct a Listener for the given target."""
LOG.debug("Listen to %s", target)
listener = ProtonListener(self)
self._ctrl.add_task(drivertasks.ListenTask(target, listener))
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
return listener
@_ensure_connect_called
def listen_for_notifications(self, targets_and_priorities, pool):
def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
LOG.debug("Listen for notifications %s", targets_and_priorities)
if pool:
raise NotImplementedError('"pool" not implemented by '
@ -284,7 +288,8 @@ class ProtonDriver(base.BaseDriver):
topic = '%s.%s' % (target.topic, priority)
t = messaging_target.Target(topic=topic)
self._ctrl.add_task(drivertasks.ListenTask(t, listener, True))
return listener
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self):
"""Release all resources."""

View File

@ -28,7 +28,7 @@ LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq()
class ZmqServer(base.Listener):
class ZmqServer(base.PollStyleListener):
def __init__(self, driver, conf, matchmaker, target, poller=None):
super(ZmqServer, self).__init__()
@ -47,7 +47,7 @@ class ZmqServer(base.Listener):
self.consumers.append(self.sub_consumer)
@base.batch_poll_helper
def poll(self, timeout=None, prefetch_size=1):
def poll(self, timeout=None):
message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout)
return message
@ -63,7 +63,7 @@ class ZmqServer(base.Listener):
consumer.cleanup()
class ZmqNotificationServer(base.Listener):
class ZmqNotificationServer(base.PollStyleListener):
def __init__(self, driver, conf, matchmaker, targets_and_priorities):
super(ZmqNotificationServer, self).__init__()
@ -82,7 +82,7 @@ class ZmqNotificationServer(base.Listener):
self.driver, self.conf, self.matchmaker, t, self.poller))
@base.batch_poll_helper
def poll(self, timeout=None, prefetch_size=1):
def poll(self, timeout=None):
message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout)
return message

View File

@ -127,10 +127,9 @@ class NotificationServer(msg_server.MessageHandlingServer):
)
def _create_listener(self):
return msg_server.SingleMessageListenerAdapter(
self.transport._listen_for_notifications(
self._targets_priorities, self._pool
)
return self.transport._listen_for_notifications(
self._targets_priorities, self._pool,
lambda incoming: self._on_incoming(incoming[0]), 1, None
)
def _process_incoming(self, incoming):
@ -163,12 +162,9 @@ class BatchNotificationServer(NotificationServer):
self._batch_timeout = batch_timeout
def _create_listener(self):
return msg_server.BatchMessageListenerAdapter(
self.transport._listen_for_notifications(
self._targets_priorities, self._pool
),
timeout=self._batch_timeout,
batch_size=self._batch_size
return self.transport._listen_for_notifications(
self._targets_priorities, self._pool, self._on_incoming,
self._batch_size, self._batch_timeout,
)
def _process_incoming(self, incoming):

View File

@ -118,8 +118,9 @@ class RPCServer(msg_server.MessageHandlingServer):
self._target = target
def _create_listener(self):
return msg_server.SingleMessageListenerAdapter(
self.transport._listen(self._target)
return self.transport._listen(
self._target,
lambda incoming: self._on_incoming(incoming[0]), 1, None
)
def _process_incoming(self, incoming):

View File

@ -33,7 +33,6 @@ import traceback
from oslo_config import cfg
from oslo_service import service
from oslo_utils import eventletutils
from oslo_utils import excutils
from oslo_utils import timeutils
import six
from stevedore import driver
@ -297,41 +296,6 @@ def ordered(after=None, reset_after=None):
return _ordered
@six.add_metaclass(abc.ABCMeta)
class MessageListenerAdapter(object):
def __init__(self, driver_listener, timeout=None):
self._driver_listener = driver_listener
self._timeout = timeout
@abc.abstractmethod
def poll(self):
"""Poll incoming and return incoming request"""
def stop(self):
self._driver_listener.stop()
def cleanup(self):
self._driver_listener.cleanup()
class SingleMessageListenerAdapter(MessageListenerAdapter):
def poll(self):
msgs = self._driver_listener.poll(prefetch_size=1,
timeout=self._timeout)
return msgs[0] if msgs else None
class BatchMessageListenerAdapter(MessageListenerAdapter):
def __init__(self, driver_listener, timeout=None, batch_size=1):
super(BatchMessageListenerAdapter, self).__init__(driver_listener,
timeout)
self._batch_size = batch_size
def poll(self):
return self._driver_listener.poll(prefetch_size=self._batch_size,
timeout=self._timeout)
@six.add_metaclass(abc.ABCMeta)
class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
"""Server for handling messages.
@ -377,15 +341,21 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
self._executor_cls = mgr.driver
self._work_executor = None
self._poll_executor = None
self._started = False
super(MessageHandlingServer, self).__init__()
def _on_incoming(self, incoming):
"""Hanles on_incoming event
:param incoming: incoming request.
"""
self._work_executor.submit(self._process_incoming, incoming)
@abc.abstractmethod
def _process_incoming(self, incoming):
"""Process incoming request
"""Perform processing incoming request
:param incoming: incoming request.
"""
@ -420,11 +390,6 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
'instantiate a new object.'))
self._started = True
try:
self.listener = self._create_listener()
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
executor_opts = {}
if self.executor_type == "threading":
@ -440,9 +405,13 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
)
self._work_executor = self._executor_cls(**executor_opts)
self._poll_executor = self._executor_cls(**executor_opts)
return lambda: self._poll_executor.submit(self._runner)
try:
self.listener = self._create_listener()
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
return self.listener.start
@ordered(after='start')
def stop(self):
@ -456,24 +425,6 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
self.listener.stop()
self._started = False
@excutils.forever_retry_uncaught_exceptions
def _runner(self):
while self._started:
incoming = self.listener.poll()
if incoming:
self._work_executor.submit(self._process_incoming, incoming)
# listener is stopped but we need to process all already consumed
# messages
while True:
incoming = self.listener.poll()
if incoming:
self._work_executor.submit(self._process_incoming, incoming)
else:
return
@ordered(after='stop')
def wait(self):
"""Wait for message processing to complete.
@ -485,7 +436,7 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections).
"""
self._poll_executor.shutdown(wait=True)
self.listener.wait()
self._work_executor.shutdown(wait=True)
# Close listener connection after processing all messages

View File

@ -106,7 +106,7 @@ class PikaPollerTestCase(unittest.TestCase):
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(prefetch_size=1)
res = poller.poll(batch_size=1)
self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value)
self.assertEqual(
@ -116,7 +116,7 @@ class PikaPollerTestCase(unittest.TestCase):
poller.stop()
res2 = poller.poll(prefetch_size=n)
res2 = poller.poll(batch_size=n)
self.assertEqual(len(res2), n - 1)
self.assertEqual(incoming_message_class_mock.call_count, n)
@ -162,7 +162,7 @@ class PikaPollerTestCase(unittest.TestCase):
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(prefetch_size=n)
res = poller.poll(batch_size=n)
self.assertEqual(len(res), n)
self.assertEqual(incoming_message_class_mock.call_count, n)
@ -210,7 +210,7 @@ class PikaPollerTestCase(unittest.TestCase):
self._poller_connection_mock.process_data_events.side_effect = f
poller.start()
res = poller.poll(prefetch_size=n, timeout=timeout)
res = poller.poll(batch_size=n, timeout=timeout)
self.assertEqual(len(res), success_count)
self.assertEqual(incoming_message_class_mock.call_count, success_count)

View File

@ -203,7 +203,8 @@ class TestKafkaListener(test_utils.BaseTestCase):
def test_create_listener(self, fake_consumer, fake_ensure_connection):
fake_target = oslo_messaging.Target(topic='fake_topic')
fake_targets_and_priorities = [(fake_target, 'info')]
self.driver.listen_for_notifications(fake_targets_and_priorities)
self.driver.listen_for_notifications(fake_targets_and_priorities, None,
None, None, None)
self.assertEqual(1, len(fake_consumer.mock_calls))
@mock.patch.object(kafka_driver.Connection, '_ensure_connection')
@ -220,7 +221,8 @@ class TestKafkaListener(test_utils.BaseTestCase):
(oslo_messaging.Target(topic="fake_topic",
exchange="test3"), 'error'),
]
self.driver.listen_for_notifications(fake_targets_and_priorities)
self.driver.listen_for_notifications(fake_targets_and_priorities, None,
None, None, None)
self.assertEqual(1, len(fake_consumer.mock_calls))
fake_consumer.assert_called_once_with(set(['fake_topic.error',
'fake_topic.info']),
@ -232,7 +234,8 @@ class TestKafkaListener(test_utils.BaseTestCase):
fake_target = oslo_messaging.Target(topic='fake_topic')
fake_targets_and_priorities = [(fake_target, 'info')]
listener = self.driver.listen_for_notifications(
fake_targets_and_priorities)
fake_targets_and_priorities, None, None, None,
None)._poll_style_listener
listener.conn.consume = mock.MagicMock()
listener.conn.consume.return_value = (
iter([kafka.common.KafkaMessage(
@ -264,7 +267,8 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
targets_and_priorities = [(target, 'fake_info')]
listener = self.driver.listen_for_notifications(
targets_and_priorities)
targets_and_priorities, None, None, None,
None)._poll_style_listener
fake_context = {"fake_context_key": "fake_context_value"}
fake_message = {"fake_message_key": "fake_message_value"}
self.driver.send_notification(
@ -281,7 +285,8 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
targets_and_priorities = [(target, 'fake_info')]
listener = self.driver.listen_for_notifications(
targets_and_priorities)
targets_and_priorities, None, None, None,
None)._poll_style_listener
fake_context = {"fake_context_key": "fake_context_value"}
fake_message = {"fake_message_key": "fake_message_value"}
self.driver.send_notification(
@ -299,9 +304,10 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
targets_and_priorities = [(target, 'fake_info')]
listener = self.driver.listen_for_notifications(
targets_and_priorities)
targets_and_priorities, None, None, None,
None)._poll_style_listener
deadline = time.time() + 3
received_message = listener.poll(timeout=3)
received_message = listener.poll(batch_timeout=3)
self.assertEqual(0, int(deadline - time.time()))
self.assertEqual([], received_message)

View File

@ -435,7 +435,7 @@ class TestSendReceive(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target)
listener = driver.listen(target, None, None, None)._poll_style_listener
senders = []
replies = []
@ -525,7 +525,7 @@ class TestPollAsync(test_utils.BaseTestCase):
self.addCleanup(transport.cleanup)
driver = transport._driver
target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target)
listener = driver.listen(target, None, None, None)._poll_style_listener
received = listener.poll(timeout=0.050)
self.assertEqual([], received)
@ -541,8 +541,7 @@ class TestRacyWaitForReply(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target)
listener = driver.listen(target, None, None, None)._poll_style_listener
senders = []
replies = []
msgs = []
@ -878,7 +877,7 @@ class TestReplyWireFormat(test_utils.BaseTestCase):
server=self.server,
fanout=self.fanout)
listener = driver.listen(target)
listener = driver.listen(target, None, None, None)._poll_style_listener
connection, producer = _create_producer(target)
self.addCleanup(connection.release)

View File

@ -42,7 +42,7 @@ class ZmqTestPortsRange(zmq_common.ZmqBaseTestCase):
for i in range(10):
try:
target = oslo_messaging.Target(topic='testtopic_' + str(i))
new_listener = self.driver.listen(target)
new_listener = self.driver.listen(target, None, None, None)
listeners.append(new_listener)
except zmq_socket.ZmqPortRangeExceededException:
pass

View File

@ -39,12 +39,14 @@ class TestServerListener(object):
self.message = None
def listen(self, target):
self.listener = self.driver.listen(target)
self.listener = self.driver.listen(target, None, None,
None)._poll_style_listener
self.executor.execute()
def listen_notifications(self, targets_and_priorities):
self.listener = self.driver.listen_for_notifications(
targets_and_priorities, {})
targets_and_priorities, None, None, None,
None)._poll_style_listener
self.executor.execute()
def _run(self):

View File

@ -29,7 +29,7 @@ load_tests = testscenarios.load_tests_apply_scenarios
class ServerSetupMixin(object):
class Server(threading.Thread):
class Server(object):
def __init__(self, transport, topic, server, endpoint, serializer):
self.controller = ServerSetupMixin.ServerController()
target = oslo_messaging.Target(topic=topic, server=server)
@ -39,9 +39,6 @@ class ServerSetupMixin(object):
self.controller],
serializer=serializer)
super(ServerSetupMixin.Server, self).__init__()
self.daemon = True
def wait(self):
# Wait for the executor to process the stop message, indicating all
# test messages have been processed
@ -52,7 +49,7 @@ class ServerSetupMixin(object):
self.server.stop()
self.server.wait()
def run(self):
def start(self):
self.server.start()
class ServerController(object):
@ -86,10 +83,7 @@ class ServerSetupMixin(object):
endpoint=endpoint,
serializer=self.serializer)
thread = threading.Thread(target=server.start)
thread.daemon = True
thread.start()
server.start()
return server
def _stop_server(self, client, server, topic=None):
@ -492,9 +486,9 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
else:
endpoint1 = endpoint2 = TestEndpoint()
thread1 = self._setup_server(transport1, endpoint1,
server1 = self._setup_server(transport1, endpoint1,
topic=self.topic1, server=self.server1)
thread2 = self._setup_server(transport2, endpoint2,
server2 = self._setup_server(transport2, endpoint2,
topic=self.topic2, server=self.server2)
client1 = self._setup_client(transport1, topic=self.topic1)
@ -513,12 +507,10 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
(client1.call if self.call1 else client1.cast)({}, 'ping', arg='1')
(client2.call if self.call2 else client2.cast)({}, 'ping', arg='2')
self.assertTrue(thread1.isAlive())
self._stop_server(client1.prepare(fanout=None),
thread1, topic=self.topic1)
self.assertTrue(thread2.isAlive())
server1, topic=self.topic1)
self._stop_server(client2.prepare(fanout=None),
thread2, topic=self.topic2)
server2, topic=self.topic2)
def check(pings, expect):
self.assertEqual(len(expect), len(pings))
@ -560,14 +552,13 @@ class TestServerLocking(test_utils.BaseTestCase):
class MessageHandlingServerImpl(oslo_messaging.MessageHandlingServer):
def _create_listener(self):
pass
return mock.Mock()
def _process_incoming(self, incoming):
pass
self.server = MessageHandlingServerImpl(mock.Mock(), mock.Mock())
self.server._executor_cls = FakeExecutor
self.server._create_listener = mock.Mock()
def test_start_stop_wait(self):
# Test a simple execution of start, stop, wait in order
@ -576,9 +567,8 @@ class TestServerLocking(test_utils.BaseTestCase):
self.server.stop()
self.server.wait()
self.assertEqual(len(self.executors), 2)
self.assertEqual(len(self.executors), 1)
self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertTrue(self.server.listener.cleanup.called)
def test_reversed_order(self):
@ -597,9 +587,8 @@ class TestServerLocking(test_utils.BaseTestCase):
self.server.wait()
self.assertEqual(len(self.executors), 2)
self.assertEqual(len(self.executors), 1)
self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
def test_wait_for_running_task(self):
# Test that if 2 threads call a method simultaneously, both will wait,
@ -660,9 +649,8 @@ class TestServerLocking(test_utils.BaseTestCase):
# Check that both threads have finished, start was only called once,
# and execute ran
self.assertTrue(waiter_finished.is_set())
self.assertEqual(2, len(self.executors))
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, [])
self.assertEqual(self.executors[1]._calls, ['submit'])
def test_start_stop_wait_stop_wait(self):
# Test that we behave correctly when calling stop/wait more than once.
@ -674,9 +662,8 @@ class TestServerLocking(test_utils.BaseTestCase):
self.server.stop()
self.server.wait()
self.assertEqual(len(self.executors), 2)
self.assertEqual(len(self.executors), 1)
self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertTrue(self.server.listener.cleanup.called)
def test_state_wrapping(self):
@ -711,9 +698,8 @@ class TestServerLocking(test_utils.BaseTestCase):
complete_waiting_callback.wait()
# The server should have started, but stop should not have been called
self.assertEqual(2, len(self.executors))
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, [])
self.assertEqual(self.executors[1]._calls, ['submit'])
self.assertFalse(thread1_finished.is_set())
self.server.stop()
@ -721,20 +707,17 @@ class TestServerLocking(test_utils.BaseTestCase):
# We should have gone through all the states, and thread1 should still
# be waiting
self.assertEqual(2, len(self.executors))
self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertFalse(thread1_finished.is_set())
# Start again
self.server.start()
# We should now record 4 executors (2 for each server)
self.assertEqual(4, len(self.executors))
self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertEqual(self.executors[2]._calls, [])
self.assertEqual(self.executors[3]._calls, ['submit'])
self.assertEqual(self.executors[1]._calls, [])
self.assertFalse(thread1_finished.is_set())
# Allow thread1 to complete
@ -743,11 +726,9 @@ class TestServerLocking(test_utils.BaseTestCase):
# thread1 should now have finished, and stop should not have been
# called again on either the first or second executor
self.assertEqual(4, len(self.executors))
self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertEqual(self.executors[2]._calls, [])
self.assertEqual(self.executors[3]._calls, ['submit'])
self.assertEqual(self.executors[1]._calls, [])
self.assertTrue(thread1_finished.is_set())
@mock.patch.object(server_module, 'DEFAULT_LOG_AFTER', 1)

View File

@ -131,14 +131,15 @@ class TestAmqpSend(_AmqpBrokerTestCase):
"""Verify unused listener can cleanly shutdown."""
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="test-topic")
listener = driver.listen(target)
listener = driver.listen(target, None, None, None)._poll_style_listener
self.assertIsInstance(listener, amqp_driver.ProtonListener)
driver.cleanup()
def test_send_no_reply(self):
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True},
{"msg": "value"}, wait_for_reply=False)
self.assertIsNone(rc)
@ -150,9 +151,11 @@ class TestAmqpSend(_AmqpBrokerTestCase):
def test_send_exchange_with_reply(self):
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target1 = oslo_messaging.Target(topic="test-topic", exchange="e1")
listener1 = _ListenerThread(driver.listen(target1), 1)
listener1 = _ListenerThread(
driver.listen(target1, None, None, None)._poll_style_listener, 1)
target2 = oslo_messaging.Target(topic="test-topic", exchange="e2")
listener2 = _ListenerThread(driver.listen(target2), 1)
listener2 = _ListenerThread(
driver.listen(target2, None, None, None)._poll_style_listener, 1)
rc = driver.send(target1, {"context": "whatever"},
{"method": "echo", "id": "e1"},
@ -178,9 +181,11 @@ class TestAmqpSend(_AmqpBrokerTestCase):
"""Verify the direct, shared, and fanout message patterns work."""
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target1 = oslo_messaging.Target(topic="test-topic", server="server1")
listener1 = _ListenerThread(driver.listen(target1), 4)
listener1 = _ListenerThread(
driver.listen(target1, None, None, None)._poll_style_listener, 4)
target2 = oslo_messaging.Target(topic="test-topic", server="server2")
listener2 = _ListenerThread(driver.listen(target2), 3)
listener2 = _ListenerThread(
driver.listen(target2, None, None, None)._poll_style_listener, 3)
shared_target = oslo_messaging.Target(topic="test-topic")
fanout_target = oslo_messaging.Target(topic="test-topic",
@ -250,7 +255,8 @@ class TestAmqpSend(_AmqpBrokerTestCase):
"""Verify send timeout."""
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
# the listener will drop this message:
try:
@ -276,7 +282,8 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
notifications = [(oslo_messaging.Target(topic="topic-1"), 'info'),
(oslo_messaging.Target(topic="topic-1"), 'error'),
(oslo_messaging.Target(topic="topic-2"), 'debug')]
nl = driver.listen_for_notifications(notifications, None)
nl = driver.listen_for_notifications(
notifications, None, None, None, None)._poll_style_listener
# send one for each support version:
msg_count = len(notifications) * 2
@ -340,7 +347,8 @@ class TestAuthentication(test_utils.BaseTestCase):
url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc)
@ -358,7 +366,8 @@ class TestAuthentication(test_utils.BaseTestCase):
url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1)
_ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
self.assertRaises(oslo_messaging.MessagingTimeout,
driver.send,
target, {"context": True},
@ -429,7 +438,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc)
@ -447,7 +457,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1)
_ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
self.assertRaises(oslo_messaging.MessagingTimeout,
driver.send,
target, {"context": True},
@ -467,7 +478,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1)
_ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
self.assertRaises(oslo_messaging.MessagingTimeout,
driver.send,
target, {"context": True},
@ -487,7 +499,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1)
listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc)
@ -522,7 +535,8 @@ class TestFailover(test_utils.BaseTestCase):
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="my-topic")
listener = _ListenerThread(driver.listen(target), 2)
listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 2)
# wait for listener links to come up
# 4 == 3 links per listener + 1 for the global reply queue
@ -608,8 +622,10 @@ class TestFailover(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic="my-topic")
bcast = oslo_messaging.Target(topic="my-topic", fanout=True)
listener1 = _ListenerThread(driver.listen(target), 2)
listener2 = _ListenerThread(driver.listen(target), 2)
listener1 = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 2)
listener2 = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 2)
# wait for 7 sending links to become active on the broker.
# 7 = 3 links per Listener + 1 global reply link

View File

@ -38,7 +38,7 @@ class _FakeDriver(object):
def send_notification(self, *args, **kwargs):
pass
def listen(self, target):
def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
pass
@ -314,10 +314,10 @@ class TestTransportMethodArgs(test_utils.BaseTestCase):
t = transport.Transport(_FakeDriver(cfg.CONF))
self.mox.StubOutWithMock(t._driver, 'listen')
t._driver.listen(self._target)
t._driver.listen(self._target, None, 1, None)
self.mox.ReplayAll()
t._listen(self._target)
t._listen(self._target, None, 1, None)
class TestTransportUrlCustomisation(test_utils.BaseTestCase):

View File

@ -96,21 +96,26 @@ class Transport(object):
self._driver.send_notification(target, ctxt, message, version,
retry=retry)
def _listen(self, target):
def _listen(self, target, on_incoming_callback, batch_size, batch_timeout):
if not (target.topic and target.server):
raise exceptions.InvalidTarget('A server\'s target must have '
'topic and server names specified',
target)
return self._driver.listen(target)
return self._driver.listen(target, on_incoming_callback, batch_size,
batch_timeout)
def _listen_for_notifications(self, targets_and_priorities, pool):
def _listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
for target, priority in targets_and_priorities:
if not target.topic:
raise exceptions.InvalidTarget('A target must have '
'topic specified',
target)
return self._driver.listen_for_notifications(
targets_and_priorities, pool)
targets_and_priorities, pool, on_incoming_callback, batch_size,
batch_timeout
)
def cleanup(self):
"""Release all resources associated with this transport."""