diff --git a/taskflow/engines/worker_based/server.py b/taskflow/engines/worker_based/server.py index a72ed0c5..ff92603b 100644 --- a/taskflow/engines/worker_based/server.py +++ b/taskflow/engines/worker_based/server.py @@ -203,4 +203,3 @@ class Server(object): def stop(self): """Stop processing incoming requests.""" self._proxy.stop() - self._executor.shutdown() diff --git a/taskflow/engines/worker_based/worker.py b/taskflow/engines/worker_based/worker.py index 1f133e12..0b3d50dd 100644 --- a/taskflow/engines/worker_based/worker.py +++ b/taskflow/engines/worker_based/worker.py @@ -66,6 +66,7 @@ class Worker(object): def __init__(self, exchange, topic, tasks, executor=None, **kwargs): self._topic = topic self._executor = executor + self._owns_executor = False self._threads_count = -1 if self._executor is None: if 'threads_count' in kwargs: @@ -75,6 +76,7 @@ class Worker(object): else: self._threads_count = tu.get_optimal_thread_count() self._executor = futures.ThreadPoolExecutor(self._threads_count) + self._owns_executor = True self._endpoints = self._derive_endpoints(tasks) self._server = server.Server(topic, exchange, self._executor, self._endpoints, **kwargs) @@ -105,3 +107,5 @@ class Worker(object): def stop(self): """Stop worker.""" self._server.stop() + if self._owns_executor: + self._executor.shutdown() diff --git a/taskflow/tests/unit/worker_based/test_server.py b/taskflow/tests/unit/worker_based/test_server.py index 2ad94eaa..93234c31 100644 --- a/taskflow/tests/unit/worker_based/test_server.py +++ b/taskflow/tests/unit/worker_based/test_server.py @@ -385,7 +385,6 @@ class TestServer(test.MockTestCase): # check calls master_mock_calls = [ - mock.call.proxy.stop(), - mock.call.executor.shutdown() + mock.call.proxy.stop() ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls) diff --git a/taskflow/tests/unit/worker_based/test_worker.py b/taskflow/tests/unit/worker_based/test_worker.py index 4e9d321b..c66255f9 100644 --- a/taskflow/tests/unit/worker_based/test_worker.py +++ b/taskflow/tests/unit/worker_based/test_worker.py @@ -132,7 +132,8 @@ class TestWorker(test.MockTestCase): self.worker(reset_master_mock=True).stop() master_mock_calls = [ - mock.call.server.stop() + mock.call.server.stop(), + mock.call.executor.shutdown() ] self.assertEqual(self.master_mock.mock_calls, master_mock_calls)