Rename remote task to request
* Renamed remote task -> request and moved to protocol.py; * Used `to_dict` method instead of `request` property for request objects; * Renamed `name` request property to `task_cls`; * Corrected unit tests. Change-Id: I6133748ab5064391480f031971c38a56cb7f4f9f
This commit is contained in:
@@ -25,7 +25,6 @@ from taskflow.engines.action_engine import executor
|
||||
from taskflow.engines.worker_based import cache
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.utils import async_utils
|
||||
from taskflow.utils import misc
|
||||
@@ -42,7 +41,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
self._proxy = proxy.Proxy(uuid, exchange, self._on_message,
|
||||
self._on_wait, **kwargs)
|
||||
self._proxy_thread = None
|
||||
self._remote_tasks_cache = cache.Cache()
|
||||
self._requests_cache = cache.Cache()
|
||||
|
||||
# TODO(skudriashev): This data should be collected from workers
|
||||
# using broadcast messages directly.
|
||||
@@ -80,61 +79,61 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
|
||||
def _process_response(self, task_uuid, response):
|
||||
"""Process response from remote side."""
|
||||
remote_task = self._remote_tasks_cache.get(task_uuid)
|
||||
if remote_task is not None:
|
||||
request = self._requests_cache.get(task_uuid)
|
||||
if request is not None:
|
||||
state = response.pop('state')
|
||||
if state == pr.RUNNING:
|
||||
remote_task.set_running()
|
||||
request.set_running()
|
||||
elif state == pr.PROGRESS:
|
||||
remote_task.on_progress(**response)
|
||||
request.on_progress(**response)
|
||||
elif state == pr.FAILURE:
|
||||
response['result'] = pu.failure_from_dict(response['result'])
|
||||
remote_task.set_result(**response)
|
||||
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||
request.set_result(**response)
|
||||
self._requests_cache.delete(request.uuid)
|
||||
elif state == pr.SUCCESS:
|
||||
remote_task.set_result(**response)
|
||||
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||
request.set_result(**response)
|
||||
self._requests_cache.delete(request.uuid)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'", state)
|
||||
else:
|
||||
LOG.debug("Remote task with id='%s' not found.", task_uuid)
|
||||
LOG.debug("Request with id='%s' not found.", task_uuid)
|
||||
|
||||
@staticmethod
|
||||
def _handle_expired_remote_task(task):
|
||||
LOG.debug("Remote task '%r' has expired.", task)
|
||||
task.set_result(misc.Failure.from_exception(
|
||||
exc.Timeout("Remote task '%r' has expired" % task)))
|
||||
def _handle_expired_request(request):
|
||||
LOG.debug("Request '%r' has expired.", request)
|
||||
request.set_result(misc.Failure.from_exception(
|
||||
exc.Timeout("Request '%r' has expired" % request)))
|
||||
|
||||
def _on_wait(self):
|
||||
"""This function is called cyclically between draining events."""
|
||||
self._remote_tasks_cache.cleanup(self._handle_expired_remote_task)
|
||||
self._requests_cache.cleanup(self._handle_expired_request)
|
||||
|
||||
def _submit_task(self, task, task_uuid, action, arguments,
|
||||
progress_callback, timeout=pr.REQUEST_TIMEOUT, **kwargs):
|
||||
"""Submit task request to workers."""
|
||||
remote_task = rt.RemoteTask(task, task_uuid, action, arguments,
|
||||
progress_callback, timeout, **kwargs)
|
||||
self._remote_tasks_cache.set(remote_task.uuid, remote_task)
|
||||
request = pr.Request(task, task_uuid, action, arguments,
|
||||
progress_callback, timeout, **kwargs)
|
||||
self._requests_cache.set(request.uuid, request)
|
||||
try:
|
||||
# get task's workers topic to send request to
|
||||
try:
|
||||
topic = self._workers_info[remote_task.name]
|
||||
topic = self._workers_info[request.task_cls]
|
||||
except KeyError:
|
||||
raise exc.NotFound("Workers topic not found for the '%s'"
|
||||
" task" % remote_task.name)
|
||||
" task" % request.task_cls)
|
||||
else:
|
||||
# publish request
|
||||
request = remote_task.request
|
||||
LOG.debug("Sending request: %s", request)
|
||||
self._proxy.publish(request, routing_key=topic,
|
||||
self._proxy.publish(request.to_dict(),
|
||||
routing_key=topic,
|
||||
reply_to=self._uuid,
|
||||
correlation_id=remote_task.uuid)
|
||||
correlation_id=request.uuid)
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.exception("Failed to submit the '%s' task", remote_task)
|
||||
self._remote_tasks_cache.delete(remote_task.uuid)
|
||||
remote_task.set_result(failure)
|
||||
return remote_task.result
|
||||
LOG.exception("Failed to submit the '%s' task", request)
|
||||
self._requests_cache.delete(request.uuid)
|
||||
request.set_result(failure)
|
||||
return request.result
|
||||
|
||||
def execute_task(self, task, task_uuid, arguments,
|
||||
progress_callback=None):
|
||||
|
@@ -14,7 +14,12 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
from taskflow.utils import reflection
|
||||
|
||||
# NOTE(skudriashev): This is protocol events, not related to the task states.
|
||||
PENDING = 'PENDING'
|
||||
@@ -43,3 +48,80 @@ REQUEST_TIMEOUT = 60
|
||||
# period is equal to the request timeout, once request is expired - queue is
|
||||
# no longer needed.
|
||||
QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT
|
||||
|
||||
|
||||
class Request(object):
|
||||
"""Represents request with execution results. Every request is created in
|
||||
the PENDING state and is expired within the given timeout.
|
||||
"""
|
||||
|
||||
def __init__(self, task, uuid, action, arguments, progress_callback,
|
||||
timeout, **kwargs):
|
||||
self._task = task
|
||||
self._task_cls = reflection.get_class_name(task)
|
||||
self._uuid = uuid
|
||||
self._action = action
|
||||
self._event = ACTION_TO_EVENT[action]
|
||||
self._arguments = arguments
|
||||
self._progress_callback = progress_callback
|
||||
self._kwargs = kwargs
|
||||
self._watch = misc.StopWatch(duration=timeout).start()
|
||||
self._state = PENDING
|
||||
self.result = futures.Future()
|
||||
|
||||
def __repr__(self):
|
||||
return "%s:%s" % (self._task_cls, self._action)
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def task_cls(self):
|
||||
return self._task_cls
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Check if request is expired.
|
||||
|
||||
When new request is created its state is set to the PENDING, creation
|
||||
time is stored and timeout is given via constructor arguments.
|
||||
|
||||
Request is considered to be expired when it is in the PENDING state
|
||||
for more then the given timeout (it is not considered to be expired
|
||||
in any other state). After request is expired - the `Timeout`
|
||||
exception is raised and task is removed from the requests map.
|
||||
"""
|
||||
if self._state == PENDING:
|
||||
return self._watch.expired()
|
||||
return False
|
||||
|
||||
def to_dict(self):
|
||||
"""Return json-serializable request, converting all `misc.Failure`
|
||||
objects into dictionaries.
|
||||
"""
|
||||
request = dict(task_cls=self._task_cls, task_name=self._task.name,
|
||||
task_version=self._task.version, action=self._action,
|
||||
arguments=self._arguments)
|
||||
if 'result' in self._kwargs:
|
||||
result = self._kwargs['result']
|
||||
if isinstance(result, misc.Failure):
|
||||
request['result'] = ('failure', pu.failure_to_dict(result))
|
||||
else:
|
||||
request['result'] = ('success', result)
|
||||
if 'failures' in self._kwargs:
|
||||
failures = self._kwargs['failures']
|
||||
request['failures'] = {}
|
||||
for task, failure in failures.items():
|
||||
request['failures'][task] = pu.failure_to_dict(failure)
|
||||
return request
|
||||
|
||||
def set_result(self, result):
|
||||
self.result.set_result((self._task, self._event, result))
|
||||
|
||||
def set_running(self):
|
||||
self._state = RUNNING
|
||||
self._watch.stop()
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
||||
|
@@ -1,101 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
from taskflow.utils import reflection
|
||||
|
||||
|
||||
class RemoteTask(object):
|
||||
"""Represents remote task with its request data and execution results.
|
||||
Every remote task is created in the PENDING state and will be expired
|
||||
within the given timeout.
|
||||
"""
|
||||
|
||||
def __init__(self, task, uuid, action, arguments, progress_callback,
|
||||
timeout, **kwargs):
|
||||
self._task = task
|
||||
self._name = reflection.get_class_name(task)
|
||||
self._uuid = uuid
|
||||
self._action = action
|
||||
self._event = pr.ACTION_TO_EVENT[action]
|
||||
self._arguments = arguments
|
||||
self._progress_callback = progress_callback
|
||||
self._kwargs = kwargs
|
||||
self._watch = misc.StopWatch(duration=timeout).start()
|
||||
self._state = pr.PENDING
|
||||
self.result = futures.Future()
|
||||
|
||||
def __repr__(self):
|
||||
return "%s:%s" % (self._name, self._action)
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
"""Return json-serializable task request, converting all `misc.Failure`
|
||||
objects into dictionaries.
|
||||
"""
|
||||
request = dict(task=self._name, task_name=self._task.name,
|
||||
task_version=self._task.version, action=self._action,
|
||||
arguments=self._arguments)
|
||||
if 'result' in self._kwargs:
|
||||
result = self._kwargs['result']
|
||||
if isinstance(result, misc.Failure):
|
||||
request['result'] = ('failure', pu.failure_to_dict(result))
|
||||
else:
|
||||
request['result'] = ('success', result)
|
||||
if 'failures' in self._kwargs:
|
||||
failures = self._kwargs['failures']
|
||||
request['failures'] = {}
|
||||
for task, failure in failures.items():
|
||||
request['failures'][task] = pu.failure_to_dict(failure)
|
||||
return request
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Check if task is expired.
|
||||
|
||||
When new remote task is created its state is set to the PENDING,
|
||||
creation time is stored and timeout is given via constructor arguments.
|
||||
|
||||
Remote task is considered to be expired when it is in the PENDING state
|
||||
for more then the given timeout (task is not considered to be expired
|
||||
in any other state). After remote task is expired - the `Timeout`
|
||||
exception is raised and task is removed from the remote tasks map.
|
||||
"""
|
||||
if self._state == pr.PENDING:
|
||||
return self._watch.expired()
|
||||
return False
|
||||
|
||||
def set_result(self, result):
|
||||
self.result.set_result((self._task, self._event, result))
|
||||
|
||||
def set_running(self):
|
||||
self._state = pr.RUNNING
|
||||
self._watch.stop()
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
@@ -63,7 +63,7 @@ class Server(object):
|
||||
LOG.debug("AMQP message requeued.")
|
||||
|
||||
@staticmethod
|
||||
def _parse_request(task, task_name, action, arguments, result=None,
|
||||
def _parse_request(task_cls, task_name, action, arguments, result=None,
|
||||
failures=None, **kwargs):
|
||||
"""Parse request before it can be processed. All `misc.Failure` objects
|
||||
that have been converted to dict on the remote side to be serializable
|
||||
@@ -80,7 +80,7 @@ class Server(object):
|
||||
action_args['failures'] = {}
|
||||
for k, v in failures.items():
|
||||
action_args['failures'][k] = pu.failure_from_dict(v)
|
||||
return task, action, action_args
|
||||
return task_cls, action, action_args
|
||||
|
||||
@staticmethod
|
||||
def _parse_message(message):
|
||||
@@ -132,7 +132,7 @@ class Server(object):
|
||||
|
||||
# parse request to get task name, action and action arguments
|
||||
try:
|
||||
task, action, action_args = self._parse_request(**request)
|
||||
task_cls, action, action_args = self._parse_request(**request)
|
||||
action_args.update(task_uuid=task_uuid,
|
||||
progress_callback=progress_callback)
|
||||
except ValueError:
|
||||
@@ -143,10 +143,11 @@ class Server(object):
|
||||
|
||||
# get task endpoint
|
||||
try:
|
||||
endpoint = self._endpoints[task]
|
||||
endpoint = self._endpoints[task_cls]
|
||||
except KeyError:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.exception("The '%s' task endpoint does not exist", task)
|
||||
LOG.exception("The '%s' task endpoint does not exist",
|
||||
task_cls)
|
||||
reply_callback(result=pu.failure_to_dict(failure))
|
||||
return
|
||||
else:
|
||||
|
@@ -23,7 +23,6 @@ from kombu import exceptions as kombu_exc
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
@@ -36,11 +35,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
super(TestWorkerTaskExecutor, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_args = {'context': 'context'}
|
||||
self.task_args = {'a': 'a'}
|
||||
self.task_result = 'task-result'
|
||||
self.task_failures = {}
|
||||
self.timeout = 60
|
||||
self.broker_url = 'test-url'
|
||||
self.broker_url = 'broker-url'
|
||||
self.executor_uuid = 'executor-uuid'
|
||||
self.executor_exchange = 'executor-exchange'
|
||||
self.executor_topic = 'executor-topic'
|
||||
@@ -58,7 +57,7 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
'taskflow.engines.worker_based.executor.async_utils.wait_for_any')
|
||||
self.message_mock = mock.MagicMock(name='message')
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid}
|
||||
self.remote_task_mock = mock.MagicMock(uuid=self.task_uuid)
|
||||
self.request_mock = mock.MagicMock(uuid=self.task_uuid)
|
||||
|
||||
def _fake_proxy_start(self):
|
||||
self.proxy_started_event.set()
|
||||
@@ -80,19 +79,19 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
return ex
|
||||
|
||||
def request(self, **kwargs):
|
||||
request = dict(task=self.task.name, task_name=self.task.name,
|
||||
request_kwargs = dict(task=self.task, uuid=self.task_uuid,
|
||||
action='execute', arguments=self.task_args,
|
||||
progress_callback=None, timeout=self.timeout)
|
||||
request_kwargs.update(kwargs)
|
||||
return pr.Request(**request_kwargs)
|
||||
|
||||
def request_dict(self, **kwargs):
|
||||
request = dict(task_cls=self.task.name, task_name=self.task.name,
|
||||
task_version=self.task.version,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
def remote_task(self, **kwargs):
|
||||
remote_task_kwargs = dict(task=self.task, uuid=self.task_uuid,
|
||||
action='execute', arguments=self.task_args,
|
||||
progress_callback=None, timeout=self.timeout)
|
||||
remote_task_kwargs.update(kwargs)
|
||||
return rt.RemoteTask(**remote_task_kwargs)
|
||||
|
||||
def test_creation(self):
|
||||
ex = self.executor(reset_master_mock=False)
|
||||
|
||||
@@ -105,20 +104,20 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_state_running(self):
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
self.assertEqual(self.request_mock.mock_calls,
|
||||
[mock.call.set_running()])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_progress(self):
|
||||
response = dict(state=pr.PROGRESS, progress=1.0)
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
self.assertEqual(self.request_mock.mock_calls,
|
||||
[mock.call.on_progress(progress=1.0)])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
@@ -127,11 +126,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
response = dict(state=pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
self.assertEqual(self.request_mock.mock_calls, [
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
@@ -140,10 +139,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
response = dict(state=pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls,
|
||||
self.assertEqual(self.request_mock.mock_calls,
|
||||
[mock.call.set_result(result=self.task_result,
|
||||
event='executed')])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
@@ -151,30 +150,30 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
def test_on_message_unknown_state(self):
|
||||
response = dict(state='unknown')
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
self.assertEqual(self.request_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_non_existent_task(self):
|
||||
self.message_mock.properties = {'correlation_id': 'non-existent'}
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
self.assertEqual(self.request_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_no_correlation_id(self):
|
||||
self.message_mock.properties = {}
|
||||
response = dict(state=pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
|
||||
self.assertEqual(self.remote_task_mock.mock_calls, [])
|
||||
self.assertEqual(self.request_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.exception')
|
||||
@@ -183,46 +182,46 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.executor()._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
||||
def test_on_wait_task_not_expired(self, mocked_time):
|
||||
mocked_time.side_effect = [1, self.timeout]
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_on_wait_task_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout]
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
|
||||
ex._requests_cache.set(self.task_uuid, self.request())
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.remote_task.misc.wallclock')
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_on_wait_task_expired(self, mocked_time):
|
||||
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, self.remote_task())
|
||||
ex._requests_cache.set(self.task_uuid, self.request())
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
|
||||
def test_remove_task_non_existent(self):
|
||||
task = self.remote_task()
|
||||
task = self.request()
|
||||
ex = self.executor()
|
||||
ex._remote_tasks_cache.set(self.task_uuid, task)
|
||||
ex._requests_cache.set(self.task_uuid, task)
|
||||
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 1)
|
||||
ex._remote_tasks_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._requests_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
|
||||
# remove non-existent
|
||||
ex._remote_tasks_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._remote_tasks_cache._data), 0)
|
||||
ex._requests_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
|
||||
def test_execute_task(self):
|
||||
request = self.request(action='execute')
|
||||
request_dict = self.request_dict(action='execute')
|
||||
ex = self.executor()
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request,
|
||||
mock.call.proxy.publish(request_dict,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
@@ -231,15 +230,15 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertIsInstance(result, futures.Future)
|
||||
|
||||
def test_revert_task(self):
|
||||
request = self.request(action='revert',
|
||||
result=('success', self.task_result),
|
||||
failures=self.task_failures)
|
||||
request_dict = self.request_dict(action='revert',
|
||||
result=('success', self.task_result),
|
||||
failures=self.task_failures)
|
||||
ex = self.executor()
|
||||
result = ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||
self.task_result, self.task_failures)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request,
|
||||
mock.call.proxy.publish(request_dict,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
@@ -262,12 +261,12 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
|
||||
def test_execute_task_publish_error(self):
|
||||
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
||||
request = self.request(action='execute')
|
||||
request_dict = self.request_dict(action='execute')
|
||||
ex = self.executor()
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request,
|
||||
mock.call.proxy.publish(request_dict,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
|
126
taskflow/tests/unit/worker_based/test_protocol.py
Normal file
126
taskflow/tests/unit/worker_based/test_protocol.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestProtocol(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestProtocol, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_action = 'execute'
|
||||
self.task_args = {'a': 'a'}
|
||||
self.timeout = 60
|
||||
|
||||
def request(self, **kwargs):
|
||||
request_kwargs = dict(task=self.task,
|
||||
uuid=self.task_uuid,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
timeout=self.timeout)
|
||||
request_kwargs.update(kwargs)
|
||||
return pr.Request(**request_kwargs)
|
||||
|
||||
def request_to_dict(self, **kwargs):
|
||||
to_dict = dict(task_cls=self.task.name,
|
||||
task_name=self.task.name,
|
||||
task_version=self.task.version,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args)
|
||||
to_dict.update(kwargs)
|
||||
return to_dict
|
||||
|
||||
def test_creation(self):
|
||||
request = self.request()
|
||||
self.assertEqual(request.uuid, self.task_uuid)
|
||||
self.assertEqual(request.task_cls, self.task.name)
|
||||
self.assertIsInstance(request.result, futures.Future)
|
||||
self.assertFalse(request.result.done())
|
||||
|
||||
def test_repr(self):
|
||||
expected_name = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.request()), expected_name)
|
||||
|
||||
def test_to_dict_default(self):
|
||||
self.assertEqual(self.request().to_dict(), self.request_to_dict())
|
||||
|
||||
def test_to_dict_with_result(self):
|
||||
self.assertEqual(self.request(result=333).to_dict(),
|
||||
self.request_to_dict(result=('success', 333)))
|
||||
|
||||
def test_to_dict_with_result_none(self):
|
||||
self.assertEqual(self.request(result=None).to_dict(),
|
||||
self.request_to_dict(result=('success', None)))
|
||||
|
||||
def test_to_dict_with_result_failure(self):
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
expected = self.request_to_dict(
|
||||
result=('failure', pu.failure_to_dict(failure)))
|
||||
self.assertEqual(self.request(result=failure).to_dict(), expected)
|
||||
|
||||
def test_to_dict_with_failures(self):
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
request = self.request(failures={self.task.name: failure})
|
||||
expected = self.request_to_dict(
|
||||
failures={self.task.name: pu.failure_to_dict(failure)})
|
||||
self.assertEqual(request.to_dict(), expected)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_pending_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout]
|
||||
self.assertFalse(self.request().expired)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_pending_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout + 2]
|
||||
self.assertTrue(self.request().expired)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_running_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout + 2]
|
||||
request = self.request()
|
||||
request.set_running()
|
||||
self.assertFalse(request.expired)
|
||||
|
||||
def test_set_result(self):
|
||||
request = self.request()
|
||||
request.set_result(111)
|
||||
result = request.result.result()
|
||||
self.assertEqual(result, (self.task, 'executed', 111))
|
||||
|
||||
def test_on_progress(self):
|
||||
progress_callback = mock.MagicMock(name='progress_callback')
|
||||
request = self.request(task=self.task,
|
||||
progress_callback=progress_callback)
|
||||
request.on_progress('event_data', 0.0)
|
||||
request.on_progress('event_data', 1.0)
|
||||
|
||||
expected_calls = [
|
||||
mock.call(self.task, 'event_data', 0.0),
|
||||
mock.call(self.task, 'event_data', 1.0)
|
||||
]
|
||||
self.assertEqual(progress_callback.mock_calls, expected_calls)
|
@@ -1,133 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import mock
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.worker_based import remote_task as rt
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestRemoteTask(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRemoteTask, self).setUp()
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_action = 'execute'
|
||||
self.task_args = {'context': 'context'}
|
||||
self.timeout = 60
|
||||
|
||||
def remote_task(self, **kwargs):
|
||||
task_kwargs = dict(task=self.task,
|
||||
uuid=self.task_uuid,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
timeout=self.timeout)
|
||||
task_kwargs.update(kwargs)
|
||||
return rt.RemoteTask(**task_kwargs)
|
||||
|
||||
def remote_task_request(self, **kwargs):
|
||||
request = dict(task=self.task.name,
|
||||
task_name=self.task.name,
|
||||
task_version=self.task.version,
|
||||
action=self.task_action,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
def test_creation(self):
|
||||
remote_task = self.remote_task()
|
||||
self.assertEqual(remote_task.uuid, self.task_uuid)
|
||||
self.assertEqual(remote_task.name, self.task.name)
|
||||
self.assertIsInstance(remote_task.result, futures.Future)
|
||||
self.assertFalse(remote_task.result.done())
|
||||
|
||||
def test_repr(self):
|
||||
expected_name = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.remote_task()), expected_name)
|
||||
|
||||
def test_request(self):
|
||||
remote_task = self.remote_task()
|
||||
request = self.remote_task_request()
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_result(self):
|
||||
remote_task = self.remote_task(result=333)
|
||||
request = self.remote_task_request(result=('success', 333))
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_result_none(self):
|
||||
remote_task = self.remote_task(result=None)
|
||||
request = self.remote_task_request(result=('success', None))
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_result_failure(self):
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
remote_task = self.remote_task(result=failure)
|
||||
request = self.remote_task_request(
|
||||
result=('failure', pu.failure_to_dict(failure)))
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
def test_request_with_failures(self):
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
remote_task = self.remote_task(failures={self.task.name: failure})
|
||||
request = self.remote_task_request(
|
||||
failures={self.task.name: pu.failure_to_dict(failure)})
|
||||
self.assertEqual(remote_task.request, request)
|
||||
|
||||
@mock.patch('time.time')
|
||||
def test_pending_not_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout]
|
||||
remote_task = self.remote_task()
|
||||
self.assertFalse(remote_task.expired)
|
||||
|
||||
@mock.patch('time.time')
|
||||
def test_pending_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout + 2]
|
||||
remote_task = self.remote_task()
|
||||
self.assertTrue(remote_task.expired)
|
||||
|
||||
@mock.patch('time.time')
|
||||
def test_running_not_expired(self, mock_time):
|
||||
mock_time.side_effect = [1, self.timeout]
|
||||
remote_task = self.remote_task()
|
||||
remote_task.set_running()
|
||||
self.assertFalse(remote_task.expired)
|
||||
|
||||
def test_set_result(self):
|
||||
remote_task = self.remote_task()
|
||||
remote_task.set_result(111)
|
||||
result = remote_task.result.result()
|
||||
self.assertEqual(result, (self.task, 'executed', 111))
|
||||
|
||||
def test_on_progress(self):
|
||||
progress_callback = mock.MagicMock(name='progress_callback')
|
||||
remote_task = self.remote_task(task=self.task,
|
||||
progress_callback=progress_callback)
|
||||
remote_task.on_progress('event_data', 0.0)
|
||||
remote_task.on_progress('event_data', 1.0)
|
||||
|
||||
expected_calls = [
|
||||
mock.call(self.task, 'event_data', 0.0),
|
||||
mock.call(self.task, 'event_data', 1.0)
|
||||
]
|
||||
self.assertEqual(progress_callback.mock_calls, expected_calls)
|
@@ -71,7 +71,7 @@ class TestServer(test.MockTestCase):
|
||||
return s
|
||||
|
||||
def request(self, **kwargs):
|
||||
request = dict(task=self.task_name,
|
||||
request = dict(task_cls=self.task_name,
|
||||
task_name=self.task_name,
|
||||
action=self.task_action,
|
||||
task_version=self.task_version,
|
||||
@@ -164,18 +164,18 @@ class TestServer(test.MockTestCase):
|
||||
|
||||
def test_parse_request(self):
|
||||
request = self.request()
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task, action, task_args),
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task_name, self.task_action,
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args)))
|
||||
|
||||
def test_parse_request_with_success_result(self):
|
||||
request = self.request(action='revert', result=('success', 1))
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task, action, task_args),
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args,
|
||||
@@ -186,9 +186,9 @@ class TestServer(test.MockTestCase):
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
request = self.request(action='revert',
|
||||
result=('failure', failure_dict))
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task, action, task_args),
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args,
|
||||
@@ -200,10 +200,10 @@ class TestServer(test.MockTestCase):
|
||||
failures_dict = dict((str(i), pu.failure_to_dict(f))
|
||||
for i, f in enumerate(failures))
|
||||
request = self.request(action='revert', failures=failures_dict)
|
||||
task, action, task_args = server.Server._parse_request(**request)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual(
|
||||
(task, action, task_args),
|
||||
(task_cls, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
arguments=self.task_args,
|
||||
@@ -225,7 +225,7 @@ class TestServer(test.MockTestCase):
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
def test_on_update_progress(self):
|
||||
request = self.request(task='taskflow.tests.utils.ProgressingTask',
|
||||
request = self.request(task_cls='taskflow.tests.utils.ProgressingTask',
|
||||
arguments={})
|
||||
|
||||
# create server and process request
|
||||
@@ -292,7 +292,7 @@ class TestServer(test.MockTestCase):
|
||||
def test_process_request_endpoint_not_found(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(task='<unknown>')
|
||||
request = self.request(task_cls='<unknown>')
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
|
Reference in New Issue
Block a user