Use explicit WBE request state transitions

Instead of having an implicit state machine for a requests
lifecycle move toward an explicit state model and transition
set that is validated and transitioned in a more easy to
understand/reason about manner.

This also fixes a bug that was found due to a response validation
not taking into account a transition that was found due to this
stricter transition checking.

Includes a few tiny related/affected commits:

* Remove testing of request repr() and str() as these types of
  tests are not useful and we removed the repr() version of the
  request message as the base classes is good enough.
* Raise and capture a better exception and save its associated
  failure object when a request has expired (this gives expired
  requests better failure objects and associated details).

Fixes bug 1356658

Partially fixes bug 1357117

Change-Id: Ie1386cca13a2da7265e22447b4c111a0a0074201
This commit is contained in:
Joshua Harlow
2014-08-13 16:37:56 -07:00
committed by Joshua Harlow
parent 7c3332e49b
commit e720bdb8af
4 changed files with 168 additions and 56 deletions

View File

@@ -23,6 +23,7 @@ from taskflow.engines.worker_based import cache
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import proxy
from taskflow import exceptions as exc
from taskflow.openstack.common import timeutils
from taskflow.types import timing as tt
from taskflow.utils import async_utils
from taskflow.utils import misc
@@ -109,8 +110,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
# publish waiting requests
for request in self._requests_cache.get_waiting_requests(tasks):
request.set_pending()
self._publish_request(request, topic)
if request.transition_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, topic)
def _process_response(self, response, message):
"""Process response from remote side."""
@@ -120,20 +121,23 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
except KeyError:
LOG.warning("The 'correlation_id' message property is missing.")
else:
LOG.debug("Task uuid: '%s'", task_uuid)
request = self._requests_cache.get(task_uuid)
if request is not None:
response = pr.Response.from_dict(response)
if response.state == pr.RUNNING:
request.set_running()
request.transition_log_error(pr.RUNNING, logger=LOG)
elif response.state == pr.PROGRESS:
request.on_progress(**response.data)
elif response.state in (pr.FAILURE, pr.SUCCESS):
# NOTE(imelnikov): request should not be in cache when
# another thread can see its result and schedule another
# request with same uuid; so we remove it, then set result
del self._requests_cache[request.uuid]
request.set_result(**response.data)
moved = request.transition_log_error(response.state,
logger=LOG)
if moved:
# NOTE(imelnikov): request should not be in the
# cache when another thread can see its result and
# schedule another request with the same uuid; so
# we remove it, then set the result...
del self._requests_cache[request.uuid]
request.set_result(**response.data)
else:
LOG.warning("Unexpected response status: '%s'",
response.state)
@@ -147,10 +151,21 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
When request has expired it is removed from the requests cache and
the `RequestTimeout` exception is set as a request result.
"""
LOG.debug("Request '%r' has expired.", request)
LOG.debug("The '%r' request has expired.", request)
request.set_result(misc.Failure.from_exception(
exc.RequestTimeout("The '%r' request has expired" % request)))
if request.transition_log_error(pr.FAILURE, logger=LOG):
# Raise an exception (and then catch it) so we get a nice
# traceback that the request will get instead of it getting
# just an exception with no traceback...
try:
request_age = timeutils.delta_seconds(request.created_on,
timeutils.utcnow())
raise exc.RequestTimeout(
"Request '%s' has expired after waiting for %0.2f"
" seconds for it to transition out of (%s) states"
% (request, request_age, ", ".join(pr.WAITING_STATES)))
except exc.RequestTimeout:
with misc.capture_failure() as fail:
LOG.debug(fail.exception_str)
request.set_result(fail)
def _on_wait(self):
"""This function is called cyclically between draining events."""
@@ -169,9 +184,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
# before putting it into the requests cache to prevent the notify
# processing thread get list of waiting requests and publish it
# before it is published here, so it wouldn't be published twice.
request.set_pending()
self._requests_cache[request.uuid] = request
self._publish_request(request, topic)
if request.transition_log_error(pr.PENDING, logger=LOG):
self._requests_cache[request.uuid] = request
self._publish_request(request, topic)
else:
self._requests_cache[request.uuid] = request
@@ -187,8 +202,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase):
except Exception:
with misc.capture_failure() as failure:
LOG.exception("Failed to submit the '%s' request.", request)
del self._requests_cache[request.uuid]
request.set_result(failure)
if request.transition_log_error(pr.FAILURE, logger=LOG):
del self._requests_cache[request.uuid]
request.set_result(failure)
def _notify_topics(self):
"""Cyclically called to publish notify message to each topic."""

View File

@@ -15,6 +15,8 @@
# under the License.
import abc
import logging
import threading
from concurrent import futures
import jsonschema
@@ -23,7 +25,9 @@ import six
from taskflow.engines.action_engine import executor
from taskflow import exceptions as excp
from taskflow.openstack.common import timeutils
from taskflow.types import timing as tt
from taskflow.utils import lock_utils
from taskflow.utils import misc
from taskflow.utils import reflection
@@ -36,7 +40,34 @@ SUCCESS = 'SUCCESS'
FAILURE = 'FAILURE'
PROGRESS = 'PROGRESS'
# During these states the expiry is active (once out of these states the expiry
# no longer matters, since we have no way of knowing how long a task will run
# for).
WAITING_STATES = (WAITING, PENDING)
_ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS)
_STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE)
# Transitions that a request state can go through.
_ALLOWED_TRANSITIONS = (
# Used when a executor starts to publish a request to a selected worker.
(WAITING, PENDING),
# When a request expires (isn't able to be processed by any worker).
(WAITING, FAILURE),
# Worker has started executing a request.
(PENDING, RUNNING),
# Worker failed to construct/process a request to run (either the worker
# did not transition to RUNNING in the given timeout or the worker itself
# had some type of failure before RUNNING started).
#
# Also used by the executor if the request was attempted to be published
# but that did publishing process did not work out.
(PENDING, FAILURE),
# Execution failed due to some type of remote failure.
(RUNNING, FAILURE),
# Execution succeeded & has completed.
(RUNNING, SUCCESS),
)
# Remote task actions.
EXECUTE = 'execute'
@@ -73,6 +104,8 @@ _SCHEMA_TYPES = {
'array': (list, tuple),
}
LOG = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class Message(object):
@@ -143,8 +176,10 @@ class Request(Message):
"""Represents request with execution results.
Every request is created in the WAITING state and is expired within the
given timeout.
given timeout if it does not transition out of the (WAITING, PENDING)
states.
"""
TYPE = REQUEST
_SCHEMA = {
"type": "object",
@@ -196,11 +231,10 @@ class Request(Message):
self._kwargs = kwargs
self._watch = tt.StopWatch(duration=timeout).start()
self._state = WAITING
self._lock = threading.Lock()
self._created_on = timeutils.utcnow()
self.result = futures.Future()
def __repr__(self):
return "%s:%s" % (self._task_cls, self._action)
@property
def uuid(self):
return self._uuid
@@ -213,6 +247,10 @@ class Request(Message):
def state(self):
return self._state
@property
def created_on(self):
return self._created_on
@property
def expired(self):
"""Check if request has expired.
@@ -224,7 +262,7 @@ class Request(Message):
state for more then the given timeout (it is not considered to be
expired in any other state).
"""
if self._state in (WAITING, PENDING):
if self._state in WAITING_STATES:
return self._watch.expired()
return False
@@ -254,16 +292,43 @@ class Request(Message):
def set_result(self, result):
self.result.set_result((self._task, self._event, result))
def set_pending(self):
self._state = PENDING
def set_running(self):
self._state = RUNNING
self._watch.stop()
def on_progress(self, event_data, progress):
self._progress_callback(self._task, event_data, progress)
def transition_log_error(self, new_state, logger=None):
if logger is None:
logger = LOG
moved = False
try:
moved = self.transition(new_state)
except excp.InvalidState:
logger.warn("Failed to transition '%s' to %s state.", self,
new_state, exc_info=True)
return moved
@lock_utils.locked
def transition(self, new_state):
"""Transitions the request to a new state.
If transition was performed, it returns True. If transition
should was ignored, it returns False. If transition is not
valid (and will not be performed), it raises an InvalidState
exception.
"""
old_state = self._state
if old_state == new_state:
return False
pair = (old_state, new_state)
if pair not in _ALLOWED_TRANSITIONS:
raise excp.InvalidState("Request transition from %s to %s is"
" not allowed" % pair)
if new_state in _STOP_TIMER_STATES:
self._watch.stop()
self._state = new_state
LOG.debug("Transitioned '%s' from %s state to %s state", self,
old_state, new_state)
return True
@classmethod
def validate(cls, data):
try:
@@ -292,6 +357,9 @@ class Response(Message):
{
"$ref": "#/definitions/completion",
},
{
"$ref": "#/definitions/empty",
},
],
},
},
@@ -311,6 +379,12 @@ class Response(Message):
"required": ["progress", 'event_data'],
"additionalProperties": False,
},
# Used when sending *only* request state changes (and no data is
# expected).
"empty": {
"type": "object",
"additionalProperties": False,
},
"completion": {
"type": "object",
"properties": {

View File

@@ -22,6 +22,7 @@ import mock
from taskflow.engines.worker_based import executor
from taskflow.engines.worker_based import protocol as pr
from taskflow.openstack.common import timeutils
from taskflow import test
from taskflow.tests import utils
from taskflow.utils import misc
@@ -95,8 +96,10 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual(self.request_inst_mock.mock_calls,
[mock.call.set_running()])
expected_calls = [
mock.call.transition_log_error(pr.RUNNING, logger=mock.ANY),
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
def test_on_message_response_state_progress(self):
response = pr.Response(pr.PROGRESS, progress=1.0)
@@ -116,9 +119,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual(len(ex._requests_cache), 0)
self.assertEqual(self.request_inst_mock.mock_calls, [
expected_calls = [
mock.call.transition_log_error(pr.FAILURE, logger=mock.ANY),
mock.call.set_result(result=utils.FailureMatcher(failure))
])
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
def test_on_message_response_state_success(self):
response = pr.Response(pr.SUCCESS, result=self.task_result,
@@ -127,9 +132,11 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex._requests_cache[self.task_uuid] = self.request_inst_mock
ex._process_response(response.to_dict(), self.message_mock)
self.assertEqual(self.request_inst_mock.mock_calls,
[mock.call.set_result(result=self.task_result,
event='executed')])
expected_calls = [
mock.call.transition_log_error(pr.SUCCESS, logger=mock.ANY),
mock.call.set_result(result=self.task_result, event='executed')
]
self.assertEqual(expected_calls, self.request_inst_mock.mock_calls)
def test_on_message_response_unknown_state(self):
response = pr.Response(state='<unknown>')
@@ -166,7 +173,13 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertEqual(len(ex._requests_cache), 1)
def test_on_wait_task_expired(self):
now = timeutils.utcnow()
self.request_inst_mock.expired = True
self.request_inst_mock.created_on = now
timeutils.set_time_override(now)
self.addCleanup(timeutils.clear_time_override)
timeutils.advance_time_seconds(120)
ex = self.executor()
ex._requests_cache[self.task_uuid] = self.request_inst_mock
@@ -199,13 +212,14 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout),
mock.call.request.set_pending(),
mock.call.request.transition_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(msg=self.request_inst_mock,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertEqual(expected_calls, self.master_mock.mock_calls)
def test_revert_task(self):
self.message_mock.properties['type'] = pr.NOTIFY
@@ -220,13 +234,14 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.task_args, None, self.timeout,
failures=self.task_failures,
result=self.task_result),
mock.call.request.set_pending(),
mock.call.request.transition_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(msg=self.request_inst_mock,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertEqual(expected_calls, self.master_mock.mock_calls)
def test_execute_task_topic_not_found(self):
workers_info = {self.executor_topic: ['<unknown>']}
@@ -250,14 +265,17 @@ class TestWorkerTaskExecutor(test.MockTestCase):
expected_calls = [
mock.call.Request(self.task, self.task_uuid, 'execute',
self.task_args, None, self.timeout),
mock.call.request.set_pending(),
mock.call.request.transition_log_error(pr.PENDING,
logger=mock.ANY),
mock.call.proxy.publish(msg=self.request_inst_mock,
routing_key=self.executor_topic,
reply_to=self.executor_uuid,
correlation_id=self.task_uuid),
mock.call.request.transition_log_error(pr.FAILURE,
logger=mock.ANY),
mock.call.request.set_result(mock.ANY)
]
self.assertEqual(self.master_mock.mock_calls, expected_calls)
self.assertEqual(expected_calls, self.master_mock.mock_calls)
def test_wait_for_any(self):
fs = [futures.Future(), futures.Future()]

View File

@@ -115,6 +115,18 @@ class TestProtocol(test.TestCase):
to_dict.update(kwargs)
return to_dict
def test_request_transitions(self):
request = self.request()
self.assertEqual(pr.WAITING, request.state)
self.assertIn(request.state, pr.WAITING_STATES)
self.assertRaises(excp.InvalidState, request.transition, pr.SUCCESS)
self.assertFalse(request.transition(pr.WAITING))
self.assertTrue(request.transition(pr.PENDING))
self.assertTrue(request.transition(pr.RUNNING))
self.assertTrue(request.transition(pr.SUCCESS))
for s in (pr.PENDING, pr.WAITING):
self.assertRaises(excp.InvalidState, request.transition, s)
def test_creation(self):
request = self.request()
self.assertEqual(request.uuid, self.task_uuid)
@@ -122,15 +134,6 @@ class TestProtocol(test.TestCase):
self.assertIsInstance(request.result, futures.Future)
self.assertFalse(request.result.done())
def test_str(self):
request = self.request()
self.assertEqual(str(request),
"<REQUEST> %s" % self.request_to_dict())
def test_repr(self):
expected = '%s:%s' % (self.task.name, self.task_action)
self.assertEqual(repr(self.request()), expected)
def test_to_dict_default(self):
self.assertEqual(self.request().to_dict(), self.request_to_dict())
@@ -156,19 +159,20 @@ class TestProtocol(test.TestCase):
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_pending_not_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout]
mocked_wallclock.side_effect = [0, self.timeout - 1]
self.assertFalse(self.request().expired)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_pending_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout + 2]
mocked_wallclock.side_effect = [0, self.timeout + 2]
self.assertTrue(self.request().expired)
@mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock')
def test_running_not_expired(self, mocked_wallclock):
mocked_wallclock.side_effect = [1, self.timeout + 2]
mocked_wallclock.side_effect = [0, self.timeout + 2]
request = self.request()
request.set_running()
request.transition(pr.PENDING)
request.transition(pr.RUNNING)
self.assertFalse(request.expired)
def test_set_result(self):