Merge "Move server related logic from dispatchers"

This commit is contained in:
Jenkins 2016-03-30 15:36:02 +00:00 committed by Gerrit Code Review
commit d4e8ac42b5
12 changed files with 348 additions and 317 deletions

View File

@ -192,8 +192,11 @@ class PikaDriver(base.BaseDriver):
# exchange which is not exists, we get ChannelClosed exception # exchange which is not exists, we get ChannelClosed exception
# and need to reconnect # and need to reconnect
try: try:
self._declare_rpc_exchange(exchange, self._declare_rpc_exchange(
expiration_time - time.time()) exchange,
None if expiration_time is None else
expiration_time - time.time()
)
except pika_drv_exc.ConnectionException as e: except pika_drv_exc.ConnectionException as e:
LOG.warning("Problem during declaring exchange. %s", e) LOG.warning("Problem during declaring exchange. %s", e)
return True return True

View File

@ -16,96 +16,22 @@ import logging
import six import six
from oslo_messaging._i18n import _
__all__ = [ __all__ = [
"DispatcherBase", "DispatcherBase"
"DispatcherExecutorContext"
] ]
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class DispatcherExecutorContext(object):
"""Dispatcher executor context helper
A dispatcher can have work to do before and after the dispatch of the
request in the main server thread while the dispatcher itself can be
done in its own thread.
The executor can use the helper like this:
callback = dispatcher(incoming)
callback.prepare()
thread = MyWhateverThread()
thread.on_done(callback.done)
thread.run(callback.run)
"""
def __init__(self, incoming, dispatch, post=None):
self._result = None
self._incoming = incoming
self._dispatch = dispatch
self._post = post
def run(self):
"""The incoming message dispath itself
Can be run in an other thread/greenlet/corotine if the executor is
able to do it.
"""
try:
self._result = self._dispatch(self._incoming)
except Exception:
msg = _('The dispatcher method must catches all exceptions')
LOG.exception(msg)
raise RuntimeError(msg)
def done(self):
"""Callback after the incoming message have been dispathed
Should be ran in the main executor thread/greenlet/corotine
"""
# FIXME(sileht): this is not currently true, this works only because
# the driver connection used for polling write on the wire only to
# ack/requeue message, but what if one day, the driver do something
# else
if self._post is not None:
self._post(self._incoming, self._result)
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class DispatcherBase(object): class DispatcherBase(object):
"Base class for dispatcher" "Base class for dispatcher"
batch_size = 1
"Number of messages to wait before calling endpoints callacks"
batch_timeout = None
"Number of seconds to wait before calling endpoints callacks"
@abc.abstractmethod @abc.abstractmethod
def _listen(self, transport): def dispatch(self, incoming):
"""Initiate the driver Listener """Dispatch incoming messages to the endpoints and return result
Usually the driver Listener is start with the transport helper methods: :param incoming: incoming object for dispatching to the endpoint
:type incoming: object, depends on endpoint type
* transport._listen()
* transport._listen_for_notifications()
:param transport: the transport object
:type transport: oslo_messaging.transport.Transport
:returns: a driver Listener object
:rtype: oslo_messaging._drivers.base.Listener
"""
@abc.abstractmethod
def __call__(self, incoming):
"""Called by the executor to get the DispatcherExecutorContext
:param incoming: list of messages
:type incoming: oslo_messging._drivers.base.IncomingMessage
:returns: DispatcherExecutorContext
:rtype: DispatcherExecutorContext
""" """

View File

@ -19,7 +19,7 @@ import logging
import six import six
from oslo_messaging._i18n import _LE, _LW from oslo_messaging._i18n import _LW
from oslo_messaging import dispatcher from oslo_messaging import dispatcher
from oslo_messaging import localcontext from oslo_messaging import localcontext
from oslo_messaging import serializer as msg_serializer from oslo_messaging import serializer as msg_serializer
@ -35,14 +35,11 @@ class NotificationResult(object):
REQUEUE = 'requeue' REQUEUE = 'requeue'
class _NotificationDispatcherBase(dispatcher.DispatcherBase): class NotificationDispatcher(dispatcher.DispatcherBase):
def __init__(self, targets, endpoints, serializer, allow_requeue, def __init__(self, endpoints, serializer):
pool=None):
self.targets = targets
self.endpoints = endpoints self.endpoints = endpoints
self.serializer = serializer or msg_serializer.NoOpSerializer() self.serializer = serializer or msg_serializer.NoOpSerializer()
self.allow_requeue = allow_requeue
self.pool = pool
self._callbacks_by_priority = {} self._callbacks_by_priority = {}
for endpoint, prio in itertools.product(endpoints, PRIORITIES): for endpoint, prio in itertools.product(endpoints, PRIORITIES):
@ -52,42 +49,77 @@ class _NotificationDispatcherBase(dispatcher.DispatcherBase):
self._callbacks_by_priority.setdefault(prio, []).append( self._callbacks_by_priority.setdefault(prio, []).append(
(screen, method)) (screen, method))
priorities = self._callbacks_by_priority.keys() @property
self._targets_priorities = set(itertools.product(self.targets, def supported_priorities(self):
priorities)) return self._callbacks_by_priority.keys()
def _listen(self, transport): def dispatch(self, incoming):
transport._require_driver_features(requeue=self.allow_requeue) """Dispatch notification messages to the appropriate endpoint method.
return transport._listen_for_notifications(self._targets_priorities,
pool=self.pool)
def __call__(self, incoming):
return dispatcher.DispatcherExecutorContext(
incoming, self._dispatch_and_handle_error,
post=self._post_dispatch)
def _post_dispatch(self, incoming, requeues):
for m in incoming:
try:
if requeues and m in requeues:
m.requeue()
else:
m.acknowledge()
except Exception:
LOG.error(_LE("Fail to ack/requeue message"), exc_info=True)
def _dispatch_and_handle_error(self, incoming):
"""Dispatch a notification message to the appropriate endpoint method.
:param incoming: the incoming notification message
:type ctxt: IncomingMessage
""" """
try: priority, raw_message, message = self._extract_user_message(incoming)
return self._dispatch(incoming)
except Exception:
LOG.error(_LE('Exception during message handling'), exc_info=True)
def _dispatch(self, incoming): if priority not in PRIORITIES:
LOG.warning(_LW('Unknown priority "%s"'), priority)
return
for screen, callback in self._callbacks_by_priority.get(priority,
[]):
if screen and not screen.match(message["ctxt"],
message["publisher_id"],
message["event_type"],
message["metadata"],
message["payload"]):
continue
ret = self._exec_callback(callback, message)
if ret == NotificationResult.REQUEUE:
return ret
return NotificationResult.HANDLED
def _exec_callback(self, callback, message):
localcontext._set_local_context(message["ctxt"])
try:
return callback(message["ctxt"],
message["publisher_id"],
message["event_type"],
message["payload"],
message["metadata"])
except Exception:
LOG.exception("Callback raised an exception.")
return NotificationResult.REQUEUE
finally:
localcontext._clear_local_context()
def _extract_user_message(self, incoming):
ctxt = self.serializer.deserialize_context(incoming.ctxt)
message = incoming.message
publisher_id = message.get('publisher_id')
event_type = message.get('event_type')
metadata = {
'message_id': message.get('message_id'),
'timestamp': message.get('timestamp')
}
priority = message.get('priority', '').lower()
payload = self.serializer.deserialize_entity(ctxt,
message.get('payload'))
return priority, incoming, dict(ctxt=ctxt,
publisher_id=publisher_id,
event_type=event_type,
payload=payload,
metadata=metadata)
class BatchNotificationDispatcher(NotificationDispatcher):
"""A message dispatcher which understands Notification messages.
A MessageHandlingServer is constructed by passing a callable dispatcher
which is invoked with a list of message dictionaries each time 'batch_size'
messages are received or 'batch_timeout' seconds is reached.
"""
def dispatch(self, incoming):
"""Dispatch notification messages to the appropriate endpoint method. """Dispatch notification messages to the appropriate endpoint method.
""" """
@ -120,70 +152,14 @@ class _NotificationDispatcherBase(dispatcher.DispatcherBase):
continue continue
ret = self._exec_callback(callback, filtered_messages) ret = self._exec_callback(callback, filtered_messages)
if self.allow_requeue and ret == NotificationResult.REQUEUE: if ret == NotificationResult.REQUEUE:
requeues.update(raw_messages) requeues.update(raw_messages)
break break
return requeues return requeues
def _exec_callback(self, callback, *args):
ret = callback(*args)
return NotificationResult.HANDLED if ret is None else ret
def _extract_user_message(self, incoming):
ctxt = self.serializer.deserialize_context(incoming.ctxt)
message = incoming.message
publisher_id = message.get('publisher_id')
event_type = message.get('event_type')
metadata = {
'message_id': message.get('message_id'),
'timestamp': message.get('timestamp')
}
priority = message.get('priority', '').lower()
payload = self.serializer.deserialize_entity(ctxt,
message.get('payload'))
return priority, incoming, dict(ctxt=ctxt,
publisher_id=publisher_id,
event_type=event_type,
payload=payload,
metadata=metadata)
class NotificationDispatcher(_NotificationDispatcherBase):
"""A message dispatcher which understands Notification messages.
A MessageHandlingServer is constructed by passing a callable dispatcher
which is invoked with context and message dictionaries each time a message
is received.
"""
def _exec_callback(self, callback, messages): def _exec_callback(self, callback, messages):
localcontext._set_local_context(
messages[0]["ctxt"])
try: try:
return super(NotificationDispatcher, self)._exec_callback( return callback(messages)
callback, except Exception:
messages[0]["ctxt"], LOG.exception("Callback raised an exception.")
messages[0]["publisher_id"], return NotificationResult.REQUEUE
messages[0]["event_type"],
messages[0]["payload"],
messages[0]["metadata"])
finally:
localcontext._clear_local_context()
class BatchNotificationDispatcher(_NotificationDispatcherBase):
"""A message dispatcher which understands Notification messages.
A MessageHandlingServer is constructed by passing a callable dispatcher
which is invoked with a list of message dictionaries each time 'batch_size'
messages are received or 'batch_timeout' seconds is reached.
"""
def __init__(self, targets, endpoints, serializer, allow_requeue,
pool=None, batch_size=None, batch_timeout=None):
super(BatchNotificationDispatcher, self).__init__(targets, endpoints,
serializer,
allow_requeue,
pool)
self.batch_size = batch_size
self.batch_timeout = batch_timeout

View File

@ -103,10 +103,89 @@ by passing allow_requeue=True to get_notification_listener(). If the driver
does not support requeueing, it will raise NotImplementedError at this point. does not support requeueing, it will raise NotImplementedError at this point.
""" """
import itertools
import logging
from oslo_messaging._i18n import _LE
from oslo_messaging.notify import dispatcher as notify_dispatcher from oslo_messaging.notify import dispatcher as notify_dispatcher
from oslo_messaging import server as msg_server from oslo_messaging import server as msg_server
LOG = logging.getLogger(__name__)
class NotificationServer(msg_server.MessageHandlingServer):
def __init__(self, transport, targets, dispatcher, executor='blocking',
allow_requeue=True, pool=None):
super(NotificationServer, self).__init__(transport, dispatcher,
executor)
self._allow_requeue = allow_requeue
self._pool = pool
self.targets = targets
self._targets_priorities = set(
itertools.product(self.targets,
self.dispatcher.supported_priorities)
)
def _create_listener(self):
return msg_server.SingleMessageListenerAdapter(
self.transport._listen_for_notifications(
self._targets_priorities, self._pool
)
)
def _process_incoming(self, incoming):
res = notify_dispatcher.NotificationResult.REQUEUE
try:
res = self.dispatcher.dispatch(incoming)
except Exception:
LOG.error(_LE('Exception during message handling'), exc_info=True)
try:
if (res == notify_dispatcher.NotificationResult.REQUEUE and
self._allow_requeue):
incoming.requeue()
else:
incoming.acknowledge()
except Exception:
LOG.error(_LE("Fail to ack/requeue message"), exc_info=True)
class BatchNotificationServer(NotificationServer):
def __init__(self, transport, targets, dispatcher, executor='blocking',
allow_requeue=True, pool=None, batch_size=1,
batch_timeout=None):
super(BatchNotificationServer, self).__init__(
transport=transport, targets=targets, dispatcher=dispatcher,
executor=executor, allow_requeue=allow_requeue, pool=pool
)
self._batch_size = batch_size
self._batch_timeout = batch_timeout
def _create_listener(self):
return msg_server.BatchMessageListenerAdapter(
self.transport._listen_for_notifications(
self._targets_priorities, self._pool
),
timeout=self._batch_timeout,
batch_size=self._batch_size
)
def _process_incoming(self, incoming):
try:
not_processed_messages = self.dispatcher.dispatch(incoming)
except Exception:
not_processed_messages = set(incoming)
LOG.error(_LE('Exception during message handling'), exc_info=True)
for m in incoming:
try:
if m in not_processed_messages and self._allow_requeue:
m.requeue()
else:
m.acknowledge()
except Exception:
LOG.error(_LE("Fail to ack/requeue message"), exc_info=True)
def get_notification_listener(transport, targets, endpoints, def get_notification_listener(transport, targets, endpoints,
executor='blocking', serializer=None, executor='blocking', serializer=None,
@ -137,10 +216,10 @@ def get_notification_listener(transport, targets, endpoints,
:type pool: str :type pool: str
:raises: NotImplementedError :raises: NotImplementedError
""" """
dispatcher = notify_dispatcher.NotificationDispatcher(targets, endpoints, dispatcher = notify_dispatcher.NotificationDispatcher(endpoints,
serializer, serializer)
return NotificationServer(transport, targets, dispatcher, executor,
allow_requeue, pool) allow_requeue, pool)
return msg_server.MessageHandlingServer(transport, dispatcher, executor)
def get_batch_notification_listener(transport, targets, endpoints, def get_batch_notification_listener(transport, targets, endpoints,
@ -180,6 +259,8 @@ def get_batch_notification_listener(transport, targets, endpoints,
:raises: NotImplementedError :raises: NotImplementedError
""" """
dispatcher = notify_dispatcher.BatchNotificationDispatcher( dispatcher = notify_dispatcher.BatchNotificationDispatcher(
targets, endpoints, serializer, allow_requeue, pool, endpoints, serializer)
batch_size, batch_timeout) return BatchNotificationServer(
return msg_server.MessageHandlingServer(transport, dispatcher, executor) transport, targets, dispatcher, executor, allow_requeue, pool,
batch_size, batch_timeout
)

View File

@ -29,7 +29,6 @@ import sys
import six import six
from oslo_messaging._i18n import _LE
from oslo_messaging import _utils as utils from oslo_messaging import _utils as utils
from oslo_messaging import dispatcher from oslo_messaging import dispatcher
from oslo_messaging import localcontext from oslo_messaging import localcontext
@ -94,20 +93,16 @@ class RPCDispatcher(dispatcher.DispatcherBase):
""" """
def __init__(self, target, endpoints, serializer): def __init__(self, endpoints, serializer):
"""Construct a rpc server dispatcher. """Construct a rpc server dispatcher.
:param target: the exchange, topic and server to listen on :param endpoints: list of endpoint objects for dispatching to
:type target: Target :param serializer: optional message serializer
""" """
self.endpoints = endpoints self.endpoints = endpoints
self.serializer = serializer or msg_serializer.NoOpSerializer() self.serializer = serializer or msg_serializer.NoOpSerializer()
self._default_target = msg_target.Target() self._default_target = msg_target.Target()
self._target = target
def _listen(self, transport):
return transport._listen(self._target)
@staticmethod @staticmethod
def _is_namespace(target, namespace): def _is_namespace(target, namespace):
@ -127,43 +122,16 @@ class RPCDispatcher(dispatcher.DispatcherBase):
result = func(ctxt, **new_args) result = func(ctxt, **new_args)
return self.serializer.serialize_entity(ctxt, result) return self.serializer.serialize_entity(ctxt, result)
def __call__(self, incoming): def dispatch(self, incoming):
incoming[0].acknowledge()
return dispatcher.DispatcherExecutorContext(
incoming[0], self._dispatch_and_reply)
def _dispatch_and_reply(self, incoming):
try:
incoming.reply(self._dispatch(incoming.ctxt,
incoming.message))
except ExpectedException as e:
LOG.debug(u'Expected exception during message handling (%s)',
e.exc_info[1])
incoming.reply(failure=e.exc_info, log_failure=False)
except Exception as e:
# current sys.exc_info() content can be overriden
# by another exception raise by a log handler during
# LOG.exception(). So keep a copy and delete it later.
exc_info = sys.exc_info()
try:
LOG.error(_LE('Exception during message handling: %s'), e,
exc_info=exc_info)
incoming.reply(failure=exc_info)
finally:
# NOTE(dhellmann): Remove circular object reference
# between the current stack frame and the traceback in
# exc_info.
del exc_info
def _dispatch(self, ctxt, message):
"""Dispatch an RPC message to the appropriate endpoint method. """Dispatch an RPC message to the appropriate endpoint method.
:param ctxt: the request context :param incoming: incoming message
:type ctxt: dict :type incoming: IncomingMessage
:param message: the message payload
:type message: dict
:raises: NoSuchMethod, UnsupportedVersion :raises: NoSuchMethod, UnsupportedVersion
""" """
message = incoming.message
ctxt = incoming.ctxt
method = message.get('method') method = message.get('method')
args = message.get('args', {}) args = message.get('args', {})
namespace = message.get('namespace') namespace = message.get('namespace')

View File

@ -102,9 +102,53 @@ __all__ = [
'expected_exceptions', 'expected_exceptions',
] ]
import logging
import sys
from oslo_messaging._i18n import _LE
from oslo_messaging.rpc import dispatcher as rpc_dispatcher from oslo_messaging.rpc import dispatcher as rpc_dispatcher
from oslo_messaging import server as msg_server from oslo_messaging import server as msg_server
LOG = logging.getLogger(__name__)
class RPCServer(msg_server.MessageHandlingServer):
def __init__(self, transport, target, dispatcher, executor='blocking'):
super(RPCServer, self).__init__(transport, dispatcher, executor)
self._target = target
def _create_listener(self):
return msg_server.SingleMessageListenerAdapter(
self.transport._listen(self._target)
)
def _process_incoming(self, incoming):
incoming.acknowledge()
try:
res = self.dispatcher.dispatch(incoming)
except rpc_dispatcher.ExpectedException as e:
LOG.debug(u'Expected exception during message handling (%s)',
e.exc_info[1])
incoming.reply(failure=e.exc_info)
except Exception as e:
# current sys.exc_info() content can be overriden
# by another exception raise by a log handler during
# LOG.exception(). So keep a copy and delete it later.
exc_info = sys.exc_info()
try:
LOG.exception(_LE('Exception during message handling: %s'), e)
incoming.reply(failure=exc_info)
finally:
# NOTE(dhellmann): Remove circular object reference
# between the current stack frame and the traceback in
# exc_info.
del exc_info
else:
try:
incoming.reply(res)
except Exception:
LOG.Exception("Can not send reply for message %s", incoming)
def get_rpc_server(transport, target, endpoints, def get_rpc_server(transport, target, endpoints,
executor='blocking', serializer=None): executor='blocking', serializer=None):
@ -129,8 +173,8 @@ def get_rpc_server(transport, target, endpoints,
:param serializer: an optional entity serializer :param serializer: an optional entity serializer
:type serializer: Serializer :type serializer: Serializer
""" """
dispatcher = rpc_dispatcher.RPCDispatcher(target, endpoints, serializer) dispatcher = rpc_dispatcher.RPCDispatcher(endpoints, serializer)
return msg_server.MessageHandlingServer(transport, dispatcher, executor) return RPCServer(transport, target, dispatcher, executor)
def expected_exceptions(*exceptions): def expected_exceptions(*exceptions):

View File

@ -23,6 +23,7 @@ __all__ = [
'ServerListenError', 'ServerListenError',
] ]
import abc
import functools import functools
import inspect import inspect
import logging import logging
@ -34,6 +35,7 @@ from oslo_service import service
from oslo_utils import eventletutils from oslo_utils import eventletutils
from oslo_utils import excutils from oslo_utils import excutils
from oslo_utils import timeutils from oslo_utils import timeutils
import six
from stevedore import driver from stevedore import driver
from oslo_messaging._drivers import base as driver_base from oslo_messaging._drivers import base as driver_base
@ -295,6 +297,42 @@ def ordered(after=None, reset_after=None):
return _ordered return _ordered
@six.add_metaclass(abc.ABCMeta)
class MessageListenerAdapter(object):
def __init__(self, driver_listener, timeout=None):
self._driver_listener = driver_listener
self._timeout = timeout
@abc.abstractmethod
def poll(self):
"""Poll incoming and return incoming request"""
def stop(self):
self._driver_listener.stop()
def cleanup(self):
self._driver_listener.cleanup()
class SingleMessageListenerAdapter(MessageListenerAdapter):
def poll(self):
msgs = self._driver_listener.poll(prefetch_size=1,
timeout=self._timeout)
return msgs[0] if msgs else None
class BatchMessageListenerAdapter(MessageListenerAdapter):
def __init__(self, driver_listener, timeout=None, batch_size=1):
super(BatchMessageListenerAdapter, self).__init__(driver_listener,
timeout)
self._batch_size = batch_size
def poll(self):
return self._driver_listener.poll(prefetch_size=self._batch_size,
timeout=self._timeout)
@six.add_metaclass(abc.ABCMeta)
class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
"""Server for handling messages. """Server for handling messages.
@ -345,9 +383,18 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
super(MessageHandlingServer, self).__init__() super(MessageHandlingServer, self).__init__()
def _submit_work(self, callback): @abc.abstractmethod
fut = self._work_executor.submit(callback.run) def _process_incoming(self, incoming):
fut.add_done_callback(lambda f: callback.done()) """Process incoming request
:param incoming: incoming request.
"""
@abc.abstractmethod
def _create_listener(self):
"""Creates listener object for polling requests
:return: MessageListenerAdapter
"""
@ordered(reset_after='stop') @ordered(reset_after='stop')
def start(self, override_pool_size=None): def start(self, override_pool_size=None):
@ -374,7 +421,7 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
self._started = True self._started = True
try: try:
self.listener = self.dispatcher._listen(self.transport) self.listener = self._create_listener()
except driver_base.TransportDriverError as ex: except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex) raise ServerListenError(self.target, ex)
@ -412,22 +459,18 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
@excutils.forever_retry_uncaught_exceptions @excutils.forever_retry_uncaught_exceptions
def _runner(self): def _runner(self):
while self._started: while self._started:
incoming = self.listener.poll( incoming = self.listener.poll()
timeout=self.dispatcher.batch_timeout,
prefetch_size=self.dispatcher.batch_size)
if incoming: if incoming:
self._submit_work(self.dispatcher(incoming)) self._work_executor.submit(self._process_incoming, incoming)
# listener is stopped but we need to process all already consumed # listener is stopped but we need to process all already consumed
# messages # messages
while True: while True:
incoming = self.listener.poll( incoming = self.listener.poll()
timeout=self.dispatcher.batch_timeout,
prefetch_size=self.dispatcher.batch_size)
if incoming: if incoming:
self._submit_work(self.dispatcher(incoming)) self._work_executor.submit(self._process_incoming, incoming)
else: else:
return return

View File

@ -13,8 +13,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import itertools
from oslo_utils import timeutils from oslo_utils import timeutils
import testscenarios import testscenarios
@ -25,7 +23,6 @@ from six.moves import mock
load_tests = testscenarios.load_tests_apply_scenarios load_tests = testscenarios.load_tests_apply_scenarios
notification_msg = dict( notification_msg = dict(
publisher_id="publisher_id", publisher_id="publisher_id",
event_type="compute.start", event_type="compute.start",
@ -96,20 +93,21 @@ class TestDispatcher(test_utils.BaseTestCase):
msg = notification_msg.copy() msg = notification_msg.copy()
msg['priority'] = self.priority msg['priority'] = self.priority
targets = [oslo_messaging.Target(topic='notifications')] dispatcher = notify_dispatcher.NotificationDispatcher(endpoints, None)
dispatcher = notify_dispatcher.NotificationDispatcher(
targets, endpoints, None, allow_requeue=True, pool=None)
# check it listen on wanted topics
self.assertEqual(sorted(set((targets[0], prio)
for prio in itertools.chain.from_iterable(
self.endpoints))),
sorted(dispatcher._targets_priorities))
incoming = mock.Mock(ctxt={}, message=msg) incoming = mock.Mock(ctxt={}, message=msg)
callback = dispatcher([incoming])
callback.run() res = dispatcher.dispatch(incoming)
callback.done()
expected_res = (
notify_dispatcher.NotificationResult.REQUEUE
if (self.return_value ==
notify_dispatcher.NotificationResult.REQUEUE or
self.ex is not None)
else notify_dispatcher.NotificationResult.HANDLED
)
self.assertEqual(expected_res, res)
# check endpoint callbacks are called or not # check endpoint callbacks are called or not
for i, endpoint_methods in enumerate(self.endpoints): for i, endpoint_methods in enumerate(self.endpoints):
@ -127,26 +125,14 @@ class TestDispatcher(test_utils.BaseTestCase):
else: else:
self.assertEqual(0, endpoints[i].call_count) self.assertEqual(0, endpoints[i].call_count)
if self.ex:
self.assertEqual(1, incoming.acknowledge.call_count)
self.assertEqual(0, incoming.requeue.call_count)
elif self.return_value == oslo_messaging.NotificationResult.HANDLED \
or self.return_value is None:
self.assertEqual(1, incoming.acknowledge.call_count)
self.assertEqual(0, incoming.requeue.call_count)
elif self.return_value == oslo_messaging.NotificationResult.REQUEUE:
self.assertEqual(0, incoming.acknowledge.call_count)
self.assertEqual(1, incoming.requeue.call_count)
@mock.patch('oslo_messaging.notify.dispatcher.LOG') @mock.patch('oslo_messaging.notify.dispatcher.LOG')
def test_dispatcher_unknown_prio(self, mylog): def test_dispatcher_unknown_prio(self, mylog):
msg = notification_msg.copy() msg = notification_msg.copy()
msg['priority'] = 'what???' msg['priority'] = 'what???'
dispatcher = notify_dispatcher.NotificationDispatcher( dispatcher = notify_dispatcher.NotificationDispatcher(
[mock.Mock()], [mock.Mock()], None, allow_requeue=True, pool=None) [mock.Mock()], None)
callback = dispatcher([mock.Mock(ctxt={}, message=msg)]) res = dispatcher.dispatch(mock.Mock(ctxt={}, message=msg))
callback.run() self.assertEqual(None, res)
callback.done()
mylog.warning.assert_called_once_with('Unknown priority "%s"', mylog.warning.assert_called_once_with('Unknown priority "%s"',
'what???') 'what???')
@ -236,9 +222,8 @@ class TestDispatcherFilter(test_utils.BaseTestCase):
**self.filter_rule) **self.filter_rule)
endpoint = mock.Mock(spec=['info'], filter_rule=notification_filter) endpoint = mock.Mock(spec=['info'], filter_rule=notification_filter)
targets = [oslo_messaging.Target(topic='notifications')]
dispatcher = notify_dispatcher.NotificationDispatcher( dispatcher = notify_dispatcher.NotificationDispatcher(
targets, [endpoint], serializer=None, allow_requeue=True) [endpoint], serializer=None)
message = {'payload': {'state': 'active'}, message = {'payload': {'state': 'active'},
'priority': 'info', 'priority': 'info',
'publisher_id': self.publisher_id, 'publisher_id': self.publisher_id,
@ -246,9 +231,7 @@ class TestDispatcherFilter(test_utils.BaseTestCase):
'timestamp': '2014-03-03 18:21:04.369234', 'timestamp': '2014-03-03 18:21:04.369234',
'message_id': '99863dda-97f0-443a-a0c1-6ed317b7fd45'} 'message_id': '99863dda-97f0-443a-a0c1-6ed317b7fd45'}
incoming = mock.Mock(ctxt=self.context, message=message) incoming = mock.Mock(ctxt=self.context, message=message)
callback = dispatcher([incoming]) dispatcher.dispatch(incoming)
callback.run()
callback.done()
if self.match: if self.match:
self.assertEqual(1, endpoint.info.call_count) self.assertEqual(1, endpoint.info.call_count)

View File

@ -351,8 +351,7 @@ class TestLogNotifier(test_utils.BaseTestCase):
logger = mock.MagicMock() logger = mock.MagicMock()
logger.info = mock.MagicMock() logger.info = mock.MagicMock()
message = {'password': 'passw0rd', 'event_type': 'foo'} message = {'password': 'passw0rd', 'event_type': 'foo'}
json_str = jsonutils.dumps(message) mask_str = jsonutils.dumps(strutils.mask_dict_password(message))
mask_str = strutils.mask_password(json_str)
with mock.patch.object(logging, 'getLogger') as gl: with mock.patch.object(logging, 'getLogger') as gl:
gl.return_value = logger gl.return_value = logger

View File

@ -109,13 +109,15 @@ class TestDispatcher(test_utils.BaseTestCase):
for e in self.endpoints] for e in self.endpoints]
serializer = None serializer = None
target = oslo_messaging.Target() dispatcher = oslo_messaging.RPCDispatcher(endpoints, serializer)
dispatcher = oslo_messaging.RPCDispatcher(target, endpoints,
serializer)
def check_reply(reply=None, failure=None, log_failure=True): incoming = mock.Mock(ctxt=self.ctxt, message=self.msg)
if self.ex and failure is not None:
ex = failure[1] res = None
try:
res = dispatcher.dispatch(incoming)
except Exception as ex:
self.assertFalse(self.success, ex) self.assertFalse(self.success, ex)
self.assertIsNotNone(self.ex, ex) self.assertIsNotNone(self.ex, ex)
self.assertIsInstance(ex, self.ex, ex) self.assertIsInstance(ex, self.ex, ex)
@ -127,15 +129,9 @@ class TestDispatcher(test_utils.BaseTestCase):
if ex.method: if ex.method:
self.assertEqual(self.msg.get('method'), ex.method) self.assertEqual(self.msg.get('method'), ex.method)
else: else:
self.assertTrue(self.success, failure) self.assertTrue(self.success,
self.assertIsNone(failure) "Not expected success of operation durung testing")
self.assertIsNotNone(res)
incoming = mock.Mock(ctxt=self.ctxt, message=self.msg)
incoming.reply.side_effect = check_reply
callback = dispatcher([incoming])
callback.run()
callback.done()
for n, endpoint in enumerate(endpoints): for n, endpoint in enumerate(endpoints):
for method_name in ['foo', 'bar']: for method_name in ['foo', 'bar']:
@ -147,8 +143,6 @@ class TestDispatcher(test_utils.BaseTestCase):
else: else:
self.assertEqual(0, method.call_count) self.assertEqual(0, method.call_count)
self.assertEqual(1, incoming.reply.call_count)
class TestSerializer(test_utils.BaseTestCase): class TestSerializer(test_utils.BaseTestCase):
@ -165,9 +159,7 @@ class TestSerializer(test_utils.BaseTestCase):
def test_serializer(self): def test_serializer(self):
endpoint = _FakeEndpoint() endpoint = _FakeEndpoint()
serializer = msg_serializer.NoOpSerializer() serializer = msg_serializer.NoOpSerializer()
target = oslo_messaging.Target() dispatcher = oslo_messaging.RPCDispatcher([endpoint], serializer)
dispatcher = oslo_messaging.RPCDispatcher(target, [endpoint],
serializer)
self.mox.StubOutWithMock(endpoint, 'foo') self.mox.StubOutWithMock(endpoint, 'foo')
args = dict([(k, 'd' + v) for k, v in self.args.items()]) args = dict([(k, 'd' + v) for k, v in self.args.items()])
@ -187,7 +179,9 @@ class TestSerializer(test_utils.BaseTestCase):
self.mox.ReplayAll() self.mox.ReplayAll()
retval = dispatcher._dispatch(self.ctxt, dict(method='foo', incoming = mock.Mock()
args=self.args)) incoming.ctxt = self.ctxt
incoming.message = dict(method='foo', args=self.args)
retval = dispatcher.dispatch(incoming)
if self.retval is not None: if self.retval is not None:
self.assertEqual('s' + self.retval, retval) self.assertEqual('s' + self.retval, retval)

View File

@ -149,7 +149,7 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin):
serializer=serializer) serializer=serializer)
# Mocking executor # Mocking executor
server._executor_cls = MagicMockIgnoreArgs server._executor_cls = MagicMockIgnoreArgs
server.listener = MagicMockIgnoreArgs() server._create_listener = MagicMockIgnoreArgs()
server.dispatcher = MagicMockIgnoreArgs() server.dispatcher = MagicMockIgnoreArgs()
# Here assigning executor's listener object to listener variable # Here assigning executor's listener object to listener variable
# before calling wait method, because in wait method we are # before calling wait method, because in wait method we are
@ -551,7 +551,6 @@ class TestServerLocking(test_utils.BaseTestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._lock = threading.Lock() self._lock = threading.Lock()
self._calls = [] self._calls = []
self.listener = mock.MagicMock()
executors.append(self) executors.append(self)
submit = _logmethod('submit') submit = _logmethod('submit')
@ -559,9 +558,16 @@ class TestServerLocking(test_utils.BaseTestCase):
self.executors = executors self.executors = executors
self.server = oslo_messaging.MessageHandlingServer(mock.Mock(), class MessageHandlingServerImpl(oslo_messaging.MessageHandlingServer):
mock.Mock()) def _create_listener(self):
pass
def _process_incoming(self, incoming):
pass
self.server = MessageHandlingServerImpl(mock.Mock(), mock.Mock())
self.server._executor_cls = FakeExecutor self.server._executor_cls = FakeExecutor
self.server._create_listener = mock.Mock()
def test_start_stop_wait(self): def test_start_stop_wait(self):
# Test a simple execution of start, stop, wait in order # Test a simple execution of start, stop, wait in order

View File

@ -61,6 +61,14 @@ class OptsTestCase(test_utils.BaseTestCase):
def test_defaults(self): def test_defaults(self):
transport = mock.Mock() transport = mock.Mock()
transport.conf = self.conf transport.conf = self.conf
server.MessageHandlingServer(transport, mock.Mock())
class MessageHandlingServerImpl(server.MessageHandlingServer):
def _create_listener(self):
pass
def _process_incoming(self, incoming):
pass
MessageHandlingServerImpl(transport, mock.Mock())
opts.set_defaults(self.conf, executor_thread_pool_size=100) opts.set_defaults(self.conf, executor_thread_pool_size=100)
self.assertEqual(100, self.conf.executor_thread_pool_size) self.assertEqual(100, self.conf.executor_thread_pool_size)