Add a thread bundle helper utility + tests
To make it easier to create a bunch of threads in a single call (and stop them in a single call) create a concept of a thread bundle (similar to a thread group) that will call into a provided set of factories to get a thread, activate callbacks to notify others that a thread is about to start or stop and then perform the start or stop of the bound threads in a orderly manner. Change-Id: I7d233cccb230b716af41243ad27220b988eec14c
This commit is contained in:
@@ -59,10 +59,16 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
transport=transport,
|
||||
transport_options=transport_options,
|
||||
retry_options=retry_options)
|
||||
self._proxy_thread = None
|
||||
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
|
||||
[self._notify_topics])
|
||||
self._periodic_thread = None
|
||||
self._helpers = tu.ThreadBundle()
|
||||
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
|
||||
after_start=lambda t: self._proxy.wait(),
|
||||
before_join=lambda t: self._proxy.stop())
|
||||
self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start),
|
||||
before_join=lambda t: self._periodic.stop(),
|
||||
after_join=lambda t: self._periodic.reset(),
|
||||
before_start=lambda t: self._periodic.reset())
|
||||
|
||||
def _process_notify(self, notify, message):
|
||||
"""Process notify message from remote side."""
|
||||
@@ -226,24 +232,10 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
|
||||
def start(self):
|
||||
"""Starts proxy thread and associated topic notification thread."""
|
||||
if not tu.is_alive(self._proxy_thread):
|
||||
self._proxy_thread = tu.daemon_thread(self._proxy.start)
|
||||
self._proxy_thread.start()
|
||||
self._proxy.wait()
|
||||
if not tu.is_alive(self._periodic_thread):
|
||||
self._periodic.reset()
|
||||
self._periodic_thread = tu.daemon_thread(self._periodic.start)
|
||||
self._periodic_thread.start()
|
||||
self._helpers.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stops proxy thread and associated topic notification thread."""
|
||||
if self._periodic_thread is not None:
|
||||
self._periodic.stop()
|
||||
self._periodic_thread.join()
|
||||
self._periodic_thread = None
|
||||
if self._proxy_thread is not None:
|
||||
self._proxy.stop()
|
||||
self._proxy_thread.join()
|
||||
self._proxy_thread = None
|
||||
self._helpers.stop()
|
||||
self._requests_cache.clear(self._handle_expired_request)
|
||||
self._workers.clear()
|
||||
|
||||
115
taskflow/tests/unit/test_utils_threading_utils.py
Normal file
115
taskflow/tests/unit/test_utils_threading_utils.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import time
|
||||
|
||||
from taskflow import test
|
||||
from taskflow.utils import threading_utils as tu
|
||||
|
||||
|
||||
def _spinner(death):
|
||||
while not death.is_set():
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
class TestThreadHelpers(test.TestCase):
|
||||
def test_event_wait(self):
|
||||
e = tu.Event()
|
||||
e.set()
|
||||
self.assertTrue(e.wait())
|
||||
|
||||
def test_alive_thread_falsey(self):
|
||||
for v in [False, 0, None, ""]:
|
||||
self.assertFalse(tu.is_alive(v))
|
||||
|
||||
def test_alive_thread(self):
|
||||
death = tu.Event()
|
||||
t = tu.daemon_thread(_spinner, death)
|
||||
self.assertFalse(tu.is_alive(t))
|
||||
t.start()
|
||||
self.assertTrue(tu.is_alive(t))
|
||||
death.set()
|
||||
t.join()
|
||||
self.assertFalse(tu.is_alive(t))
|
||||
|
||||
def test_daemon_thread(self):
|
||||
death = tu.Event()
|
||||
t = tu.daemon_thread(_spinner, death)
|
||||
self.assertTrue(t.daemon)
|
||||
|
||||
|
||||
class TestThreadBundle(test.TestCase):
|
||||
thread_count = 5
|
||||
|
||||
def setUp(self):
|
||||
super(TestThreadBundle, self).setUp()
|
||||
self.bundle = tu.ThreadBundle()
|
||||
self.death = tu.Event()
|
||||
self.addCleanup(self.bundle.stop)
|
||||
self.addCleanup(self.death.set)
|
||||
|
||||
def test_bind_invalid(self):
|
||||
self.assertRaises(ValueError, self.bundle.bind, 1)
|
||||
for k in ['after_start', 'before_start',
|
||||
'before_join', 'after_join']:
|
||||
kwargs = {
|
||||
k: 1,
|
||||
}
|
||||
self.assertRaises(ValueError, self.bundle.bind,
|
||||
lambda: tu.daemon_thread(_spinner, self.death),
|
||||
**kwargs)
|
||||
|
||||
def test_bundle_length(self):
|
||||
self.assertEqual(0, len(self.bundle))
|
||||
for i in range(0, self.thread_count):
|
||||
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death))
|
||||
self.assertEqual(1, self.bundle.start())
|
||||
self.assertEqual(i + 1, len(self.bundle))
|
||||
self.death.set()
|
||||
self.assertEqual(self.thread_count, self.bundle.stop())
|
||||
self.assertEqual(self.thread_count, len(self.bundle))
|
||||
|
||||
def test_start_stop(self):
|
||||
events = collections.deque()
|
||||
|
||||
def before_start(t):
|
||||
events.append('bs')
|
||||
|
||||
def before_join(t):
|
||||
events.append('bj')
|
||||
self.death.set()
|
||||
|
||||
def after_start(t):
|
||||
events.append('as')
|
||||
|
||||
def after_join(t):
|
||||
events.append('aj')
|
||||
|
||||
for _i in range(0, self.thread_count):
|
||||
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
|
||||
before_join=before_join,
|
||||
after_join=after_join,
|
||||
before_start=before_start,
|
||||
after_start=after_start)
|
||||
self.assertEqual(self.thread_count, self.bundle.start())
|
||||
self.assertEqual(self.thread_count, len(self.bundle))
|
||||
self.assertEqual(self.thread_count, self.bundle.stop())
|
||||
for event in ['as', 'bs', 'bj', 'aj']:
|
||||
self.assertEqual(self.thread_count,
|
||||
len([e for e in events if e == event]))
|
||||
self.assertEqual(0, self.bundle.stop())
|
||||
self.assertTrue(self.death.is_set())
|
||||
@@ -353,9 +353,6 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex = self.executor()
|
||||
ex.start()
|
||||
|
||||
# wait until executor thread is done
|
||||
ex._proxy_thread.join()
|
||||
|
||||
# stop executor
|
||||
ex.stop()
|
||||
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import collections
|
||||
import multiprocessing
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import six
|
||||
from six.moves import _thread
|
||||
|
||||
|
||||
@@ -71,3 +73,105 @@ def daemon_thread(target, *args, **kwargs):
|
||||
# unless the daemon property is set to True.
|
||||
thread.daemon = True
|
||||
return thread
|
||||
|
||||
|
||||
# Container for thread creator + associated callbacks.
|
||||
_ThreadBuilder = collections.namedtuple('_ThreadBuilder',
|
||||
['thread_factory',
|
||||
'before_start', 'after_start',
|
||||
'before_join', 'after_join'])
|
||||
_ThreadBuilder.callables = tuple([
|
||||
# Attribute name -> none allowed as a valid value...
|
||||
('thread_factory', False),
|
||||
('before_start', True),
|
||||
('after_start', True),
|
||||
('before_join', True),
|
||||
('after_join', True),
|
||||
])
|
||||
|
||||
|
||||
class ThreadBundle(object):
|
||||
"""A group/bundle of threads that start/stop together."""
|
||||
|
||||
def __init__(self):
|
||||
self._threads = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def bind(self, thread_factory,
|
||||
before_start=None, after_start=None,
|
||||
before_join=None, after_join=None):
|
||||
"""Adds a thread (to-be) into this bundle (with given callbacks).
|
||||
|
||||
NOTE(harlowja): callbacks provided should not attempt to call
|
||||
mutating methods (:meth:`.stop`, :meth:`.start`,
|
||||
:meth:`.bind` ...) on this object as that will result
|
||||
in dead-lock since the lock on this object is not
|
||||
meant to be (and is not) reentrant...
|
||||
"""
|
||||
builder = _ThreadBuilder(thread_factory,
|
||||
before_start, after_start,
|
||||
before_join, after_join)
|
||||
for attr_name, none_allowed in builder.callables:
|
||||
cb = getattr(builder, attr_name)
|
||||
if cb is None and none_allowed:
|
||||
continue
|
||||
if not six.callable(cb):
|
||||
raise ValueError("Provided callback for argument"
|
||||
" '%s' must be callable" % attr_name)
|
||||
with self._lock:
|
||||
self._threads.append([
|
||||
builder,
|
||||
# The built thread.
|
||||
None,
|
||||
# Whether the built thread was started (and should have
|
||||
# ran or still be running).
|
||||
False,
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _trigger_callback(callback, thread):
|
||||
if callback is not None:
|
||||
callback(thread)
|
||||
|
||||
def start(self):
|
||||
"""Creates & starts all associated threads (that are not running)."""
|
||||
count = 0
|
||||
with self._lock:
|
||||
for i, (builder, thread, started) in enumerate(self._threads):
|
||||
if thread and started:
|
||||
continue
|
||||
if not thread:
|
||||
self._threads[i][1] = thread = builder.thread_factory()
|
||||
self._trigger_callback(builder.before_start, thread)
|
||||
thread.start()
|
||||
count += 1
|
||||
try:
|
||||
self._trigger_callback(builder.after_start, thread)
|
||||
finally:
|
||||
# Just incase the 'after_start' callback blows up make sure
|
||||
# we always set this...
|
||||
self._threads[i][2] = started = True
|
||||
return count
|
||||
|
||||
def stop(self):
|
||||
"""Stops & joins all associated threads (that have been started)."""
|
||||
count = 0
|
||||
with self._lock:
|
||||
for i, (builder, thread, started) in enumerate(self._threads):
|
||||
if not thread or not started:
|
||||
continue
|
||||
self._trigger_callback(builder.before_join, thread)
|
||||
thread.join()
|
||||
count += 1
|
||||
try:
|
||||
self._trigger_callback(builder.after_join, thread)
|
||||
finally:
|
||||
# Just incase the 'after_join' callback blows up make sure
|
||||
# we always set/reset these...
|
||||
self._threads[i][1] = thread = None
|
||||
self._threads[i][2] = started = False
|
||||
return count
|
||||
|
||||
def __len__(self):
|
||||
"""Returns how many threads (to-be) are in this bundle."""
|
||||
return len(self._threads)
|
||||
|
||||
Reference in New Issue
Block a user