Merge "Introduce message types for WBE protocol"
This commit is contained in:
commit
4aa9a4c0df
@ -28,7 +28,6 @@ from taskflow.engines.worker_based import proxy
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.utils import async_utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -58,43 +57,48 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
proxy_thread.daemon = True
|
||||
return proxy_thread
|
||||
|
||||
def _on_message(self, response, message):
|
||||
"""This method is called on incoming response."""
|
||||
LOG.debug("Got response: %s", response)
|
||||
def _on_message(self, data, message):
|
||||
"""This method is called on incoming message."""
|
||||
LOG.debug("Got message: %s", data)
|
||||
try:
|
||||
# acknowledge message before processing.
|
||||
# acknowledge message before processing
|
||||
message.ack()
|
||||
except kombu_exc.MessageStateError:
|
||||
LOG.exception("Failed to acknowledge AMQP message.")
|
||||
else:
|
||||
LOG.debug("AMQP message acknowledged.")
|
||||
# get task uuid from message correlation id parameter
|
||||
try:
|
||||
msg_type = message.properties['type']
|
||||
except KeyError:
|
||||
LOG.warning("The 'type' message property is missing.")
|
||||
else:
|
||||
if msg_type == pr.RESPONSE:
|
||||
self._process_response(data, message)
|
||||
else:
|
||||
LOG.warning("Unexpected message type: %s", msg_type)
|
||||
|
||||
def _process_response(self, response, message):
|
||||
"""Process response from remote side."""
|
||||
LOG.debug("Start processing response message.")
|
||||
try:
|
||||
task_uuid = message.properties['correlation_id']
|
||||
except KeyError:
|
||||
LOG.warning("Got message with no 'correlation_id' property.")
|
||||
LOG.warning("The 'correlation_id' message property is missing.")
|
||||
else:
|
||||
LOG.debug("Task uuid: '%s'", task_uuid)
|
||||
self._process_response(task_uuid, response)
|
||||
|
||||
def _process_response(self, task_uuid, response):
|
||||
"""Process response from remote side."""
|
||||
request = self._requests_cache.get(task_uuid)
|
||||
if request is not None:
|
||||
state = response.pop('state')
|
||||
if state == pr.RUNNING:
|
||||
response = pr.Response.from_dict(response)
|
||||
if response.state == pr.RUNNING:
|
||||
request.set_running()
|
||||
elif state == pr.PROGRESS:
|
||||
request.on_progress(**response)
|
||||
elif state == pr.FAILURE:
|
||||
response['result'] = pu.failure_from_dict(response['result'])
|
||||
request.set_result(**response)
|
||||
self._requests_cache.delete(request.uuid)
|
||||
elif state == pr.SUCCESS:
|
||||
request.set_result(**response)
|
||||
elif response.state == pr.PROGRESS:
|
||||
request.on_progress(**response.data)
|
||||
elif response.state in (pr.FAILURE, pr.SUCCESS):
|
||||
request.set_result(**response.data)
|
||||
self._requests_cache.delete(request.uuid)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'", state)
|
||||
LOG.warning("Unexpected response status: '%s'",
|
||||
response.state)
|
||||
else:
|
||||
LOG.debug("Request with id='%s' not found.", task_uuid)
|
||||
|
||||
@ -129,7 +133,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
else:
|
||||
# publish request
|
||||
LOG.debug("Sending request: %s", request)
|
||||
self._proxy.publish(request.to_dict(),
|
||||
self._proxy.publish(request,
|
||||
routing_key=topic,
|
||||
reply_to=self._uuid,
|
||||
correlation_id=request.uuid)
|
||||
|
@ -14,6 +14,10 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from concurrent import futures
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
@ -49,11 +53,28 @@ REQUEST_TIMEOUT = 60
|
||||
# no longer needed.
|
||||
QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT
|
||||
|
||||
# Message types.
|
||||
REQUEST = 'REQUEST'
|
||||
RESPONSE = 'RESPONSE'
|
||||
|
||||
class Request(object):
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Message(object):
|
||||
"""Base class for all message types."""
|
||||
|
||||
def __str__(self):
|
||||
return str(self.to_dict())
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_dict(self):
|
||||
"""Return json-serializable message representation."""
|
||||
|
||||
|
||||
class Request(Message):
|
||||
"""Represents request with execution results. Every request is created in
|
||||
the PENDING state and is expired within the given timeout.
|
||||
"""
|
||||
TYPE = REQUEST
|
||||
|
||||
def __init__(self, task, uuid, action, arguments, progress_callback,
|
||||
timeout, **kwargs):
|
||||
@ -111,7 +132,7 @@ class Request(object):
|
||||
if 'failures' in self._kwargs:
|
||||
failures = self._kwargs['failures']
|
||||
request['failures'] = {}
|
||||
for task, failure in failures.items():
|
||||
for task, failure in six.iteritems(failures):
|
||||
request['failures'][task] = pu.failure_to_dict(failure)
|
||||
return request
|
||||
|
||||
@ -124,3 +145,31 @@ class Request(object):
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
||||
|
||||
|
||||
class Response(Message):
|
||||
"""Represents response message type."""
|
||||
TYPE = RESPONSE
|
||||
|
||||
def __init__(self, state, **data):
|
||||
self._state = state
|
||||
self._data = data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
state = data['state']
|
||||
data = data['data']
|
||||
if state == FAILURE and 'result' in data:
|
||||
data['result'] = pu.failure_from_dict(data['result'])
|
||||
return cls(state, **data)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
def to_dict(self):
|
||||
return dict(state=self._state, data=self._data)
|
||||
|
@ -69,10 +69,11 @@ class Proxy(object):
|
||||
"""Publish message to the named exchange with routing key."""
|
||||
with kombu.producers[self._conn].acquire(block=True) as producer:
|
||||
queue = self._make_queue(routing_key, self._exchange)
|
||||
producer.publish(body=msg,
|
||||
producer.publish(body=msg.to_dict(),
|
||||
routing_key=routing_key,
|
||||
exchange=self._exchange,
|
||||
declare=[queue],
|
||||
type=msg.TYPE,
|
||||
**kwargs)
|
||||
|
||||
def start(self):
|
||||
|
@ -36,23 +36,32 @@ class Server(object):
|
||||
self._endpoints = dict([(endpoint.name, endpoint)
|
||||
for endpoint in endpoints])
|
||||
|
||||
def _on_message(self, request, message):
|
||||
"""This method is called on incoming request."""
|
||||
LOG.debug("Got request: %s", request)
|
||||
# NOTE(skudriashev): Process all incoming requests only if proxy is
|
||||
def _on_message(self, data, message):
|
||||
"""This method is called on incoming message."""
|
||||
LOG.debug("Got message: %s", data)
|
||||
# NOTE(skudriashev): Process all incoming messages only if proxy is
|
||||
# running, otherwise requeue them.
|
||||
if self._proxy.is_running:
|
||||
# NOTE(skudriashev): Process request only if message has been
|
||||
# acknowledged successfully.
|
||||
try:
|
||||
# acknowledge message
|
||||
# acknowledge message before processing
|
||||
message.ack()
|
||||
except kombu_exc.MessageStateError:
|
||||
LOG.exception("Failed to acknowledge AMQP message.")
|
||||
else:
|
||||
LOG.debug("AMQP message acknowledged.")
|
||||
try:
|
||||
msg_type = message.properties['type']
|
||||
except KeyError:
|
||||
LOG.warning("The 'type' message property is missing.")
|
||||
else:
|
||||
if msg_type == pr.REQUEST:
|
||||
# spawn new thread to process request
|
||||
self._executor.submit(self._process_request, request, message)
|
||||
self._executor.submit(self._process_request, data,
|
||||
message)
|
||||
else:
|
||||
LOG.warning("Unexpected message type: %s", msg_type)
|
||||
else:
|
||||
try:
|
||||
# requeue message
|
||||
@ -100,7 +109,7 @@ class Server(object):
|
||||
|
||||
def _reply(self, reply_to, task_uuid, state=pr.FAILURE, **kwargs):
|
||||
"""Send reply to the `reply_to` queue."""
|
||||
response = dict(state=state, **kwargs)
|
||||
response = pr.Response(state, **kwargs)
|
||||
LOG.debug("Sending reply: %s", response)
|
||||
try:
|
||||
self._proxy.publish(response, reply_to, correlation_id=task_uuid)
|
||||
@ -115,7 +124,7 @@ class Server(object):
|
||||
|
||||
def _process_request(self, request, message):
|
||||
"""Process request in separate thread and reply back."""
|
||||
# NOTE(skudriashev): parse broker message first to get the `reply_to`
|
||||
# NOTE(skudriashev): Parse broker message first to get the `reply_to`
|
||||
# and the `task_uuid` parameters to have possibility to reply back.
|
||||
try:
|
||||
reply_to, task_uuid = self._parse_message(message)
|
||||
|
@ -49,15 +49,20 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
# patch classes
|
||||
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
|
||||
executor.proxy, 'Proxy')
|
||||
self.request_mock, self.request_inst_mock = self._patch_class(
|
||||
executor.pr, 'Request', autospec=False)
|
||||
|
||||
# other mocking
|
||||
self.proxy_inst_mock.start.side_effect = self._fake_proxy_start
|
||||
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
|
||||
self.request_inst_mock.uuid = self.task_uuid
|
||||
self.request_inst_mock.expired = False
|
||||
self.request_inst_mock.task_cls = self.task.name
|
||||
self.wait_for_any_mock = self._patch(
|
||||
'taskflow.engines.worker_based.executor.async_utils.wait_for_any')
|
||||
self.message_mock = mock.MagicMock(name='message')
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid}
|
||||
self.request_mock = mock.MagicMock(uuid=self.task_uuid)
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid,
|
||||
'type': pr.RESPONSE}
|
||||
|
||||
def _fake_proxy_start(self):
|
||||
self.proxy_started_event.set()
|
||||
@ -78,20 +83,6 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self._reset_master_mock()
|
||||
return ex
|
||||
|
||||
def request(self, **kwargs):
|
||||
request_kwargs = dict(task=self.task, uuid=self.task_uuid,
|
||||
action='execute', arguments=self.task_args,
|
||||
progress_callback=None, timeout=self.timeout)
|
||||
request_kwargs.update(kwargs)
|
||||
return pr.Request(**request_kwargs)
|
||||
|
||||
def request_dict(self, **kwargs):
|
||||
request = dict(task_cls=self.task.name, task_name=self.task.name,
|
||||
task_version=self.task.version,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
def test_creation(self):
|
||||
ex = self.executor(reset_master_mock=False)
|
||||
|
||||
@ -101,184 +92,190 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_on_message_state_running(self):
|
||||
response = dict(state=pr.RUNNING)
|
||||
def test_on_message_response_state_running(self):
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_mock.mock_calls,
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_running()])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_progress(self):
|
||||
response = dict(state=pr.PROGRESS, progress=1.0)
|
||||
def test_on_message_response_state_progress(self):
|
||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_mock.mock_calls,
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.on_progress(progress=1.0)])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_failure(self):
|
||||
def test_on_message_response_state_failure(self):
|
||||
failure = misc.Failure.from_exception(Exception('test'))
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
response = dict(state=pr.FAILURE, result=failure_dict)
|
||||
response = pr.Response(pr.FAILURE, result=failure_dict)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
self.assertEqual(self.request_mock.mock_calls, [
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_state_success(self):
|
||||
response = dict(state=pr.SUCCESS, result=self.task_result,
|
||||
def test_on_message_response_state_success(self):
|
||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||
event='executed')
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_mock.mock_calls,
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_result(result=self.task_result,
|
||||
event='executed')])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_unknown_state(self):
|
||||
response = dict(state='unknown')
|
||||
def test_on_message_response_unknown_state(self):
|
||||
response = pr.Response(state='<unknown>')
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_mock.mock_calls, [])
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_non_existent_task(self):
|
||||
self.message_mock.properties = {'correlation_id': 'non-existent'}
|
||||
response = dict(state=pr.RUNNING)
|
||||
def test_on_message_response_unknown_task(self):
|
||||
self.message_mock.properties['correlation_id'] = '<unknown>'
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_mock.mock_calls, [])
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
def test_on_message_no_correlation_id(self):
|
||||
self.message_mock.properties = {}
|
||||
response = dict(state=pr.RUNNING)
|
||||
def test_on_message_response_no_correlation_id(self):
|
||||
self.message_mock.properties = {'type': pr.RESPONSE}
|
||||
response = pr.Response(pr.RUNNING)
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request_mock)
|
||||
ex._on_message(response, self.message_mock)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
ex._on_message(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_mock.mock_calls, [])
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [])
|
||||
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
|
||||
def test_on_message_unknown_type(self, mocked_warning):
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid,
|
||||
'type': '<unknown>'}
|
||||
ex = self.executor()
|
||||
ex._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
|
||||
def test_on_message_no_type(self, mocked_warning):
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid}
|
||||
ex = self.executor()
|
||||
ex._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.executor.LOG.exception')
|
||||
def test_on_message_acknowledge_raises(self, mocked_exception):
|
||||
self.message_mock.ack.side_effect = kombu_exc.MessageStateError()
|
||||
self.executor()._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_on_wait_task_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout]
|
||||
def test_on_wait_task_not_expired(self):
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request())
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_on_wait_task_expired(self, mocked_time):
|
||||
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
|
||||
def test_on_wait_task_expired(self):
|
||||
self.request_inst_mock.expired = True
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, self.request())
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._on_wait()
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
|
||||
def test_remove_task_non_existent(self):
|
||||
task = self.request()
|
||||
ex = self.executor()
|
||||
ex._requests_cache.set(self.task_uuid, task)
|
||||
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache._data), 1)
|
||||
ex._requests_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
|
||||
# remove non-existent
|
||||
# delete non-existent
|
||||
ex._requests_cache.delete(self.task_uuid)
|
||||
self.assertEqual(len(ex._requests_cache._data), 0)
|
||||
|
||||
def test_execute_task(self):
|
||||
request_dict = self.request_dict(action='execute')
|
||||
ex = self.executor()
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request_dict,
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.proxy.publish(self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertIsInstance(result, futures.Future)
|
||||
|
||||
def test_revert_task(self):
|
||||
request_dict = self.request_dict(action='revert',
|
||||
result=('success', self.task_result),
|
||||
failures=self.task_failures)
|
||||
ex = self.executor()
|
||||
result = ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||
ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||
self.task_result, self.task_failures)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request_dict,
|
||||
mock.call.Request(self.task, self.task_uuid, 'revert',
|
||||
self.task_args, None, self.timeout,
|
||||
failures=self.task_failures,
|
||||
result=self.task_result),
|
||||
mock.call.proxy.publish(self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertIsInstance(result, futures.Future)
|
||||
|
||||
def test_execute_task_topic_not_found(self):
|
||||
workers_info = {self.executor_topic: ['non-existent-task']}
|
||||
workers_info = {self.executor_topic: ['<unknown>']}
|
||||
ex = self.executor(workers_info=workers_info)
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
self.assertFalse(self.proxy_inst_mock.publish.called)
|
||||
|
||||
# check execute result
|
||||
task, event, res = result.result()
|
||||
self.assertEqual(task, self.task)
|
||||
self.assertEqual(event, 'executed')
|
||||
self.assertIsInstance(res, misc.Failure)
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.request.set_result(mock.ANY)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
|
||||
def test_execute_task_publish_error(self):
|
||||
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
||||
request_dict = self.request_dict(action='execute')
|
||||
ex = self.executor()
|
||||
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
mock.call.proxy.publish(request_dict,
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.proxy.publish(self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.request.set_result(mock.ANY)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
|
||||
# check execute result
|
||||
task, event, res = result.result()
|
||||
self.assertEqual(task, self.task)
|
||||
self.assertEqual(event, 'executed')
|
||||
self.assertIsInstance(res, misc.Failure)
|
||||
|
||||
def test_wait_for_any(self):
|
||||
fs = [futures.Future(), futures.Future()]
|
||||
ex = self.executor()
|
||||
|
@ -61,9 +61,13 @@ class TestProtocol(test.TestCase):
|
||||
self.assertIsInstance(request.result, futures.Future)
|
||||
self.assertFalse(request.result.done())
|
||||
|
||||
def test_str(self):
|
||||
request = self.request()
|
||||
self.assertEqual(str(request), str(request.to_dict()))
|
||||
|
||||
def test_repr(self):
|
||||
expected_name = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.request()), expected_name)
|
||||
expected = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.request()), expected)
|
||||
|
||||
def test_to_dict_default(self):
|
||||
self.assertEqual(self.request().to_dict(), self.request_to_dict())
|
||||
|
@ -128,13 +128,15 @@ class TestProxy(test.MockTestCase):
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_publish(self):
|
||||
task_data = 'task-data'
|
||||
task_uuid = 'task-uuid'
|
||||
msg_mock = mock.MagicMock()
|
||||
msg_data = 'msg-data'
|
||||
msg_mock.to_dict.return_value = msg_data
|
||||
routing_key = 'routing-key'
|
||||
task_uuid = 'task-uuid'
|
||||
kwargs = dict(a='a', b='b')
|
||||
|
||||
self.proxy(reset_master_mock=True).publish(
|
||||
task_data, routing_key, correlation_id=task_uuid, **kwargs)
|
||||
msg_mock, routing_key, correlation_id=task_uuid, **kwargs)
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.Queue(name=self._queue_name(routing_key),
|
||||
@ -142,11 +144,12 @@ class TestProxy(test.MockTestCase):
|
||||
routing_key=routing_key,
|
||||
durable=False,
|
||||
auto_delete=True),
|
||||
mock.call.producer.publish(body=task_data,
|
||||
mock.call.producer.publish(body=msg_data,
|
||||
routing_key=routing_key,
|
||||
exchange=self.exchange_inst_mock,
|
||||
correlation_id=task_uuid,
|
||||
declare=[self.queue_inst_mock],
|
||||
type=msg_mock.TYPE,
|
||||
**kwargs)
|
||||
]
|
||||
self.master_mock.assert_has_calls(master_mock_calls)
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
import mock
|
||||
|
||||
import six
|
||||
|
||||
from kombu import exceptions as exc
|
||||
|
||||
from taskflow.engines.worker_based import endpoint as ep
|
||||
@ -24,7 +26,6 @@ from taskflow.engines.worker_based import server
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import persistence_utils as pu
|
||||
|
||||
|
||||
class TestServer(test.MockTestCase):
|
||||
@ -34,27 +35,28 @@ class TestServer(test.MockTestCase):
|
||||
self.server_topic = 'server-topic'
|
||||
self.server_exchange = 'server-exchange'
|
||||
self.broker_url = 'test-url'
|
||||
self.task = utils.TaskOneArgOneReturn()
|
||||
self.task_uuid = 'task-uuid'
|
||||
self.task_args = {'x': 1}
|
||||
self.task_action = 'execute'
|
||||
self.task_name = 'taskflow.tests.utils.TaskOneArgOneReturn'
|
||||
self.task_version = (1, 0)
|
||||
self.reply_to = 'reply-to'
|
||||
self.endpoints = [ep.Endpoint(task_cls=utils.TaskOneArgOneReturn),
|
||||
ep.Endpoint(task_cls=utils.TaskWithFailure),
|
||||
ep.Endpoint(task_cls=utils.ProgressingTask)]
|
||||
self.resp_running = dict(state=pr.RUNNING)
|
||||
|
||||
# patch classes
|
||||
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
|
||||
server.proxy, 'Proxy')
|
||||
self.response_mock, self.response_inst_mock = self._patch_class(
|
||||
server.pr, 'Response')
|
||||
|
||||
# other mocking
|
||||
self.proxy_inst_mock.is_running = True
|
||||
self.executor_mock = mock.MagicMock(name='executor')
|
||||
self.message_mock = mock.MagicMock(name='message')
|
||||
self.message_mock.properties = {'correlation_id': self.task_uuid,
|
||||
'reply_to': self.reply_to}
|
||||
'reply_to': self.reply_to,
|
||||
'type': pr.REQUEST}
|
||||
self.master_mock.attach_mock(self.executor_mock, 'executor')
|
||||
self.master_mock.attach_mock(self.message_mock, 'message')
|
||||
|
||||
@ -70,28 +72,15 @@ class TestServer(test.MockTestCase):
|
||||
self._reset_master_mock()
|
||||
return s
|
||||
|
||||
def request(self, **kwargs):
|
||||
request = dict(task_cls=self.task_name,
|
||||
task_name=self.task_name,
|
||||
def make_request(self, **kwargs):
|
||||
request_kwargs = dict(task=self.task,
|
||||
uuid=self.task_uuid,
|
||||
action=self.task_action,
|
||||
task_version=self.task_version,
|
||||
arguments=self.task_args)
|
||||
request.update(kwargs)
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def resp_progress(progress):
|
||||
return dict(state=pr.PROGRESS, progress=progress, event_data={})
|
||||
|
||||
@staticmethod
|
||||
def resp_success(result):
|
||||
return dict(state=pr.SUCCESS, result=result)
|
||||
|
||||
@staticmethod
|
||||
def resp_failure(result, **kwargs):
|
||||
response = dict(state=pr.FAILURE, result=result)
|
||||
response.update(kwargs)
|
||||
return response
|
||||
arguments=self.task_args,
|
||||
progress_callback=None,
|
||||
timeout=60)
|
||||
request_kwargs.update(kwargs)
|
||||
return pr.Request(**request_kwargs).to_dict()
|
||||
|
||||
def test_creation(self):
|
||||
s = self.server()
|
||||
@ -116,7 +105,7 @@ class TestServer(test.MockTestCase):
|
||||
self.assertEqual(len(s._endpoints), len(self.endpoints))
|
||||
|
||||
def test_on_message_proxy_running_ack_success(self):
|
||||
request = self.request()
|
||||
request = self.make_request()
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._on_message(request, self.message_mock)
|
||||
|
||||
@ -162,99 +151,115 @@ class TestServer(test.MockTestCase):
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.LOG.warning')
|
||||
def test_on_message_unknown_type(self, mocked_warning):
|
||||
self.message_mock.properties['type'] = '<unknown>'
|
||||
s = self.server()
|
||||
s._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.LOG.warning')
|
||||
def test_on_message_no_type(self, mocked_warning):
|
||||
self.message_mock.properties = {}
|
||||
s = self.server()
|
||||
s._on_message({}, self.message_mock)
|
||||
self.assertTrue(mocked_warning.called)
|
||||
|
||||
def test_parse_request(self):
|
||||
request = self.request()
|
||||
request = self.make_request()
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task_name, self.task_action,
|
||||
dict(task_name=self.task_name,
|
||||
(self.task.name, self.task_action,
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args)))
|
||||
|
||||
def test_parse_request_with_success_result(self):
|
||||
request = self.request(action='revert', result=('success', 1))
|
||||
request = self.make_request(action='revert', result=1)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
(self.task.name, 'revert',
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args,
|
||||
result=1)))
|
||||
|
||||
def test_parse_request_with_failure_result(self):
|
||||
failure = misc.Failure.from_exception(Exception('test'))
|
||||
failure_dict = pu.failure_to_dict(failure)
|
||||
request = self.request(action='revert',
|
||||
result=('failure', failure_dict))
|
||||
request = self.make_request(action='revert', result=failure)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual((task_cls, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
(self.task.name, 'revert',
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args,
|
||||
result=utils.FailureMatcher(failure))))
|
||||
|
||||
def test_parse_request_with_failures(self):
|
||||
failures = [misc.Failure.from_exception(Exception('test1')),
|
||||
misc.Failure.from_exception(Exception('test2'))]
|
||||
failures_dict = dict((str(i), pu.failure_to_dict(f))
|
||||
for i, f in enumerate(failures))
|
||||
request = self.request(action='revert', failures=failures_dict)
|
||||
failures = {'0': misc.Failure.from_exception(Exception('test1')),
|
||||
'1': misc.Failure.from_exception(Exception('test2'))}
|
||||
request = self.make_request(action='revert', failures=failures)
|
||||
task_cls, action, task_args = server.Server._parse_request(**request)
|
||||
|
||||
self.assertEqual(
|
||||
(task_cls, action, task_args),
|
||||
(self.task_name, 'revert',
|
||||
dict(task_name=self.task_name,
|
||||
(self.task.name, 'revert',
|
||||
dict(task_name=self.task.name,
|
||||
arguments=self.task_args,
|
||||
failures=dict((str(i), utils.FailureMatcher(f))
|
||||
for i, f in enumerate(failures)))))
|
||||
failures=dict((i, utils.FailureMatcher(f))
|
||||
for i, f in six.iteritems(failures)))))
|
||||
|
||||
@mock.patch("taskflow.engines.worker_based.server.LOG.exception")
|
||||
def test_reply_publish_failure(self, mocked_exception):
|
||||
self.proxy_inst_mock.publish.side_effect = RuntimeError('Woot!')
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._reply(self.reply_to, self.task_uuid)
|
||||
|
||||
self.assertEqual(self.master_mock.mock_calls, [
|
||||
mock.call.proxy.publish({'state': 'FAILURE'}, self.reply_to,
|
||||
mock.call.Response(pr.FAILURE),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
])
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
def test_on_update_progress(self):
|
||||
request = self.request(task_cls='taskflow.tests.utils.ProgressingTask',
|
||||
arguments={})
|
||||
request = self.make_request(task=utils.ProgressingTask(), arguments={})
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.reply_to,
|
||||
mock.call.Response(pr.RUNNING),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.proxy.publish(self.resp_progress(0.0), self.reply_to,
|
||||
mock.call.Response(pr.PROGRESS, progress=0.0, event_data={}),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.proxy.publish(self.resp_progress(1.0), self.reply_to,
|
||||
mock.call.Response(pr.PROGRESS, progress=1.0, event_data={}),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.proxy.publish(self.resp_success(5), self.reply_to,
|
||||
mock.call.Response(pr.SUCCESS, result=5),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_process_request(self):
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s._process_request(self.request(), self.message_mock)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(self.make_request(), self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.reply_to,
|
||||
mock.call.Response(pr.RUNNING),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.proxy.publish(self.resp_success(1), self.reply_to,
|
||||
mock.call.Response(pr.SUCCESS, result=1),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
@ -262,7 +267,7 @@ class TestServer(test.MockTestCase):
|
||||
@mock.patch("taskflow.engines.worker_based.server.LOG.exception")
|
||||
def test_process_request_parse_message_failure(self, mocked_exception):
|
||||
self.message_mock.properties = {}
|
||||
request = self.request()
|
||||
request = self.make_request()
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
@ -270,19 +275,21 @@ class TestServer(test.MockTestCase):
|
||||
self.assertTrue(mocked_exception.called)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.server.pu')
|
||||
def test_process_request_parse_failure(self, pu_mock):
|
||||
def test_process_request_parse_request_failure(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
pu_mock.failure_from_dict.side_effect = ValueError('Woot!')
|
||||
request = self.request(result=('failure', 1))
|
||||
request = self.make_request(result=failure)
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
mock.call.Response(pr.FAILURE, result=failure_dict),
|
||||
mock.call.proxy.publish(self.response_inst_mock,
|
||||
self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
@ -292,15 +299,16 @@ class TestServer(test.MockTestCase):
|
||||
def test_process_request_endpoint_not_found(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(task_cls='<unknown>')
|
||||
request = self.make_request(task=mock.MagicMock(name='<unknown>'))
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
mock.call.Response(pr.FAILURE, result=failure_dict),
|
||||
mock.call.proxy.publish(self.response_inst_mock,
|
||||
self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
@ -310,17 +318,20 @@ class TestServer(test.MockTestCase):
|
||||
def test_process_request_execution_failure(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(action='<unknown>')
|
||||
request = self.make_request()
|
||||
request['action'] = '<unknown>'
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.reply_to,
|
||||
mock.call.Response(pr.RUNNING),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
mock.call.Response(pr.FAILURE, result=failure_dict),
|
||||
mock.call.proxy.publish(self.response_inst_mock,
|
||||
self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
@ -330,18 +341,19 @@ class TestServer(test.MockTestCase):
|
||||
def test_process_request_task_failure(self, pu_mock):
|
||||
failure_dict = 'failure_dict'
|
||||
pu_mock.failure_to_dict.return_value = failure_dict
|
||||
request = self.request(task='taskflow.tests.utils.TaskWithFailure',
|
||||
arguments={})
|
||||
request = self.make_request(task=utils.TaskWithFailure(), arguments={})
|
||||
|
||||
# create server and process request
|
||||
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
|
||||
s = self.server(reset_master_mock=True)
|
||||
s._process_request(request, self.message_mock)
|
||||
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.proxy.publish(self.resp_running, self.reply_to,
|
||||
mock.call.Response(pr.RUNNING),
|
||||
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.proxy.publish(self.resp_failure(failure_dict),
|
||||
mock.call.Response(pr.FAILURE, result=failure_dict),
|
||||
mock.call.proxy.publish(self.response_inst_mock,
|
||||
self.reply_to,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
|
@ -77,6 +77,9 @@ class TestWorker(test.MockTestCase):
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_creation_with_negative_threads_count(self):
|
||||
self.assertRaises(ValueError, self.worker, threads_count=-10)
|
||||
|
||||
def test_creation_with_custom_executor(self):
|
||||
executor_mock = mock.MagicMock(name='executor')
|
||||
self.worker(executor=executor_mock)
|
||||
@ -104,6 +107,16 @@ class TestWorker(test.MockTestCase):
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_run_with_custom_executor(self):
|
||||
executor_mock = mock.MagicMock(name='executor')
|
||||
self.worker(reset_master_mock=True,
|
||||
executor=executor_mock).run()
|
||||
|
||||
master_mock_calls = [
|
||||
mock.call.server.start()
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
def test_wait(self):
|
||||
w = self.worker(reset_master_mock=True)
|
||||
w.run()
|
||||
|
Loading…
Reference in New Issue
Block a user