diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index 69f659a6..033ffff7 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -85,6 +85,7 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): # publish waiting requests for request in self._requests_cache.get_waiting_requests(tasks): + request.set_pending() self._publish_request(request, topic) def _process_response(self, response, message): @@ -133,12 +134,19 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): """Submit task request to workers.""" request = pr.Request(task, task_uuid, action, arguments, progress_callback, timeout, **kwargs) - self._requests_cache.set(request.uuid, request) # Get task's topic and publish request if topic was found. topic = self._workers_cache.get_topic_by_task(request.task_cls) if topic is not None: + # NOTE(skudriashev): Make sure request is set to the PENDING state + # before putting it into the requests cache to prevent the notify + # processing thread get list of waiting requests and publish it + # before it is published here, so it wouldn't be published twice. + request.set_pending() + self._requests_cache.set(request.uuid, request) self._publish_request(request, topic) + else: + self._requests_cache.set(request.uuid, request) return request.result @@ -156,8 +164,6 @@ class WorkerTaskExecutor(executor.TaskExecutorBase): request) self._requests_cache.delete(request.uuid) request.set_result(failure) - else: - request.set_pending() def _notify_topics(self): """Cyclically publish notify message to each topic.""" diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index b763f673..85ccf753 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -226,11 +226,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', self.task_args, None, self.timeout), + mock.call.request.set_pending(), mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, - correlation_id=self.task_uuid), - mock.call.request.set_pending() + correlation_id=self.task_uuid) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) @@ -247,11 +247,11 @@ class TestWorkerTaskExecutor(test.MockTestCase): self.task_args, None, self.timeout, failures=self.task_failures, result=self.task_result), + mock.call.request.set_pending(), mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid, - correlation_id=self.task_uuid), - mock.call.request.set_pending() + correlation_id=self.task_uuid) ] self.assertEqual(self.master_mock.mock_calls, expected_calls) @@ -277,6 +277,7 @@ class TestWorkerTaskExecutor(test.MockTestCase): expected_calls = [ mock.call.Request(self.task, self.task_uuid, 'execute', self.task_args, None, self.timeout), + mock.call.request.set_pending(), mock.call.proxy.publish(msg=self.request_inst_mock, routing_key=self.executor_topic, reply_to=self.executor_uuid,