diff --git a/taskflow/utils/async_utils.py b/taskflow/utils/async_utils.py index b055a27b..2fa3b5f6 100644 --- a/taskflow/utils/async_utils.py +++ b/taskflow/utils/async_utils.py @@ -54,8 +54,9 @@ def wait_for_any(fs, timeout=None): """ green_fs = sum(1 for f in fs if isinstance(f, futures.GreenFuture)) if not green_fs: - return tuple(_futures.wait(fs, timeout=timeout, - return_when=_futures.FIRST_COMPLETED)) + return _futures.wait(fs, + timeout=timeout, + return_when=_futures.FIRST_COMPLETED) else: non_green_fs = len(fs) - green_fs if non_green_fs: @@ -81,23 +82,24 @@ class _GreenWaiter(object): self.event.set() +def _partition_futures(fs): + done = set() + not_done = set() + for f in fs: + if f._state in _DONE_STATES: + done.add(f) + else: + not_done.add(f) + return done, not_done + + def _wait_for_any_green(fs, timeout=None): assert EVENTLET_AVAILABLE, 'eventlet is needed to wait on green futures' - def _partition_futures(fs): - done = set() - not_done = set() - for f in fs: - if f._state in _DONE_STATES: - done.add(f) - else: - not_done.add(f) - return (done, not_done) - with _base._AcquireFutures(fs): - (done, not_done) = _partition_futures(fs) + done, not_done = _partition_futures(fs) if done: - return (done, not_done) + return _base.DoneAndNotDoneFutures(done, not_done) waiter = _GreenWaiter() for f in fs: f._waiters.append(waiter) @@ -107,4 +109,5 @@ def _wait_for_any_green(fs, timeout=None): f._waiters.remove(waiter) with _base._AcquireFutures(fs): - return _partition_futures(fs) + done, not_done = _partition_futures(fs) + return _base.DoneAndNotDoneFutures(done, not_done)