diff --git a/taskflow/engines/action_engine/executor.py b/taskflow/engines/action_engine/executor.py index 5d063c5b..b1e969d0 100644 --- a/taskflow/engines/action_engine/executor.py +++ b/taskflow/engines/action_engine/executor.py @@ -17,6 +17,7 @@ import abc import functools import multiprocessing +from multiprocessing import managers import os import pickle import threading @@ -120,17 +121,20 @@ class _JoinedWorkItem(object): self._args = args self._kwargs = kwargs + def _on_finish(self): + w = timing.StopWatch() + w.start() + self._queue.join() + LOG.blather("Waited %0.2f seconds until task '%s' emitted" + " notifications were depleted", w.elapsed(), self._task) + def __call__(self): args = self._args kwargs = self._kwargs try: return self._func(self._task, *args, **kwargs) finally: - w = timing.StopWatch().start() - self._queue.join() - LOG.blather("Waited %0.2f seconds until task '%s' emitted" - " notifications were depleted", w.elapsed(), - self._task) + self._on_finish() class _EventSender(object): @@ -231,6 +235,8 @@ class _EventDispatcher(object): def reset(self): self._stop_when_empty = False + while self._targets: + self.deregister(self._targets.pop()) self._dead.clear() def interrupt(self): @@ -396,16 +402,25 @@ class ParallelProcessTaskExecutor(ParallelTaskExecutor): super(ParallelProcessTaskExecutor, self).__init__( executor=executor, max_workers=max_workers) self._manager = multiprocessing.Manager() - self._queue_factory = lambda: self._manager.JoinableQueue() self._dispatcher = _EventDispatcher( dispatch_periodicity=dispatch_periodicity) self._worker = None + def _queue_factory(self): + return self._manager.JoinableQueue() + def _create_executor(self, max_workers=None): return futures.ProcessPoolExecutor(max_workers=max_workers) def start(self): super(ParallelProcessTaskExecutor, self).start() + # TODO(harlowja): do something else here besides accessing a state + # of the manager internals (it doesn't seem to expose any way to know + # this information)... + if self._manager._state.value == managers.State.SHUTDOWN: + self._manager = multiprocessing.Manager() + if self._manager._state.value == managers.State.INITIAL: + self._manager.start() if not threading_utils.is_alive(self._worker): self._dispatcher.reset() self._worker = threading_utils.daemon_thread(self._dispatcher.run) @@ -418,6 +433,8 @@ class ParallelProcessTaskExecutor(ParallelTaskExecutor): self._worker.join() self._worker = None self._dispatcher.reset() + self._manager.shutdown() + self._manager.join() def _rebind_task(self, task, clone, queue, progress_callback=None): # Creates and binds proxies for all events the task could receive