Merge "batch notification listener"

This commit is contained in:
Jenkins 2015-12-11 04:16:57 +00:00 committed by Gerrit Code Review
commit 213176657d
25 changed files with 450 additions and 127 deletions

View File

@ -203,6 +203,7 @@ class AMQPListener(base.Listener):
ctxt.reply_q, ctxt.reply_q,
self._obsolete_reply_queues)) self._obsolete_reply_queues))
@base.batch_poll_helper
def poll(self, timeout=None): def poll(self, timeout=None):
while not self._stopped.is_set(): while not self._stopped.is_set():
if self.incoming: if self.incoming:

View File

@ -15,9 +15,12 @@
import abc import abc
import six
from oslo_config import cfg from oslo_config import cfg
from oslo_utils import timeutils
import six
from six.moves import range as compat_range
from oslo_messaging import exceptions from oslo_messaging import exceptions
base_opts = [ base_opts = [
@ -28,6 +31,27 @@ base_opts = [
] ]
def batch_poll_helper(func):
"""Decorator to poll messages in batch
This decorator helps driver that polls message one by one,
to returns a list of message.
"""
def wrapper(in_self, timeout=None, prefetch_size=1):
incomings = []
watch = timeutils.StopWatch(duration=timeout)
with watch:
for __ in compat_range(prefetch_size):
msg = func(in_self, timeout=watch.leftover(return_none=True))
if msg is not None:
incomings.append(msg)
else:
# timeout reached or listener stopped
break
return incomings
return wrapper
class TransportDriverError(exceptions.MessagingException): class TransportDriverError(exceptions.MessagingException):
"""Base class for transport driver specific exceptions.""" """Base class for transport driver specific exceptions."""
@ -61,8 +85,9 @@ class Listener(object):
self.driver = driver self.driver = driver
@abc.abstractmethod @abc.abstractmethod
def poll(self, timeout=None): def poll(self, timeout=None, prefetch_size=1):
"""Blocking until a message is pending and return IncomingMessage. """Blocking until 'prefetch_size' message is pending and return
[IncomingMessage].
Return None after timeout seconds if timeout is set and no message is Return None after timeout seconds if timeout is set and no message is
ending or if the listener have been stopped. ending or if the listener have been stopped.
""" """

View File

@ -54,6 +54,7 @@ class FakeListener(base.Listener):
exchange = self._exchange_manager.get_exchange(target.exchange) exchange = self._exchange_manager.get_exchange(target.exchange)
exchange.ensure_queue(target, pool) exchange.ensure_queue(target, pool)
@base.batch_poll_helper
def poll(self, timeout=None): def poll(self, timeout=None):
if timeout is not None: if timeout is not None:
deadline = time.time() + timeout deadline = time.time() + timeout

View File

@ -252,6 +252,7 @@ class KafkaListener(base.Listener):
self.conn = conn self.conn = conn
self.incoming_queue = [] self.incoming_queue = []
@base.batch_poll_helper
def poll(self, timeout=None): def poll(self, timeout=None):
while not self._stopped.is_set(): while not self._stopped.is_set():
if self.incoming_queue: if self.incoming_queue:

View File

@ -859,6 +859,7 @@ class Connection(object):
raise rpc_common.Timeout() raise rpc_common.Timeout()
def _recoverable_error_callback(exc): def _recoverable_error_callback(exc):
if not isinstance(exc, rpc_common.Timeout):
self._new_consumers = self._consumers self._new_consumers = self._consumers
timer.check_return(_raise_timeout, exc) timer.check_return(_raise_timeout, exc)

View File

@ -117,8 +117,12 @@ class ProtonListener(base.Listener):
super(ProtonListener, self).__init__(driver) super(ProtonListener, self).__init__(driver)
self.incoming = moves.queue.Queue() self.incoming = moves.queue.Queue()
def poll(self): @base.batch_poll_helper
message = self.incoming.get() def poll(self, timeout=None):
try:
message = self.incoming.get(True, timeout)
except moves.queue.Empty:
return
request, ctxt = unmarshal_request(message) request, ctxt = unmarshal_request(message)
LOG.debug("Returning incoming message") LOG.debug("Returning incoming message")
return ProtonIncomingMessage(self, ctxt, request, message) return ProtonIncomingMessage(self, ctxt, request, message)

View File

@ -40,6 +40,7 @@ class ZmqServer(base.Listener):
self.notify_consumer = self.rpc_consumer self.notify_consumer = self.rpc_consumer
self.consumers = [self.rpc_consumer] self.consumers = [self.rpc_consumer]
@base.batch_poll_helper
def poll(self, timeout=None): def poll(self, timeout=None):
message, socket = self.poller.poll( message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout) timeout or self.conf.rpc_poll_timeout)

View File

@ -93,8 +93,11 @@ class PooledExecutor(base.ExecutorBase):
@excutils.forever_retry_uncaught_exceptions @excutils.forever_retry_uncaught_exceptions
def _runner(self): def _runner(self):
while not self._tombstone.is_set(): while not self._tombstone.is_set():
incoming = self.listener.poll() incoming = self.listener.poll(
if incoming is None: timeout=self.dispatcher.batch_timeout,
prefetch_size=self.dispatcher.batch_size)
if not incoming:
continue continue
callback = self.dispatcher(incoming, self._executor_callback) callback = self.dispatcher(incoming, self._executor_callback)
was_submitted = self._do_submit(callback) was_submitted = self._do_submit(callback)

View File

@ -79,6 +79,12 @@ class DispatcherExecutorContext(object):
class DispatcherBase(object): class DispatcherBase(object):
"Base class for dispatcher" "Base class for dispatcher"
batch_size = 1
"Number of messages to wait before calling endpoints callacks"
batch_timeout = None
"Number of seconds to wait before calling endpoints callacks"
@abc.abstractmethod @abc.abstractmethod
def _listen(self, transport): def _listen(self, transport):
"""Initiate the driver Listener """Initiate the driver Listener
@ -98,7 +104,7 @@ class DispatcherBase(object):
def __call__(self, incoming, executor_callback=None): def __call__(self, incoming, executor_callback=None):
"""Called by the executor to get the DispatcherExecutorContext """Called by the executor to get the DispatcherExecutorContext
:param incoming: message or list of messages :param incoming: list of messages
:type incoming: oslo_messging._drivers.base.IncomingMessage :type incoming: oslo_messging._drivers.base.IncomingMessage
:returns: DispatcherExecutorContext :returns: DispatcherExecutorContext
:rtype: DispatcherExecutorContext :rtype: DispatcherExecutorContext

View File

@ -17,6 +17,7 @@ __all__ = ['Notifier',
'LoggingNotificationHandler', 'LoggingNotificationHandler',
'get_notification_transport', 'get_notification_transport',
'get_notification_listener', 'get_notification_listener',
'get_batch_notification_listener',
'NotificationResult', 'NotificationResult',
'NotificationFilter', 'NotificationFilter',
'PublishErrorsHandler', 'PublishErrorsHandler',

View File

@ -16,7 +16,8 @@
import itertools import itertools
import logging import logging
import sys
import six
from oslo_messaging import dispatcher from oslo_messaging import dispatcher
from oslo_messaging import localcontext from oslo_messaging import localcontext
@ -33,17 +34,7 @@ class NotificationResult(object):
REQUEUE = 'requeue' REQUEUE = 'requeue'
class NotificationDispatcher(dispatcher.DispatcherBase): class _NotificationDispatcherBase(dispatcher.DispatcherBase):
"""A message dispatcher which understands Notification messages.
A MessageHandlingServer is constructed by passing a callable dispatcher
which is invoked with context and message dictionaries each time a message
is received.
NotifcationDispatcher is one such dispatcher which pass a raw notification
message to the endpoints
"""
def __init__(self, targets, endpoints, serializer, allow_requeue, def __init__(self, targets, endpoints, serializer, allow_requeue,
pool=None): pool=None):
self.targets = targets self.targets = targets
@ -74,12 +65,15 @@ class NotificationDispatcher(dispatcher.DispatcherBase):
executor_callback=executor_callback, executor_callback=executor_callback,
post=self._post_dispatch) post=self._post_dispatch)
@staticmethod def _post_dispatch(self, incoming, requeues):
def _post_dispatch(incoming, result): for m in incoming:
if result == NotificationResult.HANDLED: try:
incoming.acknowledge() if requeues and m in requeues:
m.requeue()
else: else:
incoming.requeue() m.acknowledge()
except Exception:
LOG.error("Fail to ack/requeue message", exc_info=True)
def _dispatch_and_handle_error(self, incoming, executor_callback): def _dispatch_and_handle_error(self, incoming, executor_callback):
"""Dispatch a notification message to the appropriate endpoint method. """Dispatch a notification message to the appropriate endpoint method.
@ -88,24 +82,59 @@ class NotificationDispatcher(dispatcher.DispatcherBase):
:type ctxt: IncomingMessage :type ctxt: IncomingMessage
""" """
try: try:
return self._dispatch(incoming.ctxt, incoming.message, return self._dispatch(incoming, executor_callback)
executor_callback)
except Exception: except Exception:
# sys.exc_info() is deleted by LOG.exception(). LOG.error('Exception during message handling', exc_info=True)
exc_info = sys.exc_info()
LOG.error('Exception during message handling',
exc_info=exc_info)
return NotificationResult.HANDLED
def _dispatch(self, ctxt, message, executor_callback=None): def _dispatch(self, incoming, executor_callback=None):
"""Dispatch an RPC message to the appropriate endpoint method. """Dispatch notification messages to the appropriate endpoint method.
:param ctxt: the request context
:type ctxt: dict
:param message: the message payload
:type message: dict
""" """
ctxt = self.serializer.deserialize_context(ctxt)
messages_grouped = itertools.groupby((
self._extract_user_message(m)
for m in incoming), lambda x: x[0])
requeues = set()
for priority, messages in messages_grouped:
__, raw_messages, messages = six.moves.zip(*messages)
raw_messages = list(raw_messages)
messages = list(messages)
if priority not in PRIORITIES:
LOG.warning('Unknown priority "%s"', priority)
continue
for screen, callback in self._callbacks_by_priority.get(priority,
[]):
if screen:
filtered_messages = [message for message in messages
if screen.match(
message["ctxt"],
message["publisher_id"],
message["event_type"],
message["metadata"],
message["payload"])]
else:
filtered_messages = messages
if not filtered_messages:
continue
ret = self._exec_callback(executor_callback, callback,
filtered_messages)
if self.allow_requeue and ret == NotificationResult.REQUEUE:
requeues.update(raw_messages)
break
return requeues
def _exec_callback(self, executor_callback, callback, *args):
if executor_callback:
ret = executor_callback(callback, *args)
else:
ret = callback(*args)
return NotificationResult.HANDLED if ret is None else ret
def _extract_user_message(self, incoming):
ctxt = self.serializer.deserialize_context(incoming.ctxt)
message = incoming.message
publisher_id = message.get('publisher_id') publisher_id = message.get('publisher_id')
event_type = message.get('event_type') event_type = message.get('event_type')
@ -114,28 +143,50 @@ class NotificationDispatcher(dispatcher.DispatcherBase):
'timestamp': message.get('timestamp') 'timestamp': message.get('timestamp')
} }
priority = message.get('priority', '').lower() priority = message.get('priority', '').lower()
if priority not in PRIORITIES:
LOG.warning('Unknown priority "%s"', priority)
return
payload = self.serializer.deserialize_entity(ctxt, payload = self.serializer.deserialize_entity(ctxt,
message.get('payload')) message.get('payload'))
return priority, incoming, dict(ctxt=ctxt,
publisher_id=publisher_id,
event_type=event_type,
payload=payload,
metadata=metadata)
for screen, callback in self._callbacks_by_priority.get(priority, []):
if screen and not screen.match(ctxt, publisher_id, event_type, class NotificationDispatcher(_NotificationDispatcherBase):
metadata, payload): """A message dispatcher which understands Notification messages.
continue
localcontext._set_local_context(ctxt) A MessageHandlingServer is constructed by passing a callable dispatcher
which is invoked with context and message dictionaries each time a message
is received.
"""
def _exec_callback(self, executor_callback, callback, messages):
localcontext._set_local_context(
messages[0]["ctxt"])
try: try:
if executor_callback: return super(NotificationDispatcher, self)._exec_callback(
ret = executor_callback(callback, ctxt, publisher_id, executor_callback, callback,
event_type, payload, metadata) messages[0]["ctxt"],
else: messages[0]["publisher_id"],
ret = callback(ctxt, publisher_id, event_type, payload, messages[0]["event_type"],
metadata) messages[0]["payload"],
ret = NotificationResult.HANDLED if ret is None else ret messages[0]["metadata"])
if self.allow_requeue and ret == NotificationResult.REQUEUE:
return ret
finally: finally:
localcontext._clear_local_context() localcontext._clear_local_context()
return NotificationResult.HANDLED
class BatchNotificationDispatcher(_NotificationDispatcherBase):
"""A message dispatcher which understands Notification messages.
A MessageHandlingServer is constructed by passing a callable dispatcher
which is invoked with a list of message dictionaries each time 'batch_size'
messages are received or 'batch_timeout' seconds is reached.
"""
def __init__(self, targets, endpoints, serializer, allow_requeue,
pool=None, batch_size=None, batch_timeout=None):
super(BatchNotificationDispatcher, self).__init__(targets, endpoints,
serializer,
allow_requeue,
pool)
self.batch_size = batch_size
self.batch_timeout = batch_timeout

View File

@ -142,3 +142,46 @@ def get_notification_listener(transport, targets, endpoints,
serializer, serializer,
allow_requeue, pool) allow_requeue, pool)
return msg_server.MessageHandlingServer(transport, dispatcher, executor) return msg_server.MessageHandlingServer(transport, dispatcher, executor)
def get_batch_notification_listener(transport, targets, endpoints,
executor='blocking', serializer=None,
allow_requeue=False, pool=None,
batch_size=None, batch_timeout=None):
"""Construct a batch notification listener
The executor parameter controls how incoming messages will be received and
dispatched. By default, the most simple executor is used - the blocking
executor.
If the eventlet executor is used, the threading and time library need to be
monkeypatched.
:param transport: the messaging transport
:type transport: Transport
:param targets: the exchanges and topics to listen on
:type targets: list of Target
:param endpoints: a list of endpoint objects
:type endpoints: list
:param executor: name of a message executor - for example
'eventlet', 'blocking'
:type executor: str
:param serializer: an optional entity serializer
:type serializer: Serializer
:param allow_requeue: whether NotificationResult.REQUEUE support is needed
:type allow_requeue: bool
:param pool: the pool name
:type pool: str
:param batch_size: number of messages to wait before calling
endpoints callacks
:type batch_size: int
:param batch_timeout: number of seconds to wait before calling
endpoints callacks
:type batch_timeout: int
:raises: NotImplementedError
"""
transport._require_driver_features(requeue=allow_requeue)
dispatcher = notify_dispatcher.BatchNotificationDispatcher(
targets, endpoints, serializer, allow_requeue, pool,
batch_size, batch_timeout)
return msg_server.MessageHandlingServer(transport, dispatcher, executor)

View File

@ -131,9 +131,9 @@ class RPCDispatcher(dispatcher.DispatcherBase):
return self.serializer.serialize_entity(ctxt, result) return self.serializer.serialize_entity(ctxt, result)
def __call__(self, incoming, executor_callback=None): def __call__(self, incoming, executor_callback=None):
incoming.acknowledge() incoming[0].acknowledge()
return dispatcher.DispatcherExecutorContext( return dispatcher.DispatcherExecutorContext(
incoming, self._dispatch_and_reply, incoming[0], self._dispatch_and_reply,
executor_callback=executor_callback) executor_callback=executor_callback)
def _dispatch_and_reply(self, incoming, executor_callback): def _dispatch_and_reply(self, incoming, executor_callback):

View File

@ -226,7 +226,7 @@ class TestKafkaListener(test_utils.BaseTestCase):
listener.stop() listener.stop()
fake_response = listener.poll() fake_response = listener.poll()
self.assertEqual(1, len(listener.conn.consume.mock_calls)) self.assertEqual(1, len(listener.conn.consume.mock_calls))
self.assertEqual(fake_response, None) self.assertEqual([], fake_response)
class TestWithRealKafkaBroker(test_utils.BaseTestCase): class TestWithRealKafkaBroker(test_utils.BaseTestCase):
@ -251,7 +251,7 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
self.driver.send_notification( self.driver.send_notification(
target, fake_context, fake_message, None) target, fake_context, fake_message, None)
received_message = listener.poll() received_message = listener.poll()[0]
self.assertEqual(fake_context, received_message.ctxt) self.assertEqual(fake_context, received_message.ctxt)
self.assertEqual(fake_message, received_message.message) self.assertEqual(fake_message, received_message.message)
@ -268,7 +268,7 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
self.driver.send_notification( self.driver.send_notification(
target, fake_context, fake_message, None) target, fake_context, fake_message, None)
received_message = listener.poll() received_message = listener.poll()[0]
self.assertEqual(fake_context, received_message.ctxt) self.assertEqual(fake_context, received_message.ctxt)
self.assertEqual(fake_message, received_message.message) self.assertEqual(fake_message, received_message.message)

View File

@ -423,7 +423,7 @@ class TestSendReceive(test_utils.BaseTestCase):
for i in range(len(senders)): for i in range(len(senders)):
senders[i].start() senders[i].start()
received = listener.poll() received = listener.poll()[0]
self.assertIsNotNone(received) self.assertIsNotNone(received)
self.assertEqual(self.ctxt, received.ctxt) self.assertEqual(self.ctxt, received.ctxt)
self.assertEqual({'tx_id': i}, received.message) self.assertEqual({'tx_id': i}, received.message)
@ -501,7 +501,7 @@ class TestPollAsync(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic='testtopic') target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target) listener = driver.listen(target)
received = listener.poll(timeout=0.050) received = listener.poll(timeout=0.050)
self.assertIsNone(received) self.assertEqual([], received)
class TestRacyWaitForReply(test_utils.BaseTestCase): class TestRacyWaitForReply(test_utils.BaseTestCase):
@ -561,13 +561,13 @@ class TestRacyWaitForReply(test_utils.BaseTestCase):
senders[0].start() senders[0].start()
notify_condition.wait() notify_condition.wait()
msgs.append(listener.poll()) msgs.extend(listener.poll())
self.assertEqual({'tx_id': 0}, msgs[-1].message) self.assertEqual({'tx_id': 0}, msgs[-1].message)
# Start the second guy, receive his message # Start the second guy, receive his message
senders[1].start() senders[1].start()
msgs.append(listener.poll()) msgs.extend(listener.poll())
self.assertEqual({'tx_id': 1}, msgs[-1].message) self.assertEqual({'tx_id': 1}, msgs[-1].message)
# Reply to both in order, making the second thread queue # Reply to both in order, making the second thread queue
@ -581,7 +581,7 @@ class TestRacyWaitForReply(test_utils.BaseTestCase):
# Start the 3rd guy, receive his message # Start the 3rd guy, receive his message
senders[2].start() senders[2].start()
msgs.append(listener.poll()) msgs.extend(listener.poll())
self.assertEqual({'tx_id': 2}, msgs[-1].message) self.assertEqual({'tx_id': 2}, msgs[-1].message)
# Verify the _send_reply was not invoked by driver: # Verify the _send_reply was not invoked by driver:
@ -862,7 +862,7 @@ class TestReplyWireFormat(test_utils.BaseTestCase):
producer.publish(msg) producer.publish(msg)
received = listener.poll() received = listener.poll()[0]
self.assertIsNotNone(received) self.assertIsNotNone(received)
self.assertEqual(self.expected_ctxt, received.ctxt) self.assertEqual(self.expected_ctxt, received.ctxt)
self.assertEqual(self.expected, received.message) self.assertEqual(self.expected, received.message)

View File

@ -52,7 +52,8 @@ class TestServerListener(object):
def _run(self): def _run(self):
try: try:
message = self.listener.poll() message = self.listener.poll()
if message is not None: if message:
message = message[0]
message.acknowledge() message.acknowledge()
self._received.set() self._received.set()
self.message = message self.message = message

View File

@ -132,11 +132,14 @@ class TestExecutor(test_utils.BaseTestCase):
endpoint = mock.MagicMock(return_value='result') endpoint = mock.MagicMock(return_value='result')
event = None event = None
class Dispatcher(object): class Dispatcher(dispatcher_base.DispatcherBase):
def __init__(self, endpoint): def __init__(self, endpoint):
self.endpoint = endpoint self.endpoint = endpoint
self.result = "not set" self.result = "not set"
def _listen(self, transport):
pass
def callback(self, incoming, executor_callback): def callback(self, incoming, executor_callback):
if executor_callback is None: if executor_callback is None:
result = self.endpoint(incoming.ctxt, result = self.endpoint(incoming.ctxt,
@ -152,7 +155,7 @@ class TestExecutor(test_utils.BaseTestCase):
def __call__(self, incoming, executor_callback=None): def __call__(self, incoming, executor_callback=None):
return dispatcher_base.DispatcherExecutorContext( return dispatcher_base.DispatcherExecutorContext(
incoming, self.callback, executor_callback) incoming[0], self.callback, executor_callback)
return Dispatcher(endpoint), endpoint, event, run_executor return Dispatcher(endpoint), endpoint, event, run_executor
@ -162,7 +165,7 @@ class TestExecutor(test_utils.BaseTestCase):
executor = self.executor(self.conf, listener, dispatcher) executor = self.executor(self.conf, listener, dispatcher)
incoming_message = mock.MagicMock(ctxt={}, message={'payload': 'data'}) incoming_message = mock.MagicMock(ctxt={}, message={'payload': 'data'})
def fake_poll(timeout=None): def fake_poll(timeout=None, prefetch_size=1):
time.sleep(0.1) time.sleep(0.1)
if listener.poll.call_count == 10: if listener.poll.call_count == 10:
if event is not None: if event is not None:
@ -190,9 +193,9 @@ class TestExecutor(test_utils.BaseTestCase):
executor = self.executor(self.conf, listener, dispatcher) executor = self.executor(self.conf, listener, dispatcher)
incoming_message = mock.MagicMock(ctxt={}, message={'payload': 'data'}) incoming_message = mock.MagicMock(ctxt={}, message={'payload': 'data'})
def fake_poll(timeout=None): def fake_poll(timeout=None, prefetch_size=1):
if listener.poll.call_count == 1: if listener.poll.call_count == 1:
return incoming_message return [incoming_message]
if event is not None: if event is not None:
event.wait() event.wait()
executor.stop() executor.stop()

View File

@ -16,6 +16,7 @@ import uuid
import concurrent.futures import concurrent.futures
from oslo_config import cfg from oslo_config import cfg
import six.moves
from testtools import matchers from testtools import matchers
import oslo_messaging import oslo_messaging
@ -324,3 +325,18 @@ class NotifyTestCase(utils.SkipIfNoTransportURL):
self.assertEqual(expected[1], actual[0]) self.assertEqual(expected[1], actual[0])
self.assertEqual(expected[2], actual[1]) self.assertEqual(expected[2], actual[1])
self.assertEqual(expected[3], actual[2]) self.assertEqual(expected[3], actual[2])
def test_simple_batch(self):
listener = self.useFixture(
utils.BatchNotificationFixture(self.conf, self.url,
['test_simple_batch'],
batch_size=100, batch_timeout=2))
notifier = listener.notifier('abc')
for i in six.moves.range(0, 205):
notifier.info({}, 'test%s' % i, 'Hello World!')
events = listener.get_events(timeout=3)
self.assertEqual(3, len(events), events)
self.assertEqual(100, len(events[0][1]))
self.assertEqual(100, len(events[1][1]))
self.assertEqual(5, len(events[2][1]))

View File

@ -293,13 +293,14 @@ class SkipIfNoTransportURL(test_utils.BaseTestCase):
class NotificationFixture(fixtures.Fixture): class NotificationFixture(fixtures.Fixture):
def __init__(self, conf, url, topics): def __init__(self, conf, url, topics, batch=None):
super(NotificationFixture, self).__init__() super(NotificationFixture, self).__init__()
self.conf = conf self.conf = conf
self.url = url self.url = url
self.topics = topics self.topics = topics
self.events = moves.queue.Queue() self.events = moves.queue.Queue()
self.name = str(id(self)) self.name = str(id(self))
self.batch = batch
def setUp(self): def setUp(self):
super(NotificationFixture, self).setUp() super(NotificationFixture, self).setUp()
@ -307,10 +308,7 @@ class NotificationFixture(fixtures.Fixture):
# add a special topic for internal notifications # add a special topic for internal notifications
targets.append(oslo_messaging.Target(topic=self.name)) targets.append(oslo_messaging.Target(topic=self.name))
transport = self.useFixture(TransportFixture(self.conf, self.url)) transport = self.useFixture(TransportFixture(self.conf, self.url))
self.server = oslo_messaging.get_notification_listener( self.server = self._get_server(transport, targets)
transport.transport,
targets,
[self], 'eventlet')
self._ctrl = self.notifier('internal', topic=self.name) self._ctrl = self.notifier('internal', topic=self.name)
self._start() self._start()
transport.wait() transport.wait()
@ -319,6 +317,12 @@ class NotificationFixture(fixtures.Fixture):
self._stop() self._stop()
super(NotificationFixture, self).cleanUp() super(NotificationFixture, self).cleanUp()
def _get_server(self, transport, targets):
return oslo_messaging.get_notification_listener(
transport.transport,
targets,
[self], 'eventlet')
def _start(self): def _start(self):
self.thread = test_utils.ServerThreadHelper(self.server) self.thread = test_utils.ServerThreadHelper(self.server)
self.thread.start() self.thread.start()
@ -366,3 +370,39 @@ class NotificationFixture(fixtures.Fixture):
except moves.queue.Empty: except moves.queue.Empty:
pass pass
return results return results
class BatchNotificationFixture(NotificationFixture):
def __init__(self, conf, url, topics, batch_size=5, batch_timeout=2):
super(BatchNotificationFixture, self).__init__(conf, url, topics)
self.batch_size = batch_size
self.batch_timeout = batch_timeout
def _get_server(self, transport, targets):
return oslo_messaging.get_batch_notification_listener(
transport.transport,
targets,
[self], 'eventlet',
batch_timeout=self.batch_timeout,
batch_size=self.batch_size)
def debug(self, messages):
self.events.put(['debug', messages])
def audit(self, messages):
self.events.put(['audit', messages])
def info(self, messages):
self.events.put(['info', messages])
def warn(self, messages):
self.events.put(['warn', messages])
def error(self, messages):
self.events.put(['error', messages])
def critical(self, messages):
self.events.put(['critical', messages])
def sample(self, messages):
pass # Just used for internal shutdown control

View File

@ -107,7 +107,7 @@ class TestDispatcher(test_utils.BaseTestCase):
sorted(dispatcher._targets_priorities)) sorted(dispatcher._targets_priorities))
incoming = mock.Mock(ctxt={}, message=msg) incoming = mock.Mock(ctxt={}, message=msg)
callback = dispatcher(incoming) callback = dispatcher([incoming])
callback.run() callback.run()
callback.done() callback.done()
@ -144,7 +144,7 @@ class TestDispatcher(test_utils.BaseTestCase):
msg['priority'] = 'what???' msg['priority'] = 'what???'
dispatcher = notify_dispatcher.NotificationDispatcher( dispatcher = notify_dispatcher.NotificationDispatcher(
[mock.Mock()], [mock.Mock()], None, allow_requeue=True, pool=None) [mock.Mock()], [mock.Mock()], None, allow_requeue=True, pool=None)
callback = dispatcher(mock.Mock(ctxt={}, message=msg)) callback = dispatcher([mock.Mock(ctxt={}, message=msg)])
callback.run() callback.run()
callback.done() callback.done()
mylog.warning.assert_called_once_with('Unknown priority "%s"', mylog.warning.assert_called_once_with('Unknown priority "%s"',
@ -246,7 +246,7 @@ class TestDispatcherFilter(test_utils.BaseTestCase):
'timestamp': '2014-03-03 18:21:04.369234', 'timestamp': '2014-03-03 18:21:04.369234',
'message_id': '99863dda-97f0-443a-a0c1-6ed317b7fd45'} 'message_id': '99863dda-97f0-443a-a0c1-6ed317b7fd45'}
incoming = mock.Mock(ctxt=self.context, message=message) incoming = mock.Mock(ctxt=self.context, message=message)
callback = dispatcher(incoming) callback = dispatcher([incoming])
callback.run() callback.run()
callback.done() callback.done()

View File

@ -23,6 +23,7 @@ import oslo_messaging
from oslo_messaging.notify import dispatcher from oslo_messaging.notify import dispatcher
from oslo_messaging.notify import notifier as msg_notifier from oslo_messaging.notify import notifier as msg_notifier
from oslo_messaging.tests import utils as test_utils from oslo_messaging.tests import utils as test_utils
import six
from six.moves import mock from six.moves import mock
load_tests = testscenarios.load_tests_apply_scenarios load_tests = testscenarios.load_tests_apply_scenarios
@ -56,7 +57,7 @@ class ListenerSetupMixin(object):
self.threads = [] self.threads = []
self.lock = threading.Condition() self.lock = threading.Condition()
def info(self, ctxt, publisher_id, event_type, payload, metadata): def info(self, *args, **kwargs):
# NOTE(sileht): this run into an other thread # NOTE(sileht): this run into an other thread
with self.lock: with self.lock:
self._received_msgs += 1 self._received_msgs += 1
@ -86,7 +87,7 @@ class ListenerSetupMixin(object):
self.trackers = {} self.trackers = {}
def _setup_listener(self, transport, endpoints, def _setup_listener(self, transport, endpoints,
targets=None, pool=None): targets=None, pool=None, batch=False):
if pool is None: if pool is None:
tracker_name = '__default__' tracker_name = '__default__'
@ -98,6 +99,12 @@ class ListenerSetupMixin(object):
tracker = self.trackers.setdefault( tracker = self.trackers.setdefault(
tracker_name, self.ThreadTracker()) tracker_name, self.ThreadTracker())
if batch:
listener = oslo_messaging.get_batch_notification_listener(
transport, targets=targets, endpoints=[tracker] + endpoints,
allow_requeue=True, pool=pool, executor='eventlet',
batch_size=batch[0], batch_timeout=batch[1])
else:
listener = oslo_messaging.get_notification_listener( listener = oslo_messaging.get_notification_listener(
transport, targets=targets, endpoints=[tracker] + endpoints, transport, targets=targets, endpoints=[tracker] + endpoints,
allow_requeue=True, pool=pool, executor='eventlet') allow_requeue=True, pool=pool, executor='eventlet')
@ -170,6 +177,82 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
else: else:
self.assertTrue(False) self.assertTrue(False)
def test_batch_timeout(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
endpoint = mock.Mock()
endpoint.info.return_value = None
listener_thread = self._setup_listener(transport, [endpoint],
batch=(5, 1))
notifier = self._setup_notifier(transport)
for i in six.moves.range(12):
notifier.info({}, 'an_event.start', 'test message')
self.wait_for_messages(3)
self.assertFalse(listener_thread.stop())
messages = [dict(ctxt={},
publisher_id='testpublisher',
event_type='an_event.start',
payload='test message',
metadata={'message_id': mock.ANY,
'timestamp': mock.ANY})]
endpoint.info.assert_has_calls([mock.call(messages * 5),
mock.call(messages * 5),
mock.call(messages * 2)])
def test_batch_size(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
endpoint = mock.Mock()
endpoint.info.return_value = None
listener_thread = self._setup_listener(transport, [endpoint],
batch=(5, None))
notifier = self._setup_notifier(transport)
for i in six.moves.range(10):
notifier.info({}, 'an_event.start', 'test message')
self.wait_for_messages(2)
self.assertFalse(listener_thread.stop())
messages = [dict(ctxt={},
publisher_id='testpublisher',
event_type='an_event.start',
payload='test message',
metadata={'message_id': mock.ANY,
'timestamp': mock.ANY})]
endpoint.info.assert_has_calls([mock.call(messages * 5),
mock.call(messages * 5)])
def test_batch_size_exception_path(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
endpoint = mock.Mock()
endpoint.info.side_effect = [None, Exception('boom!')]
listener_thread = self._setup_listener(transport, [endpoint],
batch=(5, None))
notifier = self._setup_notifier(transport)
for i in six.moves.range(10):
notifier.info({}, 'an_event.start', 'test message')
self.wait_for_messages(2)
self.assertFalse(listener_thread.stop())
messages = [dict(ctxt={},
publisher_id='testpublisher',
event_type='an_event.start',
payload='test message',
metadata={'message_id': mock.ANY,
'timestamp': mock.ANY})]
endpoint.info.assert_has_calls([mock.call(messages * 5)])
def test_one_topic(self): def test_one_topic(self):
transport = msg_notifier.get_notification_transport( transport = msg_notifier.get_notification_transport(
self.conf, url='fake:') self.conf, url='fake:')

View File

@ -133,7 +133,7 @@ class TestDispatcher(test_utils.BaseTestCase):
incoming = mock.Mock(ctxt=self.ctxt, message=self.msg) incoming = mock.Mock(ctxt=self.ctxt, message=self.msg)
incoming.reply.side_effect = check_reply incoming.reply.side_effect = check_reply
callback = dispatcher(incoming) callback = dispatcher([incoming])
callback.run() callback.run()
callback.done() callback.done()

View File

@ -60,7 +60,7 @@ class _ListenerThread(threading.Thread):
def run(self): def run(self):
LOG.debug("Listener started") LOG.debug("Listener started")
while self.msg_count > 0: while self.msg_count > 0:
in_msg = self.listener.poll() in_msg = self.listener.poll()[0]
self.messages.put(in_msg) self.messages.put(in_msg)
self.msg_count -= 1 self.msg_count -= 1
if in_msg.message.get('method') == 'echo': if in_msg.message.get('method') == 'echo':

View File

@ -79,14 +79,34 @@ class LoggingNoParsingFilter(logging.Filter):
return True return True
class NotifyEndpoint(object): class Monitor(object):
def __init__(self): def __init__(self, show_stats=False, *args, **kwargs):
self._count = self._prev_count = 0
self.show_stats = show_stats
if self.show_stats:
self._monitor()
def _monitor(self):
threading.Timer(1.0, self._monitor).start()
print ("%d msg was received per second"
% (self._count - self._prev_count))
self._prev_count = self._count
def info(self, *args, **kwargs):
self._count += 1
class NotifyEndpoint(Monitor):
def __init__(self, *args, **kwargs):
super(NotifyEndpoint, self).__init__(*args, **kwargs)
self.cache = [] self.cache = []
def info(self, ctxt, publisher_id, event_type, payload, metadata): def info(self, ctxt, publisher_id, event_type, payload, metadata):
super(NotifyEndpoint, self).info(ctxt, publisher_id, event_type,
payload, metadata)
LOG.info('msg rcv') LOG.info('msg rcv')
LOG.info("%s %s %s %s" % (ctxt, publisher_id, event_type, payload)) LOG.info("%s %s %s %s" % (ctxt, publisher_id, event_type, payload))
if payload not in self.cache: if not self.show_stats and payload not in self.cache:
LOG.info('requeue msg') LOG.info('requeue msg')
self.cache.append(payload) self.cache.append(payload)
for i in range(15): for i in range(15):
@ -97,8 +117,8 @@ class NotifyEndpoint(object):
return messaging.NotificationResult.HANDLED return messaging.NotificationResult.HANDLED
def notify_server(transport): def notify_server(transport, show_stats):
endpoints = [NotifyEndpoint()] endpoints = [NotifyEndpoint(show_stats)]
target = messaging.Target(topic='n-t1') target = messaging.Target(topic='n-t1')
server = notify.get_notification_listener(transport, [target], server = notify.get_notification_listener(transport, [target],
endpoints, executor='eventlet') endpoints, executor='eventlet')
@ -106,8 +126,41 @@ def notify_server(transport):
server.wait() server.wait()
class RpcEndpoint(object): class BatchNotifyEndpoint(Monitor):
def __init__(self, wait_before_answer): def __init__(self, *args, **kwargs):
super(BatchNotifyEndpoint, self).__init__(*args, **kwargs)
self.cache = []
def info(self, messages):
super(BatchNotifyEndpoint, self).info(messages)
self._count += len(messages) - 1
LOG.info('msg rcv')
LOG.info("%s" % messages)
if not self.show_stats and messages not in self.cache:
LOG.info('requeue msg')
self.cache.append(messages)
for i in range(15):
eventlet.sleep(1)
return messaging.NotificationResult.REQUEUE
else:
LOG.info('ack msg')
return messaging.NotificationResult.HANDLED
def batch_notify_server(transport, show_stats):
endpoints = [BatchNotifyEndpoint(show_stats)]
target = messaging.Target(topic='n-t1')
server = notify.get_batch_notification_listener(
transport, [target],
endpoints, executor='eventlet',
batch_size=1000, batch_time=5)
server.start()
server.wait()
class RpcEndpoint(Monitor):
def __init__(self, wait_before_answer, show_stats):
self.count = None self.count = None
self.wait_before_answer = wait_before_answer self.wait_before_answer = wait_before_answer
@ -126,27 +179,8 @@ class RpcEndpoint(object):
return "OK: %s" % message return "OK: %s" % message
class RpcEndpointMonitor(RpcEndpoint):
def __init__(self, *args, **kwargs):
super(RpcEndpointMonitor, self).__init__(*args, **kwargs)
self._count = self._prev_count = 0
self._monitor()
def _monitor(self):
threading.Timer(1.0, self._monitor).start()
print ("%d msg was received per second"
% (self._count - self._prev_count))
self._prev_count = self._count
def info(self, *args, **kwargs):
self._count += 1
super(RpcEndpointMonitor, self).info(*args, **kwargs)
def rpc_server(transport, target, wait_before_answer, executor, show_stats): def rpc_server(transport, target, wait_before_answer, executor, show_stats):
endpoint_cls = RpcEndpointMonitor if show_stats else RpcEndpoint endpoints = [RpcEndpoint(wait_before_answer, show_stats)]
endpoints = [endpoint_cls(wait_before_answer)]
server = rpc.get_rpc_server(transport, target, endpoints, server = rpc.get_rpc_server(transport, target, endpoints,
executor=executor) executor=executor)
server.start() server.start()
@ -244,6 +278,11 @@ def main():
help='notify/rpc server/client mode') help='notify/rpc server/client mode')
server = subparsers.add_parser('notify-server') server = subparsers.add_parser('notify-server')
server.add_argument('--show-stats', dest='show_stats',
type=bool, default=True)
server = subparsers.add_parser('batch-notify-server')
server.add_argument('--show-stats', dest='show_stats',
type=bool, default=True)
client = subparsers.add_parser('notify-client') client = subparsers.add_parser('notify-client')
client.add_argument('-p', dest='threads', type=int, default=1, client.add_argument('-p', dest='threads', type=int, default=1,
help='number of client threads') help='number of client threads')
@ -302,7 +341,9 @@ def main():
rpc_server(transport, target, args.wait_before_answer, args.executor, rpc_server(transport, target, args.wait_before_answer, args.executor,
args.show_stats) args.show_stats)
elif args.mode == 'notify-server': elif args.mode == 'notify-server':
notify_server(transport) notify_server(transport, args.show_stats)
elif args.mode == 'batch-notify-server':
batch_notify_server(transport, args.show_stats)
elif args.mode == 'notify-client': elif args.mode == 'notify-client':
threads_spawner(args.threads, notifier, transport, args.messages, threads_spawner(args.threads, notifier, transport, args.messages,
args.wait_after_msg, args.timeout) args.wait_after_msg, args.timeout)

View File

@ -5,6 +5,7 @@ envlist = py34,py27,pep8,bandit
setenv = setenv =
VIRTUAL_ENV={envdir} VIRTUAL_ENV={envdir}
OS_TEST_TIMEOUT=30 OS_TEST_TIMEOUT=30
passend = OS_*
deps = -r{toxinidir}/test-requirements.txt deps = -r{toxinidir}/test-requirements.txt
commands = python setup.py testr --slowest --testr-args='{posargs}' commands = python setup.py testr --slowest --testr-args='{posargs}'