Use a common message dispatcher
Instead of recreating a dispatcher in the server
and executor objects use a common dispatcher that
is shared between them. It will dispatch based on
the message type received into a provided dict of
dispatch handler callbacks.
It also can generically requeue messages and can
reject messages if they are missing key required
message properties ('type' in the current case).
Part of blueprint wbe-message-validation
Change-Id: I8320f4707183f36e6a69f0552cf62f99a5467b7e
This commit is contained in:
96
taskflow/engines/worker_based/dispatcher.py
Normal file
96
taskflow/engines/worker_based/dispatcher.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
import six
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TypeDispatcher(object):
|
||||
"""Receives messages and dispatches to type specific handlers."""
|
||||
|
||||
def __init__(self, type_handlers):
|
||||
self._handlers = dict(type_handlers)
|
||||
self._requeue_filters = []
|
||||
|
||||
def add_requeue_filter(self, callback):
|
||||
"""Add a callback that can *request* message requeuing.
|
||||
|
||||
The callback will be activated before the message has been acked and
|
||||
it can be used to instruct the dispatcher to requeue the message
|
||||
instead of processing it.
|
||||
"""
|
||||
assert six.callable(callback), "Callback must be callable"
|
||||
self._requeue_filters.append(callback)
|
||||
|
||||
def _collect_requeue_votes(self, data, message):
|
||||
# Returns how many of the filters asked for the message to be requeued.
|
||||
requeue_votes = 0
|
||||
for f in self._requeue_filters:
|
||||
try:
|
||||
if f(data, message):
|
||||
requeue_votes += 1
|
||||
except Exception:
|
||||
LOG.exception("Failed calling requeue filter to determine"
|
||||
" if message %r should be requeued.",
|
||||
message.delivery_tag)
|
||||
return requeue_votes
|
||||
|
||||
def _requeue_log_error(self, message, errors):
|
||||
# TODO(harlowja): Remove when http://github.com/celery/kombu/pull/372
|
||||
# is merged and a version is released with this change...
|
||||
try:
|
||||
message.requeue()
|
||||
except errors as exc:
|
||||
# This was taken from how kombu is formatting its messages
|
||||
# when its reject_log_error or ack_log_error functions are
|
||||
# used so that we have a similar error format for requeuing.
|
||||
LOG.critical("Couldn't requeue %r, reason:%r",
|
||||
message.delivery_tag, exc, exc_info=True)
|
||||
else:
|
||||
LOG.debug("AMQP message %r requeued.", message.delivery_tag)
|
||||
|
||||
def on_message(self, data, message):
|
||||
"""This method is called on incoming messages."""
|
||||
LOG.debug("Got message: %r", message.delivery_tag)
|
||||
if self._collect_requeue_votes(data, message):
|
||||
self._requeue_log_error(message,
|
||||
errors=(kombu_exc.MessageStateError,))
|
||||
else:
|
||||
try:
|
||||
msg_type = message.properties['type']
|
||||
except KeyError:
|
||||
message.reject_log_error(
|
||||
logger=LOG, errors=(kombu_exc.MessageStateError,))
|
||||
LOG.warning("The 'type' message property is missing"
|
||||
" in message %r", message.delivery_tag)
|
||||
else:
|
||||
handler = self._handlers.get(msg_type)
|
||||
if handler is None:
|
||||
message.reject_log_error(
|
||||
logger=LOG, errors=(kombu_exc.MessageStateError,))
|
||||
LOG.warning("Unexpected message type: '%s' in message"
|
||||
" %r", msg_type, message.delivery_tag)
|
||||
else:
|
||||
message.ack_log_error(
|
||||
logger=LOG, errors=(kombu_exc.MessageStateError,))
|
||||
if message.acknowledged:
|
||||
LOG.debug("AMQP message %r acknowledged.",
|
||||
message.delivery_tag)
|
||||
handler(data, message)
|
||||
@@ -16,8 +16,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.engines.worker_based import cache
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
@@ -75,36 +73,18 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
self._topics = topics
|
||||
self._requests_cache = cache.RequestsCache()
|
||||
self._workers_cache = cache.WorkersCache()
|
||||
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
|
||||
handlers = {
|
||||
pr.NOTIFY: self._process_notify,
|
||||
pr.RESPONSE: self._process_response,
|
||||
}
|
||||
self._proxy = proxy.Proxy(uuid, exchange, handlers,
|
||||
self._on_wait, **kwargs)
|
||||
self._proxy_thread = None
|
||||
self._periodic = PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
|
||||
[self._notify_topics])
|
||||
self._periodic_thread = None
|
||||
|
||||
def _on_message(self, data, message):
|
||||
"""This method is called on incoming message."""
|
||||
LOG.debug("Got message: %s", data)
|
||||
try:
|
||||
# acknowledge message before processing
|
||||
message.ack()
|
||||
except kombu_exc.MessageStateError:
|
||||
LOG.exception("Failed to acknowledge AMQP message.")
|
||||
else:
|
||||
LOG.debug("AMQP message acknowledged.")
|
||||
try:
|
||||
msg_type = message.properties['type']
|
||||
except KeyError:
|
||||
LOG.warning("The 'type' message property is missing.")
|
||||
else:
|
||||
if msg_type == pr.NOTIFY:
|
||||
self._process_notify(data)
|
||||
elif msg_type == pr.RESPONSE:
|
||||
self._process_response(data, message)
|
||||
else:
|
||||
LOG.warning("Unexpected message type: %s", msg_type)
|
||||
|
||||
def _process_notify(self, notify):
|
||||
def _process_notify(self, notify, message):
|
||||
"""Process notify message from remote side."""
|
||||
LOG.debug("Start processing notify message.")
|
||||
topic = notify['topic']
|
||||
|
||||
@@ -21,6 +21,9 @@ import threading
|
||||
import kombu
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import dispatcher
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# NOTE(skudriashev): A timeout of 1 is often used in environments where
|
||||
@@ -31,17 +34,20 @@ DRAIN_EVENTS_PERIOD = 1
|
||||
class Proxy(object):
|
||||
"""A proxy processes messages from/to the named exchange."""
|
||||
|
||||
def __init__(self, topic, exchange_name, on_message, on_wait=None,
|
||||
def __init__(self, topic, exchange_name, type_handlers, on_wait=None,
|
||||
**kwargs):
|
||||
self._topic = topic
|
||||
self._exchange_name = exchange_name
|
||||
self._on_message = on_message
|
||||
self._on_wait = on_wait
|
||||
self._running = threading.Event()
|
||||
self._url = kwargs.get('url')
|
||||
self._transport = kwargs.get('transport')
|
||||
self._transport_opts = kwargs.get('transport_options')
|
||||
|
||||
self._dispatcher = dispatcher.TypeDispatcher(type_handlers)
|
||||
self._dispatcher.add_requeue_filter(
|
||||
# NOTE(skudriashev): Process all incoming messages only if proxy is
|
||||
# running, otherwise requeue them.
|
||||
lambda data, message: not self.is_running)
|
||||
self._drain_events_timeout = DRAIN_EVENTS_PERIOD
|
||||
if self._transport == 'memory' and self._transport_opts:
|
||||
polling_interval = self._transport_opts.get('polling_interval')
|
||||
@@ -95,7 +101,7 @@ class Proxy(object):
|
||||
with kombu.connections[self._conn].acquire(block=True) as conn:
|
||||
queue = self._make_queue(self._topic, self._exchange, channel=conn)
|
||||
with conn.Consumer(queues=queue,
|
||||
callbacks=[self._on_message]):
|
||||
callbacks=[self._dispatcher.on_message]):
|
||||
self._running.set()
|
||||
while self.is_running:
|
||||
try:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import functools
|
||||
import logging
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
@@ -26,54 +26,35 @@ from taskflow.utils import misc
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def delayed(executor):
|
||||
"""Wraps & runs the function using a futures compatible executor."""
|
||||
|
||||
def decorator(f):
|
||||
|
||||
@six.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return executor.submit(f, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Server(object):
|
||||
"""Server implementation that waits for incoming tasks requests."""
|
||||
|
||||
def __init__(self, topic, exchange, executor, endpoints, **kwargs):
|
||||
self._proxy = proxy.Proxy(topic, exchange, self._on_message, **kwargs)
|
||||
handlers = {
|
||||
pr.NOTIFY: delayed(executor)(self._process_notify),
|
||||
pr.REQUEST: delayed(executor)(self._process_request),
|
||||
}
|
||||
self._proxy = proxy.Proxy(topic, exchange, handlers,
|
||||
on_wait=None, **kwargs)
|
||||
self._topic = topic
|
||||
self._executor = executor
|
||||
self._endpoints = dict([(endpoint.name, endpoint)
|
||||
for endpoint in endpoints])
|
||||
|
||||
def _on_message(self, data, message):
|
||||
"""This method is called on incoming message."""
|
||||
LOG.debug("Got message: %s", data)
|
||||
# NOTE(skudriashev): Process all incoming messages only if proxy is
|
||||
# running, otherwise requeue them.
|
||||
if self._proxy.is_running:
|
||||
# NOTE(skudriashev): Process request only if message has been
|
||||
# acknowledged successfully.
|
||||
try:
|
||||
# acknowledge message before processing
|
||||
message.ack()
|
||||
except kombu_exc.MessageStateError:
|
||||
LOG.exception("Failed to acknowledge AMQP message.")
|
||||
else:
|
||||
LOG.debug("AMQP message acknowledged.")
|
||||
try:
|
||||
msg_type = message.properties['type']
|
||||
except KeyError:
|
||||
LOG.warning("The 'type' message property is missing.")
|
||||
else:
|
||||
if msg_type == pr.NOTIFY:
|
||||
handler = self._process_notify
|
||||
elif msg_type == pr.REQUEST:
|
||||
handler = self._process_request
|
||||
else:
|
||||
LOG.warning("Unexpected message type: %s", msg_type)
|
||||
return
|
||||
# spawn new thread to process request
|
||||
self._executor.submit(handler, data, message)
|
||||
else:
|
||||
try:
|
||||
# requeue message
|
||||
message.requeue()
|
||||
except kombu_exc.MessageStateError:
|
||||
LOG.exception("Failed to requeue AMQP message.")
|
||||
else:
|
||||
LOG.debug("AMQP message requeued.")
|
||||
|
||||
@staticmethod
|
||||
def _parse_request(task_cls, task_name, action, arguments, result=None,
|
||||
failures=None, **kwargs):
|
||||
|
||||
77
taskflow/tests/unit/worker_based/test_dispatcher.py
Normal file
77
taskflow/tests/unit/worker_based/test_dispatcher.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from kombu import message
|
||||
import mock
|
||||
|
||||
from taskflow.engines.worker_based import dispatcher
|
||||
from taskflow import test
|
||||
|
||||
|
||||
def mock_acked_message(ack_ok=True, **kwargs):
|
||||
msg = mock.create_autospec(message.Message, spec_set=True, instance=True,
|
||||
channel=None, **kwargs)
|
||||
|
||||
def ack_side_effect(*args, **kwargs):
|
||||
msg.acknowledged = True
|
||||
|
||||
if ack_ok:
|
||||
msg.ack_log_error.side_effect = ack_side_effect
|
||||
msg.acknowledged = False
|
||||
return msg
|
||||
|
||||
|
||||
class TestDispatcher(test.MockTestCase):
|
||||
def test_creation(self):
|
||||
on_hello = mock.MagicMock()
|
||||
handlers = {'hello': on_hello}
|
||||
dispatcher.TypeDispatcher(handlers)
|
||||
|
||||
def test_on_message(self):
|
||||
on_hello = mock.MagicMock()
|
||||
handlers = {'hello': on_hello}
|
||||
d = dispatcher.TypeDispatcher(handlers)
|
||||
msg = mock_acked_message(properties={'type': 'hello'})
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(on_hello.called)
|
||||
self.assertTrue(msg.ack_log_error.called)
|
||||
self.assertTrue(msg.acknowledged)
|
||||
|
||||
def test_on_rejected_message(self):
|
||||
d = dispatcher.TypeDispatcher({})
|
||||
msg = mock_acked_message(properties={'type': 'hello'})
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(msg.reject_log_error.called)
|
||||
self.assertFalse(msg.acknowledged)
|
||||
|
||||
def test_on_requeue_message(self):
|
||||
d = dispatcher.TypeDispatcher({})
|
||||
d.add_requeue_filter(lambda data, message: True)
|
||||
msg = mock_acked_message()
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(msg.requeue.called)
|
||||
self.assertFalse(msg.acknowledged)
|
||||
|
||||
def test_failed_ack(self):
|
||||
on_hello = mock.MagicMock()
|
||||
handlers = {'hello': on_hello}
|
||||
d = dispatcher.TypeDispatcher(handlers)
|
||||
msg = mock_acked_message(ack_ok=False,
|
||||
properties={'type': 'hello'})
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(msg.ack_log_error.called)
|
||||
self.assertFalse(msg.acknowledged)
|
||||
self.assertFalse(on_hello.called)
|
||||
@@ -18,7 +18,6 @@ import threading
|
||||
import time
|
||||
|
||||
from concurrent import futures
|
||||
from kombu import exceptions as kombu_exc
|
||||
import mock
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
@@ -86,7 +85,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
|
||||
ex._on_message, ex._on_wait, url=self.broker_url)
|
||||
mock.ANY, ex._on_wait, url=self.broker_url)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@@ -94,21 +93,19 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_running()])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_response_state_progress(self):
|
||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.on_progress(progress=1.0)])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_response_state_failure(self):
|
||||
failure = misc.Failure.from_exception(Exception('test'))
|
||||
@@ -116,75 +113,49 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
response = pr.Response(pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_response_state_success(self):
|
||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_result(result=self.task_result,
|
||||
event='executed')])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_response_unknown_state(self):
|
||||
response = pr.Response(state='<unknown>')
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_response_unknown_task(self):
|
||||
self.message_mock.properties['correlation_id'] = '<unknown>'
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_response_no_correlation_id(self):
|
||||
self.message_mock.properties = {'type': pr.RESPONSE}
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
|
||||
def test_on_message_unknown_type(self, mocked_warning):
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid,
|
||||
'type': '<unknown>'}
|
||||
ex = self.executor()
|
||||
ex._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
|
||||
def test_on_message_no_type(self, mocked_warning):
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid}
|
||||
ex = self.executor()
|
||||
ex._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.exception')
|
||||
def test_on_message_acknowledge_raises(self, mocked_exception):
|
||||
self.message_mock.ack.side_effect = kombu_exc.MessageStateError()
|
||||
self.executor()._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
def test_on_wait_task_not_expired(self):
|
||||
ex = self.executor()
|
||||
@@ -222,7 +193,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
||||
ex = self.executor()
|
||||
ex._on_message(notify.to_dict(), self.message_mock)
|
||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
@@ -240,7 +211,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
||||
ex = self.executor()
|
||||
ex._on_message(notify.to_dict(), self.message_mock)
|
||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
||||
ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||
self.task_result, self.task_failures)
|
||||
|
||||
@@ -273,7 +244,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
||||
ex = self.executor()
|
||||
ex._on_message(notify.to_dict(), self.message_mock)
|
||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
|
||||
80
taskflow/tests/unit/worker_based/test_message_pump.py
Normal file
80
taskflow/tests/unit/worker_based/test_message_pump.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import threading
|
||||
|
||||
import mock
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow import test
|
||||
|
||||
|
||||
class TestMessagePump(test.MockTestCase):
|
||||
def test_notify(self):
|
||||
barrier = threading.Event()
|
||||
|
||||
on_notify = mock.MagicMock()
|
||||
on_notify.side_effect = lambda *args, **kwargs: barrier.set()
|
||||
|
||||
handlers = {pr.NOTIFY: on_notify}
|
||||
p = proxy.Proxy("test", "test", handlers,
|
||||
transport='memory',
|
||||
transport_options={
|
||||
'polling_interval': 0.01,
|
||||
})
|
||||
|
||||
t = threading.Thread(target=p.start)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
p.wait()
|
||||
p.publish(pr.Notify(), 'test')
|
||||
|
||||
barrier.wait(1.0)
|
||||
self.assertTrue(barrier.is_set())
|
||||
p.stop()
|
||||
t.join()
|
||||
|
||||
self.assertTrue(on_notify.called)
|
||||
on_notify.assert_called_with({}, mock.ANY)
|
||||
|
||||
def test_response(self):
|
||||
barrier = threading.Event()
|
||||
|
||||
on_response = mock.MagicMock()
|
||||
on_response.side_effect = lambda *args, **kwargs: barrier.set()
|
||||
|
||||
handlers = {pr.RESPONSE: on_response}
|
||||
p = proxy.Proxy("test", "test", handlers,
|
||||
transport='memory',
|
||||
transport_options={
|
||||
'polling_interval': 0.01,
|
||||
})
|
||||
|
||||
t = threading.Thread(target=p.start)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
p.wait()
|
||||
resp = pr.Response(pr.RUNNING)
|
||||
p.publish(resp, 'test')
|
||||
|
||||
barrier.wait(1.0)
|
||||
self.assertTrue(barrier.is_set())
|
||||
p.stop()
|
||||
t.join()
|
||||
|
||||
self.assertTrue(on_response.called)
|
||||
on_response.assert_called_with(resp.to_dict(), mock.ANY)
|
||||
@@ -66,7 +66,6 @@ class TestProxy(test.MockTestCase):
|
||||
self.conn_inst_mock.Consumer.return_value.__exit__ = mock.MagicMock()
|
||||
|
||||
# other mocking
|
||||
self.on_message_mock = mock.MagicMock(name='on_message')
|
||||
self.on_wait_mock = mock.MagicMock(name='on_wait')
|
||||
self.master_mock.attach_mock(self.on_wait_mock, 'on_wait')
|
||||
|
||||
@@ -85,7 +84,7 @@ class TestProxy(test.MockTestCase):
|
||||
auto_delete=True,
|
||||
channel=self.conn_inst_mock),
|
||||
mock.call.connection.Consumer(queues=self.queue_inst_mock,
|
||||
callbacks=[self.on_message_mock]),
|
||||
callbacks=[mock.ANY]),
|
||||
mock.call.connection.Consumer().__enter__(),
|
||||
] + calls + [
|
||||
mock.call.connection.Consumer().__exit__(exc_type, mock.ANY,
|
||||
@@ -95,8 +94,8 @@ class TestProxy(test.MockTestCase):
|
||||
def proxy(self, reset_master_mock=False, **kwargs):
|
||||
proxy_kwargs = dict(topic=self.topic,
|
||||
exchange_name=self.exchange_name,
|
||||
on_message=self.on_message_mock,
|
||||
url=self.broker_url)
|
||||
url=self.broker_url,
|
||||
type_handlers={})
|
||||
proxy_kwargs.update(kwargs)
|
||||
p = proxy.Proxy(**proxy_kwargs)
|
||||
if reset_master_mock:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from kombu import exceptions as exc
|
||||
import mock
|
||||
import six
|
||||
|
||||
@@ -86,9 +85,9 @@ class TestServer(test.MockTestCase):
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.server_topic, self.server_exchange,
|
||||
s._on_message, url=self.broker_url)
|
||||
mock.ANY, url=self.broker_url, on_wait=mock.ANY)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
self.master_mock.assert_has_calls(master_mock_calls)
|
||||
self.assertEqual(len(s._endpoints), 3)
|
||||
|
||||
def test_creation_with_endpoints(self):
|
||||
@@ -97,72 +96,11 @@ class TestServer(test.MockTestCase):
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.server_topic, self.server_exchange,
|
||||
s._on_message, url=self.broker_url)
|
||||
mock.ANY, url=self.broker_url, on_wait=mock.ANY)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
self.master_mock.assert_has_calls(master_mock_calls)
|
||||
self.assertEqual(len(s._endpoints), len(self.endpoints))
|
||||
|
||||
def test_on_message_proxy_running_ack_success(self):
|
||||
request = self.make_request()
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.ack(),
|
||||
mock.call.executor.submit(s._process_request, request,
|
||||
self.message_mock)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_proxy_running_ack_failure(self):
|
||||
self.message_mock.ack.side_effect = exc.MessageStateError('Woot!')
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message({}, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.ack()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_proxy_not_running_requeue_success(self):
|
||||
self.proxy_inst_mock.is_running = False
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message({}, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.requeue()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_proxy_not_running_requeue_failure(self):
|
||||
self.message_mock.requeue.side_effect = exc.MessageStateError('Woot!')
|
||||
self.proxy_inst_mock.is_running = False
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message({}, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.message.requeue()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.LOG.warning')
|
||||
def test_on_message_unknown_type(self, mocked_warning):
|
||||
self.message_mock.properties['type'] = '<unknown>'
|
||||
s = self.server()
|
||||
s._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.LOG.warning')
|
||||
def test_on_message_no_type(self, mocked_warning):
|
||||
self.message_mock.properties = {}
|
||||
s = self.server()
|
||||
s._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
def test_parse_request(self):
|
||||
request = self.make_request()
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
Reference in New Issue
Block a user