Merge "Use explicit WBE request state transitions"
This commit is contained in:
@@ -23,6 +23,7 @@ from taskflow.engines.worker_based import cache
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.openstack.common import timeutils
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.utils import async_utils
|
||||
from taskflow.utils import misc
|
||||
@@ -109,8 +110,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
|
||||
# publish waiting requests
|
||||
for request in self._requests_cache.get_waiting_requests(tasks):
|
||||
request.set_pending()
|
||||
self._publish_request(request, topic)
|
||||
if request.transition_log_error(pr.PENDING, logger=LOG):
|
||||
self._publish_request(request, topic)
|
||||
|
||||
def _process_response(self, response, message):
|
||||
"""Process response from remote side."""
|
||||
@@ -120,20 +121,23 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
except KeyError:
|
||||
LOG.warning("The 'correlation_id' message property is missing.")
|
||||
else:
|
||||
LOG.debug("Task uuid: '%s'", task_uuid)
|
||||
request = self._requests_cache.get(task_uuid)
|
||||
if request is not None:
|
||||
response = pr.Response.from_dict(response)
|
||||
if response.state == pr.RUNNING:
|
||||
request.set_running()
|
||||
request.transition_log_error(pr.RUNNING, logger=LOG)
|
||||
elif response.state == pr.PROGRESS:
|
||||
request.on_progress(**response.data)
|
||||
elif response.state in (pr.FAILURE, pr.SUCCESS):
|
||||
# NOTE(imelnikov): request should not be in cache when
|
||||
# another thread can see its result and schedule another
|
||||
# request with same uuid; so we remove it, then set result
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
moved = request.transition_log_error(response.state,
|
||||
logger=LOG)
|
||||
if moved:
|
||||
# NOTE(imelnikov): request should not be in the
|
||||
# cache when another thread can see its result and
|
||||
# schedule another request with the same uuid; so
|
||||
# we remove it, then set the result...
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'",
|
||||
response.state)
|
||||
@@ -147,10 +151,21 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
When request has expired it is removed from the requests cache and
|
||||
the `RequestTimeout` exception is set as a request result.
|
||||
"""
|
||||
LOG.debug("Request '%r' has expired.", request)
|
||||
LOG.debug("The '%r' request has expired.", request)
|
||||
request.set_result(misc.Failure.from_exception(
|
||||
exc.RequestTimeout("The '%r' request has expired" % request)))
|
||||
if request.transition_log_error(pr.FAILURE, logger=LOG):
|
||||
# Raise an exception (and then catch it) so we get a nice
|
||||
# traceback that the request will get instead of it getting
|
||||
# just an exception with no traceback...
|
||||
try:
|
||||
request_age = timeutils.delta_seconds(request.created_on,
|
||||
timeutils.utcnow())
|
||||
raise exc.RequestTimeout(
|
||||
"Request '%s' has expired after waiting for %0.2f"
|
||||
" seconds for it to transition out of (%s) states"
|
||||
% (request, request_age, ", ".join(pr.WAITING_STATES)))
|
||||
except exc.RequestTimeout:
|
||||
with misc.capture_failure() as fail:
|
||||
LOG.debug(fail.exception_str)
|
||||
request.set_result(fail)
|
||||
|
||||
def _on_wait(self):
|
||||
"""This function is called cyclically between draining events."""
|
||||
@@ -169,9 +184,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
# before putting it into the requests cache to prevent the notify
|
||||
# processing thread get list of waiting requests and publish it
|
||||
# before it is published here, so it wouldn't be published twice.
|
||||
request.set_pending()
|
||||
self._requests_cache[request.uuid] = request
|
||||
self._publish_request(request, topic)
|
||||
if request.transition_log_error(pr.PENDING, logger=LOG):
|
||||
self._requests_cache[request.uuid] = request
|
||||
self._publish_request(request, topic)
|
||||
else:
|
||||
self._requests_cache[request.uuid] = request
|
||||
|
||||
@@ -187,8 +202,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.exception("Failed to submit the '%s' request.", request)
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(failure)
|
||||
if request.transition_log_error(pr.FAILURE, logger=LOG):
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(failure)
|
||||
|
||||
def _notify_topics(self):
|
||||
"""Cyclically called to publish notify message to each topic."""
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from concurrent import futures
|
||||
import jsonschema
|
||||
@@ -23,7 +25,9 @@ import six
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow import exceptions as excp
|
||||
from taskflow.openstack.common import timeutils
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.utils import lock_utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import reflection
|
||||
|
||||
@@ -36,7 +40,34 @@ SUCCESS = 'SUCCESS'
|
||||
FAILURE = 'FAILURE'
|
||||
PROGRESS = 'PROGRESS'
|
||||
|
||||
# During these states the expiry is active (once out of these states the expiry
|
||||
# no longer matters, since we have no way of knowing how long a task will run
|
||||
# for).
|
||||
WAITING_STATES = (WAITING, PENDING)
|
||||
|
||||
_ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS)
|
||||
_STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE)
|
||||
|
||||
# Transitions that a request state can go through.
|
||||
_ALLOWED_TRANSITIONS = (
|
||||
# Used when a executor starts to publish a request to a selected worker.
|
||||
(WAITING, PENDING),
|
||||
# When a request expires (isn't able to be processed by any worker).
|
||||
(WAITING, FAILURE),
|
||||
# Worker has started executing a request.
|
||||
(PENDING, RUNNING),
|
||||
# Worker failed to construct/process a request to run (either the worker
|
||||
# did not transition to RUNNING in the given timeout or the worker itself
|
||||
# had some type of failure before RUNNING started).
|
||||
#
|
||||
# Also used by the executor if the request was attempted to be published
|
||||
# but that did publishing process did not work out.
|
||||
(PENDING, FAILURE),
|
||||
# Execution failed due to some type of remote failure.
|
||||
(RUNNING, FAILURE),
|
||||
# Execution succeeded & has completed.
|
||||
(RUNNING, SUCCESS),
|
||||
)
|
||||
|
||||
# Remote task actions.
|
||||
EXECUTE = 'execute'
|
||||
@@ -73,6 +104,8 @@ _SCHEMA_TYPES = {
|
||||
'array': (list, tuple),
|
||||
}
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Message(object):
|
||||
@@ -143,8 +176,10 @@ class Request(Message):
|
||||
"""Represents request with execution results.
|
||||
|
||||
Every request is created in the WAITING state and is expired within the
|
||||
given timeout.
|
||||
given timeout if it does not transition out of the (WAITING, PENDING)
|
||||
states.
|
||||
"""
|
||||
|
||||
TYPE = REQUEST
|
||||
_SCHEMA = {
|
||||
"type": "object",
|
||||
@@ -196,11 +231,10 @@ class Request(Message):
|
||||
self._kwargs = kwargs
|
||||
self._watch = tt.StopWatch(duration=timeout).start()
|
||||
self._state = WAITING
|
||||
self._lock = threading.Lock()
|
||||
self._created_on = timeutils.utcnow()
|
||||
self.result = futures.Future()
|
||||
|
||||
def __repr__(self):
|
||||
return "%s:%s" % (self._task_cls, self._action)
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
@@ -213,6 +247,10 @@ class Request(Message):
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def created_on(self):
|
||||
return self._created_on
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Check if request has expired.
|
||||
@@ -224,7 +262,7 @@ class Request(Message):
|
||||
state for more then the given timeout (it is not considered to be
|
||||
expired in any other state).
|
||||
"""
|
||||
if self._state in (WAITING, PENDING):
|
||||
if self._state in WAITING_STATES:
|
||||
return self._watch.expired()
|
||||
return False
|
||||
|
||||
@@ -254,16 +292,43 @@ class Request(Message):
|
||||
def set_result(self, result):
|
||||
self.result.set_result((self._task, self._event, result))
|
||||
|
||||
def set_pending(self):
|
||||
self._state = PENDING
|
||||
|
||||
def set_running(self):
|
||||
self._state = RUNNING
|
||||
self._watch.stop()
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
||||
|
||||
def transition_log_error(self, new_state, logger=None):
|
||||
if logger is None:
|
||||
logger = LOG
|
||||
moved = False
|
||||
try:
|
||||
moved = self.transition(new_state)
|
||||
except excp.InvalidState:
|
||||
logger.warn("Failed to transition '%s' to %s state.", self,
|
||||
new_state, exc_info=True)
|
||||
return moved
|
||||
|
||||
@lock_utils.locked
|
||||
def transition(self, new_state):
|
||||
"""Transitions the request to a new state.
|
||||
|
||||
If transition was performed, it returns True. If transition
|
||||
should was ignored, it returns False. If transition is not
|
||||
valid (and will not be performed), it raises an InvalidState
|
||||
exception.
|
||||
"""
|
||||
old_state = self._state
|
||||
if old_state == new_state:
|
||||
return False
|
||||
pair = (old_state, new_state)
|
||||
if pair not in _ALLOWED_TRANSITIONS:
|
||||
raise excp.InvalidState("Request transition from %s to %s is"
|
||||
" not allowed" % pair)
|
||||
if new_state in _STOP_TIMER_STATES:
|
||||
self._watch.stop()
|
||||
self._state = new_state
|
||||
LOG.debug("Transitioned '%s' from %s state to %s state", self,
|
||||
old_state, new_state)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data):
|
||||
try:
|
||||
@@ -292,6 +357,9 @@ class Response(Message):
|
||||
{
|
||||
"$ref": "#/definitions/completion",
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/empty",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
@@ -311,6 +379,12 @@ class Response(Message):
|
||||
"required": ["progress", 'event_data'],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
# Used when sending *only* request state changes (and no data is
|
||||
# expected).
|
||||
"empty": {
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"completion": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -22,6 +22,7 @@ import mock
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.openstack.common import timeutils
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
@@ -95,8 +96,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_running()])
|
||||
expected_calls = [
|
||||
mock.call.transition_log_error(pr.RUNNING, logger=mock.ANY),
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_state_progress(self):
|
||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||
@@ -116,9 +119,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [
|
||||
expected_calls = [
|
||||
mock.call.transition_log_error(pr.FAILURE, logger=mock.ANY),
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_state_success(self):
|
||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||
@@ -127,9 +132,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_result(result=self.task_result,
|
||||
event='executed')])
|
||||
expected_calls = [
|
||||
mock.call.transition_log_error(pr.SUCCESS, logger=mock.ANY),
|
||||
mock.call.set_result(result=self.task_result, event='executed')
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_unknown_state(self):
|
||||
response = pr.Response(state='<unknown>')
|
||||
@@ -166,7 +173,13 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertEqual(len(ex._requests_cache), 1)
|
||||
|
||||
def test_on_wait_task_expired(self):
|
||||
now = timeutils.utcnow()
|
||||
self.request_inst_mock.expired = True
|
||||
self.request_inst_mock.created_on = now
|
||||
timeutils.set_time_override(now)
|
||||
self.addCleanup(timeutils.clear_time_override)
|
||||
timeutils.advance_time_seconds(120)
|
||||
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
@@ -199,13 +212,14 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.request.set_pending(),
|
||||
mock.call.request.transition_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(msg=self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_revert_task(self):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
@@ -220,13 +234,14 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.task_args, None, self.timeout,
|
||||
failures=self.task_failures,
|
||||
result=self.task_result),
|
||||
mock.call.request.set_pending(),
|
||||
mock.call.request.transition_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(msg=self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_execute_task_topic_not_found(self):
|
||||
workers_info = {self.executor_topic: ['<unknown>']}
|
||||
@@ -250,14 +265,17 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.request.set_pending(),
|
||||
mock.call.request.transition_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(msg=self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.request.transition_log_error(pr.FAILURE,
|
||||
logger=mock.ANY),
|
||||
mock.call.request.set_result(mock.ANY)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_wait_for_any(self):
|
||||
fs = [futures.Future(), futures.Future()]
|
||||
|
||||
@@ -115,6 +115,18 @@ class TestProtocol(test.TestCase):
|
||||
to_dict.update(kwargs)
|
||||
return to_dict
|
||||
|
||||
def test_request_transitions(self):
|
||||
request = self.request()
|
||||
self.assertEqual(pr.WAITING, request.state)
|
||||
self.assertIn(request.state, pr.WAITING_STATES)
|
||||
self.assertRaises(excp.InvalidState, request.transition, pr.SUCCESS)
|
||||
self.assertFalse(request.transition(pr.WAITING))
|
||||
self.assertTrue(request.transition(pr.PENDING))
|
||||
self.assertTrue(request.transition(pr.RUNNING))
|
||||
self.assertTrue(request.transition(pr.SUCCESS))
|
||||
for s in (pr.PENDING, pr.WAITING):
|
||||
self.assertRaises(excp.InvalidState, request.transition, s)
|
||||
|
||||
def test_creation(self):
|
||||
request = self.request()
|
||||
self.assertEqual(request.uuid, self.task_uuid)
|
||||
@@ -122,15 +134,6 @@ class TestProtocol(test.TestCase):
|
||||
self.assertIsInstance(request.result, futures.Future)
|
||||
self.assertFalse(request.result.done())
|
||||
|
||||
def test_str(self):
|
||||
request = self.request()
|
||||
self.assertEqual(str(request),
|
||||
"<REQUEST> %s" % self.request_to_dict())
|
||||
|
||||
def test_repr(self):
|
||||
expected = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.request()), expected)
|
||||
|
||||
def test_to_dict_default(self):
|
||||
self.assertEqual(self.request().to_dict(), self.request_to_dict())
|
||||
|
||||
@@ -156,19 +159,20 @@ class TestProtocol(test.TestCase):
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_pending_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout]
|
||||
mocked_wallclock.side_effect = [0, self.timeout - 1]
|
||||
self.assertFalse(self.request().expired)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_pending_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout + 2]
|
||||
mocked_wallclock.side_effect = [0, self.timeout + 2]
|
||||
self.assertTrue(self.request().expired)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_running_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout + 2]
|
||||
mocked_wallclock.side_effect = [0, self.timeout + 2]
|
||||
request = self.request()
|
||||
request.set_running()
|
||||
request.transition(pr.PENDING)
|
||||
request.transition(pr.RUNNING)
|
||||
self.assertFalse(request.expired)
|
||||
|
||||
def test_set_result(self):
|
||||
|
||||
Reference in New Issue
Block a user