Use explicit WBE request state transitions
Instead of having an implicit state machine for a requests lifecycle move toward an explicit state model and transition set that is validated and transitioned in a more easy to understand/reason about manner. This also fixes a bug that was found due to a response validation not taking into account a transition that was found due to this stricter transition checking. Includes a few tiny related/affected commits: * Remove testing of request repr() and str() as these types of tests are not useful and we removed the repr() version of the request message as the base classes is good enough. * Raise and capture a better exception and save its associated failure object when a request has expired (this gives expired requests better failure objects and associated details). Fixes bug 1356658 Partially fixes bug 1357117 Change-Id: Ie1386cca13a2da7265e22447b4c111a0a0074201
This commit is contained in:
committed by
Joshua Harlow
parent
7c3332e49b
commit
e720bdb8af
@@ -23,6 +23,7 @@ from taskflow.engines.worker_based import cache
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import proxy
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow.openstack.common import timeutils
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.utils import async_utils
|
||||
from taskflow.utils import misc
|
||||
@@ -109,8 +110,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
|
||||
# publish waiting requests
|
||||
for request in self._requests_cache.get_waiting_requests(tasks):
|
||||
request.set_pending()
|
||||
self._publish_request(request, topic)
|
||||
if request.transition_log_error(pr.PENDING, logger=LOG):
|
||||
self._publish_request(request, topic)
|
||||
|
||||
def _process_response(self, response, message):
|
||||
"""Process response from remote side."""
|
||||
@@ -120,20 +121,23 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
except KeyError:
|
||||
LOG.warning("The 'correlation_id' message property is missing.")
|
||||
else:
|
||||
LOG.debug("Task uuid: '%s'", task_uuid)
|
||||
request = self._requests_cache.get(task_uuid)
|
||||
if request is not None:
|
||||
response = pr.Response.from_dict(response)
|
||||
if response.state == pr.RUNNING:
|
||||
request.set_running()
|
||||
request.transition_log_error(pr.RUNNING, logger=LOG)
|
||||
elif response.state == pr.PROGRESS:
|
||||
request.on_progress(**response.data)
|
||||
elif response.state in (pr.FAILURE, pr.SUCCESS):
|
||||
# NOTE(imelnikov): request should not be in cache when
|
||||
# another thread can see its result and schedule another
|
||||
# request with same uuid; so we remove it, then set result
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
moved = request.transition_log_error(response.state,
|
||||
logger=LOG)
|
||||
if moved:
|
||||
# NOTE(imelnikov): request should not be in the
|
||||
# cache when another thread can see its result and
|
||||
# schedule another request with the same uuid; so
|
||||
# we remove it, then set the result...
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(**response.data)
|
||||
else:
|
||||
LOG.warning("Unexpected response status: '%s'",
|
||||
response.state)
|
||||
@@ -147,10 +151,21 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
When request has expired it is removed from the requests cache and
|
||||
the `RequestTimeout` exception is set as a request result.
|
||||
"""
|
||||
LOG.debug("Request '%r' has expired.", request)
|
||||
LOG.debug("The '%r' request has expired.", request)
|
||||
request.set_result(misc.Failure.from_exception(
|
||||
exc.RequestTimeout("The '%r' request has expired" % request)))
|
||||
if request.transition_log_error(pr.FAILURE, logger=LOG):
|
||||
# Raise an exception (and then catch it) so we get a nice
|
||||
# traceback that the request will get instead of it getting
|
||||
# just an exception with no traceback...
|
||||
try:
|
||||
request_age = timeutils.delta_seconds(request.created_on,
|
||||
timeutils.utcnow())
|
||||
raise exc.RequestTimeout(
|
||||
"Request '%s' has expired after waiting for %0.2f"
|
||||
" seconds for it to transition out of (%s) states"
|
||||
% (request, request_age, ", ".join(pr.WAITING_STATES)))
|
||||
except exc.RequestTimeout:
|
||||
with misc.capture_failure() as fail:
|
||||
LOG.debug(fail.exception_str)
|
||||
request.set_result(fail)
|
||||
|
||||
def _on_wait(self):
|
||||
"""This function is called cyclically between draining events."""
|
||||
@@ -169,9 +184,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
# before putting it into the requests cache to prevent the notify
|
||||
# processing thread get list of waiting requests and publish it
|
||||
# before it is published here, so it wouldn't be published twice.
|
||||
request.set_pending()
|
||||
self._requests_cache[request.uuid] = request
|
||||
self._publish_request(request, topic)
|
||||
if request.transition_log_error(pr.PENDING, logger=LOG):
|
||||
self._requests_cache[request.uuid] = request
|
||||
self._publish_request(request, topic)
|
||||
else:
|
||||
self._requests_cache[request.uuid] = request
|
||||
|
||||
@@ -187,8 +202,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
|
||||
except Exception:
|
||||
with misc.capture_failure() as failure:
|
||||
LOG.exception("Failed to submit the '%s' request.", request)
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(failure)
|
||||
if request.transition_log_error(pr.FAILURE, logger=LOG):
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(failure)
|
||||
|
||||
def _notify_topics(self):
|
||||
"""Cyclically called to publish notify message to each topic."""
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from concurrent import futures
|
||||
import jsonschema
|
||||
@@ -23,7 +25,9 @@ import six
|
||||
|
||||
from taskflow.engines.action_engine import executor
|
||||
from taskflow import exceptions as excp
|
||||
from taskflow.openstack.common import timeutils
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.utils import lock_utils
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import reflection
|
||||
|
||||
@@ -36,7 +40,34 @@ SUCCESS = 'SUCCESS'
|
||||
FAILURE = 'FAILURE'
|
||||
PROGRESS = 'PROGRESS'
|
||||
|
||||
# During these states the expiry is active (once out of these states the expiry
|
||||
# no longer matters, since we have no way of knowing how long a task will run
|
||||
# for).
|
||||
WAITING_STATES = (WAITING, PENDING)
|
||||
|
||||
_ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS)
|
||||
_STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE)
|
||||
|
||||
# Transitions that a request state can go through.
|
||||
_ALLOWED_TRANSITIONS = (
|
||||
# Used when a executor starts to publish a request to a selected worker.
|
||||
(WAITING, PENDING),
|
||||
# When a request expires (isn't able to be processed by any worker).
|
||||
(WAITING, FAILURE),
|
||||
# Worker has started executing a request.
|
||||
(PENDING, RUNNING),
|
||||
# Worker failed to construct/process a request to run (either the worker
|
||||
# did not transition to RUNNING in the given timeout or the worker itself
|
||||
# had some type of failure before RUNNING started).
|
||||
#
|
||||
# Also used by the executor if the request was attempted to be published
|
||||
# but that did publishing process did not work out.
|
||||
(PENDING, FAILURE),
|
||||
# Execution failed due to some type of remote failure.
|
||||
(RUNNING, FAILURE),
|
||||
# Execution succeeded & has completed.
|
||||
(RUNNING, SUCCESS),
|
||||
)
|
||||
|
||||
# Remote task actions.
|
||||
EXECUTE = 'execute'
|
||||
@@ -73,6 +104,8 @@ _SCHEMA_TYPES = {
|
||||
'array': (list, tuple),
|
||||
}
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Message(object):
|
||||
@@ -143,8 +176,10 @@ class Request(Message):
|
||||
"""Represents request with execution results.
|
||||
|
||||
Every request is created in the WAITING state and is expired within the
|
||||
given timeout.
|
||||
given timeout if it does not transition out of the (WAITING, PENDING)
|
||||
states.
|
||||
"""
|
||||
|
||||
TYPE = REQUEST
|
||||
_SCHEMA = {
|
||||
"type": "object",
|
||||
@@ -196,11 +231,10 @@ class Request(Message):
|
||||
self._kwargs = kwargs
|
||||
self._watch = tt.StopWatch(duration=timeout).start()
|
||||
self._state = WAITING
|
||||
self._lock = threading.Lock()
|
||||
self._created_on = timeutils.utcnow()
|
||||
self.result = futures.Future()
|
||||
|
||||
def __repr__(self):
|
||||
return "%s:%s" % (self._task_cls, self._action)
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
@@ -213,6 +247,10 @@ class Request(Message):
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def created_on(self):
|
||||
return self._created_on
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
"""Check if request has expired.
|
||||
@@ -224,7 +262,7 @@ class Request(Message):
|
||||
state for more then the given timeout (it is not considered to be
|
||||
expired in any other state).
|
||||
"""
|
||||
if self._state in (WAITING, PENDING):
|
||||
if self._state in WAITING_STATES:
|
||||
return self._watch.expired()
|
||||
return False
|
||||
|
||||
@@ -254,16 +292,43 @@ class Request(Message):
|
||||
def set_result(self, result):
|
||||
self.result.set_result((self._task, self._event, result))
|
||||
|
||||
def set_pending(self):
|
||||
self._state = PENDING
|
||||
|
||||
def set_running(self):
|
||||
self._state = RUNNING
|
||||
self._watch.stop()
|
||||
|
||||
def on_progress(self, event_data, progress):
|
||||
self._progress_callback(self._task, event_data, progress)
|
||||
|
||||
def transition_log_error(self, new_state, logger=None):
|
||||
if logger is None:
|
||||
logger = LOG
|
||||
moved = False
|
||||
try:
|
||||
moved = self.transition(new_state)
|
||||
except excp.InvalidState:
|
||||
logger.warn("Failed to transition '%s' to %s state.", self,
|
||||
new_state, exc_info=True)
|
||||
return moved
|
||||
|
||||
@lock_utils.locked
|
||||
def transition(self, new_state):
|
||||
"""Transitions the request to a new state.
|
||||
|
||||
If transition was performed, it returns True. If transition
|
||||
should was ignored, it returns False. If transition is not
|
||||
valid (and will not be performed), it raises an InvalidState
|
||||
exception.
|
||||
"""
|
||||
old_state = self._state
|
||||
if old_state == new_state:
|
||||
return False
|
||||
pair = (old_state, new_state)
|
||||
if pair not in _ALLOWED_TRANSITIONS:
|
||||
raise excp.InvalidState("Request transition from %s to %s is"
|
||||
" not allowed" % pair)
|
||||
if new_state in _STOP_TIMER_STATES:
|
||||
self._watch.stop()
|
||||
self._state = new_state
|
||||
LOG.debug("Transitioned '%s' from %s state to %s state", self,
|
||||
old_state, new_state)
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate(cls, data):
|
||||
try:
|
||||
@@ -292,6 +357,9 @@ class Response(Message):
|
||||
{
|
||||
"$ref": "#/definitions/completion",
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/empty",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
@@ -311,6 +379,12 @@ class Response(Message):
|
||||
"required": ["progress", 'event_data'],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
# Used when sending *only* request state changes (and no data is
|
||||
# expected).
|
||||
"empty": {
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"completion": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -22,6 +22,7 @@ import mock
|
||||
|
||||
from taskflow.engines.worker_based import executor
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.openstack.common import timeutils
|
||||
from taskflow import test
|
||||
from taskflow.tests import utils
|
||||
from taskflow.utils import misc
|
||||
@@ -95,8 +96,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_running()])
|
||||
expected_calls = [
|
||||
mock.call.transition_log_error(pr.RUNNING, logger=mock.ANY),
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_state_progress(self):
|
||||
response = pr.Response(pr.PROGRESS, progress=1.0)
|
||||
@@ -116,9 +119,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
self.assertEqual(self.request_inst_mock.mock_calls, [
|
||||
expected_calls = [
|
||||
mock.call.transition_log_error(pr.FAILURE, logger=mock.ANY),
|
||||
mock.call.set_result(result=utils.FailureMatcher(failure))
|
||||
])
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_state_success(self):
|
||||
response = pr.Response(pr.SUCCESS, result=self.task_result,
|
||||
@@ -127,9 +132,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
ex._process_response(response.to_dict(), self.message_mock)
|
||||
|
||||
self.assertEqual(self.request_inst_mock.mock_calls,
|
||||
[mock.call.set_result(result=self.task_result,
|
||||
event='executed')])
|
||||
expected_calls = [
|
||||
mock.call.transition_log_error(pr.SUCCESS, logger=mock.ANY),
|
||||
mock.call.set_result(result=self.task_result, event='executed')
|
||||
]
|
||||
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
|
||||
|
||||
def test_on_message_response_unknown_state(self):
|
||||
response = pr.Response(state='<unknown>')
|
||||
@@ -166,7 +173,13 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertEqual(len(ex._requests_cache), 1)
|
||||
|
||||
def test_on_wait_task_expired(self):
|
||||
now = timeutils.utcnow()
|
||||
self.request_inst_mock.expired = True
|
||||
self.request_inst_mock.created_on = now
|
||||
timeutils.set_time_override(now)
|
||||
self.addCleanup(timeutils.clear_time_override)
|
||||
timeutils.advance_time_seconds(120)
|
||||
|
||||
ex = self.executor()
|
||||
ex._requests_cache[self.task_uuid] = self.request_inst_mock
|
||||
|
||||
@@ -199,13 +212,14 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.request.set_pending(),
|
||||
mock.call.request.transition_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(msg=self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_revert_task(self):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
@@ -220,13 +234,14 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.task_args, None, self.timeout,
|
||||
failures=self.task_failures,
|
||||
result=self.task_result),
|
||||
mock.call.request.set_pending(),
|
||||
mock.call.request.transition_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(msg=self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_execute_task_topic_not_found(self):
|
||||
workers_info = {self.executor_topic: ['<unknown>']}
|
||||
@@ -250,14 +265,17 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
expected_calls = [
|
||||
mock.call.Request(self.task, self.task_uuid, 'execute',
|
||||
self.task_args, None, self.timeout),
|
||||
mock.call.request.set_pending(),
|
||||
mock.call.request.transition_log_error(pr.PENDING,
|
||||
logger=mock.ANY),
|
||||
mock.call.proxy.publish(msg=self.request_inst_mock,
|
||||
routing_key=self.executor_topic,
|
||||
reply_to=self.executor_uuid,
|
||||
correlation_id=self.task_uuid),
|
||||
mock.call.request.transition_log_error(pr.FAILURE,
|
||||
logger=mock.ANY),
|
||||
mock.call.request.set_result(mock.ANY)
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_wait_for_any(self):
|
||||
fs = [futures.Future(), futures.Future()]
|
||||
|
||||
@@ -115,6 +115,18 @@ class TestProtocol(test.TestCase):
|
||||
to_dict.update(kwargs)
|
||||
return to_dict
|
||||
|
||||
def test_request_transitions(self):
|
||||
request = self.request()
|
||||
self.assertEqual(pr.WAITING, request.state)
|
||||
self.assertIn(request.state, pr.WAITING_STATES)
|
||||
self.assertRaises(excp.InvalidState, request.transition, pr.SUCCESS)
|
||||
self.assertFalse(request.transition(pr.WAITING))
|
||||
self.assertTrue(request.transition(pr.PENDING))
|
||||
self.assertTrue(request.transition(pr.RUNNING))
|
||||
self.assertTrue(request.transition(pr.SUCCESS))
|
||||
for s in (pr.PENDING, pr.WAITING):
|
||||
self.assertRaises(excp.InvalidState, request.transition, s)
|
||||
|
||||
def test_creation(self):
|
||||
request = self.request()
|
||||
self.assertEqual(request.uuid, self.task_uuid)
|
||||
@@ -122,15 +134,6 @@ class TestProtocol(test.TestCase):
|
||||
self.assertIsInstance(request.result, futures.Future)
|
||||
self.assertFalse(request.result.done())
|
||||
|
||||
def test_str(self):
|
||||
request = self.request()
|
||||
self.assertEqual(str(request),
|
||||
"<REQUEST> %s" % self.request_to_dict())
|
||||
|
||||
def test_repr(self):
|
||||
expected = '%s:%s' % (self.task.name, self.task_action)
|
||||
self.assertEqual(repr(self.request()), expected)
|
||||
|
||||
def test_to_dict_default(self):
|
||||
self.assertEqual(self.request().to_dict(), self.request_to_dict())
|
||||
|
||||
@@ -156,19 +159,20 @@ class TestProtocol(test.TestCase):
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_pending_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout]
|
||||
mocked_wallclock.side_effect = [0, self.timeout - 1]
|
||||
self.assertFalse(self.request().expired)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_pending_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout + 2]
|
||||
mocked_wallclock.side_effect = [0, self.timeout + 2]
|
||||
self.assertTrue(self.request().expired)
|
||||
|
||||
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
|
||||
def test_running_not_expired(self, mocked_wallclock):
|
||||
mocked_wallclock.side_effect = [1, self.timeout + 2]
|
||||
mocked_wallclock.side_effect = [0, self.timeout + 2]
|
||||
request = self.request()
|
||||
request.set_running()
|
||||
request.transition(pr.PENDING)
|
||||
request.transition(pr.RUNNING)
|
||||
self.assertFalse(request.expired)
|
||||
|
||||
def test_set_result(self):
|
||||
|
||||
Reference in New Issue
Block a user