diff --git a/taskflow/engines/action_engine/engine.py b/taskflow/engines/action_engine/engine.py index 054c2ccf..ca8a80a6 100644 --- a/taskflow/engines/action_engine/engine.py +++ b/taskflow/engines/action_engine/engine.py @@ -216,9 +216,12 @@ class MultiThreadedActionEngine(ActionEngine): _storage_factory = atom_storage.MultiThreadedStorage def _task_executor_factory(self): - return executor.ParallelTaskExecutor(self._executor) + return executor.ParallelTaskExecutor(executor=self._executor, + max_workers=self._max_workers) - def __init__(self, flow, flow_detail, backend, conf, **kwargs): + def __init__(self, flow, flow_detail, backend, conf, + executor=None, max_workers=None): super(MultiThreadedActionEngine, self).__init__( flow, flow_detail, backend, conf) - self._executor = kwargs.get('executor') + self._executor = executor + self._max_workers = max_workers diff --git a/taskflow/engines/action_engine/executor.py b/taskflow/engines/action_engine/executor.py index e28e863b..b2bdbdae 100644 --- a/taskflow/engines/action_engine/executor.py +++ b/taskflow/engines/action_engine/executor.py @@ -111,13 +111,14 @@ class SerialTaskExecutor(TaskExecutorBase): class ParallelTaskExecutor(TaskExecutorBase): """Executes tasks in parallel. - Submits tasks to executor which should provide interface similar + Submits tasks to an executor which should provide an interface similar to concurrent.Futures.Executor. """ - def __init__(self, executor=None): + def __init__(self, executor=None, max_workers=None): self._executor = executor - self._own_executor = executor is None + self._max_workers = max_workers + self._create_executor = executor is None def execute_task(self, task, task_uuid, arguments, progress_callback=None): return self._executor.submit( @@ -133,11 +134,14 @@ class ParallelTaskExecutor(TaskExecutorBase): return async_utils.wait_for_any(fs, timeout) def start(self): - if self._own_executor: - thread_count = threading_utils.get_optimal_thread_count() - self._executor = futures.ThreadPoolExecutor(thread_count) + if self._create_executor: + if self._max_workers is not None: + max_workers = self._max_workers + else: + max_workers = threading_utils.get_optimal_thread_count() + self._executor = futures.ThreadPoolExecutor(max_workers) def stop(self): - if self._own_executor: + if self._create_executor: self._executor.shutdown(wait=True) self._executor = None