158 lines
5.4 KiB
Python
158 lines
5.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import collections
|
|
import functools
|
|
import threading
|
|
import time
|
|
|
|
from taskflow import test
|
|
from taskflow.utils import threading_utils as tu
|
|
|
|
|
|
def _spinner(death):
|
|
while not death.is_set():
|
|
time.sleep(0.1)
|
|
|
|
|
|
class TestThreadHelpers(test.TestCase):
|
|
def test_alive_thread_falsey(self):
|
|
for v in [False, 0, None, ""]:
|
|
self.assertFalse(tu.is_alive(v))
|
|
|
|
def test_alive_thread(self):
|
|
death = threading.Event()
|
|
t = tu.daemon_thread(_spinner, death)
|
|
self.assertFalse(tu.is_alive(t))
|
|
t.start()
|
|
self.assertTrue(tu.is_alive(t))
|
|
death.set()
|
|
t.join()
|
|
self.assertFalse(tu.is_alive(t))
|
|
|
|
def test_daemon_thread(self):
|
|
death = threading.Event()
|
|
t = tu.daemon_thread(_spinner, death)
|
|
self.assertTrue(t.daemon)
|
|
|
|
|
|
class TestThreadBundle(test.TestCase):
|
|
thread_count = 5
|
|
|
|
def setUp(self):
|
|
super(TestThreadBundle, self).setUp()
|
|
self.bundle = tu.ThreadBundle()
|
|
self.death = threading.Event()
|
|
self.addCleanup(self.bundle.stop)
|
|
self.addCleanup(self.death.set)
|
|
|
|
def test_bind_invalid(self):
|
|
self.assertRaises(ValueError, self.bundle.bind, 1)
|
|
for k in ['after_start', 'before_start',
|
|
'before_join', 'after_join']:
|
|
kwargs = {
|
|
k: 1,
|
|
}
|
|
self.assertRaises(ValueError, self.bundle.bind,
|
|
lambda: tu.daemon_thread(_spinner, self.death),
|
|
**kwargs)
|
|
|
|
def test_bundle_length(self):
|
|
self.assertEqual(0, len(self.bundle))
|
|
for i in range(0, self.thread_count):
|
|
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death))
|
|
self.assertEqual(1, self.bundle.start())
|
|
self.assertEqual(i + 1, len(self.bundle))
|
|
self.death.set()
|
|
self.assertEqual(self.thread_count, self.bundle.stop())
|
|
self.assertEqual(self.thread_count, len(self.bundle))
|
|
|
|
def test_start_stop_order(self):
|
|
start_events = collections.deque()
|
|
death_events = collections.deque()
|
|
|
|
def before_start(i, t):
|
|
start_events.append((i, 'bs'))
|
|
|
|
def before_join(i, t):
|
|
death_events.append((i, 'bj'))
|
|
self.death.set()
|
|
|
|
def after_start(i, t):
|
|
start_events.append((i, 'as'))
|
|
|
|
def after_join(i, t):
|
|
death_events.append((i, 'aj'))
|
|
|
|
for i in range(0, self.thread_count):
|
|
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
|
|
before_join=functools.partial(before_join, i),
|
|
after_join=functools.partial(after_join, i),
|
|
before_start=functools.partial(before_start, i),
|
|
after_start=functools.partial(after_start, i))
|
|
self.assertEqual(self.thread_count, self.bundle.start())
|
|
self.assertEqual(self.thread_count, len(self.bundle))
|
|
self.assertEqual(self.thread_count, self.bundle.stop())
|
|
self.assertEqual(0, self.bundle.stop())
|
|
self.assertTrue(self.death.is_set())
|
|
|
|
expected_start_events = []
|
|
for i in range(0, self.thread_count):
|
|
expected_start_events.extend([
|
|
(i, 'bs'), (i, 'as'),
|
|
])
|
|
self.assertEqual(expected_start_events, list(start_events))
|
|
|
|
expected_death_events = []
|
|
j = self.thread_count - 1
|
|
for _i in range(0, self.thread_count):
|
|
expected_death_events.extend([
|
|
(j, 'bj'), (j, 'aj'),
|
|
])
|
|
j -= 1
|
|
self.assertEqual(expected_death_events, list(death_events))
|
|
|
|
def test_start_stop(self):
|
|
events = collections.deque()
|
|
|
|
def before_start(t):
|
|
events.append('bs')
|
|
|
|
def before_join(t):
|
|
events.append('bj')
|
|
self.death.set()
|
|
|
|
def after_start(t):
|
|
events.append('as')
|
|
|
|
def after_join(t):
|
|
events.append('aj')
|
|
|
|
for _i in range(0, self.thread_count):
|
|
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
|
|
before_join=before_join,
|
|
after_join=after_join,
|
|
before_start=before_start,
|
|
after_start=after_start)
|
|
self.assertEqual(self.thread_count, self.bundle.start())
|
|
self.assertEqual(self.thread_count, len(self.bundle))
|
|
self.assertEqual(self.thread_count, self.bundle.stop())
|
|
for event in ['as', 'bs', 'bj', 'aj']:
|
|
self.assertEqual(self.thread_count,
|
|
len([e for e in events if e == event]))
|
|
self.assertEqual(0, self.bundle.stop())
|
|
self.assertTrue(self.death.is_set())
|