Merge "Introduce message types for WBE protocol"

This commit is contained in:
Jenkins 2014-03-17 08:29:02 +00:00 committed by Gerrit Code Review
commit 4aa9a4c0df
9 changed files with 312 additions and 220 deletions

View File

@ -28,7 +28,6 @@ from taskflow.engines.worker_based import proxy
from taskflow import exceptions as exc
from taskflow.utils import async_utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
LOG = logging.getLogger(__name__)
@ -58,45 +57,50 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
proxy_thread.daemon = True
return proxy_thread
def _on_message(self, response, message):
"""This method is called on incoming response."""
LOG.debug("Got response: %s", response)
def _on_message(self, data, message):
"""This method is called on incoming message."""
LOG.debug("Got message: %s", data)
try:
# acknowledge message before processing.
# acknowledge message before processing
message.ack()
except kombu_exc.MessageStateError:
LOG.exception("Failed to acknowledge AMQP message.")
else:
LOG.debug("AMQP message acknowledged.")
# get task uuid from message correlation id parameter
try:
task_uuid = message.properties['correlation_id']
msg_type = message.properties['type']
except KeyError:
LOG.warning("Got message with no 'correlation_id' property.")
LOG.warning("The 'type' message property is missing.")
else:
LOG.debug("Task uuid: '%s'", task_uuid)
self._process_response(task_uuid, response)
if msg_type == pr.RESPONSE:
self._process_response(data, message)
else:
LOG.warning("Unexpected message type: %s", msg_type)
def _process_response(self, task_uuid, response):
def _process_response(self, response, message):
"""Process response from remote side."""
request = self._requests_cache.get(task_uuid)
if request is not None:
state = response.pop('state')
if state == pr.RUNNING:
request.set_running()
elif state == pr.PROGRESS:
request.on_progress(**response)
elif state == pr.FAILURE:
response['result'] = pu.failure_from_dict(response['result'])
request.set_result(**response)
self._requests_cache.delete(request.uuid)
elif state == pr.SUCCESS:
request.set_result(**response)
self._requests_cache.delete(request.uuid)
else:
LOG.warning("Unexpected response status: '%s'", state)
LOG.debug("Start processing response message.")
try:
task_uuid = message.properties['correlation_id']
except KeyError:
LOG.warning("The 'correlation_id' message property is missing.")
else:
LOG.debug("Request with id='%s' not found.", task_uuid)
LOG.debug("Task uuid: '%s'", task_uuid)
request = self._requests_cache.get(task_uuid)
if request is not None:
response = pr.Response.from_dict(response)
if response.state == pr.RUNNING:
request.set_running()
elif response.state == pr.PROGRESS:
request.on_progress(**response.data)
elif response.state in (pr.FAILURE, pr.SUCCESS):
request.set_result(**response.data)
self._requests_cache.delete(request.uuid)
else:
LOG.warning("Unexpected response status: '%s'",
response.state)
else:
LOG.debug("Request with id='%s' not found.", task_uuid)
@staticmethod
def _handle_expired_request(request):
@ -129,7 +133,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
else:
# publish request
LOG.debug("Sending request: %s", request)
self._proxy.publish(request.to_dict(),
self._proxy.publish(request,
routing_key=topic,
reply_to=self._uuid,
correlation_id=request.uuid)

View File

@ -14,6 +14,10 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
import six
from concurrent import futures
from taskflow.engines.action_engine import executor
@ -49,11 +53,28 @@ REQUEST_TIMEOUT = 60
# no longer needed.
QUEUE_EXPIRE_TIMEOUT = REQUEST_TIMEOUT
# Message types.
REQUEST = 'REQUEST'
RESPONSE = 'RESPONSE'
class Request(object):
@six.add_metaclass(abc.ABCMeta)
class Message(object):
"""Base class for all message types."""
def __str__(self):
return str(self.to_dict())
@abc.abstractmethod
def to_dict(self):
"""Return json-serializable message representation."""
class Request(Message):
"""Represents request with execution results. Every request is created in
the PENDING state and is expired within the given timeout.
"""
TYPE = REQUEST
def __init__(self, task, uuid, action, arguments, progress_callback,
timeout, **kwargs):
@ -111,7 +132,7 @@ class Request(object):
if 'failures' in self._kwargs:
failures = self._kwargs['failures']
request['failures'] = {}
for task, failure in failures.items():
for task, failure in six.iteritems(failures):
request['failures'][task] = pu.failure_to_dict(failure)
return request
@ -124,3 +145,31 @@ class Request(object):
def on_progress(self, event_data, progress):
self._progress_callback(self._task, event_data, progress)
class Response(Message):
"""Represents response message type."""
TYPE = RESPONSE
def __init__(self, state, **data):
self._state = state
self._data = data
@classmethod
def from_dict(cls, data):
state = data['state']
data = data['data']
if state == FAILURE and 'result' in data:
data['result'] = pu.failure_from_dict(data['result'])
return cls(state, **data)
@property
def state(self):
return self._state
@property
def data(self):
return self._data
def to_dict(self):
return dict(state=self._state, data=self._data)

View File

@ -69,10 +69,11 @@ class Proxy(object):
"""Publish message to the named exchange with routing key."""
with kombu.producers[self._conn].acquire(block=True) as producer:
queue = self._make_queue(routing_key, self._exchange)
producer.publish(body=msg,
producer.publish(body=msg.to_dict(),
routing_key=routing_key,
exchange=self._exchange,
declare=[queue],
type=msg.TYPE,
**kwargs)
def start(self):

View File

@ -36,23 +36,32 @@ class Server(object):
self._endpoints = dict([(endpoint.name, endpoint)
for endpoint in endpoints])
def _on_message(self, request, message):
"""This method is called on incoming request."""
LOG.debug("Got request: %s", request)
# NOTE(skudriashev): Process all incoming requests only if proxy is
def _on_message(self, data, message):
"""This method is called on incoming message."""
LOG.debug("Got message: %s", data)
# NOTE(skudriashev): Process all incoming messages only if proxy is
# running, otherwise requeue them.
if self._proxy.is_running:
# NOTE(skudriashev): Process request only if message has been
# acknowledged successfully.
try:
# acknowledge message
# acknowledge message before processing
message.ack()
except kombu_exc.MessageStateError:
LOG.exception("Failed to acknowledge AMQP message.")
else:
LOG.debug("AMQP message acknowledged.")
# spawn new thread to process request
self._executor.submit(self._process_request, request, message)
try:
msg_type = message.properties['type']
except KeyError:
LOG.warning("The 'type' message property is missing.")
else:
if msg_type == pr.REQUEST:
# spawn new thread to process request
self._executor.submit(self._process_request, data,
message)
else:
LOG.warning("Unexpected message type: %s", msg_type)
else:
try:
# requeue message
@ -100,7 +109,7 @@ class Server(object):
def _reply(self, reply_to, task_uuid, state=pr.FAILURE, **kwargs):
"""Send reply to the `reply_to` queue."""
response = dict(state=state, **kwargs)
response = pr.Response(state, **kwargs)
LOG.debug("Sending reply: %s", response)
try:
self._proxy.publish(response, reply_to, correlation_id=task_uuid)
@ -115,7 +124,7 @@ class Server(object):
def _process_request(self, request, message):
"""Process request in separate thread and reply back."""
# NOTE(skudriashev): parse broker message first to get the `reply_to`
# NOTE(skudriashev): Parse broker message first to get the `reply_to`
# and the `task_uuid` parameters to have possibility to reply back.
try:
reply_to, task_uuid = self._parse_message(message)

View File

@ -49,15 +49,20 @@ class TestWorkerTaskExecutor(test.MockTestCase):
# patch classes
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
executor.proxy, 'Proxy')
self.request_mock, self.request_inst_mock = self._patch_class(
executor.pr, 'Request', autospec=False)
# other mocking
self.proxy_inst_mock.start.side_effect = self._fake_proxy_start
self.proxy_inst_mock.stop.side_effect = self._fake_proxy_stop
self.request_inst_mock.uuid = self.task_uuid
self.request_inst_mock.expired = False
self.request_inst_mock.task_cls = self.task.name
self.wait_for_any_mock = self._patch(
'taskflow.engines.worker_based.executor.async_utils.wait_for_any')
self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid}
self.request_mock = mock.MagicMock(uuid=self.task_uuid)
self.message_mock.properties = {'correlation_id': self.task_uuid,
'type': pr.RESPONSE}
def _fake_proxy_start(self):
self.proxy_started_event.set()
@ -78,20 +83,6 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self._reset_master_mock()
return ex
def request(self, **kwargs):
request_kwargs = dict(task=self.task, uuid=self.task_uuid,
action='execute', arguments=self.task_args,
progress_callback=None, timeout=self.timeout)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs)
def request_dict(self, **kwargs):
request = dict(task_cls=self.task.name, task_name=self.task.name,
task_version=self.task.version,
arguments=self.task_args)
request.update(kwargs)
return request
def test_creation(self):
ex = self.executor(reset_master_mock=False)
@ -101,184 +92,190 @@ class TestWorkerTaskExecutor(test.MockTestCase):
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_on_message_state_running(self):
response = dict(state=pr.RUNNING)
def test_on_message_response_state_running(self):
response = pr.Response(pr.RUNNING)
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(self.request_mock.mock_calls,
self.assertEqual(self.request_inst_mock.mock_calls,
[mock.call.set_running()])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_progress(self):
response = dict(state=pr.PROGRESS, progress=1.0)
def test_on_message_response_state_progress(self):
response = pr.Response(pr.PROGRESS, progress=1.0)
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(self.request_mock.mock_calls,
self.assertEqual(self.request_inst_mock.mock_calls,
[mock.call.on_progress(progress=1.0)])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_failure(self):
def test_on_message_response_state_failure(self):
failure = misc.Failure.from_exception(Exception('test'))
failure_dict = pu.failure_to_dict(failure)
response = dict(state=pr.FAILURE, result=failure_dict)
response = pr.Response(pr.FAILURE, result=failure_dict)
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(len(ex._requests_cache._data), 0)
self.assertEqual(self.request_mock.mock_calls, [
self.assertEqual(self.request_inst_mock.mock_calls, [
mock.call.set_result(result=utils.FailureMatcher(failure))
])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_state_success(self):
response = dict(state=pr.SUCCESS, result=self.task_result,
event='executed')
def test_on_message_response_state_success(self):
response = pr.Response(pr.SUCCESS, result=self.task_result,
event='executed')
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(self.request_mock.mock_calls,
self.assertEqual(self.request_inst_mock.mock_calls,
[mock.call.set_result(result=self.task_result,
event='executed')])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_unknown_state(self):
response = dict(state='unknown')
def test_on_message_response_unknown_state(self):
response = pr.Response(state='<unknown>')
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(self.request_mock.mock_calls, [])
self.assertEqual(self.request_inst_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_non_existent_task(self):
self.message_mock.properties = {'correlation_id': 'non-existent'}
response = dict(state=pr.RUNNING)
def test_on_message_response_unknown_task(self):
self.message_mock.properties['correlation_id'] = '<unknown>'
response = pr.Response(pr.RUNNING)
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(self.request_mock.mock_calls, [])
self.assertEqual(self.request_inst_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
def test_on_message_no_correlation_id(self):
self.message_mock.properties = {}
response = dict(state=pr.RUNNING)
def test_on_message_response_no_correlation_id(self):
self.message_mock.properties = {'type': pr.RESPONSE}
response = pr.Response(pr.RUNNING)
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request_mock)
ex._on_message(response, self.message_mock)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
ex._on_message(response.to_dict(), self.message_mock)
self.assertEqual(self.request_mock.mock_calls, [])
self.assertEqual(self.request_inst_mock.mock_calls, [])
self.assertEqual(self.message_mock.mock_calls, [mock.call.ack()])
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
def test_on_message_unknown_type(self, mocked_warning):
self.message_mock.properties = {'correlation_id': self.task_uuid,
'type': '<unknown>'}
ex = self.executor()
ex._on_message({}, self.message_mock)
self.assertTrue(mocked_warning.called)
@mock.patch('taskflow.engines.worker_based.executor.LOG.warning')
def test_on_message_no_type(self, mocked_warning):
self.message_mock.properties = {'correlation_id': self.task_uuid}
ex = self.executor()
ex._on_message({}, self.message_mock)
self.assertTrue(mocked_warning.called)
@mock.patch('taskflow.engines.worker_based.executor.LOG.exception')
def test_on_message_acknowledge_raises(self, mocked_exception):
self.message_mock.ack.side_effect = kombu_exc.MessageStateError()
self.executor()._on_message({}, self.message_mock)
self.assertTrue(mocked_exception.called)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_on_wait_task_not_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout]
def test_on_wait_task_not_expired(self):
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request())
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
self.assertEqual(len(ex._requests_cache._data), 1)
ex._on_wait()
self.assertEqual(len(ex._requests_cache._data), 1)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_on_wait_task_expired(self, mocked_time):
mocked_time.side_effect = [1, self.timeout + 2, self.timeout * 2]
def test_on_wait_task_expired(self):
self.request_inst_mock.expired = True
ex = self.executor()
ex._requests_cache.set(self.task_uuid, self.request())
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
self.assertEqual(len(ex._requests_cache._data), 1)
ex._on_wait()
self.assertEqual(len(ex._requests_cache._data), 0)
def test_remove_task_non_existent(self):
task = self.request()
ex = self.executor()
ex._requests_cache.set(self.task_uuid, task)
ex._requests_cache.set(self.task_uuid, self.request_inst_mock)
self.assertEqual(len(ex._requests_cache._data), 1)
ex._requests_cache.delete(self.task_uuid)
self.assertEqual(len(ex._requests_cache._data), 0)
# remove non-existent
# delete non-existent
ex._requests_cache.delete(self.task_uuid)
self.assertEqual(len(ex._requests_cache._data), 0)
def test_execute_task(self):
request_dict = self.request_dict(action='execute')
ex = self.executor()
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
mock.call.proxy.publish(request_dict,
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout),
mock.call.proxy.publish(self.request_inst_mock,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertIsInstance(result, futures.Future)
def test_revert_task(self):
request_dict = self.request_dict(action='revert',
result=('success', self.task_result),
failures=self.task_failures)
ex = self.executor()
result = ex.revert_task(self.task, self.task_uuid, self.task_args,
self.task_result, self.task_failures)
ex.revert_task(self.task, self.task_uuid, self.task_args,
self.task_result, self.task_failures)
expected_calls = [
mock.call.proxy.publish(request_dict,
mock.call.Request(self.task, self.task_uuid, 'revert',
self.task_args, None, self.timeout,
failures=self.task_failures,
result=self.task_result),
mock.call.proxy.publish(self.request_inst_mock,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertIsInstance(result, futures.Future)
def test_execute_task_topic_not_found(self):
workers_info = {self.executor_topic: ['non-existent-task']}
workers_info = {self.executor_topic: ['<unknown>']}
ex = self.executor(workers_info=workers_info)
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
ex.execute_task(self.task, self.task_uuid, self.task_args)
self.assertFalse(self.proxy_inst_mock.publish.called)
# check execute result
task, event, res = result.result()
self.assertEqual(task, self.task)
self.assertEqual(event, 'executed')
self.assertIsInstance(res, misc.Failure)
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout),
mock.call.request.set_result(mock.ANY)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
def test_execute_task_publish_error(self):
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
request_dict = self.request_dict(action='execute')
ex = self.executor()
result = ex.execute_task(self.task, self.task_uuid, self.task_args)
ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
mock.call.proxy.publish(request_dict,
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout),
mock.call.proxy.publish(self.request_inst_mock,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
correlation_id=self.task_uuid),
mock.call.request.set_result(mock.ANY)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
# check execute result
task, event, res = result.result()
self.assertEqual(task, self.task)
self.assertEqual(event, 'executed')
self.assertIsInstance(res, misc.Failure)
def test_wait_for_any(self):
fs = [futures.Future(), futures.Future()]
ex = self.executor()

View File

@ -61,9 +61,13 @@ class TestProtocol(test.TestCase):
self.assertIsInstance(request.result, futures.Future)
self.assertFalse(request.result.done())
def test_str(self):
request = self.request()
self.assertEqual(str(request), str(request.to_dict()))
def test_repr(self):
expected_name = '%s:%s' % (self.task.name, self.task_action)
self.assertEqual(repr(self.request()), expected_name)
expected = '%s:%s' % (self.task.name, self.task_action)
self.assertEqual(repr(self.request()), expected)
def test_to_dict_default(self):
self.assertEqual(self.request().to_dict(), self.request_to_dict())

View File

@ -128,13 +128,15 @@ class TestProxy(test.MockTestCase):
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_publish(self):
task_data = 'task-data'
task_uuid = 'task-uuid'
msg_mock = mock.MagicMock()
msg_data = 'msg-data'
msg_mock.to_dict.return_value = msg_data
routing_key = 'routing-key'
task_uuid = 'task-uuid'
kwargs = dict(a='a', b='b')
self.proxy(reset_master_mock=True).publish(
task_data, routing_key, correlation_id=task_uuid, **kwargs)
msg_mock, routing_key, correlation_id=task_uuid, **kwargs)
master_mock_calls = [
mock.call.Queue(name=self._queue_name(routing_key),
@ -142,11 +144,12 @@ class TestProxy(test.MockTestCase):
routing_key=routing_key,
durable=False,
auto_delete=True),
mock.call.producer.publish(body=task_data,
mock.call.producer.publish(body=msg_data,
routing_key=routing_key,
exchange=self.exchange_inst_mock,
correlation_id=task_uuid,
declare=[self.queue_inst_mock],
type=msg_mock.TYPE,
**kwargs)
]
self.master_mock.assert_has_calls(master_mock_calls)

View File

@ -16,6 +16,8 @@
import mock
import six
from kombu import exceptions as exc
from taskflow.engines.worker_based import endpoint as ep
@ -24,7 +26,6 @@ from taskflow.engines.worker_based import server
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
from taskflow.utils import persistence_utils as pu
class TestServer(test.MockTestCase):
@ -34,27 +35,28 @@ class TestServer(test.MockTestCase):
self.server_topic = 'server-topic'
self.server_exchange = 'server-exchange'
self.broker_url = 'test-url'
self.task = utils.TaskOneArgOneReturn()
self.task_uuid = 'task-uuid'
self.task_args = {'x': 1}
self.task_action = 'execute'
self.task_name = 'taskflow.tests.utils.TaskOneArgOneReturn'
self.task_version = (1, 0)
self.reply_to = 'reply-to'
self.endpoints = [ep.Endpoint(task_cls=utils.TaskOneArgOneReturn),
ep.Endpoint(task_cls=utils.TaskWithFailure),
ep.Endpoint(task_cls=utils.ProgressingTask)]
self.resp_running = dict(state=pr.RUNNING)
# patch classes
self.proxy_mock, self.proxy_inst_mock = self._patch_class(
server.proxy, 'Proxy')
self.response_mock, self.response_inst_mock = self._patch_class(
server.pr, 'Response')
# other mocking
self.proxy_inst_mock.is_running = True
self.executor_mock = mock.MagicMock(name='executor')
self.message_mock = mock.MagicMock(name='message')
self.message_mock.properties = {'correlation_id': self.task_uuid,
'reply_to': self.reply_to}
'reply_to': self.reply_to,
'type': pr.REQUEST}
self.master_mock.attach_mock(self.executor_mock, 'executor')
self.master_mock.attach_mock(self.message_mock, 'message')
@ -70,28 +72,15 @@ class TestServer(test.MockTestCase):
self._reset_master_mock()
return s
def request(self, **kwargs):
request = dict(task_cls=self.task_name,
task_name=self.task_name,
action=self.task_action,
task_version=self.task_version,
arguments=self.task_args)
request.update(kwargs)
return request
@staticmethod
def resp_progress(progress):
return dict(state=pr.PROGRESS, progress=progress, event_data={})
@staticmethod
def resp_success(result):
return dict(state=pr.SUCCESS, result=result)
@staticmethod
def resp_failure(result, **kwargs):
response = dict(state=pr.FAILURE, result=result)
response.update(kwargs)
return response
def make_request(self, **kwargs):
request_kwargs = dict(task=self.task,
uuid=self.task_uuid,
action=self.task_action,
arguments=self.task_args,
progress_callback=None,
timeout=60)
request_kwargs.update(kwargs)
return pr.Request(**request_kwargs).to_dict()
def test_creation(self):
s = self.server()
@ -116,7 +105,7 @@ class TestServer(test.MockTestCase):
self.assertEqual(len(s._endpoints), len(self.endpoints))
def test_on_message_proxy_running_ack_success(self):
request = self.request()
request = self.make_request()
s = self.server(reset_master_mock=True)
s._on_message(request, self.message_mock)
@ -162,99 +151,115 @@ class TestServer(test.MockTestCase):
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@mock.patch('taskflow.engines.worker_based.server.LOG.warning')
def test_on_message_unknown_type(self, mocked_warning):
self.message_mock.properties['type'] = '<unknown>'
s = self.server()
s._on_message({}, self.message_mock)
self.assertTrue(mocked_warning.called)
@mock.patch('taskflow.engines.worker_based.server.LOG.warning')
def test_on_message_no_type(self, mocked_warning):
self.message_mock.properties = {}
s = self.server()
s._on_message({}, self.message_mock)
self.assertTrue(mocked_warning.called)
def test_parse_request(self):
request = self.request()
request = self.make_request()
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task_cls, action, task_args),
(self.task_name, self.task_action,
dict(task_name=self.task_name,
(self.task.name, self.task_action,
dict(task_name=self.task.name,
arguments=self.task_args)))
def test_parse_request_with_success_result(self):
request = self.request(action='revert', result=('success', 1))
request = self.make_request(action='revert', result=1)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task_cls, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
(self.task.name, 'revert',
dict(task_name=self.task.name,
arguments=self.task_args,
result=1)))
def test_parse_request_with_failure_result(self):
failure = misc.Failure.from_exception(Exception('test'))
failure_dict = pu.failure_to_dict(failure)
request = self.request(action='revert',
result=('failure', failure_dict))
request = self.make_request(action='revert', result=failure)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual((task_cls, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
(self.task.name, 'revert',
dict(task_name=self.task.name,
arguments=self.task_args,
result=utils.FailureMatcher(failure))))
def test_parse_request_with_failures(self):
failures = [misc.Failure.from_exception(Exception('test1')),
misc.Failure.from_exception(Exception('test2'))]
failures_dict = dict((str(i), pu.failure_to_dict(f))
for i, f in enumerate(failures))
request = self.request(action='revert', failures=failures_dict)
failures = {'0': misc.Failure.from_exception(Exception('test1')),
'1': misc.Failure.from_exception(Exception('test2'))}
request = self.make_request(action='revert', failures=failures)
task_cls, action, task_args = server.Server._parse_request(**request)
self.assertEqual(
(task_cls, action, task_args),
(self.task_name, 'revert',
dict(task_name=self.task_name,
(self.task.name, 'revert',
dict(task_name=self.task.name,
arguments=self.task_args,
failures=dict((str(i), utils.FailureMatcher(f))
for i, f in enumerate(failures)))))
failures=dict((i, utils.FailureMatcher(f))
for i, f in six.iteritems(failures)))))
@mock.patch("taskflow.engines.worker_based.server.LOG.exception")
def test_reply_publish_failure(self, mocked_exception):
self.proxy_inst_mock.publish.side_effect = RuntimeError('Woot!')
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s = self.server(reset_master_mock=True)
s._reply(self.reply_to, self.task_uuid)
self.assertEqual(self.master_mock.mock_calls, [
mock.call.proxy.publish({'state': 'FAILURE'}, self.reply_to,
mock.call.Response(pr.FAILURE),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid)
])
self.assertTrue(mocked_exception.called)
def test_on_update_progress(self):
request = self.request(task_cls='taskflow.tests.utils.ProgressingTask',
arguments={})
request = self.make_request(task=utils.ProgressingTask(), arguments={})
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.reply_to,
mock.call.Response(pr.RUNNING),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid),
mock.call.proxy.publish(self.resp_progress(0.0), self.reply_to,
mock.call.Response(pr.PROGRESS, progress=0.0, event_data={}),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid),
mock.call.proxy.publish(self.resp_progress(1.0), self.reply_to,
mock.call.Response(pr.PROGRESS, progress=1.0, event_data={}),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid),
mock.call.proxy.publish(self.resp_success(5), self.reply_to,
mock.call.Response(pr.SUCCESS, result=5),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_process_request(self):
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s._process_request(self.request(), self.message_mock)
s = self.server(reset_master_mock=True)
s._process_request(self.make_request(), self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.reply_to,
mock.call.Response(pr.RUNNING),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid),
mock.call.proxy.publish(self.resp_success(1), self.reply_to,
mock.call.Response(pr.SUCCESS, result=1),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid)
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@ -262,7 +267,7 @@ class TestServer(test.MockTestCase):
@mock.patch("taskflow.engines.worker_based.server.LOG.exception")
def test_process_request_parse_message_failure(self, mocked_exception):
self.message_mock.properties = {}
request = self.request()
request = self.make_request()
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
@ -270,19 +275,21 @@ class TestServer(test.MockTestCase):
self.assertTrue(mocked_exception.called)
@mock.patch('taskflow.engines.worker_based.server.pu')
def test_process_request_parse_failure(self, pu_mock):
def test_process_request_parse_request_failure(self, pu_mock):
failure_dict = 'failure_dict'
failure = misc.Failure.from_exception(RuntimeError('Woot!'))
pu_mock.failure_to_dict.return_value = failure_dict
pu_mock.failure_from_dict.side_effect = ValueError('Woot!')
request = self.request(result=('failure', 1))
request = self.make_request(result=failure)
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_failure(failure_dict),
mock.call.Response(pr.FAILURE, result=failure_dict),
mock.call.proxy.publish(self.response_inst_mock,
self.reply_to,
correlation_id=self.task_uuid)
]
@ -292,15 +299,16 @@ class TestServer(test.MockTestCase):
def test_process_request_endpoint_not_found(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(task_cls='<unknown>')
request = self.make_request(task=mock.MagicMock(name='<unknown>'))
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_failure(failure_dict),
mock.call.Response(pr.FAILURE, result=failure_dict),
mock.call.proxy.publish(self.response_inst_mock,
self.reply_to,
correlation_id=self.task_uuid)
]
@ -310,17 +318,20 @@ class TestServer(test.MockTestCase):
def test_process_request_execution_failure(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(action='<unknown>')
request = self.make_request()
request['action'] = '<unknown>'
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.reply_to,
mock.call.Response(pr.RUNNING),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid),
mock.call.proxy.publish(self.resp_failure(failure_dict),
mock.call.Response(pr.FAILURE, result=failure_dict),
mock.call.proxy.publish(self.response_inst_mock,
self.reply_to,
correlation_id=self.task_uuid)
]
@ -330,18 +341,19 @@ class TestServer(test.MockTestCase):
def test_process_request_task_failure(self, pu_mock):
failure_dict = 'failure_dict'
pu_mock.failure_to_dict.return_value = failure_dict
request = self.request(task='taskflow.tests.utils.TaskWithFailure',
arguments={})
request = self.make_request(task=utils.TaskWithFailure(), arguments={})
# create server and process request
s = self.server(reset_master_mock=True, endpoints=self.endpoints)
s = self.server(reset_master_mock=True)
s._process_request(request, self.message_mock)
# check calls
master_mock_calls = [
mock.call.proxy.publish(self.resp_running, self.reply_to,
mock.call.Response(pr.RUNNING),
mock.call.proxy.publish(self.response_inst_mock, self.reply_to,
correlation_id=self.task_uuid),
mock.call.proxy.publish(self.resp_failure(failure_dict),
mock.call.Response(pr.FAILURE, result=failure_dict),
mock.call.proxy.publish(self.response_inst_mock,
self.reply_to,
correlation_id=self.task_uuid)
]

View File

@ -77,6 +77,9 @@ class TestWorker(test.MockTestCase):
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_creation_with_negative_threads_count(self):
self.assertRaises(ValueError, self.worker, threads_count=-10)
def test_creation_with_custom_executor(self):
executor_mock = mock.MagicMock(name='executor')
self.worker(executor=executor_mock)
@ -104,6 +107,16 @@ class TestWorker(test.MockTestCase):
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_run_with_custom_executor(self):
executor_mock = mock.MagicMock(name='executor')
self.worker(reset_master_mock=True,
executor=executor_mock).run()
master_mock_calls = [
mock.call.server.start()
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
def test_wait(self):
w = self.worker(reset_master_mock=True)
w.run()