From 8a0b972c7871646e10270db29f6eeacb5a5f2699 Mon Sep 17 00:00:00 2001 From: Stanislav Kudriashev Date: Tue, 11 Mar 2014 20:14:06 +0200 Subject: [PATCH] Introduce message types for WBE protocol * Abstract message class added - all messages types are derived from it now and have to implement the `to_dict` method, so it can be serialized and transferred with broker; * Implemented the `Response` message type, that restores failures from dictionary on creation; * Corrected and improved unit tests; Change-Id: I10e017a613f0422420d0244b9f8786f988863107 --- taskflow/engines/worker_based/executor.py | 62 +++--- taskflow/engines/worker_based/protocol.py | 53 ++++- taskflow/engines/worker_based/proxy.py | 3 +- taskflow/engines/worker_based/server.py | 27 ++- .../tests/unit/worker_based/test_executor.py | 187 +++++++++--------- .../tests/unit/worker_based/test_protocol.py | 8 +- .../tests/unit/worker_based/test_proxy.py | 11 +- .../tests/unit/worker_based/test_server.py | 168 ++++++++-------- .../tests/unit/worker_based/test_worker.py | 13 ++ 9 files changed, 312 insertions(+), 220 deletions(-) diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index c8d696dee..0ef75118b 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -28,7 +28,6 @@ from taskflow.engines.worker_based import proxy from taskflow import exceptions as exc from taskflow.utils import async_utils from taskflow.utils import misc -from taskflow.utils import persistence_utils as pu LOG = logging.getLogger(__name__) @@ -58,45 +57,50 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): proxy_thread.daemon = True return proxy_thread - def _on_message(self, response, message): - """This method is called on incoming response.""" - LOG.debug("Got response: %s", response) + def _on_message(self, data, message): + """This method is called on incoming message.""" + LOG.debug("Got message: %s", data) try: - # acknowledge message before processing. + # acknowledge message before processing message.ack() except kombu_exc.MessageStateError: LOG.exception("Failed to acknowledge AMQP message.") else: LOG.debug("AMQP message acknowledged.") - # get task uuid from message correlation id parameter try: - task_uuid = message.properties['correlation_id'] + msg_type = message.properties['type'] except KeyError: - LOG.warning("Got message with no 'correlation_id' property.") + LOG.warning("The 'type' message property is missing.") else: - LOG.debug("Task uuid: '%s'", task_uuid) - self._process_response(task_uuid, response) + if msg_type == pr.RESPONSE: + self._process_response(data, message) + else: + LOG.warning("Unexpected message type: %s", msg_type) - def _process_response(self, task_uuid, response): + def _process_response(self, response, message): """Process response from remote side.""" - request = self._requests_cache.get(task_uuid) - if request is not None: - state = response.pop('state') - if state == pr.RUNNING: - request.set_running() - elif state == pr.PROGRESS: - request.on_progress(**response) - elif state == pr.FAILURE: - response['result'] = pu.failure_from_dict(response['result']) - request.set_result(**response) - self._requests_cache.delete(request.uuid) - elif state == pr.SUCCESS: - request.set_result(**response) - self._requests_cache.delete(request.uuid) - else: - LOG.warning("Unexpected response status: '%s'", state) + LOG.debug("Start processing response message.") + try: + task_uuid = message.properties['correlation_id'] + except KeyError: + LOG.warning("The 'correlation_id' message property is missing.") else: - LOG.debug("Request with id='%s' not found.", task_uuid) + LOG.debug("Task uuid: '%s'", task_uuid) + request = self._requests_cache.get(task_uuid) + if request is not None: + response = pr.Response.from_dict(response) + if response.state == pr.RUNNING: + request.set_running() + elif response.state == pr.PROGRESS: + request.on_progress(**response.data) + elif response.state in (pr.FAILURE, pr.SUCCESS): + request.set_result(**response.data) + self._requests_cache.delete(request.uuid) + else: + LOG.warning("Unexpected response status: '%s'", + response.state) + else: + LOG.debug("Request with id='%s' not found.", task_uuid) @staticmethod def _handle_expired_request(request): @@ -124,7 +128,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): else: # publish request LOG.debug("Sending request: %s", request) - self._proxy.publish(request.to_dict(), + self._proxy.publish(request, routing_key=topic, reply_to=self._uuid, correlation_id=request.uuid) diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index b778bae3f..0c9eca7a4 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -14,6 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. +import abc + +import six + from concurrent import futures from taskflow.engines.action_engine import executor @@ -49,11 +53,28 @@ REQUEST_TIMEOUT = 60 # no longer needed. QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT +# Message types. +REQUEST = 'REQUEST' +RESPONSE = 'RESPONSE' -class Request(object): + +@six.add_metaclass(abc.ABCMeta) +class Message(object): + """Base class for all message types.""" + + def __str__(self): + return str(self.to_dict()) + + @abc.abstractmethod + def to_dict(self): + """Return json-serializable message representation.""" + + +class Request(Message): """Represents request with execution results. Every request is created in the PENDING state and is expired within the given timeout. """ + TYPE = REQUEST def __init__(self, task, uuid, action, arguments, progress_callback, timeout, **kwargs): @@ -112,7 +133,7 @@ class Request(object): if 'failures' in self._kwargs: failures = self._kwargs['failures'] request['failures'] = {} - for task, failure in failures.items(): + for task, failure in six.iteritems(failures): request['failures'][task] = pu.failure_to_dict(failure) return request @@ -125,3 +146,31 @@ class Request(object): def on_progress(self, event_data, progress): self._progress_callback(self._task, event_data, progress) + + +class Response(Message): + """Represents response message type.""" + TYPE = RESPONSE + + def __init__(self, state, **data): + self._state = state + self._data = data + + @classmethod + def from_dict(cls, data): + state = data['state'] + data = data['data'] + if state == FAILURE and 'result' in data: + data['result'] = pu.failure_from_dict(data['result']) + return cls(state, **data) + + @property + def state(self): + return self._state + + @property + def data(self): + return self._data + + def to_dict(self): + return dict(state=self._state, data=self._data) diff --git a/taskflow/engines/worker_based/proxy.py b/taskflow/engines/worker_based/proxy.py index 1fab0f110..c2203f300 100644 --- a/taskflow/engines/worker_based/proxy.py +++ b/taskflow/engines/worker_based/proxy.py @@ -69,10 +69,11 @@ class Proxy(object): """Publish message to the named exchange with routing key.""" with kombu.producers[self._conn].acquire(block=True) as producer: queue = self._make_queue(routing_key, self._exchange) - producer.publish(body=msg, + producer.publish(body=msg.to_dict(), routing_key=routing_key, exchange=self._exchange, declare=[queue], + type=msg.TYPE, **kwargs) def start(self): diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index ff630df99..a018f71cd 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -36,23 +36,32 @@ class Server(object): self._endpoints = dict([(endpoint.name, endpoint) for endpoint in endpoints]) - def _on_message(self, request, message): - """This method is called on incoming request.""" - LOG.debug("Got request: %s", request) - # NOTE(skudriashev): Process all incoming requests only if proxy is + def _on_message(self, data, message): + """This method is called on incoming message.""" + LOG.debug("Got message: %s", data) + # NOTE(skudriashev): Process all incoming messages only if proxy is # running, otherwise requeue them. if self._proxy.is_running: # NOTE(skudriashev): Process request only if message has been # acknowledged successfully. try: - # acknowledge message + # acknowledge message before processing message.ack() except kombu_exc.MessageStateError: LOG.exception("Failed to acknowledge AMQP message.") else: LOG.debug("AMQP message acknowledged.") - # spawn new thread to process request - self._executor.submit(self._process_request, request, message) + try: + msg_type = message.properties['type'] + except KeyError: + LOG.warning("The 'type' message property is missing.") + else: + if msg_type == pr.REQUEST: + # spawn new thread to process request + self._executor.submit(self._process_request, data, + message) + else: + LOG.warning("Unexpected message type: %s", msg_type) else: try: # requeue message @@ -100,7 +109,7 @@ class Server(object): def _reply(self, reply_to, task_uuid, state=pr.FAILURE, **kwargs): """Send reply to the `reply_to` queue.""" - response = dict(state=state, **kwargs) + response = pr.Response(state, **kwargs) LOG.debug("Sending reply: %s", response) try: self._proxy.publish(response, reply_to, correlation_id=task_uuid) @@ -115,7 +124,7 @@ class Server(object): def _process_request(self, request, message): """Process request in separate thread and reply back.""" - # NOTE(skudriashev): parse broker message first to get the `reply_to` + # NOTE(skudriashev): Parse broker message first to get the `reply_to` # and the `task_uuid` parameters to have possibility to reply back. try: reply_to, task_uuid = self._parse_message(message) diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index ff206660b..63f987ddc 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -49,15 +49,20 @@ class TestWorkerTaskExecutor(test.MockTestCase): # patch classes self.proxy_mock, self.proxy_inst_mock = self._patch_class( executor.proxy, 'Proxy') + self.request_mock, self.request_inst_mock = self._patch_class( + executor.pr, 'Request', autospec=False) # other mocking self.proxy_inst_mock.start.side_effect = self._fake_proxy_start self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop + self.request_inst_mock.uuid = self.task_uuid + self.request_inst_mock.expired = False + self.request_inst_mock.task_cls = self.task.name self.wait_for_any_mock = self._patch( 'taskflow.engines.worker_based.executor.async_utils.wait_for_any') self.message_mock = mock.MagicMock(name='message') - self.message_mock.properties = {'correlation_id': self.task_uuid} - self.request_mock = mock.MagicMock(uuid=self.task_uuid) + self.message_mock.properties = {'correlation_id': self.task_uuid, + 'type': pr.RESPONSE} def _fake_proxy_start(self): self.proxy_started_event.set() @@ -78,20 +83,6 @@ class TestWorkerTaskExecutor(test.MockTestCase): self._reset_master_mock() return ex - def request(self, **kwargs): - request_kwargs = dict(task=self.task, uuid=self.task_uuid, - action='execute', arguments=self.task_args, - progress_callback=None, timeout=self.timeout) - request_kwargs.update(kwargs) - return pr.Request(**request_kwargs) - - def request_dict(self, **kwargs): - request = dict(task_cls=self.task.name, task_name=self.task.name, - task_version=self.task.version, - arguments=self.task_args) - request.update(kwargs) - return request - def test_creation(self): ex = self.executor(reset_master_mock=False) @@ -101,184 +92,190 @@ class TestWorkerTaskExecutor(test.MockTestCase): ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) - def test_on_message_state_running(self): - response = dict(state=pr.RUNNING) + def test_on_message_response_state_running(self): + response = pr.Response(pr.RUNNING) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(self.request_mock.mock_calls, + self.assertEqual(self.request_inst_mock.mock_calls, [mock.call.set_running()]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) - def test_on_message_state_progress(self): - response = dict(state=pr.PROGRESS, progress=1.0) + def test_on_message_response_state_progress(self): + response = pr.Response(pr.PROGRESS, progress=1.0) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(self.request_mock.mock_calls, + self.assertEqual(self.request_inst_mock.mock_calls, [mock.call.on_progress(progress=1.0)]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) - def test_on_message_state_failure(self): + def test_on_message_response_state_failure(self): failure = misc.Failure.from_exception(Exception('test')) failure_dict = pu.failure_to_dict(failure) - response = dict(state=pr.FAILURE, result=failure_dict) + response = pr.Response(pr.FAILURE, result=failure_dict) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) self.assertEqual(len(ex._requests_cache._data), 0) - self.assertEqual(self.request_mock.mock_calls, [ + self.assertEqual(self.request_inst_mock.mock_calls, [ mock.call.set_result(result=utils.FailureMatcher(failure)) ]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) - def test_on_message_state_success(self): - response = dict(state=pr.SUCCESS, result=self.task_result, - event='executed') + def test_on_message_response_state_success(self): + response = pr.Response(pr.SUCCESS, result=self.task_result, + event='executed') ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(self.request_mock.mock_calls, + self.assertEqual(self.request_inst_mock.mock_calls, [mock.call.set_result(result=self.task_result, event='executed')]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) - def test_on_message_unknown_state(self): - response = dict(state='unknown') + def test_on_message_response_unknown_state(self): + response = pr.Response(state='') ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(self.request_mock.mock_calls, []) + self.assertEqual(self.request_inst_mock.mock_calls, []) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) - def test_on_message_non_existent_task(self): - self.message_mock.properties = {'correlation_id': 'non-existent'} - response = dict(state=pr.RUNNING) + def test_on_message_response_unknown_task(self): + self.message_mock.properties['correlation_id'] = '' + response = pr.Response(pr.RUNNING) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(self.request_mock.mock_calls, []) + self.assertEqual(self.request_inst_mock.mock_calls, []) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) - def test_on_message_no_correlation_id(self): - self.message_mock.properties = {} - response = dict(state=pr.RUNNING) + def test_on_message_response_no_correlation_id(self): + self.message_mock.properties = {'type': pr.RESPONSE} + response = pr.Response(pr.RUNNING) ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request_mock) - ex._on_message(response, self.message_mock) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) + ex._on_message(response.to_dict(), self.message_mock) - self.assertEqual(self.request_mock.mock_calls, []) + self.assertEqual(self.request_inst_mock.mock_calls, []) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) + @mock.patch('taskflow.engines.worker_based.executor.LOG.warning') + def test_on_message_unknown_type(self, mocked_warning): + self.message_mock.properties = {'correlation_id': self.task_uuid, + 'type': ''} + ex = self.executor() + ex._on_message({}, self.message_mock) + self.assertTrue(mocked_warning.called) + + @mock.patch('taskflow.engines.worker_based.executor.LOG.warning') + def test_on_message_no_type(self, mocked_warning): + self.message_mock.properties = {'correlation_id': self.task_uuid} + ex = self.executor() + ex._on_message({}, self.message_mock) + self.assertTrue(mocked_warning.called) + @mock.patch('taskflow.engines.worker_based.executor.LOG.exception') def test_on_message_acknowledge_raises(self, mocked_exception): self.message_mock.ack.side_effect = kombu_exc.MessageStateError() self.executor()._on_message({}, self.message_mock) self.assertTrue(mocked_exception.called) - @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') - def test_on_wait_task_not_expired(self, mocked_wallclock): - mocked_wallclock.side_effect = [1, self.timeout] + def test_on_wait_task_not_expired(self): ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request()) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) self.assertEqual(len(ex._requests_cache._data), 1) ex._on_wait() self.assertEqual(len(ex._requests_cache._data), 1) - @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') - def test_on_wait_task_expired(self, mocked_time): - mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2] + def test_on_wait_task_expired(self): + self.request_inst_mock.expired = True ex = self.executor() - ex._requests_cache.set(self.task_uuid, self.request()) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) self.assertEqual(len(ex._requests_cache._data), 1) ex._on_wait() self.assertEqual(len(ex._requests_cache._data), 0) def test_remove_task_non_existent(self): - task = self.request() ex = self.executor() - ex._requests_cache.set(self.task_uuid, task) + ex._requests_cache.set(self.task_uuid, self.request_inst_mock) self.assertEqual(len(ex._requests_cache._data), 1) ex._requests_cache.delete(self.task_uuid) self.assertEqual(len(ex._requests_cache._data), 0) - # remove non-existent + # delete non-existent ex._requests_cache.delete(self.task_uuid) self.assertEqual(len(ex._requests_cache._data), 0) def test_execute_task(self): - request_dict = self.request_dict(action='execute') ex = self.executor() - result = ex.execute_task(self.task, self.task_uuid, self.task_args) + ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.proxy.publish(request_dict, + mock.call.Request(self.task, self.task_uuid, 'execute', + self.task_args, None, self.timeout), + mock.call.proxy.publish(self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) - self.assertIsInstance(result, futures.Future) def test_revert_task(self): - request_dict = self.request_dict(action='revert', - result=('success', self.task_result), - failures=self.task_failures) ex = self.executor() - result = ex.revert_task(self.task, self.task_uuid, self.task_args, - self.task_result, self.task_failures) + ex.revert_task(self.task, self.task_uuid, self.task_args, + self.task_result, self.task_failures) expected_calls = [ - mock.call.proxy.publish(request_dict, + mock.call.Request(self.task, self.task_uuid, 'revert', + self.task_args, None, self.timeout, + failures=self.task_failures, + result=self.task_result), + mock.call.proxy.publish(self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) - self.assertIsInstance(result, futures.Future) def test_execute_task_topic_not_found(self): - workers_info = {self.executor_topic: ['non-existent-task']} + workers_info = {self.executor_topic: ['']} ex = self.executor(workers_info=workers_info) - result = ex.execute_task(self.task, self.task_uuid, self.task_args) + ex.execute_task(self.task, self.task_uuid, self.task_args) - self.assertFalse(self.proxy_inst_mock.publish.called) - - # check execute result - task, event, res = result.result() - self.assertEqual(task, self.task) - self.assertEqual(event, 'executed') - self.assertIsInstance(res, misc.Failure) + expected_calls = [ + mock.call.Request(self.task, self.task_uuid, 'execute', + self.task_args, None, self.timeout), + mock.call.request.set_result(mock.ANY) + ] + self.assertEqual(self.master_mock.mock_calls, expected_calls) def test_execute_task_publish_error(self): self.proxy_inst_mock.publish.side_effect = Exception('Woot!') - request_dict = self.request_dict(action='execute') ex = self.executor() - result = ex.execute_task(self.task, self.task_uuid, self.task_args) + ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.proxy.publish(request_dict, + mock.call.Request(self.task, self.task_uuid, 'execute', + self.task_args, None, self.timeout), + mock.call.proxy.publish(self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, - correlation_id=self.task_uuid) + correlation_id=self.task_uuid), + mock.call.request.set_result(mock.ANY) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) - # check execute result - task, event, res = result.result() - self.assertEqual(task, self.task) - self.assertEqual(event, 'executed') - self.assertIsInstance(res, misc.Failure) - def test_wait_for_any(self): fs = [futures.Future(), futures.Future()] ex = self.executor() diff --git a/taskflow/tests/unit/worker_based/test_protocol.py b/taskflow/tests/unit/worker_based/test_protocol.py index ae805621d..1f31da0e6 100644 --- a/taskflow/tests/unit/worker_based/test_protocol.py +++ b/taskflow/tests/unit/worker_based/test_protocol.py @@ -61,9 +61,13 @@ class TestProtocol(test.TestCase): self.assertIsInstance(request.result, futures.Future) self.assertFalse(request.result.done()) + def test_str(self): + request = self.request() + self.assertEqual(str(request), str(request.to_dict())) + def test_repr(self): - expected_name = '%s:%s' % (self.task.name, self.task_action) - self.assertEqual(repr(self.request()), expected_name) + expected = '%s:%s' % (self.task.name, self.task_action) + self.assertEqual(repr(self.request()), expected) def test_to_dict_default(self): self.assertEqual(self.request().to_dict(), self.request_to_dict()) diff --git a/taskflow/tests/unit/worker_based/test_proxy.py b/taskflow/tests/unit/worker_based/test_proxy.py index d181ed1ed..8876c23cd 100644 --- a/taskflow/tests/unit/worker_based/test_proxy.py +++ b/taskflow/tests/unit/worker_based/test_proxy.py @@ -128,13 +128,15 @@ class TestProxy(test.MockTestCase): self.assertEqual(self.master_mock.mock_calls, master_mock_calls) def test_publish(self): - task_data = 'task-data' - task_uuid = 'task-uuid' + msg_mock = mock.MagicMock() + msg_data = 'msg-data' + msg_mock.to_dict.return_value = msg_data routing_key = 'routing-key' + task_uuid = 'task-uuid' kwargs = dict(a='a', b='b') self.proxy(reset_master_mock=True).publish( - task_data, routing_key, correlation_id=task_uuid, **kwargs) + msg_mock, routing_key, correlation_id=task_uuid, **kwargs) master_mock_calls = [ mock.call.Queue(name=self._queue_name(routing_key), @@ -142,11 +144,12 @@ class TestProxy(test.MockTestCase): routing_key=routing_key, durable=False, auto_delete=True), - mock.call.producer.publish(body=task_data, + mock.call.producer.publish(body=msg_data, routing_key=routing_key, exchange=self.exchange_inst_mock, correlation_id=task_uuid, declare=[self.queue_inst_mock], + type=msg_mock.TYPE, **kwargs) ] self.master_mock.assert_has_calls(master_mock_calls) diff --git a/taskflow/tests/unit/worker_based/test_server.py b/taskflow/tests/unit/worker_based/test_server.py index af7abc05d..2ad94eaaa 100644 --- a/taskflow/tests/unit/worker_based/test_server.py +++ b/taskflow/tests/unit/worker_based/test_server.py @@ -16,6 +16,8 @@ import mock +import six + from kombu import exceptions as exc from taskflow.engines.worker_based import endpoint as ep @@ -24,7 +26,6 @@ from taskflow.engines.worker_based import server from taskflow import test from taskflow.tests import utils from taskflow.utils import misc -from taskflow.utils import persistence_utils as pu class TestServer(test.MockTestCase): @@ -34,27 +35,28 @@ class TestServer(test.MockTestCase): self.server_topic = 'server-topic' self.server_exchange = 'server-exchange' self.broker_url = 'test-url' + self.task = utils.TaskOneArgOneReturn() self.task_uuid = 'task-uuid' self.task_args = {'x': 1} self.task_action = 'execute' - self.task_name = 'taskflow.tests.utils.TaskOneArgOneReturn' - self.task_version = (1, 0) self.reply_to = 'reply-to' self.endpoints = [ep.Endpoint(task_cls=utils.TaskOneArgOneReturn), ep.Endpoint(task_cls=utils.TaskWithFailure), ep.Endpoint(task_cls=utils.ProgressingTask)] - self.resp_running = dict(state=pr.RUNNING) # patch classes self.proxy_mock, self.proxy_inst_mock = self._patch_class( server.proxy, 'Proxy') + self.response_mock, self.response_inst_mock = self._patch_class( + server.pr, 'Response') # other mocking self.proxy_inst_mock.is_running = True self.executor_mock = mock.MagicMock(name='executor') self.message_mock = mock.MagicMock(name='message') self.message_mock.properties = {'correlation_id': self.task_uuid, - 'reply_to': self.reply_to} + 'reply_to': self.reply_to, + 'type': pr.REQUEST} self.master_mock.attach_mock(self.executor_mock, 'executor') self.master_mock.attach_mock(self.message_mock, 'message') @@ -70,28 +72,15 @@ class TestServer(test.MockTestCase): self._reset_master_mock() return s - def request(self, **kwargs): - request = dict(task_cls=self.task_name, - task_name=self.task_name, - action=self.task_action, - task_version=self.task_version, - arguments=self.task_args) - request.update(kwargs) - return request - - @staticmethod - def resp_progress(progress): - return dict(state=pr.PROGRESS, progress=progress, event_data={}) - - @staticmethod - def resp_success(result): - return dict(state=pr.SUCCESS, result=result) - - @staticmethod - def resp_failure(result, **kwargs): - response = dict(state=pr.FAILURE, result=result) - response.update(kwargs) - return response + def make_request(self, **kwargs): + request_kwargs = dict(task=self.task, + uuid=self.task_uuid, + action=self.task_action, + arguments=self.task_args, + progress_callback=None, + timeout=60) + request_kwargs.update(kwargs) + return pr.Request(**request_kwargs).to_dict() def test_creation(self): s = self.server() @@ -116,7 +105,7 @@ class TestServer(test.MockTestCase): self.assertEqual(len(s._endpoints), len(self.endpoints)) def test_on_message_proxy_running_ack_success(self): - request = self.request() + request = self.make_request() s = self.server(reset_master_mock=True) s._on_message(request, self.message_mock) @@ -162,99 +151,115 @@ class TestServer(test.MockTestCase): ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) + @mock.patch('taskflow.engines.worker_based.server.LOG.warning') + def test_on_message_unknown_type(self, mocked_warning): + self.message_mock.properties['type'] = '' + s = self.server() + s._on_message({}, self.message_mock) + self.assertTrue(mocked_warning.called) + + @mock.patch('taskflow.engines.worker_based.server.LOG.warning') + def test_on_message_no_type(self, mocked_warning): + self.message_mock.properties = {} + s = self.server() + s._on_message({}, self.message_mock) + self.assertTrue(mocked_warning.called) + def test_parse_request(self): - request = self.request() + request = self.make_request() task_cls, action, task_args = server.Server._parse_request(**request) self.assertEqual((task_cls, action, task_args), - (self.task_name, self.task_action, - dict(task_name=self.task_name, + (self.task.name, self.task_action, + dict(task_name=self.task.name, arguments=self.task_args))) def test_parse_request_with_success_result(self): - request = self.request(action='revert', result=('success', 1)) + request = self.make_request(action='revert', result=1) task_cls, action, task_args = server.Server._parse_request(**request) self.assertEqual((task_cls, action, task_args), - (self.task_name, 'revert', - dict(task_name=self.task_name, + (self.task.name, 'revert', + dict(task_name=self.task.name, arguments=self.task_args, result=1))) def test_parse_request_with_failure_result(self): failure = misc.Failure.from_exception(Exception('test')) - failure_dict = pu.failure_to_dict(failure) - request = self.request(action='revert', - result=('failure', failure_dict)) + request = self.make_request(action='revert', result=failure) task_cls, action, task_args = server.Server._parse_request(**request) self.assertEqual((task_cls, action, task_args), - (self.task_name, 'revert', - dict(task_name=self.task_name, + (self.task.name, 'revert', + dict(task_name=self.task.name, arguments=self.task_args, result=utils.FailureMatcher(failure)))) def test_parse_request_with_failures(self): - failures = [misc.Failure.from_exception(Exception('test1')), - misc.Failure.from_exception(Exception('test2'))] - failures_dict = dict((str(i), pu.failure_to_dict(f)) - for i, f in enumerate(failures)) - request = self.request(action='revert', failures=failures_dict) + failures = {'0': misc.Failure.from_exception(Exception('test1')), + '1': misc.Failure.from_exception(Exception('test2'))} + request = self.make_request(action='revert', failures=failures) task_cls, action, task_args = server.Server._parse_request(**request) self.assertEqual( (task_cls, action, task_args), - (self.task_name, 'revert', - dict(task_name=self.task_name, + (self.task.name, 'revert', + dict(task_name=self.task.name, arguments=self.task_args, - failures=dict((str(i), utils.FailureMatcher(f)) - for i, f in enumerate(failures))))) + failures=dict((i, utils.FailureMatcher(f)) + for i, f in six.iteritems(failures))))) @mock.patch("taskflow.engines.worker_based.server.LOG.exception") def test_reply_publish_failure(self, mocked_exception): self.proxy_inst_mock.publish.side_effect = RuntimeError('Woot!') # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) + s = self.server(reset_master_mock=True) s._reply(self.reply_to, self.task_uuid) self.assertEqual(self.master_mock.mock_calls, [ - mock.call.proxy.publish({'state': 'FAILURE'}, self.reply_to, + mock.call.Response(pr.FAILURE), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ]) self.assertTrue(mocked_exception.called) def test_on_update_progress(self): - request = self.request(task_cls='taskflow.tests.utils.ProgressingTask', - arguments={}) + request = self.make_request(task=utils.ProgressingTask(), arguments={}) # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) + s = self.server(reset_master_mock=True) s._process_request(request, self.message_mock) # check calls master_mock_calls = [ - mock.call.proxy.publish(self.resp_running, self.reply_to, + mock.call.Response(pr.RUNNING), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid), - mock.call.proxy.publish(self.resp_progress(0.0), self.reply_to, + mock.call.Response(pr.PROGRESS, progress=0.0, event_data={}), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid), - mock.call.proxy.publish(self.resp_progress(1.0), self.reply_to, + mock.call.Response(pr.PROGRESS, progress=1.0, event_data={}), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid), - mock.call.proxy.publish(self.resp_success(5), self.reply_to, + mock.call.Response(pr.SUCCESS, result=5), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) def test_process_request(self): # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) - s._process_request(self.request(), self.message_mock) + s = self.server(reset_master_mock=True) + s._process_request(self.make_request(), self.message_mock) # check calls master_mock_calls = [ - mock.call.proxy.publish(self.resp_running, self.reply_to, + mock.call.Response(pr.RUNNING), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid), - mock.call.proxy.publish(self.resp_success(1), self.reply_to, + mock.call.Response(pr.SUCCESS, result=1), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) @@ -262,7 +267,7 @@ class TestServer(test.MockTestCase): @mock.patch("taskflow.engines.worker_based.server.LOG.exception") def test_process_request_parse_message_failure(self, mocked_exception): self.message_mock.properties = {} - request = self.request() + request = self.make_request() s = self.server(reset_master_mock=True) s._process_request(request, self.message_mock) @@ -270,19 +275,21 @@ class TestServer(test.MockTestCase): self.assertTrue(mocked_exception.called) @mock.patch('taskflow.engines.worker_based.server.pu') - def test_process_request_parse_failure(self, pu_mock): + def test_process_request_parse_request_failure(self, pu_mock): failure_dict = 'failure_dict' + failure = misc.Failure.from_exception(RuntimeError('Woot!')) pu_mock.failure_to_dict.return_value = failure_dict pu_mock.failure_from_dict.side_effect = ValueError('Woot!') - request = self.request(result=('failure', 1)) + request = self.make_request(result=failure) # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) + s = self.server(reset_master_mock=True) s._process_request(request, self.message_mock) # check calls master_mock_calls = [ - mock.call.proxy.publish(self.resp_failure(failure_dict), + mock.call.Response(pr.FAILURE, result=failure_dict), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ] @@ -292,15 +299,16 @@ class TestServer(test.MockTestCase): def test_process_request_endpoint_not_found(self, pu_mock): failure_dict = 'failure_dict' pu_mock.failure_to_dict.return_value = failure_dict - request = self.request(task_cls='') + request = self.make_request(task=mock.MagicMock(name='')) # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) + s = self.server(reset_master_mock=True) s._process_request(request, self.message_mock) # check calls master_mock_calls = [ - mock.call.proxy.publish(self.resp_failure(failure_dict), + mock.call.Response(pr.FAILURE, result=failure_dict), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ] @@ -310,17 +318,20 @@ class TestServer(test.MockTestCase): def test_process_request_execution_failure(self, pu_mock): failure_dict = 'failure_dict' pu_mock.failure_to_dict.return_value = failure_dict - request = self.request(action='') + request = self.make_request() + request['action'] = '' # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) + s = self.server(reset_master_mock=True) s._process_request(request, self.message_mock) # check calls master_mock_calls = [ - mock.call.proxy.publish(self.resp_running, self.reply_to, + mock.call.Response(pr.RUNNING), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid), - mock.call.proxy.publish(self.resp_failure(failure_dict), + mock.call.Response(pr.FAILURE, result=failure_dict), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ] @@ -330,18 +341,19 @@ class TestServer(test.MockTestCase): def test_process_request_task_failure(self, pu_mock): failure_dict = 'failure_dict' pu_mock.failure_to_dict.return_value = failure_dict - request = self.request(task='taskflow.tests.utils.TaskWithFailure', - arguments={}) + request = self.make_request(task=utils.TaskWithFailure(), arguments={}) # create server and process request - s = self.server(reset_master_mock=True, endpoints=self.endpoints) + s = self.server(reset_master_mock=True) s._process_request(request, self.message_mock) # check calls master_mock_calls = [ - mock.call.proxy.publish(self.resp_running, self.reply_to, + mock.call.Response(pr.RUNNING), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid), - mock.call.proxy.publish(self.resp_failure(failure_dict), + mock.call.Response(pr.FAILURE, result=failure_dict), + mock.call.proxy.publish(self.response_inst_mock, self.reply_to, correlation_id=self.task_uuid) ] diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index c2d6dcfc8..658c98f4d 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -77,6 +77,9 @@ class TestWorker(test.MockTestCase): ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) + def test_creation_with_negative_threads_count(self): + self.assertRaises(ValueError, self.worker, threads_count=-10) + def test_creation_with_custom_executor(self): executor_mock = mock.MagicMock(name='executor') self.worker(executor=executor_mock) @@ -104,6 +107,16 @@ class TestWorker(test.MockTestCase): ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) + def test_run_with_custom_executor(self): + executor_mock = mock.MagicMock(name='executor') + self.worker(reset_master_mock=True, + executor=executor_mock).run() + + master_mock_calls = [ + mock.call.server.start() + ] + self.assertEqual(self.master_mock.mock_calls, master_mock_calls) + def test_wait(self): w = self.worker(reset_master_mock=True) w.run()