Rename remote task to request
* Renamed remote task -> request and moved to protocol.py; * Used `to_dict` method instead of `request` property for request objects; * Renamed `name` request property to `task_cls`; * Corrected unit tests. Change-Id: I6133748ab5064391480f031971c38a56cb7f4f9f
This commit is contained in:
		| @@ -25,7 +25,6 @@ from taskflow.engines.action_engine import executor | ||||
| from taskflow.engines.worker_based import cache | ||||
| from taskflow.engines.worker_based import protocol as pr | ||||
| from taskflow.engines.worker_based import proxy | ||||
| from taskflow.engines.worker_based import remote_task as rt | ||||
| from taskflow import exceptions as exc | ||||
| from taskflow.utils import async_utils | ||||
| from taskflow.utils import misc | ||||
| @@ -42,7 +41,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): | ||||
|         self._proxy = proxy.Proxy(uuid, exchange, self._on_message, | ||||
|                                   self._on_wait, **kwargs) | ||||
|         self._proxy_thread = None | ||||
|         self._remote_tasks_cache = cache.Cache() | ||||
|         self._requests_cache = cache.Cache() | ||||
|  | ||||
|         # TODO(skudriashev): This data should be collected from workers | ||||
|         # using broadcast messages directly. | ||||
| @@ -80,61 +79,61 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): | ||||
|  | ||||
|     def _process_response(self, task_uuid, response): | ||||
|         """Process response from remote side.""" | ||||
|         remote_task = self._remote_tasks_cache.get(task_uuid) | ||||
|         if remote_task is not None: | ||||
|         request = self._requests_cache.get(task_uuid) | ||||
|         if request is not None: | ||||
|             state = response.pop('state') | ||||
|             if state == pr.RUNNING: | ||||
|                 remote_task.set_running() | ||||
|                 request.set_running() | ||||
|             elif state == pr.PROGRESS: | ||||
|                 remote_task.on_progress(**response) | ||||
|                 request.on_progress(**response) | ||||
|             elif state == pr.FAILURE: | ||||
|                 response['result'] = pu.failure_from_dict(response['result']) | ||||
|                 remote_task.set_result(**response) | ||||
|                 self._remote_tasks_cache.delete(remote_task.uuid) | ||||
|                 request.set_result(**response) | ||||
|                 self._requests_cache.delete(request.uuid) | ||||
|             elif state == pr.SUCCESS: | ||||
|                 remote_task.set_result(**response) | ||||
|                 self._remote_tasks_cache.delete(remote_task.uuid) | ||||
|                 request.set_result(**response) | ||||
|                 self._requests_cache.delete(request.uuid) | ||||
|             else: | ||||
|                 LOG.warning("Unexpected response status: '%s'", state) | ||||
|         else: | ||||
|             LOG.debug("Remote task with id='%s' not found.", task_uuid) | ||||
|             LOG.debug("Request with id='%s' not found.", task_uuid) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _handle_expired_remote_task(task): | ||||
|         LOG.debug("Remote task '%r' has expired.", task) | ||||
|         task.set_result(misc.Failure.from_exception( | ||||
|             exc.Timeout("Remote task '%r' has expired" % task))) | ||||
|     def _handle_expired_request(request): | ||||
|         LOG.debug("Request '%r' has expired.", request) | ||||
|         request.set_result(misc.Failure.from_exception( | ||||
|             exc.Timeout("Request '%r' has expired" % request))) | ||||
|  | ||||
|     def _on_wait(self): | ||||
|         """This function is called cyclically between draining events.""" | ||||
|         self._remote_tasks_cache.cleanup(self._handle_expired_remote_task) | ||||
|         self._requests_cache.cleanup(self._handle_expired_request) | ||||
|  | ||||
|     def _submit_task(self, task, task_uuid, action, arguments, | ||||
|                      progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs): | ||||
|         """Submit task request to workers.""" | ||||
|         remote_task = rt.RemoteTask(task, task_uuid, action, arguments, | ||||
|         request = pr.Request(task, task_uuid, action, arguments, | ||||
|                              progress_callback, timeout, **kwargs) | ||||
|         self._remote_tasks_cache.set(remote_task.uuid, remote_task) | ||||
|         self._requests_cache.set(request.uuid, request) | ||||
|         try: | ||||
|             # get task's workers topic to send request to | ||||
|             try: | ||||
|                 topic = self._workers_info[remote_task.name] | ||||
|                 topic = self._workers_info[request.task_cls] | ||||
|             except KeyError: | ||||
|                 raise exc.NotFound("Workers topic not found for the '%s'" | ||||
|                                    " task" % remote_task.name) | ||||
|                                    " task" % request.task_cls) | ||||
|             else: | ||||
|                 # publish request | ||||
|                 request = remote_task.request | ||||
|                 LOG.debug("Sending request: %s", request) | ||||
|                 self._proxy.publish(request, routing_key=topic, | ||||
|                 self._proxy.publish(request.to_dict(), | ||||
|                                     routing_key=topic, | ||||
|                                     reply_to=self._uuid, | ||||
|                                     correlation_id=remote_task.uuid) | ||||
|                                     correlation_id=request.uuid) | ||||
|         except Exception: | ||||
|             with misc.capture_failure() as failure: | ||||
|                 LOG.exception("Failed to submit the '%s' task", remote_task) | ||||
|                 self._remote_tasks_cache.delete(remote_task.uuid) | ||||
|                 remote_task.set_result(failure) | ||||
|         return remote_task.result | ||||
|                 LOG.exception("Failed to submit the '%s' task", request) | ||||
|                 self._requests_cache.delete(request.uuid) | ||||
|                 request.set_result(failure) | ||||
|         return request.result | ||||
|  | ||||
|     def execute_task(self, task, task_uuid, arguments, | ||||
|                      progress_callback=None): | ||||
|   | ||||
| @@ -14,7 +14,12 @@ | ||||
| #    License for the specific language governing permissions and limitations | ||||
| #    under the License. | ||||
|  | ||||
| from concurrent import futures | ||||
|  | ||||
| from taskflow.engines.action_engine import executor | ||||
| from taskflow.utils import misc | ||||
| from taskflow.utils import persistence_utils as pu | ||||
| from taskflow.utils import reflection | ||||
|  | ||||
| # NOTE(skudriashev): This is protocol events, not related to the task states. | ||||
| PENDING = 'PENDING' | ||||
| @@ -43,3 +48,80 @@ REQUEST_TIMEOUT = 60 | ||||
| # period is equal to the request timeout, once request is expired - queue is | ||||
| # no longer needed. | ||||
| QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT | ||||
|  | ||||
|  | ||||
| class Request(object): | ||||
|     """Represents request with execution results. Every request is created in | ||||
|     the PENDING state and is expired within the given timeout. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, task, uuid, action, arguments, progress_callback, | ||||
|                  timeout, **kwargs): | ||||
|         self._task = task | ||||
|         self._task_cls = reflection.get_class_name(task) | ||||
|         self._uuid = uuid | ||||
|         self._action = action | ||||
|         self._event = ACTION_TO_EVENT[action] | ||||
|         self._arguments = arguments | ||||
|         self._progress_callback = progress_callback | ||||
|         self._kwargs = kwargs | ||||
|         self._watch = misc.StopWatch(duration=timeout).start() | ||||
|         self._state = PENDING | ||||
|         self.result = futures.Future() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "%s:%s" % (self._task_cls, self._action) | ||||
|  | ||||
|     @property | ||||
|     def uuid(self): | ||||
|         return self._uuid | ||||
|  | ||||
|     @property | ||||
|     def task_cls(self): | ||||
|         return self._task_cls | ||||
|  | ||||
|     @property | ||||
|     def expired(self): | ||||
|         """Check if request is expired. | ||||
|  | ||||
|         When new request is created its state is set to the PENDING, creation | ||||
|         time is stored and timeout is given via constructor arguments. | ||||
|  | ||||
|         Request is considered to be expired when it is in the PENDING state | ||||
|         for more then the given timeout (it is not considered to be expired | ||||
|         in any other state). After request is expired - the `Timeout` | ||||
|         exception is raised and task is removed from the requests map. | ||||
|         """ | ||||
|         if self._state == PENDING: | ||||
|             return self._watch.expired() | ||||
|         return False | ||||
|  | ||||
|     def to_dict(self): | ||||
|         """Return json-serializable request, converting all `misc.Failure` | ||||
|         objects into dictionaries. | ||||
|         """ | ||||
|         request = dict(task_cls=self._task_cls, task_name=self._task.name, | ||||
|                        task_version=self._task.version, action=self._action, | ||||
|                        arguments=self._arguments) | ||||
|         if 'result' in self._kwargs: | ||||
|             result = self._kwargs['result'] | ||||
|             if isinstance(result, misc.Failure): | ||||
|                 request['result'] = ('failure', pu.failure_to_dict(result)) | ||||
|             else: | ||||
|                 request['result'] = ('success', result) | ||||
|         if 'failures' in self._kwargs: | ||||
|             failures = self._kwargs['failures'] | ||||
|             request['failures'] = {} | ||||
|             for task, failure in failures.items(): | ||||
|                 request['failures'][task] = pu.failure_to_dict(failure) | ||||
|         return request | ||||
|  | ||||
|     def set_result(self, result): | ||||
|         self.result.set_result((self._task, self._event, result)) | ||||
|  | ||||
|     def set_running(self): | ||||
|         self._state = RUNNING | ||||
|         self._watch.stop() | ||||
|  | ||||
|     def on_progress(self, event_data, progress): | ||||
|         self._progress_callback(self._task, event_data, progress) | ||||
|   | ||||
| @@ -1,101 +0,0 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| #    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. | ||||
| # | ||||
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| #    not use this file except in compliance with the License. You may obtain | ||||
| #    a copy of the License at | ||||
| # | ||||
| #         http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| #    Unless required by applicable law or agreed to in writing, software | ||||
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
| #    License for the specific language governing permissions and limitations | ||||
| #    under the License. | ||||
|  | ||||
| from concurrent import futures | ||||
|  | ||||
| from taskflow.engines.worker_based import protocol as pr | ||||
| from taskflow.utils import misc | ||||
| from taskflow.utils import persistence_utils as pu | ||||
| from taskflow.utils import reflection | ||||
|  | ||||
|  | ||||
| class RemoteTask(object): | ||||
|     """Represents remote task with its request data and execution results. | ||||
|     Every remote task is created in the PENDING state and will be expired | ||||
|     within the given timeout. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, task, uuid, action, arguments, progress_callback, | ||||
|                  timeout, **kwargs): | ||||
|         self._task = task | ||||
|         self._name = reflection.get_class_name(task) | ||||
|         self._uuid = uuid | ||||
|         self._action = action | ||||
|         self._event = pr.ACTION_TO_EVENT[action] | ||||
|         self._arguments = arguments | ||||
|         self._progress_callback = progress_callback | ||||
|         self._kwargs = kwargs | ||||
|         self._watch = misc.StopWatch(duration=timeout).start() | ||||
|         self._state = pr.PENDING | ||||
|         self.result = futures.Future() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "%s:%s" % (self._name, self._action) | ||||
|  | ||||
|     @property | ||||
|     def uuid(self): | ||||
|         return self._uuid | ||||
|  | ||||
|     @property | ||||
|     def name(self): | ||||
|         return self._name | ||||
|  | ||||
|     @property | ||||
|     def request(self): | ||||
|         """Return json-serializable task request, converting all `misc.Failure` | ||||
|         objects into dictionaries. | ||||
|         """ | ||||
|         request = dict(task=self._name, task_name=self._task.name, | ||||
|                        task_version=self._task.version, action=self._action, | ||||
|                        arguments=self._arguments) | ||||
|         if 'result' in self._kwargs: | ||||
|             result = self._kwargs['result'] | ||||
|             if isinstance(result, misc.Failure): | ||||
|                 request['result'] = ('failure', pu.failure_to_dict(result)) | ||||
|             else: | ||||
|                 request['result'] = ('success', result) | ||||
|         if 'failures' in self._kwargs: | ||||
|             failures = self._kwargs['failures'] | ||||
|             request['failures'] = {} | ||||
|             for task, failure in failures.items(): | ||||
|                 request['failures'][task] = pu.failure_to_dict(failure) | ||||
|         return request | ||||
|  | ||||
|     @property | ||||
|     def expired(self): | ||||
|         """Check if task is expired. | ||||
|  | ||||
|         When new remote task is created its state is set to the PENDING, | ||||
|         creation time is stored and timeout is given via constructor arguments. | ||||
|  | ||||
|         Remote task is considered to be expired when it is in the PENDING state | ||||
|         for more then the given timeout (task is not considered to be expired | ||||
|         in any other state). After remote task is expired - the `Timeout` | ||||
|         exception is raised and task is removed from the remote tasks map. | ||||
|         """ | ||||
|         if self._state == pr.PENDING: | ||||
|             return self._watch.expired() | ||||
|         return False | ||||
|  | ||||
|     def set_result(self, result): | ||||
|         self.result.set_result((self._task, self._event, result)) | ||||
|  | ||||
|     def set_running(self): | ||||
|         self._state = pr.RUNNING | ||||
|         self._watch.stop() | ||||
|  | ||||
|     def on_progress(self, event_data, progress): | ||||
|         self._progress_callback(self._task, event_data, progress) | ||||
| @@ -63,7 +63,7 @@ class Server(object): | ||||
|                 LOG.debug("AMQP message requeued.") | ||||
|  | ||||
|     @staticmethod | ||||
|     def _parse_request(task, task_name, action, arguments, result=None, | ||||
|     def _parse_request(task_cls, task_name, action, arguments, result=None, | ||||
|                        failures=None, **kwargs): | ||||
|         """Parse request before it can be processed. All `misc.Failure` objects | ||||
|         that have been converted to dict on the remote side to be serializable | ||||
| @@ -80,7 +80,7 @@ class Server(object): | ||||
|             action_args['failures'] = {} | ||||
|             for k, v in failures.items(): | ||||
|                 action_args['failures'][k] = pu.failure_from_dict(v) | ||||
|         return task, action, action_args | ||||
|         return task_cls, action, action_args | ||||
|  | ||||
|     @staticmethod | ||||
|     def _parse_message(message): | ||||
| @@ -132,7 +132,7 @@ class Server(object): | ||||
|  | ||||
|         # parse request to get task name, action and action arguments | ||||
|         try: | ||||
|             task, action, action_args = self._parse_request(**request) | ||||
|             task_cls, action, action_args = self._parse_request(**request) | ||||
|             action_args.update(task_uuid=task_uuid, | ||||
|                                progress_callback=progress_callback) | ||||
|         except ValueError: | ||||
| @@ -143,10 +143,11 @@ class Server(object): | ||||
|  | ||||
|         # get task endpoint | ||||
|         try: | ||||
|             endpoint = self._endpoints[task] | ||||
|             endpoint = self._endpoints[task_cls] | ||||
|         except KeyError: | ||||
|             with misc.capture_failure() as failure: | ||||
|                 LOG.exception("The '%s' task endpoint does not exist", task) | ||||
|                 LOG.exception("The '%s' task endpoint does not exist", | ||||
|                               task_cls) | ||||
|                 reply_callback(result=pu.failure_to_dict(failure)) | ||||
|                 return | ||||
|         else: | ||||
|   | ||||
| @@ -23,7 +23,6 @@ from kombu import exceptions as kombu_exc | ||||
|  | ||||
| from taskflow.engines.worker_based import executor | ||||
| from taskflow.engines.worker_based import protocol as pr | ||||
| from taskflow.engines.worker_based import remote_task as rt | ||||
| from taskflow import test | ||||
| from taskflow.tests import utils | ||||
| from taskflow.utils import misc | ||||
| @@ -36,11 +35,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|         super(TestWorkerTaskExecutor, self).setUp() | ||||
|         self.task = utils.DummyTask() | ||||
|         self.task_uuid = 'task-uuid' | ||||
|         self.task_args = {'context': 'context'} | ||||
|         self.task_args = {'a': 'a'} | ||||
|         self.task_result = 'task-result' | ||||
|         self.task_failures = {} | ||||
|         self.timeout = 60 | ||||
|         self.broker_url = 'test-url' | ||||
|         self.broker_url = 'broker-url' | ||||
|         self.executor_uuid = 'executor-uuid' | ||||
|         self.executor_exchange = 'executor-exchange' | ||||
|         self.executor_topic = 'executor-topic' | ||||
| @@ -58,7 +57,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|             'taskflow.engines.worker_based.executor.async_utils.wait_for_any') | ||||
|         self.message_mock = mock.MagicMock(name='message') | ||||
|         self.message_mock.properties = {'correlation_id': self.task_uuid} | ||||
|         self.remote_task_mock = mock.MagicMock(uuid=self.task_uuid) | ||||
|         self.request_mock = mock.MagicMock(uuid=self.task_uuid) | ||||
|  | ||||
|     def _fake_proxy_start(self): | ||||
|         self.proxy_started_event.set() | ||||
| @@ -80,19 +79,19 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|         return ex | ||||
|  | ||||
|     def request(self, **kwargs): | ||||
|         request = dict(task=self.task.name, task_name=self.task.name, | ||||
|         request_kwargs = dict(task=self.task, uuid=self.task_uuid, | ||||
|                               action='execute', arguments=self.task_args, | ||||
|                               progress_callback=None, timeout=self.timeout) | ||||
|         request_kwargs.update(kwargs) | ||||
|         return pr.Request(**request_kwargs) | ||||
|  | ||||
|     def request_dict(self, **kwargs): | ||||
|         request = dict(task_cls=self.task.name, task_name=self.task.name, | ||||
|                        task_version=self.task.version, | ||||
|                        arguments=self.task_args) | ||||
|         request.update(kwargs) | ||||
|         return request | ||||
|  | ||||
|     def remote_task(self, **kwargs): | ||||
|         remote_task_kwargs = dict(task=self.task, uuid=self.task_uuid, | ||||
|                                   action='execute', arguments=self.task_args, | ||||
|                                   progress_callback=None, timeout=self.timeout) | ||||
|         remote_task_kwargs.update(kwargs) | ||||
|         return rt.RemoteTask(**remote_task_kwargs) | ||||
|  | ||||
|     def test_creation(self): | ||||
|         ex = self.executor(reset_master_mock=False) | ||||
|  | ||||
| @@ -105,20 +104,20 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|     def test_on_message_state_running(self): | ||||
|         response = dict(state=pr.RUNNING) | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, | ||||
|         self.assertEqual(self.request_mock.mock_calls, | ||||
|                          [mock.call.set_running()]) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
|  | ||||
|     def test_on_message_state_progress(self): | ||||
|         response = dict(state=pr.PROGRESS, progress=1.0) | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, | ||||
|         self.assertEqual(self.request_mock.mock_calls, | ||||
|                          [mock.call.on_progress(progress=1.0)]) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
|  | ||||
| @@ -127,11 +126,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|         failure_dict = pu.failure_to_dict(failure) | ||||
|         response = dict(state=pr.FAILURE, result=failure_dict) | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 0) | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, [ | ||||
|         self.assertEqual(len(ex._requests_cache._data), 0) | ||||
|         self.assertEqual(self.request_mock.mock_calls, [ | ||||
|             mock.call.set_result(result=utils.FailureMatcher(failure)) | ||||
|         ]) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
| @@ -140,10 +139,10 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|         response = dict(state=pr.SUCCESS, result=self.task_result, | ||||
|                         event='executed') | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, | ||||
|         self.assertEqual(self.request_mock.mock_calls, | ||||
|                          [mock.call.set_result(result=self.task_result, | ||||
|                                                event='executed')]) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
| @@ -151,30 +150,30 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|     def test_on_message_unknown_state(self): | ||||
|         response = dict(state='unknown') | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, []) | ||||
|         self.assertEqual(self.request_mock.mock_calls, []) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
|  | ||||
|     def test_on_message_non_existent_task(self): | ||||
|         self.message_mock.properties = {'correlation_id': 'non-existent'} | ||||
|         response = dict(state=pr.RUNNING) | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, []) | ||||
|         self.assertEqual(self.request_mock.mock_calls, []) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
|  | ||||
|     def test_on_message_no_correlation_id(self): | ||||
|         self.message_mock.properties = {} | ||||
|         response = dict(state=pr.RUNNING) | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request_mock) | ||||
|         ex._on_message(response, self.message_mock) | ||||
|  | ||||
|         self.assertEqual(self.remote_task_mock.mock_calls, []) | ||||
|         self.assertEqual(self.request_mock.mock_calls, []) | ||||
|         self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) | ||||
|  | ||||
|     @mock.patch('taskflow.engines.worker_based.executor.LOG.exception') | ||||
| @@ -183,46 +182,46 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|         self.executor()._on_message({}, self.message_mock) | ||||
|         self.assertTrue(mocked_exception.called) | ||||
|  | ||||
|     @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock') | ||||
|     def test_on_wait_task_not_expired(self, mocked_time): | ||||
|         mocked_time.side_effect = [1, self.timeout] | ||||
|     @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') | ||||
|     def test_on_wait_task_not_expired(self, mocked_wallclock): | ||||
|         mocked_wallclock.side_effect = [1, self.timeout] | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task()) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request()) | ||||
|  | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 1) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 1) | ||||
|         ex._on_wait() | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 1) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 1) | ||||
|  | ||||
|     @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock') | ||||
|     @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') | ||||
|     def test_on_wait_task_expired(self, mocked_time): | ||||
|         mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2] | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, self.remote_task()) | ||||
|         ex._requests_cache.set(self.task_uuid, self.request()) | ||||
|  | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 1) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 1) | ||||
|         ex._on_wait() | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 0) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 0) | ||||
|  | ||||
|     def test_remove_task_non_existent(self): | ||||
|         task = self.remote_task() | ||||
|         task = self.request() | ||||
|         ex = self.executor() | ||||
|         ex._remote_tasks_cache.set(self.task_uuid, task) | ||||
|         ex._requests_cache.set(self.task_uuid, task) | ||||
|  | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 1) | ||||
|         ex._remote_tasks_cache.delete(self.task_uuid) | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 0) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 1) | ||||
|         ex._requests_cache.delete(self.task_uuid) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 0) | ||||
|  | ||||
|         # remove non-existent | ||||
|         ex._remote_tasks_cache.delete(self.task_uuid) | ||||
|         self.assertEqual(len(ex._remote_tasks_cache._data), 0) | ||||
|         ex._requests_cache.delete(self.task_uuid) | ||||
|         self.assertEqual(len(ex._requests_cache._data), 0) | ||||
|  | ||||
|     def test_execute_task(self): | ||||
|         request = self.request(action='execute') | ||||
|         request_dict = self.request_dict(action='execute') | ||||
|         ex = self.executor() | ||||
|         result = ex.execute_task(self.task, self.task_uuid, self.task_args) | ||||
|  | ||||
|         expected_calls = [ | ||||
|             mock.call.proxy.publish(request, | ||||
|             mock.call.proxy.publish(request_dict, | ||||
|                                     routing_key=self.executor_topic, | ||||
|                                     reply_to=self.executor_uuid, | ||||
|                                     correlation_id=self.task_uuid) | ||||
| @@ -231,7 +230,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|         self.assertIsInstance(result, futures.Future) | ||||
|  | ||||
|     def test_revert_task(self): | ||||
|         request = self.request(action='revert', | ||||
|         request_dict = self.request_dict(action='revert', | ||||
|                                          result=('success', self.task_result), | ||||
|                                          failures=self.task_failures) | ||||
|         ex = self.executor() | ||||
| @@ -239,7 +238,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|                                 self.task_result, self.task_failures) | ||||
|  | ||||
|         expected_calls = [ | ||||
|             mock.call.proxy.publish(request, | ||||
|             mock.call.proxy.publish(request_dict, | ||||
|                                     routing_key=self.executor_topic, | ||||
|                                     reply_to=self.executor_uuid, | ||||
|                                     correlation_id=self.task_uuid) | ||||
| @@ -262,12 +261,12 @@ class TestWorkerTaskExecutor(test.MockTestCase): | ||||
|  | ||||
|     def test_execute_task_publish_error(self): | ||||
|         self.proxy_inst_mock.publish.side_effect = Exception('Woot!') | ||||
|         request = self.request(action='execute') | ||||
|         request_dict = self.request_dict(action='execute') | ||||
|         ex = self.executor() | ||||
|         result = ex.execute_task(self.task, self.task_uuid, self.task_args) | ||||
|  | ||||
|         expected_calls = [ | ||||
|             mock.call.proxy.publish(request, | ||||
|             mock.call.proxy.publish(request_dict, | ||||
|                                     routing_key=self.executor_topic, | ||||
|                                     reply_to=self.executor_uuid, | ||||
|                                     correlation_id=self.task_uuid) | ||||
|   | ||||
							
								
								
									
										126
									
								
								taskflow/tests/unit/worker_based/test_protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								taskflow/tests/unit/worker_based/test_protocol.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| #    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. | ||||
| # | ||||
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| #    not use this file except in compliance with the License. You may obtain | ||||
| #    a copy of the License at | ||||
| # | ||||
| #         http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| #    Unless required by applicable law or agreed to in writing, software | ||||
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
| #    License for the specific language governing permissions and limitations | ||||
| #    under the License. | ||||
|  | ||||
| import mock | ||||
|  | ||||
| from concurrent import futures | ||||
|  | ||||
| from taskflow.engines.worker_based import protocol as pr | ||||
| from taskflow import test | ||||
| from taskflow.tests import utils | ||||
| from taskflow.utils import misc | ||||
| from taskflow.utils import persistence_utils as pu | ||||
|  | ||||
|  | ||||
| class TestProtocol(test.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         super(TestProtocol, self).setUp() | ||||
|         self.task = utils.DummyTask() | ||||
|         self.task_uuid = 'task-uuid' | ||||
|         self.task_action = 'execute' | ||||
|         self.task_args = {'a': 'a'} | ||||
|         self.timeout = 60 | ||||
|  | ||||
|     def request(self, **kwargs): | ||||
|         request_kwargs = dict(task=self.task, | ||||
|                               uuid=self.task_uuid, | ||||
|                               action=self.task_action, | ||||
|                               arguments=self.task_args, | ||||
|                               progress_callback=None, | ||||
|                               timeout=self.timeout) | ||||
|         request_kwargs.update(kwargs) | ||||
|         return pr.Request(**request_kwargs) | ||||
|  | ||||
|     def request_to_dict(self, **kwargs): | ||||
|         to_dict = dict(task_cls=self.task.name, | ||||
|                        task_name=self.task.name, | ||||
|                        task_version=self.task.version, | ||||
|                        action=self.task_action, | ||||
|                        arguments=self.task_args) | ||||
|         to_dict.update(kwargs) | ||||
|         return to_dict | ||||
|  | ||||
|     def test_creation(self): | ||||
|         request = self.request() | ||||
|         self.assertEqual(request.uuid, self.task_uuid) | ||||
|         self.assertEqual(request.task_cls, self.task.name) | ||||
|         self.assertIsInstance(request.result, futures.Future) | ||||
|         self.assertFalse(request.result.done()) | ||||
|  | ||||
|     def test_repr(self): | ||||
|         expected_name = '%s:%s' % (self.task.name, self.task_action) | ||||
|         self.assertEqual(repr(self.request()), expected_name) | ||||
|  | ||||
|     def test_to_dict_default(self): | ||||
|         self.assertEqual(self.request().to_dict(), self.request_to_dict()) | ||||
|  | ||||
|     def test_to_dict_with_result(self): | ||||
|         self.assertEqual(self.request(result=333).to_dict(), | ||||
|                          self.request_to_dict(result=('success', 333))) | ||||
|  | ||||
|     def test_to_dict_with_result_none(self): | ||||
|         self.assertEqual(self.request(result=None).to_dict(), | ||||
|                          self.request_to_dict(result=('success', None))) | ||||
|  | ||||
|     def test_to_dict_with_result_failure(self): | ||||
|         failure = misc.Failure.from_exception(RuntimeError('Woot!')) | ||||
|         expected = self.request_to_dict( | ||||
|             result=('failure', pu.failure_to_dict(failure))) | ||||
|         self.assertEqual(self.request(result=failure).to_dict(), expected) | ||||
|  | ||||
|     def test_to_dict_with_failures(self): | ||||
|         failure = misc.Failure.from_exception(RuntimeError('Woot!')) | ||||
|         request = self.request(failures={self.task.name: failure}) | ||||
|         expected = self.request_to_dict( | ||||
|             failures={self.task.name: pu.failure_to_dict(failure)}) | ||||
|         self.assertEqual(request.to_dict(), expected) | ||||
|  | ||||
|     @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') | ||||
|     def test_pending_not_expired(self, mocked_wallclock): | ||||
|         mocked_wallclock.side_effect = [1, self.timeout] | ||||
|         self.assertFalse(self.request().expired) | ||||
|  | ||||
|     @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') | ||||
|     def test_pending_expired(self, mocked_wallclock): | ||||
|         mocked_wallclock.side_effect = [1, self.timeout + 2] | ||||
|         self.assertTrue(self.request().expired) | ||||
|  | ||||
|     @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') | ||||
|     def test_running_not_expired(self, mocked_wallclock): | ||||
|         mocked_wallclock.side_effect = [1, self.timeout + 2] | ||||
|         request = self.request() | ||||
|         request.set_running() | ||||
|         self.assertFalse(request.expired) | ||||
|  | ||||
|     def test_set_result(self): | ||||
|         request = self.request() | ||||
|         request.set_result(111) | ||||
|         result = request.result.result() | ||||
|         self.assertEqual(result, (self.task, 'executed', 111)) | ||||
|  | ||||
|     def test_on_progress(self): | ||||
|         progress_callback = mock.MagicMock(name='progress_callback') | ||||
|         request = self.request(task=self.task, | ||||
|                                progress_callback=progress_callback) | ||||
|         request.on_progress('event_data', 0.0) | ||||
|         request.on_progress('event_data', 1.0) | ||||
|  | ||||
|         expected_calls = [ | ||||
|             mock.call(self.task, 'event_data', 0.0), | ||||
|             mock.call(self.task, 'event_data', 1.0) | ||||
|         ] | ||||
|         self.assertEqual(progress_callback.mock_calls, expected_calls) | ||||
| @@ -1,133 +0,0 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| #    Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. | ||||
| # | ||||
| #    Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| #    not use this file except in compliance with the License. You may obtain | ||||
| #    a copy of the License at | ||||
| # | ||||
| #         http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| #    Unless required by applicable law or agreed to in writing, software | ||||
| #    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
| #    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
| #    License for the specific language governing permissions and limitations | ||||
| #    under the License. | ||||
|  | ||||
| import mock | ||||
|  | ||||
| from concurrent import futures | ||||
|  | ||||
| from taskflow.engines.worker_based import remote_task as rt | ||||
| from taskflow import test | ||||
| from taskflow.tests import utils | ||||
| from taskflow.utils import misc | ||||
| from taskflow.utils import persistence_utils as pu | ||||
|  | ||||
|  | ||||
| class TestRemoteTask(test.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         super(TestRemoteTask, self).setUp() | ||||
|         self.task = utils.DummyTask() | ||||
|         self.task_uuid = 'task-uuid' | ||||
|         self.task_action = 'execute' | ||||
|         self.task_args = {'context': 'context'} | ||||
|         self.timeout = 60 | ||||
|  | ||||
|     def remote_task(self, **kwargs): | ||||
|         task_kwargs = dict(task=self.task, | ||||
|                            uuid=self.task_uuid, | ||||
|                            action=self.task_action, | ||||
|                            arguments=self.task_args, | ||||
|                            progress_callback=None, | ||||
|                            timeout=self.timeout) | ||||
|         task_kwargs.update(kwargs) | ||||
|         return rt.RemoteTask(**task_kwargs) | ||||
|  | ||||
|     def remote_task_request(self, **kwargs): | ||||
|         request = dict(task=self.task.name, | ||||
|                        task_name=self.task.name, | ||||
|                        task_version=self.task.version, | ||||
|                        action=self.task_action, | ||||
|                        arguments=self.task_args) | ||||
|         request.update(kwargs) | ||||
|         return request | ||||
|  | ||||
|     def test_creation(self): | ||||
|         remote_task = self.remote_task() | ||||
|         self.assertEqual(remote_task.uuid, self.task_uuid) | ||||
|         self.assertEqual(remote_task.name, self.task.name) | ||||
|         self.assertIsInstance(remote_task.result, futures.Future) | ||||
|         self.assertFalse(remote_task.result.done()) | ||||
|  | ||||
|     def test_repr(self): | ||||
|         expected_name = '%s:%s' % (self.task.name, self.task_action) | ||||
|         self.assertEqual(repr(self.remote_task()), expected_name) | ||||
|  | ||||
|     def test_request(self): | ||||
|         remote_task = self.remote_task() | ||||
|         request = self.remote_task_request() | ||||
|         self.assertEqual(remote_task.request, request) | ||||
|  | ||||
|     def test_request_with_result(self): | ||||
|         remote_task = self.remote_task(result=333) | ||||
|         request = self.remote_task_request(result=('success', 333)) | ||||
|         self.assertEqual(remote_task.request, request) | ||||
|  | ||||
|     def test_request_with_result_none(self): | ||||
|         remote_task = self.remote_task(result=None) | ||||
|         request = self.remote_task_request(result=('success', None)) | ||||
|         self.assertEqual(remote_task.request, request) | ||||
|  | ||||
|     def test_request_with_result_failure(self): | ||||
|         failure = misc.Failure.from_exception(RuntimeError('Woot!')) | ||||
|         remote_task = self.remote_task(result=failure) | ||||
|         request = self.remote_task_request( | ||||
|             result=('failure', pu.failure_to_dict(failure))) | ||||
|         self.assertEqual(remote_task.request, request) | ||||
|  | ||||
|     def test_request_with_failures(self): | ||||
|         failure = misc.Failure.from_exception(RuntimeError('Woot!')) | ||||
|         remote_task = self.remote_task(failures={self.task.name: failure}) | ||||
|         request = self.remote_task_request( | ||||
|             failures={self.task.name: pu.failure_to_dict(failure)}) | ||||
|         self.assertEqual(remote_task.request, request) | ||||
|  | ||||
|     @mock.patch('time.time') | ||||
|     def test_pending_not_expired(self, mock_time): | ||||
|         mock_time.side_effect = [1, self.timeout] | ||||
|         remote_task = self.remote_task() | ||||
|         self.assertFalse(remote_task.expired) | ||||
|  | ||||
|     @mock.patch('time.time') | ||||
|     def test_pending_expired(self, mock_time): | ||||
|         mock_time.side_effect = [1, self.timeout + 2] | ||||
|         remote_task = self.remote_task() | ||||
|         self.assertTrue(remote_task.expired) | ||||
|  | ||||
|     @mock.patch('time.time') | ||||
|     def test_running_not_expired(self, mock_time): | ||||
|         mock_time.side_effect = [1, self.timeout] | ||||
|         remote_task = self.remote_task() | ||||
|         remote_task.set_running() | ||||
|         self.assertFalse(remote_task.expired) | ||||
|  | ||||
|     def test_set_result(self): | ||||
|         remote_task = self.remote_task() | ||||
|         remote_task.set_result(111) | ||||
|         result = remote_task.result.result() | ||||
|         self.assertEqual(result, (self.task, 'executed', 111)) | ||||
|  | ||||
|     def test_on_progress(self): | ||||
|         progress_callback = mock.MagicMock(name='progress_callback') | ||||
|         remote_task = self.remote_task(task=self.task, | ||||
|                                        progress_callback=progress_callback) | ||||
|         remote_task.on_progress('event_data', 0.0) | ||||
|         remote_task.on_progress('event_data', 1.0) | ||||
|  | ||||
|         expected_calls = [ | ||||
|             mock.call(self.task, 'event_data', 0.0), | ||||
|             mock.call(self.task, 'event_data', 1.0) | ||||
|         ] | ||||
|         self.assertEqual(progress_callback.mock_calls, expected_calls) | ||||
| @@ -71,7 +71,7 @@ class TestServer(test.MockTestCase): | ||||
|         return s | ||||
|  | ||||
|     def request(self, **kwargs): | ||||
|         request = dict(task=self.task_name, | ||||
|         request = dict(task_cls=self.task_name, | ||||
|                        task_name=self.task_name, | ||||
|                        action=self.task_action, | ||||
|                        task_version=self.task_version, | ||||
| @@ -164,18 +164,18 @@ class TestServer(test.MockTestCase): | ||||
|  | ||||
|     def test_parse_request(self): | ||||
|         request = self.request() | ||||
|         task, action, task_args = server.Server._parse_request(**request) | ||||
|         task_cls, action, task_args = server.Server._parse_request(**request) | ||||
|  | ||||
|         self.assertEqual((task, action, task_args), | ||||
|         self.assertEqual((task_cls, action, task_args), | ||||
|                          (self.task_name, self.task_action, | ||||
|                           dict(task_name=self.task_name, | ||||
|                                arguments=self.task_args))) | ||||
|  | ||||
|     def test_parse_request_with_success_result(self): | ||||
|         request = self.request(action='revert', result=('success', 1)) | ||||
|         task, action, task_args = server.Server._parse_request(**request) | ||||
|         task_cls, action, task_args = server.Server._parse_request(**request) | ||||
|  | ||||
|         self.assertEqual((task, action, task_args), | ||||
|         self.assertEqual((task_cls, action, task_args), | ||||
|                          (self.task_name, 'revert', | ||||
|                           dict(task_name=self.task_name, | ||||
|                                arguments=self.task_args, | ||||
| @@ -186,9 +186,9 @@ class TestServer(test.MockTestCase): | ||||
|         failure_dict = pu.failure_to_dict(failure) | ||||
|         request = self.request(action='revert', | ||||
|                                result=('failure', failure_dict)) | ||||
|         task, action, task_args = server.Server._parse_request(**request) | ||||
|         task_cls, action, task_args = server.Server._parse_request(**request) | ||||
|  | ||||
|         self.assertEqual((task, action, task_args), | ||||
|         self.assertEqual((task_cls, action, task_args), | ||||
|                          (self.task_name, 'revert', | ||||
|                           dict(task_name=self.task_name, | ||||
|                                arguments=self.task_args, | ||||
| @@ -200,10 +200,10 @@ class TestServer(test.MockTestCase): | ||||
|         failures_dict = dict((str(i), pu.failure_to_dict(f)) | ||||
|                              for i, f in enumerate(failures)) | ||||
|         request = self.request(action='revert', failures=failures_dict) | ||||
|         task, action, task_args = server.Server._parse_request(**request) | ||||
|         task_cls, action, task_args = server.Server._parse_request(**request) | ||||
|  | ||||
|         self.assertEqual( | ||||
|             (task, action, task_args), | ||||
|             (task_cls, action, task_args), | ||||
|             (self.task_name, 'revert', | ||||
|              dict(task_name=self.task_name, | ||||
|                   arguments=self.task_args, | ||||
| @@ -225,7 +225,7 @@ class TestServer(test.MockTestCase): | ||||
|         self.assertTrue(mocked_exception.called) | ||||
|  | ||||
|     def test_on_update_progress(self): | ||||
|         request = self.request(task='taskflow.tests.utils.ProgressingTask', | ||||
|         request = self.request(task_cls='taskflow.tests.utils.ProgressingTask', | ||||
|                                arguments={}) | ||||
|  | ||||
|         # create server and process request | ||||
| @@ -292,7 +292,7 @@ class TestServer(test.MockTestCase): | ||||
|     def test_process_request_endpoint_not_found(self, pu_mock): | ||||
|         failure_dict = 'failure_dict' | ||||
|         pu_mock.failure_to_dict.return_value = failure_dict | ||||
|         request = self.request(task='<unknown>') | ||||
|         request = self.request(task_cls='<unknown>') | ||||
|  | ||||
|         # create server and process request | ||||
|         s = self.server(reset_master_mock=True, endpoints=self.endpoints) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Stanislav Kudriashev
					Stanislav Kudriashev