From a14971778a6513ad01ba7749e4c51c9409d2c5fd Mon Sep 17 00:00:00 2001 From: "James E. Blair" Date: Fri, 25 Oct 2024 08:13:40 -0700 Subject: [PATCH] Fix endpoint races This change fixes two endpoint-related races: * We now shut down the aws driver delete thread rather than setting it to daemon. This provides more predictable behavior in tests. Some tests started launcher endpoints out-of-band (ie, they didn't really use the launcher, but it still started and stopped outside of the main action of the test). In order to ensure that every endpoint that we start threads for is also correctly stopped, standard stop/start methods are added to endpoints and locks are used to ensure that the start/stop sequence is predictable. This also lets us use endpoints in the scheduler/web without starting background threads for them. * It was possible to create multiple endpoints for a given region if they were created around the same time since we did not use a lock around the critical section of the cache. To correct that, surrond the cache lookup with a lock and move the behavior to a mixin class for reusability. (A cachetools decorator around getEndpoint was considered but the provider object in the method signature makes that difficult). Change-Id: I5e8ca06e76fced1bb342250a953ecda576168874 --- tests/unit/test_aws_driver.py | 2 + tests/unit/test_launcher.py | 5 +- zuul/driver/aws/__init__.py | 22 ++-- zuul/driver/aws/awsendpoint.py | 111 ++++++++++++--------- zuul/driver/openstack/__init__.py | 20 ++-- zuul/driver/openstack/openstackendpoint.py | 21 ++-- zuul/launcher/server.py | 2 + zuul/provider/__init__.py | 55 ++++++++++ 8 files changed, 149 insertions(+), 89 deletions(-) diff --git a/tests/unit/test_aws_driver.py b/tests/unit/test_aws_driver.py index 4d1d0880f2..db76152ec7 100644 --- a/tests/unit/test_aws_driver.py +++ b/tests/unit/test_aws_driver.py @@ -216,6 +216,8 @@ class TestAwsDriver(BaseCloudDriverTest): layout = self.scheds.first.sched.abide.tenants.get('tenant-one').layout provider = layout.providers['aws-us-east-1-main'] + # Start the endpoint since we're going to use the scheduler's endpoint. + provider.getEndpoint().start() with self.createZKContext(None) as ctx: node = AwsProviderNode.new(ctx, label=label) diff --git a/tests/unit/test_launcher.py b/tests/unit/test_launcher.py index 997ed88d1c..5d696d912e 100644 --- a/tests/unit/test_launcher.py +++ b/tests/unit/test_launcher.py @@ -62,12 +62,9 @@ class LauncherBaseTestCase(ZuulTestCase): self.s3.create_bucket( Bucket='zuul', CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + self.addCleanup(self.mock_aws.stop) super().setUp() - def tearDown(self): - self.mock_aws.stop() - super().tearDown() - def _nodes_by_label(self): nodes = self.launcher.api.nodes_cache.getItems() nodes_by_label = defaultdict(list) diff --git a/zuul/driver/aws/__init__.py b/zuul/driver/aws/__init__.py index 04ad9318e0..b302e66cda 100644 --- a/zuul/driver/aws/__init__.py +++ b/zuul/driver/aws/__init__.py @@ -16,16 +16,14 @@ import urllib from zuul.driver import Driver, ConnectionInterface, ProviderInterface from zuul.driver.aws import awsconnection, awsmodel, awsprovider, awsendpoint +from zuul.provider import EndpointCacheMixin -class AwsDriver(Driver, ConnectionInterface, ProviderInterface): +class AwsDriver(Driver, EndpointCacheMixin, + ConnectionInterface, ProviderInterface): name = 'aws' _endpoint_class = awsendpoint.AwsProviderEndpoint - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - self.endpoints = {} - def getConnection(self, name, config): return awsconnection.AwsConnection(self, name, config) @@ -52,15 +50,9 @@ class AwsDriver(Driver, ConnectionInterface, ProviderInterface): urllib.parse.quote_plus(provider.connection.connection_name), urllib.parse.quote_plus(provider.region), ]) - try: - return self.endpoints[endpoint_id] - except KeyError: - pass - endpoint = self._endpoint_class( - self, provider.connection, provider.region) - self.endpoints[endpoint_id] = endpoint - return endpoint + return self.getEndpointById( + endpoint_id, + create_args=(self, provider.connection, provider.region)) def stop(self): - for endpoint in self.endpoints.values(): - endpoint.stop() + self.stopEndpoints() diff --git a/zuul/driver/aws/awsendpoint.py b/zuul/driver/aws/awsendpoint.py index 379be0c66a..7cf370334b 100644 --- a/zuul/driver/aws/awsendpoint.py +++ b/zuul/driver/aws/awsendpoint.py @@ -356,50 +356,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint): def __init__(self, driver, connection, region): name = f'{connection.connection_name}-{region}' super().__init__(driver, connection, name) - self.region = region - - # Wrap these instance methods with a per-instance LRU cache so - # that we don't leak memory over time when the adapter is - # occasionally replaced. - self._getInstanceType = functools.lru_cache(maxsize=None)( - self._getInstanceType) - self._getImage = functools.lru_cache(maxsize=None)( - self._getImage) - self.log = logging.getLogger(f"zuul.aws.{self.name}") - self._running = True - - # AWS has a default rate limit for creating instances that - # works out to a sustained 2 instances/sec, but the actual - # create instance API call takes 1 second or more. If we want - # to achieve faster than 1 instance/second throughput, we need - # to parallelize create instance calls, so we set up a - # threadworker to do that. - - # A little bit of a heuristic here to set the worker count. - # It appears that AWS typically takes 1-1.5 seconds to execute - # a create API call. Figure out how many we have to do in - # parallel in order to run at the rate limit, then quadruple - # that for headroom. Max out at 8 so we don't end up with too - # many threads. In practice, this will be 8 with the default - # values, and only less if users slow down the rate. - workers = max(min(int(connection.rate * 4), 8), 1) - self.log.info("Create executor with max workers=%s", workers) - self.create_executor = ThreadPoolExecutor(max_workers=workers) - - # We can batch delete instances using the AWS API, so to do - # that, create a queue for deletes, and a thread to process - # the queue. It will be greedy and collect as many pending - # instance deletes as possible to delete together. Typically - # under load, that will mean a single instance delete followed - # by larger batches. That strikes a balance between - # responsiveness and efficiency. Reducing the overall number - # of requests leaves more time for create instance calls. - self.delete_host_queue = queue.Queue() - self.delete_instance_queue = queue.Queue() - self.delete_thread = threading.Thread(target=self._deleteThread) - self.delete_thread.daemon = True - self.delete_thread.start() + self.region = region self.rate_limiter = RateLimiter(self.name, connection.rate) @@ -416,6 +374,16 @@ class AwsProviderEndpoint(BaseProviderEndpoint): # minutes. self.quota_service_rate_limiter = RateLimiter(self.name, connection.rate) + + # Wrap these instance methods with a per-instance LRU cache so + # that we don't leak memory over time when the adapter is + # occasionally replaced. + # TODO: This may be able to be a different kind of cache now + self._getInstanceType = functools.lru_cache(maxsize=None)( + self._getInstanceType) + self._getImage = functools.lru_cache(maxsize=None)( + self._getImage) + self.image_id_by_filter_cache = cachetools.TTLCache( maxsize=8192, ttl=(5 * 60)) @@ -431,6 +399,45 @@ class AwsProviderEndpoint(BaseProviderEndpoint): self.aws_quotas = self.aws.client("service-quotas") self.ebs_client = self.aws.client('ebs') + def startEndpoint(self): + self._running = True + self.log.debug("Starting AWS endpoint") + + # AWS has a default rate limit for creating instances that + # works out to a sustained 2 instances/sec, but the actual + # create instance API call takes 1 second or more. If we want + # to achieve faster than 1 instance/second throughput, we need + # to parallelize create instance calls, so we set up a + # threadworker to do that. + + # A little bit of a heuristic here to set the worker count. + # It appears that AWS typically takes 1-1.5 seconds to execute + # a create API call. Figure out how many we have to do in + # parallel in order to run at the rate limit, then quadruple + # that for headroom. Max out at 8 so we don't end up with too + # many threads. In practice, this will be 8 with the default + # values, and only less if users slow down the rate. + workers = max(min(int(self.connection.rate * 4), 8), 1) + self.log.info("Create executor with max workers=%s", workers) + self.create_executor = ThreadPoolExecutor( + thread_name_prefix=f'aws-create-{self.name}', + max_workers=workers) + + # We can batch delete instances using the AWS API, so to do + # that, create a queue for deletes, and a thread to process + # the queue. It will be greedy and collect as many pending + # instance deletes as possible to delete together. Typically + # under load, that will mean a single instance delete followed + # by larger batches. That strikes a balance between + # responsiveness and efficiency. Reducing the overall number + # of requests leaves more time for create instance calls. + self.delete_host_queue = queue.Queue() + self.delete_instance_queue = queue.Queue() + self.delete_thread = threading.Thread( + name=f'aws-delete-{self.name}', + target=self._deleteThread) + self.delete_thread.start() + workers = 10 self.log.info("Create executor with max workers=%s", workers) self.api_executor = ThreadPoolExecutor( @@ -466,10 +473,14 @@ class AwsProviderEndpoint(BaseProviderEndpoint): SERVICE_QUOTA_CACHE_TTL, self.api_executor)( self._listEBSQuotas) - def stop(self): + def stopEndpoint(self): + self.log.debug("Stopping AWS endpoint") self.create_executor.shutdown() self.api_executor.shutdown() self._running = False + self.delete_host_queue.put(None) + self.delete_instance_queue.put(None) + self.delete_thread.join() def listResources(self, bucket_name): for host in self._listHosts(): @@ -1502,12 +1513,18 @@ class AwsProviderEndpoint(BaseProviderEndpoint): def _getBatch(the_queue): records = [] try: - records.append(the_queue.get(block=True, timeout=10)) + record = the_queue.get(block=True, timeout=10) + if record is None: + return records + records.append(record) except queue.Empty: return [] while True: try: - records.append(the_queue.get(block=False)) + record = the_queue.get(block=False) + if record is None: + return records + records.append(record) except queue.Empty: break # The terminate call has a limit of 1k, but AWS recommends @@ -1527,6 +1544,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint): with self.rate_limiter(log.debug, f"Deleted {count} instances"): self.ec2_client.terminate_instances(InstanceIds=ids) + if not self._running: + return records = self._getBatch(self.delete_host_queue) if records: ids = [] diff --git a/zuul/driver/openstack/__init__.py b/zuul/driver/openstack/__init__.py index c2b02fa4ec..b6be116982 100644 --- a/zuul/driver/openstack/__init__.py +++ b/zuul/driver/openstack/__init__.py @@ -21,16 +21,14 @@ from zuul.driver.openstack import ( openstackprovider, openstackendpoint, ) +from zuul.provider import EndpointCacheMixin -class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface): +class OpenstackDriver(Driver, EndpointCacheMixin, + ConnectionInterface, ProviderInterface): name = 'openstack' _endpoint_class = openstackendpoint.OpenstackProviderEndpoint - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - self.endpoints = {} - def getConnection(self, name, config): return openstackconnection.OpenstackConnection(self, name, config) @@ -54,17 +52,11 @@ class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface): urllib.parse.quote_plus(connection.connection_name), urllib.parse.quote_plus(region_str), ]) - try: - return self.endpoints[endpoint_id] - except KeyError: - pass - endpoint = self._endpoint_class(self, connection, region) - self.endpoints[endpoint_id] = endpoint - return endpoint + return self.getEndpointById(endpoint_id, + create_args=(self, connection, region)) def getEndpoint(self, provider): return self._getEndpoint(provider.connection, provider.region) def stop(self): - for endpoint in self.endpoints.values(): - endpoint.stop() + self.stopEndpoints() diff --git a/zuul/driver/openstack/openstackendpoint.py b/zuul/driver/openstack/openstackendpoint.py index bc68ee285d..26e22a8678 100644 --- a/zuul/driver/openstack/openstackendpoint.py +++ b/zuul/driver/openstack/openstackendpoint.py @@ -385,6 +385,7 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint): def __init__(self, driver, connection, region): name = f'{connection.connection_name}-{region}' super().__init__(driver, connection, name) + self.log = logging.getLogger(f"zuul.openstack.{self.name}") self.region = region # Wrap these instance methods with a per-instance LRU cache so @@ -399,9 +400,15 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint): self._listAZs = functools.lru_cache(maxsize=None)( self._listAZs) - self.log = logging.getLogger(f"zuul.openstack.{self.name}") - self._running = True + self.rate_limiter = RateLimiter(self.name, connection.rate) + self._last_image_check_failure = time.time() + self._last_port_cleanup = None + self._client = self._getClient() + + def startEndpoint(self): + self.log.debug("Starting OpenStack endpoint") + self._running = True # The default http connection pool size is 10; match it for # efficiency. workers = 10 @@ -424,14 +431,8 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint): CACHE_TTL, self.api_executor)( self._listFloatingIps) - self.rate_limiter = RateLimiter(self.name, - connection.rate) - - self._last_image_check_failure = time.time() - self._last_port_cleanup = None - self._client = self._getClient() - - def stop(self): + def stopEndpoint(self): + self.log.debug("Stopping OpenStack endpoint") self.api_executor.shutdown() self._running = False diff --git a/zuul/launcher/server.py b/zuul/launcher/server.py index 80c3d582dd..9cd4691784 100644 --- a/zuul/launcher/server.py +++ b/zuul/launcher/server.py @@ -881,6 +881,7 @@ class Launcher: self.connections.stop() self.upload_executor.shutdown() self.endpoint_upload_executor.shutdown() + # Endpoints are stopped by drivers self.log.debug("Stopped launcher") def join(self): @@ -928,6 +929,7 @@ class Launcher: continue endpoint = provider.getEndpoint() endpoints[endpoint.canonical_name] = endpoint + endpoint.start() self.endpoints = endpoints return updated diff --git a/zuul/provider/__init__.py b/zuul/provider/__init__.py index 8641bbfb35..1fd8c91937 100644 --- a/zuul/provider/__init__.py +++ b/zuul/provider/__init__.py @@ -15,6 +15,7 @@ import abc import json import math +import threading import urllib.parse from zuul.lib.voluputil import Required, Optional, Nullable, assemble @@ -121,6 +122,23 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta): self.driver = driver self.connection = connection self.name = name + self.start_lock = threading.Lock() + self.started = False + self.stopped = False + + def start(self): + with self.start_lock: + if not self.stopped and not self.started: + self.startEndpoint() + self.started = True + + def stop(self): + with self.start_lock: + if self.started: + self.stopEndpoint() + # Set the stopped flag regardless of whether we started so + # that we won't start after stopping. + self.stopped = True @property def canonical_name(self): @@ -129,6 +147,21 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta): urllib.parse.quote_plus(self.name), ]) + def handleStart(self): + """Start the endpoint + + This method may start any threads necessary for the endpoint. + + """ + raise NotImplementedError() + + def handleStop(self): + """Stop the endpoint + + This method must stop all endpoint threads. + """ + raise NotImplementedError() + class BaseProviderSchema(metaclass=abc.ABCMeta): def getLabelSchema(self): @@ -526,3 +559,25 @@ class BaseProvider(zkobject.PolymorphicZKObjectMixin, :param ProviderNode node: The node of the server """ pass + + +class EndpointCacheMixin: + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.endpoints = {} + self.endpoints_lock = threading.Lock() + + def getEndpointById(self, endpoint_id, create_args): + with self.endpoints_lock: + try: + return self.endpoints[endpoint_id] + except KeyError: + pass + endpoint = self._endpoint_class(*create_args) + self.endpoints[endpoint_id] = endpoint + return endpoint + + def stopEndpoints(self): + with self.endpoints_lock: + for endpoint in self.endpoints.values(): + endpoint.stop()