Rename remote task to request

* Renamed remote task -> request and moved to protocol.py;
* Used `to_dict` method instead of `request` property for request objects;
* Renamed `name` request property to `task_cls`;
* Corrected unit tests.

Change-Id: I6133748ab5064391480f031971c38a56cb7f4f9f
This commit is contained in:
Stanislav Kudriashev
2014-03-11 19:04:29 +02:00
parent 8d3c00ccfd
commit af4e009064
8 changed files with 303 additions and 330 deletions

View File

@@ -25,7 +25,6 @@ from taskflow.engines.action_engine import executor
from taskflow.engines.worker_based import cache
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow.engines.worker_based import remote_task as rt
from taskflow import exceptions as exc
from taskflow.utils import async_utils
from taskflow.utils import misc
@@ -42,7 +41,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
self._on_wait, **kwargs)
self._proxy_thread = None
self._remote_tasks_cache = cache.Cache()
self._requests_cache = cache.Cache()
# TODO(skudriashev): This data should be collected from workers
# using broadcast messages directly.
@@ -80,61 +79,61 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
def _process_response(self, task_uuid, response):
"""Process response from remote side."""
remote_task = self._remote_tasks_cache.get(task_uuid)
if remote_task is not None:
request = self._requests_cache.get(task_uuid)
if request is not None:
state = response.pop('state')
if state == pr.RUNNING:
remote_task.set_running()
request.set_running()
elif state == pr.PROGRESS:
remote_task.on_progress(**response)
request.on_progress(**response)
elif state == pr.FAILURE:
response['result'] = pu.failure_from_dict(response['result'])
remote_task.set_result(**response)
self._remote_tasks_cache.delete(remote_task.uuid)
request.set_result(**response)
self._requests_cache.delete(request.uuid)
elif state == pr.SUCCESS:
remote_task.set_result(**response)
self._remote_tasks_cache.delete(remote_task.uuid)
request.set_result(**response)
self._requests_cache.delete(request.uuid)
else:
LOG.warning("Unexpected response status: '%s'", state)
else:
LOG.debug("Remote task with id='%s' not found.", task_uuid)
LOG.debug("Request with id='%s' not found.", task_uuid)
@staticmethod
def _handle_expired_remote_task(task):
LOG.debug("Remote task '%r' has expired.", task)
task.set_result(misc.Failure.from_exception(
exc.Timeout("Remote task '%r' has expired" % task)))
def _handle_expired_request(request):
LOG.debug("Request '%r' has expired.", request)
request.set_result(misc.Failure.from_exception(
exc.Timeout("Request '%r' has expired" % request)))
def _on_wait(self):
"""This function is called cyclically between draining events."""
self._remote_tasks_cache.cleanup(self._handle_expired_remote_task)
self._requests_cache.cleanup(self._handle_expired_request)
def _submit_task(self, task, task_uuid, action, arguments,
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
"""Submit task request to workers."""
remote_task = rt.RemoteTask(task, task_uuid, action, arguments,
progress_callback, timeout, **kwargs)
self._remote_tasks_cache.set(remote_task.uuid, remote_task)
request = pr.Request(task, task_uuid, action, arguments,
progress_callback, timeout, **kwargs)
self._requests_cache.set(request.uuid, request)
try:
# get task's workers topic to send request to
try:
topic = self._workers_info[remote_task.name]
topic = self._workers_info[request.task_cls]
except KeyError:
raise exc.NotFound("Workers topic not found for the '%s'"
" task" % remote_task.name)
" task" % request.task_cls)
else:
# publish request
request = remote_task.request
LOG.debug("Sending request: %s", request)
self._proxy.publish(request, routing_key=topic,
self._proxy.publish(request.to_dict(),
routing_key=topic,
reply_to=self._uuid,
correlation_id=remote_task.uuid)
correlation_id=request.uuid)
except Exception:
with misc.capture_failure() as failure:
LOG.exception("Failed to submit the '%s' task", remote_task)
self._remote_tasks_cache.delete(remote_task.uuid)
remote_task.set_result(failure)
return remote_task.result
LOG.exception("Failed to submit the '%s' task", request)
self._requests_cache.delete(request.uuid)
request.set_result(failure)
return request.result
def execute_task(self, task, task_uuid, arguments,
progress_callback=None):

View File

@@ -14,7 +14,12 @@
# License for the specific language governing permissions and limitations
# under the License.
from concurrent import futures
from taskflow.engines.action_engine import executor
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
from taskflow.utils import reflection
# NOTE(skudriashev): This is protocol events, not related to the task states.
PENDING = 'PENDING'
@@ -43,3 +48,80 @@ REQUEST_TIMEOUT = 60
# period is equal to the request timeout, once request is expired - queue is
# no longer needed.
QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT
class Request(object):
"""Represents request with execution results. Every request is created in
the PENDING state and is expired within the given timeout.
"""
def __init__(self, task, uuid, action, arguments, progress_callback,
timeout, **kwargs):
self._task = task
self._task_cls = reflection.get_class_name(task)
self._uuid = uuid
self._action = action
self._event = ACTION_TO_EVENT[action]
self._arguments = arguments
self._progress_callback = progress_callback
self._kwargs = kwargs
self._watch = misc.StopWatch(duration=timeout).start()
self._state = PENDING
self.result = futures.Future()
def __repr__(self):
return "%s:%s" % (self._task_cls, self._action)
@property
def uuid(self):
return self._uuid
@property
def task_cls(self):
return self._task_cls
@property
def expired(self):
"""Check if request is expired.
When new request is created its state is set to the PENDING, creation
time is stored and timeout is given via constructor arguments.
Request is considered to be expired when it is in the PENDING state
for more then the given timeout (it is not considered to be expired
in any other state). After request is expired - the `Timeout`
exception is raised and task is removed from the requests map.
"""
if self._state == PENDING:
return self._watch.expired()
return False
def to_dict(self):
"""Return json-serializable request, converting all `misc.Failure`
objects into dictionaries.
"""
request = dict(task_cls=self._task_cls, task_name=self._task.name,
task_version=self._task.version, action=self._action,
arguments=self._arguments)
if 'result' in self._kwargs:
result = self._kwargs['result']
if isinstance(result, misc.Failure):
request['result'] = ('failure', pu.failure_to_dict(result))
else:
request['result'] = ('success', result)
if 'failures' in self._kwargs:
failures = self._kwargs['failures']
request['failures'] = {}
for task, failure in failures.items():
request['failures'][task] = pu.failure_to_dict(failure)
return request
def set_result(self, result):
self.result.set_result((self._task, self._event, result))
def set_running(self):
self._state = RUNNING
self._watch.stop()
def on_progress(self, event_data, progress):
self._progress_callback(self._task, event_data, progress)

View File

@@ -1,101 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from concurrent import futures
from taskflow.engines.worker_based import protocol as pr
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
from taskflow.utils import reflection
class RemoteTask(object):
"""Represents remote task with its request data and execution results.
Every remote task is created in the PENDING state and will be expired
within the given timeout.
"""
def __init__(self, task, uuid, action, arguments, progress_callback,
timeout, **kwargs):
self._task = task
self._name = reflection.get_class_name(task)
self._uuid = uuid
self._action = action
self._event = pr.ACTION_TO_EVENT[action]
self._arguments = arguments
self._progress_callback = progress_callback
self._kwargs = kwargs
self._watch = misc.StopWatch(duration=timeout).start()
self._state = pr.PENDING
self.result = futures.Future()
def __repr__(self):
return "%s:%s" % (self._name, self._action)
@property
def uuid(self):
return self._uuid
@property
def name(self):
return self._name
@property
def request(self):
"""Return json-serializable task request, converting all `misc.Failure`
objects into dictionaries.
"""
request = dict(task=self._name, task_name=self._task.name,
task_version=self._task.version, action=self._action,
arguments=self._arguments)
if 'result' in self._kwargs:
result = self._kwargs['result']
if isinstance(result, misc.Failure):
request['result'] = ('failure', pu.failure_to_dict(result))
else:
request['result'] = ('success', result)
if 'failures' in self._kwargs:
failures = self._kwargs['failures']
request['failures'] = {}
for task, failure in failures.items():
request['failures'][task] = pu.failure_to_dict(failure)
return request
@property
def expired(self):
"""Check if task is expired.
When new remote task is created its state is set to the PENDING,
creation time is stored and timeout is given via constructor arguments.
Remote task is considered to be expired when it is in the PENDING state
for more then the given timeout (task is not considered to be expired
in any other state). After remote task is expired - the `Timeout`
exception is raised and task is removed from the remote tasks map.
"""
if self._state == pr.PENDING:
return self._watch.expired()
return False
def set_result(self, result):
self.result.set_result((self._task, self._event, result))
def set_running(self):
self._state = pr.RUNNING
self._watch.stop()
def on_progress(self, event_data, progress):
self._progress_callback(self._task, event_data, progress)

View File

@@ -63,7 +63,7 @@ class Server(object):
LOG.debug("AMQP message requeued.")
@staticmethod
def _parse_request(task, task_name, action, arguments, result=None,
def _parse_request(task_cls, task_name, action, arguments, result=None,
failures=None, **kwargs):
"""Parse request before it can be processed. All `misc.Failure` objects
that have been converted to dict on the remote side to be serializable
@@ -80,7 +80,7 @@ class Server(object):
action_args['failures'] = {}
for k, v in failures.items():
action_args['failures'][k] = pu.failure_from_dict(v)
return task, action, action_args
return task_cls, action, action_args
@staticmethod
def _parse_message(message):
@@ -132,7 +132,7 @@ class Server(object):
# parse request to get task name, action and action arguments
try:
task, action, action_args = self._parse_request(**request)
task_cls, action, action_args = self._parse_request(**request)
action_args.update(task_uuid=task_uuid,
progress_callback=progress_callback)
except ValueError:
@@ -143,10 +143,11 @@ class Server(object):
# get task endpoint
try:
endpoint = self._endpoints[task]
endpoint = self._endpoints[task_cls]
except KeyError:
with misc.capture_failure() as failure:
LOG.exception("The '%s' task endpoint does not exist", task)
LOG.exception("The '%s' task endpoint does not exist",
task_cls)
reply_callback(result=pu.failure_to_dict(failure))
return
else:

View File

@@ -23,7 +23,6 @@ from kombu import exceptions as kombu_exc
from taskflow.engines.worker_based import executor
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import remote_task as rt
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
@@ -36,11 +35,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
super(TestWorkerTaskExecutor, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_args = {'context': 'context'}
self.task_args = {'a': 'a'}
self.task_result = 'task-result'
self.task_failures = {}
self.timeout = 60
self.broker_url = 'test-url'
self.broker_url = 'broker-url'
self.executor_uuid = 'executor-uuid'
self.executor_exchange = 'executor-exchange'
self.executor_topic = 'executor-topic'
@@ -58,7 +57,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
'taskflow.engines.worker_based.executor.async_utils.wait_for_any')
self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid}
self.remote_task_mock = mock.MagicMock(uuid=self.task_uuid)
self.request_mock = mock.MagicMock(uuid=self.task_uuid)
def _fake_proxy_start(self):
self.proxy_started_event.set()
@@ -80,19 +79,19 @@ class TestWorkerTaskExecutor(test.MockTestCase):
return ex
def request(self, **kwargs):
request = dict(task=self.task.name, task_name=self.task.name,
request_kwargs = dict(task=self.task, uuid=self.task_uuid,
action='execute', arguments=self.task_args,
progress_callback=None, timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
def request_dict(self, **kwargs):
request = dict(task_cls=self.task.name, task_name=self.task.name,
task_version=self.task.version,
arguments=self.task_args)
request.update(kwargs)
return request
def remote_task(self, **kwargs):
remote_task_kwargs = dict(task=self.task, uuid=self.task_uuid,
action='execute', arguments=self.task_args,
progress_callback=None, timeout=self.timeout)
remote_task_kwargs.update(kwargs)
return rt.RemoteTask(**remote_task_kwargs)
def test_creation(self):
ex = self.executor(reset_master_mock=False)
@@ -105,20 +104,20 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_state_running(self):
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
self.assertEqual(self.request_mock.mock_calls,
[mock.call.set_running()])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_progress(self):
response = dict(state=pr.PROGRESS, progress=1.0)
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
self.assertEqual(self.request_mock.mock_calls,
[mock.call.on_progress(progress=1.0)])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
@@ -127,11 +126,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
failure_dict = pu.failure_to_dict(failure)
response = dict(state=pr.FAILURE, result=failure_dict)
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
self.assertEqual(self.remote_task_mock.mock_calls, [
self.assertEqual(len(ex._requests_cache._data), 0)
self.assertEqual(self.request_mock.mock_calls, [
mock.call.set_result(result=utils.FailureMatcher(failure))
])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
@@ -140,10 +139,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
response = dict(state=pr.SUCCESS, result=self.task_result,
event='executed')
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls,
self.assertEqual(self.request_mock.mock_calls,
[mock.call.set_result(result=self.task_result,
event='executed')])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
@@ -151,30 +150,30 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_on_message_unknown_state(self):
response = dict(state='unknown')
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
self.assertEqual(self.request_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_non_existent_task(self):
self.message_mock.properties = {'correlation_id': 'non-existent'}
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
self.assertEqual(self.request_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_no_correlation_id(self):
self.message_mock.properties = {}
response = dict(state=pr.RUNNING)
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
self.assertEqual(self.remote_task_mock.mock_calls, [])
self.assertEqual(self.request_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
@mock.patch('taskflow.engines.worker_based.executor.LOG.exception')
@@ -183,46 +182,46 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.executor()._on_message({}, self.message_mock)
self.assertTrue(mocked_exception.called)
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
def test_on_wait_task_not_expired(self, mocked_time):
mocked_time.side_effect = [1, self.timeout]
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_on_wait_task_not_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout]
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
ex._requests_cache.set(self.task_uuid, self.request())
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
self.assertEqual(len(ex._requests_cache._data), 1)
ex._on_wait()
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
self.assertEqual(len(ex._requests_cache._data), 1)
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_on_wait_task_expired(self, mocked_time):
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
ex._requests_cache.set(self.task_uuid, self.request())
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
self.assertEqual(len(ex._requests_cache._data), 1)
ex._on_wait()
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
self.assertEqual(len(ex._requests_cache._data), 0)
def test_remove_task_non_existent(self):
task = self.remote_task()
task = self.request()
ex = self.executor()
ex._remote_tasks_cache.set(self.task_uuid, task)
ex._requests_cache.set(self.task_uuid, task)
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
ex._remote_tasks_cache.delete(self.task_uuid)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
self.assertEqual(len(ex._requests_cache._data), 1)
ex._requests_cache.delete(self.task_uuid)
self.assertEqual(len(ex._requests_cache._data), 0)
# remove non-existent
ex._remote_tasks_cache.delete(self.task_uuid)
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
ex._requests_cache.delete(self.task_uuid)
self.assertEqual(len(ex._requests_cache._data), 0)
def test_execute_task(self):
request = self.request(action='execute')
request_dict = self.request_dict(action='execute')
ex = self.executor()
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
mock.call.proxy.publish(request,
mock.call.proxy.publish(request_dict,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
@@ -231,15 +230,15 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertIsInstance(result, futures.Future)
def test_revert_task(self):
request = self.request(action='revert',
result=('success', self.task_result),
failures=self.task_failures)
request_dict = self.request_dict(action='revert',
result=('success', self.task_result),
failures=self.task_failures)
ex = self.executor()
result = ex.revert_task(self.task, self.task_uuid, self.task_args,
self.task_result, self.task_failures)
expected_calls = [
mock.call.proxy.publish(request,
mock.call.proxy.publish(request_dict,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
@@ -262,12 +261,12 @@ class TestWorkerTaskExecutor(test.MockTestCase):
def test_execute_task_publish_error(self):
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
request = self.request(action='execute')
request_dict = self.request_dict(action='execute')
ex = self.executor()
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
mock.call.proxy.publish(request,
mock.call.proxy.publish(request_dict,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)

View File

@@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from concurrent import futures
from taskflow.engines.worker_based import protocol as pr
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
class TestProtocol(test.TestCase):
def setUp(self):
super(TestProtocol, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_action = 'execute'
self.task_args = {'a': 'a'}
self.timeout = 60
def request(self, **kwargs):
request_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
def request_to_dict(self, **kwargs):
to_dict = dict(task_cls=self.task.name,
task_name=self.task.name,
task_version=self.task.version,
action=self.task_action,
arguments=self.task_args)
to_dict.update(kwargs)
return to_dict
def test_creation(self):
request = self.request()
self.assertEqual(request.uuid, self.task_uuid)
self.assertEqual(request.task_cls, self.task.name)
self.assertIsInstance(request.result, futures.Future)
self.assertFalse(request.result.done())
def test_repr(self):
expected_name = '%s:%s' % (self.task.name, self.task_action)
self.assertEqual(repr(self.request()), expected_name)
def test_to_dict_default(self):
self.assertEqual(self.request().to_dict(), self.request_to_dict())
def test_to_dict_with_result(self):
self.assertEqual(self.request(result=333).to_dict(),
self.request_to_dict(result=('success', 333)))
def test_to_dict_with_result_none(self):
self.assertEqual(self.request(result=None).to_dict(),
self.request_to_dict(result=('success', None)))
def test_to_dict_with_result_failure(self):
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
expected = self.request_to_dict(
result=('failure', pu.failure_to_dict(failure)))
self.assertEqual(self.request(result=failure).to_dict(), expected)
def test_to_dict_with_failures(self):
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
request = self.request(failures={self.task.name: failure})
expected = self.request_to_dict(
failures={self.task.name: pu.failure_to_dict(failure)})
self.assertEqual(request.to_dict(), expected)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_pending_not_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout]
self.assertFalse(self.request().expired)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_pending_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout + 2]
self.assertTrue(self.request().expired)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_running_not_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout + 2]
request = self.request()
request.set_running()
self.assertFalse(request.expired)
def test_set_result(self):
request = self.request()
request.set_result(111)
result = request.result.result()
self.assertEqual(result, (self.task, 'executed', 111))
def test_on_progress(self):
progress_callback = mock.MagicMock(name='progress_callback')
request = self.request(task=self.task,
progress_callback=progress_callback)
request.on_progress('event_data', 0.0)
request.on_progress('event_data', 1.0)
expected_calls = [
mock.call(self.task, 'event_data', 0.0),
mock.call(self.task, 'event_data', 1.0)
]
self.assertEqual(progress_callback.mock_calls, expected_calls)

View File

@@ -1,133 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from concurrent import futures
from taskflow.engines.worker_based import remote_task as rt
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
class TestRemoteTask(test.TestCase):
def setUp(self):
super(TestRemoteTask, self).setUp()
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
self.task_action = 'execute'
self.task_args = {'context': 'context'}
self.timeout = 60
def remote_task(self, **kwargs):
task_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=self.timeout)
task_kwargs.update(kwargs)
return rt.RemoteTask(**task_kwargs)
def remote_task_request(self, **kwargs):
request = dict(task=self.task.name,
task_name=self.task.name,
task_version=self.task.version,
action=self.task_action,
arguments=self.task_args)
request.update(kwargs)
return request
def test_creation(self):
remote_task = self.remote_task()
self.assertEqual(remote_task.uuid, self.task_uuid)
self.assertEqual(remote_task.name, self.task.name)
self.assertIsInstance(remote_task.result, futures.Future)
self.assertFalse(remote_task.result.done())
def test_repr(self):
expected_name = '%s:%s' % (self.task.name, self.task_action)
self.assertEqual(repr(self.remote_task()), expected_name)
def test_request(self):
remote_task = self.remote_task()
request = self.remote_task_request()
self.assertEqual(remote_task.request, request)
def test_request_with_result(self):
remote_task = self.remote_task(result=333)
request = self.remote_task_request(result=('success', 333))
self.assertEqual(remote_task.request, request)
def test_request_with_result_none(self):
remote_task = self.remote_task(result=None)
request = self.remote_task_request(result=('success', None))
self.assertEqual(remote_task.request, request)
def test_request_with_result_failure(self):
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
remote_task = self.remote_task(result=failure)
request = self.remote_task_request(
result=('failure', pu.failure_to_dict(failure)))
self.assertEqual(remote_task.request, request)
def test_request_with_failures(self):
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
remote_task = self.remote_task(failures={self.task.name: failure})
request = self.remote_task_request(
failures={self.task.name: pu.failure_to_dict(failure)})
self.assertEqual(remote_task.request, request)
@mock.patch('time.time')
def test_pending_not_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout]
remote_task = self.remote_task()
self.assertFalse(remote_task.expired)
@mock.patch('time.time')
def test_pending_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout + 2]
remote_task = self.remote_task()
self.assertTrue(remote_task.expired)
@mock.patch('time.time')
def test_running_not_expired(self, mock_time):
mock_time.side_effect = [1, self.timeout]
remote_task = self.remote_task()
remote_task.set_running()
self.assertFalse(remote_task.expired)
def test_set_result(self):
remote_task = self.remote_task()
remote_task.set_result(111)
result = remote_task.result.result()
self.assertEqual(result, (self.task, 'executed', 111))
def test_on_progress(self):
progress_callback = mock.MagicMock(name='progress_callback')
remote_task = self.remote_task(task=self.task,
progress_callback=progress_callback)
remote_task.on_progress('event_data', 0.0)
remote_task.on_progress('event_data', 1.0)
expected_calls = [
mock.call(self.task, 'event_data', 0.0),
mock.call(self.task, 'event_data', 1.0)
]
self.assertEqual(progress_callback.mock_calls, expected_calls)

View File

@@ -71,7 +71,7 @@ class TestServer(test.MockTestCase):
return s
def request(self, **kwargs):
request = dict(task=self.task_name,
request = dict(task_cls=self.task_name,
task_name=self.task_name,
action=self.task_action,
task_version=self.task_version,
@@ -164,18 +164,18 @@ class TestServer(test.MockTestCase):
def test_parse_request(self):
request = self.request()
task, action, task_args = server.Server._parse_request(**request)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task, action, task_args),
self.assertEqual((task_cls, action, task_args),
(self.task_name, self.task_action,
dict(task_name=self.task_name,
arguments=self.task_args)))
def test_parse_request_with_success_result(self):
request = self.request(action='revert', result=('success', 1))
task, action, task_args = server.Server._parse_request(**request)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task, action, task_args),
self.assertEqual((task_cls, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
arguments=self.task_args,
@@ -186,9 +186,9 @@ class TestServer(test.MockTestCase):
failure_dict = pu.failure_to_dict(failure)
request = self.request(action='revert',
result=('failure', failure_dict))
task, action, task_args = server.Server._parse_request(**request)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task, action, task_args),
self.assertEqual((task_cls, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
arguments=self.task_args,
@@ -200,10 +200,10 @@ class TestServer(test.MockTestCase):
failures_dict = dict((str(i), pu.failure_to_dict(f))
for i, f in enumerate(failures))
request = self.request(action='revert', failures=failures_dict)
task, action, task_args = server.Server._parse_request(**request)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual(
(task, action, task_args),
(task_cls, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
arguments=self.task_args,
@@ -225,7 +225,7 @@ class TestServer(test.MockTestCase):
self.assertTrue(mocked_exception.called)
def test_on_update_progress(self):
request = self.request(task='taskflow.tests.utils.ProgressingTask',
request = self.request(task_cls='taskflow.tests.utils.ProgressingTask',
arguments={})
# create server and process request
@@ -292,7 +292,7 @@ class TestServer(test.MockTestCase):
def test_process_request_endpoint_not_found(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(task='<unknown>')
request = self.request(task_cls='<unknown>')
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)