diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index f5afc72e3..3b5a0355f 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -23,6 +23,7 @@ from taskflow.engines.worker_based import cache from taskflow.engines.worker_based import protocol as pr from taskflow.engines.worker_based import proxy from taskflow import exceptions as exc +from taskflow.openstack.common import timeutils from taskflow.types import timing as tt from taskflow.utils import async_utils from taskflow.utils import misc @@ -109,8 +110,8 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): # publish waiting requests for request in self._requests_cache.get_waiting_requests(tasks): - request.set_pending() - self._publish_request(request, topic) + if request.transition_log_error(pr.PENDING, logger=LOG): + self._publish_request(request, topic) def _process_response(self, response, message): """Process response from remote side.""" @@ -120,20 +121,23 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): except KeyError: LOG.warning("The 'correlation_id' message property is missing.") else: - LOG.debug("Task uuid: '%s'", task_uuid) request = self._requests_cache.get(task_uuid) if request is not None: response = pr.Response.from_dict(response) if response.state == pr.RUNNING: - request.set_running() + request.transition_log_error(pr.RUNNING, logger=LOG) elif response.state == pr.PROGRESS: request.on_progress(**response.data) elif response.state in (pr.FAILURE, pr.SUCCESS): - # NOTE(imelnikov): request should not be in cache when - # another thread can see its result and schedule another - # request with same uuid; so we remove it, then set result - del self._requests_cache[request.uuid] - request.set_result(**response.data) + moved = request.transition_log_error(response.state, + logger=LOG) + if moved: + # NOTE(imelnikov): request should not be in the + # cache when another thread can see its result and + # schedule another request with the same uuid; so + # we remove it, then set the result... + del self._requests_cache[request.uuid] + request.set_result(**response.data) else: LOG.warning("Unexpected response status: '%s'", response.state) @@ -147,10 +151,21 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): When request has expired it is removed from the requests cache and the `RequestTimeout` exception is set as a request result. """ - LOG.debug("Request '%r' has expired.", request) - LOG.debug("The '%r' request has expired.", request) - request.set_result(misc.Failure.from_exception( - exc.RequestTimeout("The '%r' request has expired" % request))) + if request.transition_log_error(pr.FAILURE, logger=LOG): + # Raise an exception (and then catch it) so we get a nice + # traceback that the request will get instead of it getting + # just an exception with no traceback... + try: + request_age = timeutils.delta_seconds(request.created_on, + timeutils.utcnow()) + raise exc.RequestTimeout( + "Request '%s' has expired after waiting for %0.2f" + " seconds for it to transition out of (%s) states" + % (request, request_age, ", ".join(pr.WAITING_STATES))) + except exc.RequestTimeout: + with misc.capture_failure() as fail: + LOG.debug(fail.exception_str) + request.set_result(fail) def _on_wait(self): """This function is called cyclically between draining events.""" @@ -169,9 +184,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): # before putting it into the requests cache to prevent the notify # processing thread get list of waiting requests and publish it # before it is published here, so it wouldn't be published twice. - request.set_pending() - self._requests_cache[request.uuid] = request - self._publish_request(request, topic) + if request.transition_log_error(pr.PENDING, logger=LOG): + self._requests_cache[request.uuid] = request + self._publish_request(request, topic) else: self._requests_cache[request.uuid] = request @@ -187,8 +202,9 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): except Exception: with misc.capture_failure() as failure: LOG.exception("Failed to submit the '%s' request.", request) - del self._requests_cache[request.uuid] - request.set_result(failure) + if request.transition_log_error(pr.FAILURE, logger=LOG): + del self._requests_cache[request.uuid] + request.set_result(failure) def _notify_topics(self): """Cyclically called to publish notify message to each topic.""" diff --git a/taskflow/engines/worker_based/protocol.py b/taskflow/engines/worker_based/protocol.py index ea9942727..334c1d93f 100644 --- a/taskflow/engines/worker_based/protocol.py +++ b/taskflow/engines/worker_based/protocol.py @@ -15,6 +15,8 @@ # under the License. import abc +import logging +import threading from concurrent import futures import jsonschema @@ -23,7 +25,9 @@ import six from taskflow.engines.action_engine import executor from taskflow import exceptions as excp +from taskflow.openstack.common import timeutils from taskflow.types import timing as tt +from taskflow.utils import lock_utils from taskflow.utils import misc from taskflow.utils import reflection @@ -36,7 +40,34 @@ SUCCESS = 'SUCCESS' FAILURE = 'FAILURE' PROGRESS = 'PROGRESS' +# During these states the expiry is active (once out of these states the expiry +# no longer matters, since we have no way of knowing how long a task will run +# for). +WAITING_STATES = (WAITING, PENDING) + _ALL_STATES = (WAITING, PENDING, RUNNING, SUCCESS, FAILURE, PROGRESS) +_STOP_TIMER_STATES = (RUNNING, SUCCESS, FAILURE) + +# Transitions that a request state can go through. +_ALLOWED_TRANSITIONS = ( + # Used when a executor starts to publish a request to a selected worker. + (WAITING, PENDING), + # When a request expires (isn't able to be processed by any worker). + (WAITING, FAILURE), + # Worker has started executing a request. + (PENDING, RUNNING), + # Worker failed to construct/process a request to run (either the worker + # did not transition to RUNNING in the given timeout or the worker itself + # had some type of failure before RUNNING started). + # + # Also used by the executor if the request was attempted to be published + # but that did publishing process did not work out. + (PENDING, FAILURE), + # Execution failed due to some type of remote failure. + (RUNNING, FAILURE), + # Execution succeeded & has completed. + (RUNNING, SUCCESS), +) # Remote task actions. EXECUTE = 'execute' @@ -73,6 +104,8 @@ _SCHEMA_TYPES = { 'array': (list, tuple), } +LOG = logging.getLogger(__name__) + @six.add_metaclass(abc.ABCMeta) class Message(object): @@ -143,8 +176,10 @@ class Request(Message): """Represents request with execution results. Every request is created in the WAITING state and is expired within the - given timeout. + given timeout if it does not transition out of the (WAITING, PENDING) + states. """ + TYPE = REQUEST _SCHEMA = { "type": "object", @@ -196,11 +231,10 @@ class Request(Message): self._kwargs = kwargs self._watch = tt.StopWatch(duration=timeout).start() self._state = WAITING + self._lock = threading.Lock() + self._created_on = timeutils.utcnow() self.result = futures.Future() - def __repr__(self): - return "%s:%s" % (self._task_cls, self._action) - @property def uuid(self): return self._uuid @@ -213,6 +247,10 @@ class Request(Message): def state(self): return self._state + @property + def created_on(self): + return self._created_on + @property def expired(self): """Check if request has expired. @@ -224,7 +262,7 @@ class Request(Message): state for more then the given timeout (it is not considered to be expired in any other state). """ - if self._state in (WAITING, PENDING): + if self._state in WAITING_STATES: return self._watch.expired() return False @@ -254,16 +292,43 @@ class Request(Message): def set_result(self, result): self.result.set_result((self._task, self._event, result)) - def set_pending(self): - self._state = PENDING - - def set_running(self): - self._state = RUNNING - self._watch.stop() - def on_progress(self, event_data, progress): self._progress_callback(self._task, event_data, progress) + def transition_log_error(self, new_state, logger=None): + if logger is None: + logger = LOG + moved = False + try: + moved = self.transition(new_state) + except excp.InvalidState: + logger.warn("Failed to transition '%s' to %s state.", self, + new_state, exc_info=True) + return moved + + @lock_utils.locked + def transition(self, new_state): + """Transitions the request to a new state. + + If transition was performed, it returns True. If transition + should was ignored, it returns False. If transition is not + valid (and will not be performed), it raises an InvalidState + exception. + """ + old_state = self._state + if old_state == new_state: + return False + pair = (old_state, new_state) + if pair not in _ALLOWED_TRANSITIONS: + raise excp.InvalidState("Request transition from %s to %s is" + " not allowed" % pair) + if new_state in _STOP_TIMER_STATES: + self._watch.stop() + self._state = new_state + LOG.debug("Transitioned '%s' from %s state to %s state", self, + old_state, new_state) + return True + @classmethod def validate(cls, data): try: @@ -292,6 +357,9 @@ class Response(Message): { "$ref": "#/definitions/completion", }, + { + "$ref": "#/definitions/empty", + }, ], }, }, @@ -311,6 +379,12 @@ class Response(Message): "required": ["progress", 'event_data'], "additionalProperties": False, }, + # Used when sending *only* request state changes (and no data is + # expected). + "empty": { + "type": "object", + "additionalProperties": False, + }, "completion": { "type": "object", "properties": { diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index f837394ca..aa236fcfc 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -22,6 +22,7 @@ import mock from taskflow.engines.worker_based import executor from taskflow.engines.worker_based import protocol as pr +from taskflow.openstack.common import timeutils from taskflow import test from taskflow.tests import utils from taskflow.utils import misc @@ -95,8 +96,10 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._process_response(response.to_dict(), self.message_mock) - self.assertEqual(self.request_inst_mock.mock_calls, - [mock.call.set_running()]) + expected_calls = [ + mock.call.transition_log_error(pr.RUNNING, logger=mock.ANY), + ] + self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) def test_on_message_response_state_progress(self): response = pr.Response(pr.PROGRESS, progress=1.0) @@ -116,9 +119,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex._process_response(response.to_dict(), self.message_mock) self.assertEqual(len(ex._requests_cache), 0) - self.assertEqual(self.request_inst_mock.mock_calls, [ + expected_calls = [ + mock.call.transition_log_error(pr.FAILURE, logger=mock.ANY), mock.call.set_result(result=utils.FailureMatcher(failure)) - ]) + ] + self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) def test_on_message_response_state_success(self): response = pr.Response(pr.SUCCESS, result=self.task_result, @@ -127,9 +132,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex._requests_cache[self.task_uuid] = self.request_inst_mock ex._process_response(response.to_dict(), self.message_mock) - self.assertEqual(self.request_inst_mock.mock_calls, - [mock.call.set_result(result=self.task_result, - event='executed')]) + expected_calls = [ + mock.call.transition_log_error(pr.SUCCESS, logger=mock.ANY), + mock.call.set_result(result=self.task_result, event='executed') + ] + self.assertEqual(expected_calls, self.request_inst_mock.mock_calls) def test_on_message_response_unknown_state(self): response = pr.Response(state='') @@ -166,7 +173,13 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.assertEqual(len(ex._requests_cache), 1) def test_on_wait_task_expired(self): + now = timeutils.utcnow() self.request_inst_mock.expired = True + self.request_inst_mock.created_on = now + timeutils.set_time_override(now) + self.addCleanup(timeutils.clear_time_override) + timeutils.advance_time_seconds(120) + ex = self.executor() ex._requests_cache[self.task_uuid] = self.request_inst_mock @@ -199,13 +212,14 @@ class TestWorkerTaskExecutor(test.MockTestCase): expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', self.task_args, None, self.timeout), - mock.call.request.set_pending(), + mock.call.request.transition_log_error(pr.PENDING, + logger=mock.ANY), mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) ] - self.assertEqual(self.master_mock.mock_calls, expected_calls) + self.assertEqual(expected_calls, self.master_mock.mock_calls) def test_revert_task(self): self.message_mock.properties['type'] = pr.NOTIFY @@ -220,13 +234,14 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.task_args, None, self.timeout, failures=self.task_failures, result=self.task_result), - mock.call.request.set_pending(), + mock.call.request.transition_log_error(pr.PENDING, + logger=mock.ANY), mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid) ] - self.assertEqual(self.master_mock.mock_calls, expected_calls) + self.assertEqual(expected_calls, self.master_mock.mock_calls) def test_execute_task_topic_not_found(self): workers_info = {self.executor_topic: ['']} @@ -250,14 +265,17 @@ class TestWorkerTaskExecutor(test.MockTestCase): expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', self.task_args, None, self.timeout), - mock.call.request.set_pending(), + mock.call.request.transition_log_error(pr.PENDING, + logger=mock.ANY), mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, correlation_id=self.task_uuid), + mock.call.request.transition_log_error(pr.FAILURE, + logger=mock.ANY), mock.call.request.set_result(mock.ANY) ] - self.assertEqual(self.master_mock.mock_calls, expected_calls) + self.assertEqual(expected_calls, self.master_mock.mock_calls) def test_wait_for_any(self): fs = [futures.Future(), futures.Future()] diff --git a/taskflow/tests/unit/worker_based/test_protocol.py b/taskflow/tests/unit/worker_based/test_protocol.py index f94d03d41..7d51da31c 100644 --- a/taskflow/tests/unit/worker_based/test_protocol.py +++ b/taskflow/tests/unit/worker_based/test_protocol.py @@ -115,6 +115,18 @@ class TestProtocol(test.TestCase): to_dict.update(kwargs) return to_dict + def test_request_transitions(self): + request = self.request() + self.assertEqual(pr.WAITING, request.state) + self.assertIn(request.state, pr.WAITING_STATES) + self.assertRaises(excp.InvalidState, request.transition, pr.SUCCESS) + self.assertFalse(request.transition(pr.WAITING)) + self.assertTrue(request.transition(pr.PENDING)) + self.assertTrue(request.transition(pr.RUNNING)) + self.assertTrue(request.transition(pr.SUCCESS)) + for s in (pr.PENDING, pr.WAITING): + self.assertRaises(excp.InvalidState, request.transition, s) + def test_creation(self): request = self.request() self.assertEqual(request.uuid, self.task_uuid) @@ -122,15 +134,6 @@ class TestProtocol(test.TestCase): self.assertIsInstance(request.result, futures.Future) self.assertFalse(request.result.done()) - def test_str(self): - request = self.request() - self.assertEqual(str(request), - " %s" % self.request_to_dict()) - - def test_repr(self): - expected = '%s:%s' % (self.task.name, self.task_action) - self.assertEqual(repr(self.request()), expected) - def test_to_dict_default(self): self.assertEqual(self.request().to_dict(), self.request_to_dict()) @@ -156,19 +159,20 @@ class TestProtocol(test.TestCase): @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') def test_pending_not_expired(self, mocked_wallclock): - mocked_wallclock.side_effect = [1, self.timeout] + mocked_wallclock.side_effect = [0, self.timeout - 1] self.assertFalse(self.request().expired) @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') def test_pending_expired(self, mocked_wallclock): - mocked_wallclock.side_effect = [1, self.timeout + 2] + mocked_wallclock.side_effect = [0, self.timeout + 2] self.assertTrue(self.request().expired) @mock.patch('taskflow.engines.worker_based.protocol.misc.wallclock') def test_running_not_expired(self, mocked_wallclock): - mocked_wallclock.side_effect = [1, self.timeout + 2] + mocked_wallclock.side_effect = [0, self.timeout + 2] request = self.request() - request.set_running() + request.transition(pr.PENDING) + request.transition(pr.RUNNING) self.assertFalse(request.expired) def test_set_result(self):