From ca82e20efe8f5c5d50b3db89be0342710ef7f73b Mon Sep 17 00:00:00 2001 From: Joshua Harlow Date: Sat, 24 Jan 2015 00:45:36 -0800 Subject: [PATCH] Add a thread bundle helper utility + tests To make it easier to create a bunch of threads in a single call (and stop them in a single call) create a concept of a thread bundle (similar to a thread group) that will call into a provided set of factories to get a thread, activate callbacks to notify others that a thread is about to start or stop and then perform the start or stop of the bound threads in a orderly manner. Change-Id: I7d233cccb230b716af41243ad27220b988eec14c --- taskflow/engines/worker_based/executor.py | 28 ++--- .../tests/unit/test_utils_threading_utils.py | 115 ++++++++++++++++++ .../tests/unit/worker_based/test_executor.py | 3 - taskflow/utils/threading_utils.py | 104 ++++++++++++++++ 4 files changed, 229 insertions(+), 21 deletions(-) create mode 100644 taskflow/tests/unit/test_utils_threading_utils.py diff --git a/taskflow/engines/worker_based/executor.py b/taskflow/engines/worker_based/executor.py index cda37458..8290ba61 100644 --- a/taskflow/engines/worker_based/executor.py +++ b/taskflow/engines/worker_based/executor.py @@ -59,10 +59,16 @@ class WorkerTaskExecutor(executor.TaskExecutor): transport=transport, transport_options=transport_options, retry_options=retry_options) - self._proxy_thread = None self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD), [self._notify_topics]) - self._periodic_thread = None + self._helpers = tu.ThreadBundle() + self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start), + after_start=lambda t: self._proxy.wait(), + before_join=lambda t: self._proxy.stop()) + self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start), + before_join=lambda t: self._periodic.stop(), + after_join=lambda t: self._periodic.reset(), + before_start=lambda t: self._periodic.reset()) def _process_notify(self, notify, message): """Process notify message from remote side.""" @@ -226,24 +232,10 @@ class WorkerTaskExecutor(executor.TaskExecutor): def start(self): """Starts proxy thread and associated topic notification thread.""" - if not tu.is_alive(self._proxy_thread): - self._proxy_thread = tu.daemon_thread(self._proxy.start) - self._proxy_thread.start() - self._proxy.wait() - if not tu.is_alive(self._periodic_thread): - self._periodic.reset() - self._periodic_thread = tu.daemon_thread(self._periodic.start) - self._periodic_thread.start() + self._helpers.start() def stop(self): """Stops proxy thread and associated topic notification thread.""" - if self._periodic_thread is not None: - self._periodic.stop() - self._periodic_thread.join() - self._periodic_thread = None - if self._proxy_thread is not None: - self._proxy.stop() - self._proxy_thread.join() - self._proxy_thread = None + self._helpers.stop() self._requests_cache.clear(self._handle_expired_request) self._workers.clear() diff --git a/taskflow/tests/unit/test_utils_threading_utils.py b/taskflow/tests/unit/test_utils_threading_utils.py new file mode 100644 index 00000000..974285fa --- /dev/null +++ b/taskflow/tests/unit/test_utils_threading_utils.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import time + +from taskflow import test +from taskflow.utils import threading_utils as tu + + +def _spinner(death): + while not death.is_set(): + time.sleep(0.1) + + +class TestThreadHelpers(test.TestCase): + def test_event_wait(self): + e = tu.Event() + e.set() + self.assertTrue(e.wait()) + + def test_alive_thread_falsey(self): + for v in [False, 0, None, ""]: + self.assertFalse(tu.is_alive(v)) + + def test_alive_thread(self): + death = tu.Event() + t = tu.daemon_thread(_spinner, death) + self.assertFalse(tu.is_alive(t)) + t.start() + self.assertTrue(tu.is_alive(t)) + death.set() + t.join() + self.assertFalse(tu.is_alive(t)) + + def test_daemon_thread(self): + death = tu.Event() + t = tu.daemon_thread(_spinner, death) + self.assertTrue(t.daemon) + + +class TestThreadBundle(test.TestCase): + thread_count = 5 + + def setUp(self): + super(TestThreadBundle, self).setUp() + self.bundle = tu.ThreadBundle() + self.death = tu.Event() + self.addCleanup(self.bundle.stop) + self.addCleanup(self.death.set) + + def test_bind_invalid(self): + self.assertRaises(ValueError, self.bundle.bind, 1) + for k in ['after_start', 'before_start', + 'before_join', 'after_join']: + kwargs = { + k: 1, + } + self.assertRaises(ValueError, self.bundle.bind, + lambda: tu.daemon_thread(_spinner, self.death), + **kwargs) + + def test_bundle_length(self): + self.assertEqual(0, len(self.bundle)) + for i in range(0, self.thread_count): + self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death)) + self.assertEqual(1, self.bundle.start()) + self.assertEqual(i + 1, len(self.bundle)) + self.death.set() + self.assertEqual(self.thread_count, self.bundle.stop()) + self.assertEqual(self.thread_count, len(self.bundle)) + + def test_start_stop(self): + events = collections.deque() + + def before_start(t): + events.append('bs') + + def before_join(t): + events.append('bj') + self.death.set() + + def after_start(t): + events.append('as') + + def after_join(t): + events.append('aj') + + for _i in range(0, self.thread_count): + self.bundle.bind(lambda: tu.daemon_thread(_spinner, self.death), + before_join=before_join, + after_join=after_join, + before_start=before_start, + after_start=after_start) + self.assertEqual(self.thread_count, self.bundle.start()) + self.assertEqual(self.thread_count, len(self.bundle)) + self.assertEqual(self.thread_count, self.bundle.stop()) + for event in ['as', 'bs', 'bj', 'aj']: + self.assertEqual(self.thread_count, + len([e for e in events if e == event])) + self.assertEqual(0, self.bundle.stop()) + self.assertTrue(self.death.is_set()) diff --git a/taskflow/tests/unit/worker_based/test_executor.py b/taskflow/tests/unit/worker_based/test_executor.py index cdb421d1..101031c4 100644 --- a/taskflow/tests/unit/worker_based/test_executor.py +++ b/taskflow/tests/unit/worker_based/test_executor.py @@ -353,9 +353,6 @@ class TestWorkerTaskExecutor(test.MockTestCase): ex = self.executor() ex.start() - # wait until executor thread is done - ex._proxy_thread.join() - # stop executor ex.stop() diff --git a/taskflow/utils/threading_utils.py b/taskflow/utils/threading_utils.py index 5048401c..cea0760d 100644 --- a/taskflow/utils/threading_utils.py +++ b/taskflow/utils/threading_utils.py @@ -14,10 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import multiprocessing import sys import threading +import six from six.moves import _thread @@ -71,3 +73,105 @@ def daemon_thread(target, *args, **kwargs): # unless the daemon property is set to True. thread.daemon = True return thread + + +# Container for thread creator + associated callbacks. +_ThreadBuilder = collections.namedtuple('_ThreadBuilder', + ['thread_factory', + 'before_start', 'after_start', + 'before_join', 'after_join']) +_ThreadBuilder.callables = tuple([ + # Attribute name -> none allowed as a valid value... + ('thread_factory', False), + ('before_start', True), + ('after_start', True), + ('before_join', True), + ('after_join', True), +]) + + +class ThreadBundle(object): + """A group/bundle of threads that start/stop together.""" + + def __init__(self): + self._threads = [] + self._lock = threading.Lock() + + def bind(self, thread_factory, + before_start=None, after_start=None, + before_join=None, after_join=None): + """Adds a thread (to-be) into this bundle (with given callbacks). + + NOTE(harlowja): callbacks provided should not attempt to call + mutating methods (:meth:`.stop`, :meth:`.start`, + :meth:`.bind` ...) on this object as that will result + in dead-lock since the lock on this object is not + meant to be (and is not) reentrant... + """ + builder = _ThreadBuilder(thread_factory, + before_start, after_start, + before_join, after_join) + for attr_name, none_allowed in builder.callables: + cb = getattr(builder, attr_name) + if cb is None and none_allowed: + continue + if not six.callable(cb): + raise ValueError("Provided callback for argument" + " '%s' must be callable" % attr_name) + with self._lock: + self._threads.append([ + builder, + # The built thread. + None, + # Whether the built thread was started (and should have + # ran or still be running). + False, + ]) + + @staticmethod + def _trigger_callback(callback, thread): + if callback is not None: + callback(thread) + + def start(self): + """Creates & starts all associated threads (that are not running).""" + count = 0 + with self._lock: + for i, (builder, thread, started) in enumerate(self._threads): + if thread and started: + continue + if not thread: + self._threads[i][1] = thread = builder.thread_factory() + self._trigger_callback(builder.before_start, thread) + thread.start() + count += 1 + try: + self._trigger_callback(builder.after_start, thread) + finally: + # Just incase the 'after_start' callback blows up make sure + # we always set this... + self._threads[i][2] = started = True + return count + + def stop(self): + """Stops & joins all associated threads (that have been started).""" + count = 0 + with self._lock: + for i, (builder, thread, started) in enumerate(self._threads): + if not thread or not started: + continue + self._trigger_callback(builder.before_join, thread) + thread.join() + count += 1 + try: + self._trigger_callback(builder.after_join, thread) + finally: + # Just incase the 'after_join' callback blows up make sure + # we always set/reset these... + self._threads[i][1] = thread = None + self._threads[i][2] = started = False + return count + + def __len__(self): + """Returns how many threads (to-be) are in this bundle.""" + return len(self._threads)