Add a thread bundle helper utility + tests

To make it easier to create a bunch of threads in a single
call (and stop them in a single call) create a concept of
a thread bundle (similar to a thread group) that will call
into a provided set of factories to get a thread, activate
callbacks to notify others that a thread is about to start
or stop and then perform the start or stop of the bound
threads in a orderly manner.

Change-Id: I7d233cccb230b716af41243ad27220b988eec14c
This commit is contained in:
Joshua Harlow
2015-01-24 00:45:36 -08:00
parent 1ae7a8e67b
commit ca82e20efe
4 changed files with 229 additions and 21 deletions

View File

@@ -59,10 +59,16 @@ class WorkerTaskExecutor(executor.TaskExecutor):
transport=transport,
transport_options=transport_options,
retry_options=retry_options)
self._proxy_thread = None
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
[self._notify_topics])
self._periodic_thread = None
self._helpers = tu.ThreadBundle()
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
after_start=lambda t: self._proxy.wait(),
before_join=lambda t: self._proxy.stop())
self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start),
before_join=lambda t: self._periodic.stop(),
after_join=lambda t: self._periodic.reset(),
before_start=lambda t: self._periodic.reset())
def _process_notify(self, notify, message):
"""Process notify message from remote side."""
@@ -226,24 +232,10 @@ class WorkerTaskExecutor(executor.TaskExecutor):
def start(self):
"""Starts proxy thread and associated topic notification thread."""
if not tu.is_alive(self._proxy_thread):
self._proxy_thread = tu.daemon_thread(self._proxy.start)
self._proxy_thread.start()
self._proxy.wait()
if not tu.is_alive(self._periodic_thread):
self._periodic.reset()
self._periodic_thread = tu.daemon_thread(self._periodic.start)
self._periodic_thread.start()
self._helpers.start()
def stop(self):
"""Stops proxy thread and associated topic notification thread."""
if self._periodic_thread is not None:
self._periodic.stop()
self._periodic_thread.join()
self._periodic_thread = None
if self._proxy_thread is not None:
self._proxy.stop()
self._proxy_thread.join()
self._proxy_thread = None
self._helpers.stop()
self._requests_cache.clear(self._handle_expired_request)
self._workers.clear()

View File

@@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import time
from taskflow import test
from taskflow.utils import threading_utils as tu
def _spinner(death):
while not death.is_set():
time.sleep(0.1)
class TestThreadHelpers(test.TestCase):
def test_event_wait(self):
e = tu.Event()
e.set()
self.assertTrue(e.wait())
def test_alive_thread_falsey(self):
for v in [False, 0, None, ""]:
self.assertFalse(tu.is_alive(v))
def test_alive_thread(self):
death = tu.Event()
t = tu.daemon_thread(_spinner, death)
self.assertFalse(tu.is_alive(t))
t.start()
self.assertTrue(tu.is_alive(t))
death.set()
t.join()
self.assertFalse(tu.is_alive(t))
def test_daemon_thread(self):
death = tu.Event()
t = tu.daemon_thread(_spinner, death)
self.assertTrue(t.daemon)
class TestThreadBundle(test.TestCase):
thread_count = 5
def setUp(self):
super(TestThreadBundle, self).setUp()
self.bundle = tu.ThreadBundle()
self.death = tu.Event()
self.addCleanup(self.bundle.stop)
self.addCleanup(self.death.set)
def test_bind_invalid(self):
self.assertRaises(ValueError, self.bundle.bind, 1)
for k in ['after_start', 'before_start',
'before_join', 'after_join']:
kwargs = {
k: 1,
}
self.assertRaises(ValueError, self.bundle.bind,
lambda: tu.daemon_thread(_spinner, self.death),
**kwargs)
def test_bundle_length(self):
self.assertEqual(0, len(self.bundle))
for i in range(0, self.thread_count):
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death))
self.assertEqual(1, self.bundle.start())
self.assertEqual(i + 1, len(self.bundle))
self.death.set()
self.assertEqual(self.thread_count, self.bundle.stop())
self.assertEqual(self.thread_count, len(self.bundle))
def test_start_stop(self):
events = collections.deque()
def before_start(t):
events.append('bs')
def before_join(t):
events.append('bj')
self.death.set()
def after_start(t):
events.append('as')
def after_join(t):
events.append('aj')
for _i in range(0, self.thread_count):
self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death),
before_join=before_join,
after_join=after_join,
before_start=before_start,
after_start=after_start)
self.assertEqual(self.thread_count, self.bundle.start())
self.assertEqual(self.thread_count, len(self.bundle))
self.assertEqual(self.thread_count, self.bundle.stop())
for event in ['as', 'bs', 'bj', 'aj']:
self.assertEqual(self.thread_count,
len([e for e in events if e == event]))
self.assertEqual(0, self.bundle.stop())
self.assertTrue(self.death.is_set())

View File

@@ -353,9 +353,6 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex = self.executor()
ex.start()
# wait until executor thread is done
ex._proxy_thread.join()
# stop executor
ex.stop()

View File

@@ -14,10 +14,12 @@
# License for the specific language governing permissions and limitations
# under the License.
import collections
import multiprocessing
import sys
import threading
import six
from six.moves import _thread
@@ -71,3 +73,105 @@ def daemon_thread(target, *args, **kwargs):
# unless the daemon property is set to True.
thread.daemon = True
return thread
# Container for thread creator + associated callbacks.
_ThreadBuilder = collections.namedtuple('_ThreadBuilder',
['thread_factory',
'before_start', 'after_start',
'before_join', 'after_join'])
_ThreadBuilder.callables = tuple([
# Attribute name -> none allowed as a valid value...
('thread_factory', False),
('before_start', True),
('after_start', True),
('before_join', True),
('after_join', True),
])
class ThreadBundle(object):
"""A group/bundle of threads that start/stop together."""
def __init__(self):
self._threads = []
self._lock = threading.Lock()
def bind(self, thread_factory,
before_start=None, after_start=None,
before_join=None, after_join=None):
"""Adds a thread (to-be) into this bundle (with given callbacks).
NOTE(harlowja): callbacks provided should not attempt to call
mutating methods (:meth:`.stop`, :meth:`.start`,
:meth:`.bind` ...) on this object as that will result
in dead-lock since the lock on this object is not
meant to be (and is not) reentrant...
"""
builder = _ThreadBuilder(thread_factory,
before_start, after_start,
before_join, after_join)
for attr_name, none_allowed in builder.callables:
cb = getattr(builder, attr_name)
if cb is None and none_allowed:
continue
if not six.callable(cb):
raise ValueError("Provided callback for argument"
" '%s' must be callable" % attr_name)
with self._lock:
self._threads.append([
builder,
# The built thread.
None,
# Whether the built thread was started (and should have
# ran or still be running).
False,
])
@staticmethod
def _trigger_callback(callback, thread):
if callback is not None:
callback(thread)
def start(self):
"""Creates & starts all associated threads (that are not running)."""
count = 0
with self._lock:
for i, (builder, thread, started) in enumerate(self._threads):
if thread and started:
continue
if not thread:
self._threads[i][1] = thread = builder.thread_factory()
self._trigger_callback(builder.before_start, thread)
thread.start()
count += 1
try:
self._trigger_callback(builder.after_start, thread)
finally:
# Just incase the 'after_start' callback blows up make sure
# we always set this...
self._threads[i][2] = started = True
return count
def stop(self):
"""Stops & joins all associated threads (that have been started)."""
count = 0
with self._lock:
for i, (builder, thread, started) in enumerate(self._threads):
if not thread or not started:
continue
self._trigger_callback(builder.before_join, thread)
thread.join()
count += 1
try:
self._trigger_callback(builder.after_join, thread)
finally:
# Just incase the 'after_join' callback blows up make sure
# we always set/reset these...
self._threads[i][1] = thread = None
self._threads[i][2] = started = False
return count
def __len__(self):
"""Returns how many threads (to-be) are in this bundle."""
return len(self._threads)