diff --git a/taskflow/tests/unit/test_utils_threading_utils.py b/taskflow/tests/unit/test_utils_threading_utils.py index 974285fa..66ef2d09 100644 --- a/taskflow/tests/unit/test_utils_threading_utils.py +++ b/taskflow/tests/unit/test_utils_threading_utils.py @@ -15,6 +15,7 @@ # under the License. import collections +import functools import time from taskflow import test @@ -83,6 +84,51 @@ class TestThreadBundle(test.TestCase): self.assertEqual(self.thread_count, self.bundle.stop()) self.assertEqual(self.thread_count, len(self.bundle)) + def test_start_stop_order(self): + start_events = collections.deque() + death_events = collections.deque() + + def before_start(i, t): + start_events.append((i, 'bs')) + + def before_join(i, t): + death_events.append((i, 'bj')) + self.death.set() + + def after_start(i, t): + start_events.append((i, 'as')) + + def after_join(i, t): + death_events.append((i, 'aj')) + + for i in range(0, self.thread_count): + self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death), + before_join=functools.partial(before_join, i), + after_join=functools.partial(after_join, i), + before_start=functools.partial(before_start, i), + after_start=functools.partial(after_start, i)) + self.assertEqual(self.thread_count, self.bundle.start()) + self.assertEqual(self.thread_count, len(self.bundle)) + self.assertEqual(self.thread_count, self.bundle.stop()) + self.assertEqual(0, self.bundle.stop()) + self.assertTrue(self.death.is_set()) + + expected_start_events = [] + for i in range(0, self.thread_count): + expected_start_events.extend([ + (i, 'bs'), (i, 'as'), + ]) + self.assertEqual(expected_start_events, list(start_events)) + + expected_death_events = [] + j = self.thread_count - 1 + for _i in range(0, self.thread_count): + expected_death_events.extend([ + (j, 'bj'), (j, 'aj'), + ]) + j -= 1 + self.assertEqual(expected_death_events, list(death_events)) + def test_start_stop(self): events = collections.deque() diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py index cea0760d..1f3186bf 100644 --- a/taskflow/utils/threading_utils.py +++ b/taskflow/utils/threading_utils.py @@ -22,6 +22,8 @@ import threading import six from six.moves import _thread +from taskflow.utils import misc + if sys.version_info[0:2] == (2, 6): # This didn't return that was/wasn't set in 2.6, since we actually care @@ -137,7 +139,8 @@ class ThreadBundle(object): """Creates & starts all associated threads (that are not running).""" count = 0 with self._lock: - for i, (builder, thread, started) in enumerate(self._threads): + it = enumerate(self._threads) + for i, (builder, thread, started) in it: if thread and started: continue if not thread: @@ -157,7 +160,8 @@ class ThreadBundle(object): """Stops & joins all associated threads (that have been started).""" count = 0 with self._lock: - for i, (builder, thread, started) in enumerate(self._threads): + it = misc.reverse_enumerate(self._threads) + for i, (builder, thread, started) in it: if not thread or not started: continue self._trigger_callback(builder.before_join, thread)