Add a stop thread to statemachine

Providers have a "stop" and a "join" method so that when new
configuration arrives, the launcher can tell all existing
providers to quickly "stop".  If the program is exiting, they
can optionally be "joined".

In the statemachine "stop" method, we wait for all statemachines
to be complete before shutting down the keyscan threadpool,
because otherwise shutting down the threadpool would cause any
existing statemachines to fail at the end of the process.

However, this causes the "stop" method to take a long time, and
during that time, new requests can be assigned to the provider,
extending the process even more.

To allow the stop method to quickly return, yet also support
the extended sequence needed due to the keyscan threadpool, this
change causes the stop method to spawn a thread dedicated to
stopping the provider manager.  The join method now waits for
that thread to complete.  This restores the statemachine
provider manager to the intended "fire and forget" approach
of the stop method.

Change-Id: I430c70ce0daa7fc28cbd43ecc64c7a974239950d
This commit is contained in:
James E. Blair 2023-02-24 10:51:51 -08:00
parent aaecb9659e
commit da59284230
1 changed files with 9 additions and 0 deletions

View File

@ -532,6 +532,7 @@ class StateMachineProvider(Provider, QuotaSupport):
self.label_quota_cache = cachetools.LRUCache(num_labels)
self.possibly_leaked_nodes = {}
self.possibly_leaked_uploads = {}
self.stop_thread = None
def start(self, zk_conn):
super().start(zk_conn)
@ -546,6 +547,12 @@ class StateMachineProvider(Provider, QuotaSupport):
def stop(self):
self.log.debug("Stopping")
self.stop_thread = threading.Thread(
target=self._stop,
daemon=True)
self.stop_thread.start()
def _stop(self):
if self.state_machine_thread:
while self.launchers or self.deleters:
time.sleep(1)
@ -562,6 +569,8 @@ class StateMachineProvider(Provider, QuotaSupport):
self.log.debug("Joining")
if self.state_machine_thread:
self.state_machine_thread.join()
if self.stop_thread:
self.stop_thread.join()
self.log.debug("Joined")
def _runStateMachines(self):