diff --git a/openstack/common/threadgroup.py b/openstack/common/threadgroup.py index 31097f1..0f4bbf8 100644 --- a/openstack/common/threadgroup.py +++ b/openstack/common/threadgroup.py @@ -85,7 +85,7 @@ class ThreadGroup(object): def thread_done(self, thread): self.threads.remove(thread) - def stop(self): + def _stop_threads(self): current = threading.current_thread() # Iterate over a copy of self.threads so thread_done doesn't @@ -99,6 +99,7 @@ class ThreadGroup(object): except Exception as ex: LOG.exception(ex) + def _stop_timers(self): for x in self.timers: try: x.stop() @@ -106,6 +107,23 @@ class ThreadGroup(object): LOG.exception(ex) self.timers = [] + def stop(self, graceful=False): + """stop function has the option of graceful=True/False. + + * In case of graceful=True, wait for all threads to be finished. + Never kill threads. + * In case of graceful=False, kill threads immediately. + """ + self._stop_timers() + if graceful: + # In case of graceful=True, wait for all threads to be + # finished, never kill threads + self.wait() + else: + # In case of graceful=False(Default), kill threads + # immediately + self._stop_threads() + def wait(self): for x in self.timers: try: diff --git a/tests/unit/test_threadgroup.py b/tests/unit/test_threadgroup.py index 08446e2..adbb6e5 100644 --- a/tests/unit/test_threadgroup.py +++ b/tests/unit/test_threadgroup.py @@ -17,6 +17,8 @@ Unit Tests for thread groups """ +import time + from oslotest import base as test_base from openstack.common import threadgroup @@ -44,3 +46,27 @@ class ThreadGroupTestCase(test_base.BaseTestCase): self.assertTrue(timer._running) self.assertEqual(('arg',), timer.args) self.assertEqual({'kwarg': 'kwarg'}, timer.kw) + + def test_stop_immediately(self): + + def foo(*args, **kwargs): + time.sleep(1) + start_time = time.time() + self.tg.add_thread(foo, 'arg', kwarg='kwarg') + self.tg.stop() + end_time = time.time() + + self.assertEqual(0, len(self.tg.threads)) + self.assertTrue(end_time - start_time < 1) + + def test_stop_gracefully(self): + + def foo(*args, **kwargs): + time.sleep(1) + start_time = time.time() + self.tg.add_thread(foo, 'arg', kwarg='kwarg') + self.tg.stop(True) + end_time = time.time() + + self.assertEqual(0, len(self.tg.threads)) + self.assertTrue(end_time - start_time >= 1)