Fix endpoint races

This change fixes two endpoint-related races:

* We now shut down the aws driver delete thread rather than setting
  it to daemon.  This provides more predictable behavior in tests.

  Some tests started launcher endpoints out-of-band (ie, they didn't
  really use the launcher, but it still started and stopped outside
  of the main action of the test).  In order to ensure that every
  endpoint that we start threads for is also correctly stopped,
  standard stop/start methods are added to endpoints and locks are
  used to ensure that the start/stop sequence is predictable.

  This also lets us use endpoints in the scheduler/web without
  starting background threads for them.

* It was possible to create multiple endpoints for a given region
  if they were created around the same time since we did not use
  a lock around the critical section of the cache.  To correct that,
  surrond the cache lookup with a lock and move the behavior to a
  mixin class for reusability.  (A cachetools decorator around
  getEndpoint was considered but the provider object in the method
  signature makes that difficult).

Change-Id: I5e8ca06e76fced1bb342250a953ecda576168874
This commit is contained in:
James E. Blair 2024-10-25 08:13:40 -07:00
parent 407fce6494
commit a14971778a
8 changed files with 149 additions and 89 deletions

View File

@ -216,6 +216,8 @@ class TestAwsDriver(BaseCloudDriverTest):
layout = self.scheds.first.sched.abide.tenants.get('tenant-one').layout
provider = layout.providers['aws-us-east-1-main']
# Start the endpoint since we're going to use the scheduler's endpoint.
provider.getEndpoint().start()
with self.createZKContext(None) as ctx:
node = AwsProviderNode.new(ctx, label=label)

View File

@ -62,12 +62,9 @@ class LauncherBaseTestCase(ZuulTestCase):
self.s3.create_bucket(
Bucket='zuul',
CreateBucketConfiguration={'LocationConstraint': 'us-west-2'})
self.addCleanup(self.mock_aws.stop)
super().setUp()
def tearDown(self):
self.mock_aws.stop()
super().tearDown()
def _nodes_by_label(self):
nodes = self.launcher.api.nodes_cache.getItems()
nodes_by_label = defaultdict(list)

View File

@ -16,16 +16,14 @@ import urllib
from zuul.driver import Driver, ConnectionInterface, ProviderInterface
from zuul.driver.aws import awsconnection, awsmodel, awsprovider, awsendpoint
from zuul.provider import EndpointCacheMixin
class AwsDriver(Driver, ConnectionInterface, ProviderInterface):
class AwsDriver(Driver, EndpointCacheMixin,
ConnectionInterface, ProviderInterface):
name = 'aws'
_endpoint_class = awsendpoint.AwsProviderEndpoint
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.endpoints = {}
def getConnection(self, name, config):
return awsconnection.AwsConnection(self, name, config)
@ -52,15 +50,9 @@ class AwsDriver(Driver, ConnectionInterface, ProviderInterface):
urllib.parse.quote_plus(provider.connection.connection_name),
urllib.parse.quote_plus(provider.region),
])
try:
return self.endpoints[endpoint_id]
except KeyError:
pass
endpoint = self._endpoint_class(
self, provider.connection, provider.region)
self.endpoints[endpoint_id] = endpoint
return endpoint
return self.getEndpointById(
endpoint_id,
create_args=(self, provider.connection, provider.region))
def stop(self):
for endpoint in self.endpoints.values():
endpoint.stop()
self.stopEndpoints()

View File

@ -356,50 +356,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
def __init__(self, driver, connection, region):
name = f'{connection.connection_name}-{region}'
super().__init__(driver, connection, name)
self.region = region
# Wrap these instance methods with a per-instance LRU cache so
# that we don't leak memory over time when the adapter is
# occasionally replaced.
self._getInstanceType = functools.lru_cache(maxsize=None)(
self._getInstanceType)
self._getImage = functools.lru_cache(maxsize=None)(
self._getImage)
self.log = logging.getLogger(f"zuul.aws.{self.name}")
self._running = True
# AWS has a default rate limit for creating instances that
# works out to a sustained 2 instances/sec, but the actual
# create instance API call takes 1 second or more. If we want
# to achieve faster than 1 instance/second throughput, we need
# to parallelize create instance calls, so we set up a
# threadworker to do that.
# A little bit of a heuristic here to set the worker count.
# It appears that AWS typically takes 1-1.5 seconds to execute
# a create API call. Figure out how many we have to do in
# parallel in order to run at the rate limit, then quadruple
# that for headroom. Max out at 8 so we don't end up with too
# many threads. In practice, this will be 8 with the default
# values, and only less if users slow down the rate.
workers = max(min(int(connection.rate * 4), 8), 1)
self.log.info("Create executor with max workers=%s", workers)
self.create_executor = ThreadPoolExecutor(max_workers=workers)
# We can batch delete instances using the AWS API, so to do
# that, create a queue for deletes, and a thread to process
# the queue. It will be greedy and collect as many pending
# instance deletes as possible to delete together. Typically
# under load, that will mean a single instance delete followed
# by larger batches. That strikes a balance between
# responsiveness and efficiency. Reducing the overall number
# of requests leaves more time for create instance calls.
self.delete_host_queue = queue.Queue()
self.delete_instance_queue = queue.Queue()
self.delete_thread = threading.Thread(target=self._deleteThread)
self.delete_thread.daemon = True
self.delete_thread.start()
self.region = region
self.rate_limiter = RateLimiter(self.name,
connection.rate)
@ -416,6 +374,16 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
# minutes.
self.quota_service_rate_limiter = RateLimiter(self.name,
connection.rate)
# Wrap these instance methods with a per-instance LRU cache so
# that we don't leak memory over time when the adapter is
# occasionally replaced.
# TODO: This may be able to be a different kind of cache now
self._getInstanceType = functools.lru_cache(maxsize=None)(
self._getInstanceType)
self._getImage = functools.lru_cache(maxsize=None)(
self._getImage)
self.image_id_by_filter_cache = cachetools.TTLCache(
maxsize=8192, ttl=(5 * 60))
@ -431,6 +399,45 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
self.aws_quotas = self.aws.client("service-quotas")
self.ebs_client = self.aws.client('ebs')
def startEndpoint(self):
self._running = True
self.log.debug("Starting AWS endpoint")
# AWS has a default rate limit for creating instances that
# works out to a sustained 2 instances/sec, but the actual
# create instance API call takes 1 second or more. If we want
# to achieve faster than 1 instance/second throughput, we need
# to parallelize create instance calls, so we set up a
# threadworker to do that.
# A little bit of a heuristic here to set the worker count.
# It appears that AWS typically takes 1-1.5 seconds to execute
# a create API call. Figure out how many we have to do in
# parallel in order to run at the rate limit, then quadruple
# that for headroom. Max out at 8 so we don't end up with too
# many threads. In practice, this will be 8 with the default
# values, and only less if users slow down the rate.
workers = max(min(int(self.connection.rate * 4), 8), 1)
self.log.info("Create executor with max workers=%s", workers)
self.create_executor = ThreadPoolExecutor(
thread_name_prefix=f'aws-create-{self.name}',
max_workers=workers)
# We can batch delete instances using the AWS API, so to do
# that, create a queue for deletes, and a thread to process
# the queue. It will be greedy and collect as many pending
# instance deletes as possible to delete together. Typically
# under load, that will mean a single instance delete followed
# by larger batches. That strikes a balance between
# responsiveness and efficiency. Reducing the overall number
# of requests leaves more time for create instance calls.
self.delete_host_queue = queue.Queue()
self.delete_instance_queue = queue.Queue()
self.delete_thread = threading.Thread(
name=f'aws-delete-{self.name}',
target=self._deleteThread)
self.delete_thread.start()
workers = 10
self.log.info("Create executor with max workers=%s", workers)
self.api_executor = ThreadPoolExecutor(
@ -466,10 +473,14 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
SERVICE_QUOTA_CACHE_TTL, self.api_executor)(
self._listEBSQuotas)
def stop(self):
def stopEndpoint(self):
self.log.debug("Stopping AWS endpoint")
self.create_executor.shutdown()
self.api_executor.shutdown()
self._running = False
self.delete_host_queue.put(None)
self.delete_instance_queue.put(None)
self.delete_thread.join()
def listResources(self, bucket_name):
for host in self._listHosts():
@ -1502,12 +1513,18 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
def _getBatch(the_queue):
records = []
try:
records.append(the_queue.get(block=True, timeout=10))
record = the_queue.get(block=True, timeout=10)
if record is None:
return records
records.append(record)
except queue.Empty:
return []
while True:
try:
records.append(the_queue.get(block=False))
record = the_queue.get(block=False)
if record is None:
return records
records.append(record)
except queue.Empty:
break
# The terminate call has a limit of 1k, but AWS recommends
@ -1527,6 +1544,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
with self.rate_limiter(log.debug, f"Deleted {count} instances"):
self.ec2_client.terminate_instances(InstanceIds=ids)
if not self._running:
return
records = self._getBatch(self.delete_host_queue)
if records:
ids = []

View File

@ -21,16 +21,14 @@ from zuul.driver.openstack import (
openstackprovider,
openstackendpoint,
)
from zuul.provider import EndpointCacheMixin
class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface):
class OpenstackDriver(Driver, EndpointCacheMixin,
ConnectionInterface, ProviderInterface):
name = 'openstack'
_endpoint_class = openstackendpoint.OpenstackProviderEndpoint
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.endpoints = {}
def getConnection(self, name, config):
return openstackconnection.OpenstackConnection(self, name, config)
@ -54,17 +52,11 @@ class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface):
urllib.parse.quote_plus(connection.connection_name),
urllib.parse.quote_plus(region_str),
])
try:
return self.endpoints[endpoint_id]
except KeyError:
pass
endpoint = self._endpoint_class(self, connection, region)
self.endpoints[endpoint_id] = endpoint
return endpoint
return self.getEndpointById(endpoint_id,
create_args=(self, connection, region))
def getEndpoint(self, provider):
return self._getEndpoint(provider.connection, provider.region)
def stop(self):
for endpoint in self.endpoints.values():
endpoint.stop()
self.stopEndpoints()

View File

@ -385,6 +385,7 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint):
def __init__(self, driver, connection, region):
name = f'{connection.connection_name}-{region}'
super().__init__(driver, connection, name)
self.log = logging.getLogger(f"zuul.openstack.{self.name}")
self.region = region
# Wrap these instance methods with a per-instance LRU cache so
@ -399,9 +400,15 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint):
self._listAZs = functools.lru_cache(maxsize=None)(
self._listAZs)
self.log = logging.getLogger(f"zuul.openstack.{self.name}")
self._running = True
self.rate_limiter = RateLimiter(self.name, connection.rate)
self._last_image_check_failure = time.time()
self._last_port_cleanup = None
self._client = self._getClient()
def startEndpoint(self):
self.log.debug("Starting OpenStack endpoint")
self._running = True
# The default http connection pool size is 10; match it for
# efficiency.
workers = 10
@ -424,14 +431,8 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint):
CACHE_TTL, self.api_executor)(
self._listFloatingIps)
self.rate_limiter = RateLimiter(self.name,
connection.rate)
self._last_image_check_failure = time.time()
self._last_port_cleanup = None
self._client = self._getClient()
def stop(self):
def stopEndpoint(self):
self.log.debug("Stopping OpenStack endpoint")
self.api_executor.shutdown()
self._running = False

View File

@ -881,6 +881,7 @@ class Launcher:
self.connections.stop()
self.upload_executor.shutdown()
self.endpoint_upload_executor.shutdown()
# Endpoints are stopped by drivers
self.log.debug("Stopped launcher")
def join(self):
@ -928,6 +929,7 @@ class Launcher:
continue
endpoint = provider.getEndpoint()
endpoints[endpoint.canonical_name] = endpoint
endpoint.start()
self.endpoints = endpoints
return updated

View File

@ -15,6 +15,7 @@
import abc
import json
import math
import threading
import urllib.parse
from zuul.lib.voluputil import Required, Optional, Nullable, assemble
@ -121,6 +122,23 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta):
self.driver = driver
self.connection = connection
self.name = name
self.start_lock = threading.Lock()
self.started = False
self.stopped = False
def start(self):
with self.start_lock:
if not self.stopped and not self.started:
self.startEndpoint()
self.started = True
def stop(self):
with self.start_lock:
if self.started:
self.stopEndpoint()
# Set the stopped flag regardless of whether we started so
# that we won't start after stopping.
self.stopped = True
@property
def canonical_name(self):
@ -129,6 +147,21 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta):
urllib.parse.quote_plus(self.name),
])
def handleStart(self):
"""Start the endpoint
This method may start any threads necessary for the endpoint.
"""
raise NotImplementedError()
def handleStop(self):
"""Stop the endpoint
This method must stop all endpoint threads.
"""
raise NotImplementedError()
class BaseProviderSchema(metaclass=abc.ABCMeta):
def getLabelSchema(self):
@ -526,3 +559,25 @@ class BaseProvider(zkobject.PolymorphicZKObjectMixin,
:param ProviderNode node: The node of the server
"""
pass
class EndpointCacheMixin:
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.endpoints = {}
self.endpoints_lock = threading.Lock()
def getEndpointById(self, endpoint_id, create_args):
with self.endpoints_lock:
try:
return self.endpoints[endpoint_id]
except KeyError:
pass
endpoint = self._endpoint_class(*create_args)
self.endpoints[endpoint_id] = endpoint
return endpoint
def stopEndpoints(self):
with self.endpoints_lock:
for endpoint in self.endpoints.values():
endpoint.stop()