Merge "Ensure the thread bundle stops in last to first order"
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import time
|
||||
|
||||
from taskflow import test
|
||||
@@ -83,6 +84,51 @@ class TestThreadBundle(test.TestCase):
|
||||
self.assertEqual(self.thread_count, self.bundle.stop())
|
||||
self.assertEqual(self.thread_count, len(self.bundle))
|
||||
|
||||
def test_start_stop_order(self):
|
||||
start_events = collections.deque()
|
||||
death_events = collections.deque()
|
||||
|
||||
def before_start(i, t):
|
||||
start_events.append((i, 'bs'))
|
||||
|
||||
def before_join(i, t):
|
||||
death_events.append((i, 'bj'))
|
||||
self.death.set()
|
||||
|
||||
def after_start(i, t):
|
||||
start_events.append((i, 'as'))
|
||||
|
||||
def after_join(i, t):
|
||||
death_events.append((i, 'aj'))
|
||||
|
||||
for i in range(0, self.thread_count):
|
||||
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
|
||||
before_join=functools.partial(before_join, i),
|
||||
after_join=functools.partial(after_join, i),
|
||||
before_start=functools.partial(before_start, i),
|
||||
after_start=functools.partial(after_start, i))
|
||||
self.assertEqual(self.thread_count, self.bundle.start())
|
||||
self.assertEqual(self.thread_count, len(self.bundle))
|
||||
self.assertEqual(self.thread_count, self.bundle.stop())
|
||||
self.assertEqual(0, self.bundle.stop())
|
||||
self.assertTrue(self.death.is_set())
|
||||
|
||||
expected_start_events = []
|
||||
for i in range(0, self.thread_count):
|
||||
expected_start_events.extend([
|
||||
(i, 'bs'), (i, 'as'),
|
||||
])
|
||||
self.assertEqual(expected_start_events, list(start_events))
|
||||
|
||||
expected_death_events = []
|
||||
j = self.thread_count - 1
|
||||
for _i in range(0, self.thread_count):
|
||||
expected_death_events.extend([
|
||||
(j, 'bj'), (j, 'aj'),
|
||||
])
|
||||
j -= 1
|
||||
self.assertEqual(expected_death_events, list(death_events))
|
||||
|
||||
def test_start_stop(self):
|
||||
events = collections.deque()
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ import threading
|
||||
import six
|
||||
from six.moves import _thread
|
||||
|
||||
from taskflow.utils import misc
|
||||
|
||||
|
||||
if sys.version_info[0:2] == (2, 6):
|
||||
# This didn't return that was/wasn't set in 2.6, since we actually care
|
||||
@@ -137,7 +139,8 @@ class ThreadBundle(object):
|
||||
"""Creates & starts all associated threads (that are not running)."""
|
||||
count = 0
|
||||
with self._lock:
|
||||
for i, (builder, thread, started) in enumerate(self._threads):
|
||||
it = enumerate(self._threads)
|
||||
for i, (builder, thread, started) in it:
|
||||
if thread and started:
|
||||
continue
|
||||
if not thread:
|
||||
@@ -157,7 +160,8 @@ class ThreadBundle(object):
|
||||
"""Stops & joins all associated threads (that have been started)."""
|
||||
count = 0
|
||||
with self._lock:
|
||||
for i, (builder, thread, started) in enumerate(self._threads):
|
||||
it = misc.reverse_enumerate(self._threads)
|
||||
for i, (builder, thread, started) in it:
|
||||
if not thread or not started:
|
||||
continue
|
||||
self._trigger_callback(builder.before_join, thread)
|
||||
|
||||
Reference in New Issue
Block a user