diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index 7758ca097..c8d696dee 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -25,7 +25,6 @@ from taskflow.engines.action_engine import executor from taskflow.engines.worker_based import cache from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import proxy -from taskflow.engines.worker_based import remote_task as rt from taskflow import exceptions as exc from taskflow.utils import async_utils from taskflow.utils import misc @@ -42,7 +41,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): self._proxy = proxy.Proxy(uuid, exchange, self._on_message, self._on_wait, **kwargs) self._proxy_thread = None - self._remote_tasks_cache = cache.Cache() + self._requests_cache = cache.Cache() # TODO(skudriashev): This data should be collected from workers # using broadcast messages directly. @@ -80,61 +79,61 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): def _process_response(self, task_uuid, response): """Process response from remote side.""" - remote_task = self._remote_tasks_cache.get(task_uuid) - if remote_task is not None: + request = self._requests_cache.get(task_uuid) + if request is not None: state = response.pop('state') if state == pr.RUNNING: - remote_task.set_running() + request.set_running() elif state == pr.PROGRESS: - remote_task.on_progress(**response) + request.on_progress(**response) elif state == pr.FAILURE: response['result'] = pu.failure_from_dict(response['result']) - remote_task.set_result(**response) - self._remote_tasks_cache.delete(remote_task.uuid) + request.set_result(**response) + self._requests_cache.delete(request.uuid) elif state == pr.SUCCESS: - remote_task.set_result(**response) - self._remote_tasks_cache.delete(remote_task.uuid) + request.set_result(**response) + self._requests_cache.delete(request.uuid) else: LOG.warning("Unexpected response status: '%s'", state) else: - LOG.debug("Remote task with id='%s' not found.", task_uuid) + LOG.debug("Request with id='%s' not found.", task_uuid) @staticmethod - def _handle_expired_remote_task(task): - LOG.debug("Remote task '%r' has expired.", task) - task.set_result(misc.Failure.from_exception( - exc.Timeout("Remote task '%r' has expired" % task))) + def _handle_expired_request(request): + LOG.debug("Request '%r' has expired.", request) + request.set_result(misc.Failure.from_exception( + exc.Timeout("Request '%r' has expired" % request))) def _on_wait(self): """This function is called cyclically between draining events.""" - self._remote_tasks_cache.cleanup(self._handle_expired_remote_task) + self._requests_cache.cleanup(self._handle_expired_request) def _submit_task(self, task, task_uuid, action, arguments, progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs): """Submit task request to workers.""" - remote_task = rt.RemoteTask(task, task_uuid, action, arguments, - progress_callback, timeout, **kwargs) - self._remote_tasks_cache.set(remote_task.uuid, remote_task) + request = pr.Request(task, task_uuid, action, arguments, + progress_callback, timeout, **kwargs) + self._requests_cache.set(request.uuid, request) try: # get task's workers topic to send request to try: - topic = self._workers_info[remote_task.name] + topic = self._workers_info[request.task_cls] except KeyError: raise exc.NotFound("Workers topic not found for the '%s'" - " task" % remote_task.name) + " task" % request.task_cls) else: # publish request - request = remote_task.request LOG.debug("Sending request: %s", request) - self._proxy.publish(request, routing_key=topic, + self._proxy.publish(request.to_dict(), + routing_key=topic, reply_to=self._uuid, - correlation_id=remote_task.uuid) + correlation_id=request.uuid) except Exception: with misc.capture_failure() as failure: - LOG.exception("Failed to submit the '%s' task", remote_task) - self._remote_tasks_cache.delete(remote_task.uuid) - remote_task.set_result(failure) - return remote_task.result + LOG.exception("Failed to submit the '%s' task", request) + self._requests_cache.delete(request.uuid) + request.set_result(failure) + return request.result def execute_task(self, task, task_uuid, arguments, progress_callback=None): diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index 0ab877d01..b778bae3f 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -14,7 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. +from concurrent import futures + from taskflow.engines.action_engine import executor +from taskflow.utils import misc +from taskflow.utils import persistence_utils as pu +from taskflow.utils import reflection # NOTE(skudriashev): This is protocol events, not related to the task states. PENDING = 'PENDING' @@ -43,3 +48,80 @@ REQUEST_TIMEOUT = 60 # period is equal to the request timeout, once request is expired - queue is # no longer needed. QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT + + +class Request(object): + """Represents request with execution results. Every request is created in + the PENDING state and is expired within the given timeout. + """ + + def __init__(self, task, uuid, action, arguments, progress_callback, + timeout, **kwargs): + self._task = task + self._task_cls = reflection.get_class_name(task) + self._uuid = uuid + self._action = action + self._event = ACTION_TO_EVENT[action] + self._arguments = arguments + self._progress_callback = progress_callback + self._kwargs = kwargs + self._watch = misc.StopWatch(duration=timeout).start() + self._state = PENDING + self.result = futures.Future() + + def __repr__(self): + return "%s:%s" % (self._task_cls, self._action) + + @property + def uuid(self): + return self._uuid + + @property + def task_cls(self): + return self._task_cls + + @property + def expired(self): + """Check if request is expired. + + When new request is created its state is set to the PENDING, creation + time is stored and timeout is given via constructor arguments. + + Request is considered to be expired when it is in the PENDING state + for more then the given timeout (it is not considered to be expired + in any other state). After request is expired - the `Timeout` + exception is raised and task is removed from the requests map. + """ + if self._state == PENDING: + return self._watch.expired() + return False + + def to_dict(self): + """Return json-serializable request, converting all `misc.Failure` + objects into dictionaries. + """ + request = dict(task_cls=self._task_cls, task_name=self._task.name, + task_version=self._task.version, action=self._action, + arguments=self._arguments) + if 'result' in self._kwargs: + result = self._kwargs['result'] + if isinstance(result, misc.Failure): + request['result'] = ('failure', pu.failure_to_dict(result)) + else: + request['result'] = ('success', result) + if 'failures' in self._kwargs: + failures = self._kwargs['failures'] + request['failures'] = {} + for task, failure in failures.items(): + request['failures'][task] = pu.failure_to_dict(failure) + return request + + def set_result(self, result): + self.result.set_result((self._task, self._event, result)) + + def set_running(self): + self._state = RUNNING + self._watch.stop() + + def on_progress(self, event_data, progress): + self._progress_callback(self._task, event_data, progress) diff --git a/taskflow/engines/worker_based/remote_task.py b/taskflow/engines/worker_based/remote_task.py deleted file mode 100644 index 8211fc327..000000000 --- a/taskflow/engines/worker_based/remote_task.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from concurrent import futures - -from taskflow.engines.worker_based import protocol as pr -from taskflow.utils import misc -from taskflow.utils import persistence_utils as pu -from taskflow.utils import reflection - - -class RemoteTask(object): - """Represents remote task with its request data and execution results. - Every remote task is created in the PENDING state and will be expired - within the given timeout. - """ - - def __init__(self, task, uuid, action, arguments, progress_callback, - timeout, **kwargs): - self._task = task - self._name = reflection.get_class_name(task) - self._uuid = uuid - self._action = action - self._event = pr.ACTION_TO_EVENT[action] - self._arguments = arguments - self._progress_callback = progress_callback - self._kwargs = kwargs - self._watch = misc.StopWatch(duration=timeout).start() - self._state = pr.PENDING - self.result = futures.Future() - - def __repr__(self): - return "%s:%s" % (self._name, self._action) - - @property - def uuid(self): - return self._uuid - - @property - def name(self): - return self._name - - @property - def request(self): - """Return json-serializable task request, converting all `misc.Failure` - objects into dictionaries. - """ - request = dict(task=self._name, task_name=self._task.name, - task_version=self._task.version, action=self._action, - arguments=self._arguments) - if 'result' in self._kwargs: - result = self._kwargs['result'] - if isinstance(result, misc.Failure): - request['result'] = ('failure', pu.failure_to_dict(result)) - else: - request['result'] = ('success', result) - if 'failures' in self._kwargs: - failures = self._kwargs['failures'] - request['failures'] = {} - for task, failure in failures.items(): - request['failures'][task] = pu.failure_to_dict(failure) - return request - - @property - def expired(self): - """Check if task is expired. - - When new remote task is created its state is set to the PENDING, - creation time is stored and timeout is given via constructor arguments. - - Remote task is considered to be expired when it is in the PENDING state - for more then the given timeout (task is not considered to be expired - in any other state). After remote task is expired - the `Timeout` - exception is raised and task is removed from the remote tasks map. - """ - if self._state == pr.PENDING: - return self._watch.expired() - return False - - def set_result(self, result): - self.result.set_result((self._task, self._event, result)) - - def set_running(self): - self._state = pr.RUNNING - self._watch.stop() - - def on_progress(self, event_data, progress): - self._progress_callback(self._task, event_data, progress) diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index 8cf3e597c..ff630df99 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -63,7 +63,7 @@ class Server(object): LOG.debug("AMQP message requeued.") @staticmethod - def _parse_request(task, task_name, action, arguments, result=None, + def _parse_request(task_cls, task_name, action, arguments, result=None, failures=None, **kwargs): """Parse request before it can be processed. All `misc.Failure` objects that have been converted to dict on the remote side to be serializable @@ -80,7 +80,7 @@ class Server(object): action_args['failures'] = {} for k, v in failures.items(): action_args['failures'][k] = pu.failure_from_dict(v) - return task, action, action_args + return task_cls, action, action_args @staticmethod def _parse_message(message): @@ -132,7 +132,7 @@ class Server(object): # parse request to get task name, action and action arguments try: - task, action, action_args = self._parse_request(**request) + task_cls, action, action_args = self._parse_request(**request) action_args.update(task_uuid=task_uuid, progress_callback=progress_callback) except ValueError: @@ -143,10 +143,11 @@ class Server(object): # get task endpoint try: - endpoint = self._endpoints[task] + endpoint = self._endpoints[task_cls] except KeyError: with misc.capture_failure() as failure: - LOG.exception("The '%s' task endpoint does not exist", task) + LOG.exception("The '%s' task endpoint does not exist", + task_cls) reply_callback(result=pu.failure_to_dict(failure)) return else: diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index 950f817dd..ff206660b 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -23,7 +23,6 @@ from kombu import exceptions as kombu_exc from taskflow.engines.worker_based import executor from taskflow.engines.worker_based import protocol as pr -from taskflow.engines.worker_based import remote_task as rt from taskflow import test from taskflow.tests import utils from taskflow.utils import misc @@ -36,11 +35,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): super(TestWorkerTaskExecutor, self).setUp() self.task = utils.DummyTask() self.task_uuid = 'task-uuid' - self.task_args = {'context': 'context'} + self.task_args = {'a': 'a'} self.task_result = 'task-result' self.task_failures = {} self.timeout = 60 - self.broker_url = 'test-url' + self.broker_url = 'broker-url' self.executor_uuid = 'executor-uuid' self.executor_exchange = 'executor-exchange' self.executor_topic = 'executor-topic' @@ -58,7 +57,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): 'taskflow.engines.worker_based.executor.async_utils.wait_for_any') self.message_mock = mock.MagicMock(name='message') self.message_mock.properties = {'correlation_id': self.task_uuid} - self.remote_task_mock = mock.MagicMock(uuid=self.task_uuid) + self.request_mock = mock.MagicMock(uuid=self.task_uuid) def _fake_proxy_start(self): self.proxy_started_event.set() @@ -80,19 +79,19 @@ class TestWorkerTaskExecutor(test.MockTestCase): return ex def request(self, **kwargs): - request = dict(task=self.task.name, task_name=self.task.name, + request_kwargs = dict(task=self.task, uuid=self.task_uuid, + action='execute', arguments=self.task_args, + progress_callback=None, timeout=self.timeout) + request_kwargs.update(kwargs) + return pr.Request(**request_kwargs) + + def request_dict(self, **kwargs): + request = dict(task_cls=self.task.name, task_name=self.task.name, task_version=self.task.version, arguments=self.task_args) request.update(kwargs) return request - def remote_task(self, **kwargs): - remote_task_kwargs = dict(task=self.task, uuid=self.task_uuid, - action='execute', arguments=self.task_args, - progress_callback=None, timeout=self.timeout) - remote_task_kwargs.update(kwargs) - return rt.RemoteTask(**remote_task_kwargs) - def test_creation(self): ex = self.executor(reset_master_mock=False) @@ -105,20 +104,20 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_state_running(self): response = dict(state=pr.RUNNING) ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(self.remote_task_mock.mock_calls, + self.assertEqual(self.request_mock.mock_calls, [mock.call.set_running()]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) def test_on_message_state_progress(self): response = dict(state=pr.PROGRESS, progress=1.0) ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(self.remote_task_mock.mock_calls, + self.assertEqual(self.request_mock.mock_calls, [mock.call.on_progress(progress=1.0)]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) @@ -127,11 +126,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): failure_dict = pu.failure_to_dict(failure) response = dict(state=pr.FAILURE, result=failure_dict) ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(len(ex._remote_tasks_cache._data), 0) - self.assertEqual(self.remote_task_mock.mock_calls, [ + self.assertEqual(len(ex._requests_cache._data), 0) + self.assertEqual(self.request_mock.mock_calls, [ mock.call.set_result(result=utils.FailureMatcher(failure)) ]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) @@ -140,10 +139,10 @@ class TestWorkerTaskExecutor(test.MockTestCase): response = dict(state=pr.SUCCESS, result=self.task_result, event='executed') ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(self.remote_task_mock.mock_calls, + self.assertEqual(self.request_mock.mock_calls, [mock.call.set_result(result=self.task_result, event='executed')]) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) @@ -151,30 +150,30 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_on_message_unknown_state(self): response = dict(state='unknown') ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(self.remote_task_mock.mock_calls, []) + self.assertEqual(self.request_mock.mock_calls, []) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) def test_on_message_non_existent_task(self): self.message_mock.properties = {'correlation_id': 'non-existent'} response = dict(state=pr.RUNNING) ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(self.remote_task_mock.mock_calls, []) + self.assertEqual(self.request_mock.mock_calls, []) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) def test_on_message_no_correlation_id(self): self.message_mock.properties = {} response = dict(state=pr.RUNNING) ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock) + ex._requests_cache.set(self.task_uuid, self.request_mock) ex._on_message(response, self.message_mock) - self.assertEqual(self.remote_task_mock.mock_calls, []) + self.assertEqual(self.request_mock.mock_calls, []) self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()]) @mock.patch('taskflow.engines.worker_based.executor.LOG.exception') @@ -183,46 +182,46 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.executor()._on_message({}, self.message_mock) self.assertTrue(mocked_exception.called) - @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock') - def test_on_wait_task_not_expired(self, mocked_time): - mocked_time.side_effect = [1, self.timeout] + @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') + def test_on_wait_task_not_expired(self, mocked_wallclock): + mocked_wallclock.side_effect = [1, self.timeout] ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task()) + ex._requests_cache.set(self.task_uuid, self.request()) - self.assertEqual(len(ex._remote_tasks_cache._data), 1) + self.assertEqual(len(ex._requests_cache._data), 1) ex._on_wait() - self.assertEqual(len(ex._remote_tasks_cache._data), 1) + self.assertEqual(len(ex._requests_cache._data), 1) - @mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock') + @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') def test_on_wait_task_expired(self, mocked_time): mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2] ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, self.remote_task()) + ex._requests_cache.set(self.task_uuid, self.request()) - self.assertEqual(len(ex._remote_tasks_cache._data), 1) + self.assertEqual(len(ex._requests_cache._data), 1) ex._on_wait() - self.assertEqual(len(ex._remote_tasks_cache._data), 0) + self.assertEqual(len(ex._requests_cache._data), 0) def test_remove_task_non_existent(self): - task = self.remote_task() + task = self.request() ex = self.executor() - ex._remote_tasks_cache.set(self.task_uuid, task) + ex._requests_cache.set(self.task_uuid, task) - self.assertEqual(len(ex._remote_tasks_cache._data), 1) - ex._remote_tasks_cache.delete(self.task_uuid) - self.assertEqual(len(ex._remote_tasks_cache._data), 0) + self.assertEqual(len(ex._requests_cache._data), 1) + ex._requests_cache.delete(self.task_uuid) + self.assertEqual(len(ex._requests_cache._data), 0) # remove non-existent - ex._remote_tasks_cache.delete(self.task_uuid) - self.assertEqual(len(ex._remote_tasks_cache._data), 0) + ex._requests_cache.delete(self.task_uuid) + self.assertEqual(len(ex._requests_cache._data), 0) def test_execute_task(self): - request = self.request(action='execute') + request_dict = self.request_dict(action='execute') ex = self.executor() result = ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.proxy.publish(request, + mock.call.proxy.publish(request_dict, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) @@ -231,15 +230,15 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertIsInstance(result, futures.Future) def test_revert_task(self): - request = self.request(action='revert', - result=('success', self.task_result), - failures=self.task_failures) + request_dict = self.request_dict(action='revert', + result=('success', self.task_result), + failures=self.task_failures) ex = self.executor() result = ex.revert_task(self.task, self.task_uuid, self.task_args, self.task_result, self.task_failures) expected_calls = [ - mock.call.proxy.publish(request, + mock.call.proxy.publish(request_dict, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) @@ -262,12 +261,12 @@ class TestWorkerTaskExecutor(test.MockTestCase): def test_execute_task_publish_error(self): self.proxy_inst_mock.publish.side_effect = Exception('Woot!') - request = self.request(action='execute') + request_dict = self.request_dict(action='execute') ex = self.executor() result = ex.execute_task(self.task, self.task_uuid, self.task_args) expected_calls = [ - mock.call.proxy.publish(request, + mock.call.proxy.publish(request_dict, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) diff --git a/taskflow/tests/unit/worker_based/test_protocol.py b/taskflow/tests/unit/worker_based/test_protocol.py new file mode 100644 index 000000000..ae805621d --- /dev/null +++ b/taskflow/tests/unit/worker_based/test_protocol.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from concurrent import futures + +from taskflow.engines.worker_based import protocol as pr +from taskflow import test +from taskflow.tests import utils +from taskflow.utils import misc +from taskflow.utils import persistence_utils as pu + + +class TestProtocol(test.TestCase): + + def setUp(self): + super(TestProtocol, self).setUp() + self.task = utils.DummyTask() + self.task_uuid = 'task-uuid' + self.task_action = 'execute' + self.task_args = {'a': 'a'} + self.timeout = 60 + + def request(self, **kwargs): + request_kwargs = dict(task=self.task, + uuid=self.task_uuid, + action=self.task_action, + arguments=self.task_args, + progress_callback=None, + timeout=self.timeout) + request_kwargs.update(kwargs) + return pr.Request(**request_kwargs) + + def request_to_dict(self, **kwargs): + to_dict = dict(task_cls=self.task.name, + task_name=self.task.name, + task_version=self.task.version, + action=self.task_action, + arguments=self.task_args) + to_dict.update(kwargs) + return to_dict + + def test_creation(self): + request = self.request() + self.assertEqual(request.uuid, self.task_uuid) + self.assertEqual(request.task_cls, self.task.name) + self.assertIsInstance(request.result, futures.Future) + self.assertFalse(request.result.done()) + + def test_repr(self): + expected_name = '%s:%s' % (self.task.name, self.task_action) + self.assertEqual(repr(self.request()), expected_name) + + def test_to_dict_default(self): + self.assertEqual(self.request().to_dict(), self.request_to_dict()) + + def test_to_dict_with_result(self): + self.assertEqual(self.request(result=333).to_dict(), + self.request_to_dict(result=('success', 333))) + + def test_to_dict_with_result_none(self): + self.assertEqual(self.request(result=None).to_dict(), + self.request_to_dict(result=('success', None))) + + def test_to_dict_with_result_failure(self): + failure = misc.Failure.from_exception(RuntimeError('Woot!')) + expected = self.request_to_dict( + result=('failure', pu.failure_to_dict(failure))) + self.assertEqual(self.request(result=failure).to_dict(), expected) + + def test_to_dict_with_failures(self): + failure = misc.Failure.from_exception(RuntimeError('Woot!')) + request = self.request(failures={self.task.name: failure}) + expected = self.request_to_dict( + failures={self.task.name: pu.failure_to_dict(failure)}) + self.assertEqual(request.to_dict(), expected) + + @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') + def test_pending_not_expired(self, mocked_wallclock): + mocked_wallclock.side_effect = [1, self.timeout] + self.assertFalse(self.request().expired) + + @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') + def test_pending_expired(self, mocked_wallclock): + mocked_wallclock.side_effect = [1, self.timeout + 2] + self.assertTrue(self.request().expired) + + @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') + def test_running_not_expired(self, mocked_wallclock): + mocked_wallclock.side_effect = [1, self.timeout + 2] + request = self.request() + request.set_running() + self.assertFalse(request.expired) + + def test_set_result(self): + request = self.request() + request.set_result(111) + result = request.result.result() + self.assertEqual(result, (self.task, 'executed', 111)) + + def test_on_progress(self): + progress_callback = mock.MagicMock(name='progress_callback') + request = self.request(task=self.task, + progress_callback=progress_callback) + request.on_progress('event_data', 0.0) + request.on_progress('event_data', 1.0) + + expected_calls = [ + mock.call(self.task, 'event_data', 0.0), + mock.call(self.task, 'event_data', 1.0) + ] + self.assertEqual(progress_callback.mock_calls, expected_calls) diff --git a/taskflow/tests/unit/worker_based/test_remote_task.py b/taskflow/tests/unit/worker_based/test_remote_task.py deleted file mode 100644 index 24345189d..000000000 --- a/taskflow/tests/unit/worker_based/test_remote_task.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import mock - -from concurrent import futures - -from taskflow.engines.worker_based import remote_task as rt -from taskflow import test -from taskflow.tests import utils -from taskflow.utils import misc -from taskflow.utils import persistence_utils as pu - - -class TestRemoteTask(test.TestCase): - - def setUp(self): - super(TestRemoteTask, self).setUp() - self.task = utils.DummyTask() - self.task_uuid = 'task-uuid' - self.task_action = 'execute' - self.task_args = {'context': 'context'} - self.timeout = 60 - - def remote_task(self, **kwargs): - task_kwargs = dict(task=self.task, - uuid=self.task_uuid, - action=self.task_action, - arguments=self.task_args, - progress_callback=None, - timeout=self.timeout) - task_kwargs.update(kwargs) - return rt.RemoteTask(**task_kwargs) - - def remote_task_request(self, **kwargs): - request = dict(task=self.task.name, - task_name=self.task.name, - task_version=self.task.version, - action=self.task_action, - arguments=self.task_args) - request.update(kwargs) - return request - - def test_creation(self): - remote_task = self.remote_task() - self.assertEqual(remote_task.uuid, self.task_uuid) - self.assertEqual(remote_task.name, self.task.name) - self.assertIsInstance(remote_task.result, futures.Future) - self.assertFalse(remote_task.result.done()) - - def test_repr(self): - expected_name = '%s:%s' % (self.task.name, self.task_action) - self.assertEqual(repr(self.remote_task()), expected_name) - - def test_request(self): - remote_task = self.remote_task() - request = self.remote_task_request() - self.assertEqual(remote_task.request, request) - - def test_request_with_result(self): - remote_task = self.remote_task(result=333) - request = self.remote_task_request(result=('success', 333)) - self.assertEqual(remote_task.request, request) - - def test_request_with_result_none(self): - remote_task = self.remote_task(result=None) - request = self.remote_task_request(result=('success', None)) - self.assertEqual(remote_task.request, request) - - def test_request_with_result_failure(self): - failure = misc.Failure.from_exception(RuntimeError('Woot!')) - remote_task = self.remote_task(result=failure) - request = self.remote_task_request( - result=('failure', pu.failure_to_dict(failure))) - self.assertEqual(remote_task.request, request) - - def test_request_with_failures(self): - failure = misc.Failure.from_exception(RuntimeError('Woot!')) - remote_task = self.remote_task(failures={self.task.name: failure}) - request = self.remote_task_request( - failures={self.task.name: pu.failure_to_dict(failure)}) - self.assertEqual(remote_task.request, request) - - @mock.patch('time.time') - def test_pending_not_expired(self, mock_time): - mock_time.side_effect = [1, self.timeout] - remote_task = self.remote_task() - self.assertFalse(remote_task.expired) - - @mock.patch('time.time') - def test_pending_expired(self, mock_time): - mock_time.side_effect = [1, self.timeout + 2] - remote_task = self.remote_task() - self.assertTrue(remote_task.expired) - - @mock.patch('time.time') - def test_running_not_expired(self, mock_time): - mock_time.side_effect = [1, self.timeout] - remote_task = self.remote_task() - remote_task.set_running() - self.assertFalse(remote_task.expired) - - def test_set_result(self): - remote_task = self.remote_task() - remote_task.set_result(111) - result = remote_task.result.result() - self.assertEqual(result, (self.task, 'executed', 111)) - - def test_on_progress(self): - progress_callback = mock.MagicMock(name='progress_callback') - remote_task = self.remote_task(task=self.task, - progress_callback=progress_callback) - remote_task.on_progress('event_data', 0.0) - remote_task.on_progress('event_data', 1.0) - - expected_calls = [ - mock.call(self.task, 'event_data', 0.0), - mock.call(self.task, 'event_data', 1.0) - ] - self.assertEqual(progress_callback.mock_calls, expected_calls) diff --git a/taskflow/tests/unit/worker_based/test_server.py b/taskflow/tests/unit/worker_based/test_server.py index 3d2df789a..af7abc05d 100644 --- a/taskflow/tests/unit/worker_based/test_server.py +++ b/taskflow/tests/unit/worker_based/test_server.py @@ -71,7 +71,7 @@ class TestServer(test.MockTestCase): return s def request(self, **kwargs): - request = dict(task=self.task_name, + request = dict(task_cls=self.task_name, task_name=self.task_name, action=self.task_action, task_version=self.task_version, @@ -164,18 +164,18 @@ class TestServer(test.MockTestCase): def test_parse_request(self): request = self.request() - task, action, task_args = server.Server._parse_request(**request) + task_cls, action, task_args = server.Server._parse_request(**request) - self.assertEqual((task, action, task_args), + self.assertEqual((task_cls, action, task_args), (self.task_name, self.task_action, dict(task_name=self.task_name, arguments=self.task_args))) def test_parse_request_with_success_result(self): request = self.request(action='revert', result=('success', 1)) - task, action, task_args = server.Server._parse_request(**request) + task_cls, action, task_args = server.Server._parse_request(**request) - self.assertEqual((task, action, task_args), + self.assertEqual((task_cls, action, task_args), (self.task_name, 'revert', dict(task_name=self.task_name, arguments=self.task_args, @@ -186,9 +186,9 @@ class TestServer(test.MockTestCase): failure_dict = pu.failure_to_dict(failure) request = self.request(action='revert', result=('failure', failure_dict)) - task, action, task_args = server.Server._parse_request(**request) + task_cls, action, task_args = server.Server._parse_request(**request) - self.assertEqual((task, action, task_args), + self.assertEqual((task_cls, action, task_args), (self.task_name, 'revert', dict(task_name=self.task_name, arguments=self.task_args, @@ -200,10 +200,10 @@ class TestServer(test.MockTestCase): failures_dict = dict((str(i), pu.failure_to_dict(f)) for i, f in enumerate(failures)) request = self.request(action='revert', failures=failures_dict) - task, action, task_args = server.Server._parse_request(**request) + task_cls, action, task_args = server.Server._parse_request(**request) self.assertEqual( - (task, action, task_args), + (task_cls, action, task_args), (self.task_name, 'revert', dict(task_name=self.task_name, arguments=self.task_args, @@ -225,7 +225,7 @@ class TestServer(test.MockTestCase): self.assertTrue(mocked_exception.called) def test_on_update_progress(self): - request = self.request(task='taskflow.tests.utils.ProgressingTask', + request = self.request(task_cls='taskflow.tests.utils.ProgressingTask', arguments={}) # create server and process request @@ -292,7 +292,7 @@ class TestServer(test.MockTestCase): def test_process_request_endpoint_not_found(self, pu_mock): failure_dict = 'failure_dict' pu_mock.failure_to_dict.return_value = failure_dict - request = self.request(task='') + request = self.request(task_cls='') # create server and process request s = self.server(reset_master_mock=True, endpoints=self.endpoints)