Merge "Add graceful stop function to ThreadGroup.stop"
This commit is contained in:
		@@ -85,7 +85,7 @@ class ThreadGroup(object):
 | 
			
		||||
    def thread_done(self, thread):
 | 
			
		||||
        self.threads.remove(thread)
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
    def _stop_threads(self):
 | 
			
		||||
        current = threading.current_thread()
 | 
			
		||||
 | 
			
		||||
        # Iterate over a copy of self.threads so thread_done doesn't
 | 
			
		||||
@@ -99,6 +99,7 @@ class ThreadGroup(object):
 | 
			
		||||
            except Exception as ex:
 | 
			
		||||
                LOG.exception(ex)
 | 
			
		||||
 | 
			
		||||
    def _stop_timers(self):
 | 
			
		||||
        for x in self.timers:
 | 
			
		||||
            try:
 | 
			
		||||
                x.stop()
 | 
			
		||||
@@ -106,6 +107,23 @@ class ThreadGroup(object):
 | 
			
		||||
                LOG.exception(ex)
 | 
			
		||||
        self.timers = []
 | 
			
		||||
 | 
			
		||||
    def stop(self, graceful=False):
 | 
			
		||||
        """stop function has the option of graceful=True/False.
 | 
			
		||||
 | 
			
		||||
        * In case of graceful=True, wait for all threads to be finished.
 | 
			
		||||
          Never kill threads.
 | 
			
		||||
        * In case of graceful=False, kill threads immediately.
 | 
			
		||||
        """
 | 
			
		||||
        self._stop_timers()
 | 
			
		||||
        if graceful:
 | 
			
		||||
            # In case of graceful=True, wait for all threads to be
 | 
			
		||||
            # finished, never kill threads
 | 
			
		||||
            self.wait()
 | 
			
		||||
        else:
 | 
			
		||||
            # In case of graceful=False(Default), kill threads
 | 
			
		||||
            # immediately
 | 
			
		||||
            self._stop_threads()
 | 
			
		||||
 | 
			
		||||
    def wait(self):
 | 
			
		||||
        for x in self.timers:
 | 
			
		||||
            try:
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,8 @@
 | 
			
		||||
Unit Tests for thread groups
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from oslotest import base as test_base
 | 
			
		||||
 | 
			
		||||
from openstack.common import threadgroup
 | 
			
		||||
@@ -44,3 +46,27 @@ class ThreadGroupTestCase(test_base.BaseTestCase):
 | 
			
		||||
        self.assertTrue(timer._running)
 | 
			
		||||
        self.assertEqual(('arg',), timer.args)
 | 
			
		||||
        self.assertEqual({'kwarg': 'kwarg'}, timer.kw)
 | 
			
		||||
 | 
			
		||||
    def test_stop_immediately(self):
 | 
			
		||||
 | 
			
		||||
        def foo(*args, **kwargs):
 | 
			
		||||
            time.sleep(1)
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        self.tg.add_thread(foo, 'arg', kwarg='kwarg')
 | 
			
		||||
        self.tg.stop()
 | 
			
		||||
        end_time = time.time()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(0, len(self.tg.threads))
 | 
			
		||||
        self.assertTrue(end_time - start_time < 1)
 | 
			
		||||
 | 
			
		||||
    def test_stop_gracefully(self):
 | 
			
		||||
 | 
			
		||||
        def foo(*args, **kwargs):
 | 
			
		||||
            time.sleep(1)
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        self.tg.add_thread(foo, 'arg', kwarg='kwarg')
 | 
			
		||||
        self.tg.stop(True)
 | 
			
		||||
        end_time = time.time()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(0, len(self.tg.threads))
 | 
			
		||||
        self.assertTrue(end_time - start_time >= 1)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user