taskflow/taskflow/tests/unit/test_utils_threading_utils.py

158 lines
5.4 KiB
Python

# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import functools
import threading
import time
from taskflow import test
from taskflow.utils import threading_utils as tu
def _spinner(death):
while not death.is_set():
time.sleep(0.1)
class TestThreadHelpers(test.TestCase):
def test_alive_thread_falsey(self):
for v in [False, 0, None, ""]:
self.assertFalse(tu.is_alive(v))
def test_alive_thread(self):
death = threading.Event()
t = tu.daemon_thread(_spinner, death)
self.assertFalse(tu.is_alive(t))
t.start()
self.assertTrue(tu.is_alive(t))
death.set()
t.join()
self.assertFalse(tu.is_alive(t))
def test_daemon_thread(self):
death = threading.Event()
t = tu.daemon_thread(_spinner, death)
self.assertTrue(t.daemon)
class TestThreadBundle(test.TestCase):
thread_count = 5
def setUp(self):
super(TestThreadBundle, self).setUp()
self.bundle = tu.ThreadBundle()
self.death = threading.Event()
self.addCleanup(self.bundle.stop)
self.addCleanup(self.death.set)
def test_bind_invalid(self):
self.assertRaises(ValueError, self.bundle.bind, 1)
for k in ['after_start', 'before_start',
'before_join', 'after_join']:
kwargs = {
k: 1,
}
self.assertRaises(ValueError, self.bundle.bind,
lambda: tu.daemon_thread(_spinner, self.death),
**kwargs)
def test_bundle_length(self):
self.assertEqual(0, len(self.bundle))
for i in range(0, self.thread_count):
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death))
self.assertEqual(1, self.bundle.start())
self.assertEqual(i + 1, len(self.bundle))
self.death.set()
self.assertEqual(self.thread_count, self.bundle.stop())
self.assertEqual(self.thread_count, len(self.bundle))
def test_start_stop_order(self):
start_events = collections.deque()
death_events = collections.deque()
def before_start(i, t):
start_events.append((i, 'bs'))
def before_join(i, t):
death_events.append((i, 'bj'))
self.death.set()
def after_start(i, t):
start_events.append((i, 'as'))
def after_join(i, t):
death_events.append((i, 'aj'))
for i in range(0, self.thread_count):
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
before_join=functools.partial(before_join, i),
after_join=functools.partial(after_join, i),
before_start=functools.partial(before_start, i),
after_start=functools.partial(after_start, i))
self.assertEqual(self.thread_count, self.bundle.start())
self.assertEqual(self.thread_count, len(self.bundle))
self.assertEqual(self.thread_count, self.bundle.stop())
self.assertEqual(0, self.bundle.stop())
self.assertTrue(self.death.is_set())
expected_start_events = []
for i in range(0, self.thread_count):
expected_start_events.extend([
(i, 'bs'), (i, 'as'),
])
self.assertEqual(expected_start_events, list(start_events))
expected_death_events = []
j = self.thread_count - 1
for _i in range(0, self.thread_count):
expected_death_events.extend([
(j, 'bj'), (j, 'aj'),
])
j -= 1
self.assertEqual(expected_death_events, list(death_events))
def test_start_stop(self):
events = collections.deque()
def before_start(t):
events.append('bs')
def before_join(t):
events.append('bj')
self.death.set()
def after_start(t):
events.append('as')
def after_join(t):
events.append('aj')
for _i in range(0, self.thread_count):
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
before_join=before_join,
after_join=after_join,
before_start=before_start,
after_start=after_start)
self.assertEqual(self.thread_count, self.bundle.start())
self.assertEqual(self.thread_count, len(self.bundle))
self.assertEqual(self.thread_count, self.bundle.stop())
for event in ['as', 'bs', 'bj', 'aj']:
self.assertEqual(self.thread_count,
len([e for e in events if e == event]))
self.assertEqual(0, self.bundle.stop())
self.assertTrue(self.death.is_set())