diff --git a/tests/unit/test_aws_driver.py b/tests/unit/test_aws_driver.py index 4d1d0880f2..db76152ec7 100644 --- a/tests/unit/test_aws_driver.py +++ b/tests/unit/test_aws_driver.py @@ -216,6 +216,8 @@ class TestAwsDriver(BaseCloudDriverTest): layout = self.scheds.first.sched.abide.tenants.get('tenant-one').layout provider = layout.providers['aws-us-east-1-main'] + # Start the endpoint since we're going to use the scheduler's endpoint. + provider.getEndpoint().start() with self.createZKContext(None) as ctx: node = AwsProviderNode.new(ctx, label=label) diff --git a/tests/unit/test_launcher.py b/tests/unit/test_launcher.py index 997ed88d1c..5d696d912e 100644 --- a/tests/unit/test_launcher.py +++ b/tests/unit/test_launcher.py @@ -62,12 +62,9 @@ class LauncherBaseTestCase(ZuulTestCase): self.s3.create_bucket( Bucket='zuul', CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + self.addCleanup(self.mock_aws.stop) super().setUp() - def tearDown(self): - self.mock_aws.stop() - super().tearDown() - def _nodes_by_label(self): nodes = self.launcher.api.nodes_cache.getItems() nodes_by_label = defaultdict(list) diff --git a/zuul/driver/aws/__init__.py b/zuul/driver/aws/__init__.py index 04ad9318e0..b302e66cda 100644 --- a/zuul/driver/aws/__init__.py +++ b/zuul/driver/aws/__init__.py @@ -16,16 +16,14 @@ import urllib from zuul.driver import Driver, ConnectionInterface, ProviderInterface from zuul.driver.aws import awsconnection, awsmodel, awsprovider, awsendpoint +from zuul.provider import EndpointCacheMixin -class AwsDriver(Driver, ConnectionInterface, ProviderInterface): +class AwsDriver(Driver, EndpointCacheMixin, + ConnectionInterface, ProviderInterface): name = 'aws' _endpoint_class = awsendpoint.AwsProviderEndpoint - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - self.endpoints = {} - def getConnection(self, name, config): return awsconnection.AwsConnection(self, name, config) @@ -52,15 +50,9 @@ class AwsDriver(Driver, ConnectionInterface, ProviderInterface): urllib.parse.quote_plus(provider.connection.connection_name), urllib.parse.quote_plus(provider.region), ]) - try: - return self.endpoints[endpoint_id] - except KeyError: - pass - endpoint = self._endpoint_class( - self, provider.connection, provider.region) - self.endpoints[endpoint_id] = endpoint - return endpoint + return self.getEndpointById( + endpoint_id, + create_args=(self, provider.connection, provider.region)) def stop(self): - for endpoint in self.endpoints.values(): - endpoint.stop() + self.stopEndpoints() diff --git a/zuul/driver/aws/awsendpoint.py b/zuul/driver/aws/awsendpoint.py index 379be0c66a..7cf370334b 100644 --- a/zuul/driver/aws/awsendpoint.py +++ b/zuul/driver/aws/awsendpoint.py @@ -356,50 +356,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint): def __init__(self, driver, connection, region): name = f'{connection.connection_name}-{region}' super().__init__(driver, connection, name) - self.region = region - - # Wrap these instance methods with a per-instance LRU cache so - # that we don't leak memory over time when the adapter is - # occasionally replaced. - self._getInstanceType = functools.lru_cache(maxsize=None)( - self._getInstanceType) - self._getImage = functools.lru_cache(maxsize=None)( - self._getImage) - self.log = logging.getLogger(f"zuul.aws.{self.name}") - self._running = True - - # AWS has a default rate limit for creating instances that - # works out to a sustained 2 instances/sec, but the actual - # create instance API call takes 1 second or more. If we want - # to achieve faster than 1 instance/second throughput, we need - # to parallelize create instance calls, so we set up a - # threadworker to do that. - - # A little bit of a heuristic here to set the worker count. - # It appears that AWS typically takes 1-1.5 seconds to execute - # a create API call. Figure out how many we have to do in - # parallel in order to run at the rate limit, then quadruple - # that for headroom. Max out at 8 so we don't end up with too - # many threads. In practice, this will be 8 with the default - # values, and only less if users slow down the rate. - workers = max(min(int(connection.rate * 4), 8), 1) - self.log.info("Create executor with max workers=%s", workers) - self.create_executor = ThreadPoolExecutor(max_workers=workers) - - # We can batch delete instances using the AWS API, so to do - # that, create a queue for deletes, and a thread to process - # the queue. It will be greedy and collect as many pending - # instance deletes as possible to delete together. Typically - # under load, that will mean a single instance delete followed - # by larger batches. That strikes a balance between - # responsiveness and efficiency. Reducing the overall number - # of requests leaves more time for create instance calls. - self.delete_host_queue = queue.Queue() - self.delete_instance_queue = queue.Queue() - self.delete_thread = threading.Thread(target=self._deleteThread) - self.delete_thread.daemon = True - self.delete_thread.start() + self.region = region self.rate_limiter = RateLimiter(self.name, connection.rate) @@ -416,6 +374,16 @@ class AwsProviderEndpoint(BaseProviderEndpoint): # minutes. self.quota_service_rate_limiter = RateLimiter(self.name, connection.rate) + + # Wrap these instance methods with a per-instance LRU cache so + # that we don't leak memory over time when the adapter is + # occasionally replaced. + # TODO: This may be able to be a different kind of cache now + self._getInstanceType = functools.lru_cache(maxsize=None)( + self._getInstanceType) + self._getImage = functools.lru_cache(maxsize=None)( + self._getImage) + self.image_id_by_filter_cache = cachetools.TTLCache( maxsize=8192, ttl=(5 * 60)) @@ -431,6 +399,45 @@ class AwsProviderEndpoint(BaseProviderEndpoint): self.aws_quotas = self.aws.client("service-quotas") self.ebs_client = self.aws.client('ebs') + def startEndpoint(self): + self._running = True + self.log.debug("Starting AWS endpoint") + + # AWS has a default rate limit for creating instances that + # works out to a sustained 2 instances/sec, but the actual + # create instance API call takes 1 second or more. If we want + # to achieve faster than 1 instance/second throughput, we need + # to parallelize create instance calls, so we set up a + # threadworker to do that. + + # A little bit of a heuristic here to set the worker count. + # It appears that AWS typically takes 1-1.5 seconds to execute + # a create API call. Figure out how many we have to do in + # parallel in order to run at the rate limit, then quadruple + # that for headroom. Max out at 8 so we don't end up with too + # many threads. In practice, this will be 8 with the default + # values, and only less if users slow down the rate. + workers = max(min(int(self.connection.rate * 4), 8), 1) + self.log.info("Create executor with max workers=%s", workers) + self.create_executor = ThreadPoolExecutor( + thread_name_prefix=f'aws-create-{self.name}', + max_workers=workers) + + # We can batch delete instances using the AWS API, so to do + # that, create a queue for deletes, and a thread to process + # the queue. It will be greedy and collect as many pending + # instance deletes as possible to delete together. Typically + # under load, that will mean a single instance delete followed + # by larger batches. That strikes a balance between + # responsiveness and efficiency. Reducing the overall number + # of requests leaves more time for create instance calls. + self.delete_host_queue = queue.Queue() + self.delete_instance_queue = queue.Queue() + self.delete_thread = threading.Thread( + name=f'aws-delete-{self.name}', + target=self._deleteThread) + self.delete_thread.start() + workers = 10 self.log.info("Create executor with max workers=%s", workers) self.api_executor = ThreadPoolExecutor( @@ -466,10 +473,14 @@ class AwsProviderEndpoint(BaseProviderEndpoint): SERVICE_QUOTA_CACHE_TTL, self.api_executor)( self._listEBSQuotas) - def stop(self): + def stopEndpoint(self): + self.log.debug("Stopping AWS endpoint") self.create_executor.shutdown() self.api_executor.shutdown() self._running = False + self.delete_host_queue.put(None) + self.delete_instance_queue.put(None) + self.delete_thread.join() def listResources(self, bucket_name): for host in self._listHosts(): @@ -1502,12 +1513,18 @@ class AwsProviderEndpoint(BaseProviderEndpoint): def _getBatch(the_queue): records = [] try: - records.append(the_queue.get(block=True, timeout=10)) + record = the_queue.get(block=True, timeout=10) + if record is None: + return records + records.append(record) except queue.Empty: return [] while True: try: - records.append(the_queue.get(block=False)) + record = the_queue.get(block=False) + if record is None: + return records + records.append(record) except queue.Empty: break # The terminate call has a limit of 1k, but AWS recommends @@ -1527,6 +1544,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint): with self.rate_limiter(log.debug, f"Deleted {count} instances"): self.ec2_client.terminate_instances(InstanceIds=ids) + if not self._running: + return records = self._getBatch(self.delete_host_queue) if records: ids = [] diff --git a/zuul/driver/openstack/__init__.py b/zuul/driver/openstack/__init__.py index c2b02fa4ec..b6be116982 100644 --- a/zuul/driver/openstack/__init__.py +++ b/zuul/driver/openstack/__init__.py @@ -21,16 +21,14 @@ from zuul.driver.openstack import ( openstackprovider, openstackendpoint, ) +from zuul.provider import EndpointCacheMixin -class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface): +class OpenstackDriver(Driver, EndpointCacheMixin, + ConnectionInterface, ProviderInterface): name = 'openstack' _endpoint_class = openstackendpoint.OpenstackProviderEndpoint - def __init__(self, *args, **kw): - super().__init__(*args, **kw) - self.endpoints = {} - def getConnection(self, name, config): return openstackconnection.OpenstackConnection(self, name, config) @@ -54,17 +52,11 @@ class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface): urllib.parse.quote_plus(connection.connection_name), urllib.parse.quote_plus(region_str), ]) - try: - return self.endpoints[endpoint_id] - except KeyError: - pass - endpoint = self._endpoint_class(self, connection, region) - self.endpoints[endpoint_id] = endpoint - return endpoint + return self.getEndpointById(endpoint_id, + create_args=(self, connection, region)) def getEndpoint(self, provider): return self._getEndpoint(provider.connection, provider.region) def stop(self): - for endpoint in self.endpoints.values(): - endpoint.stop() + self.stopEndpoints() diff --git a/zuul/driver/openstack/openstackendpoint.py b/zuul/driver/openstack/openstackendpoint.py index bc68ee285d..26e22a8678 100644 --- a/zuul/driver/openstack/openstackendpoint.py +++ b/zuul/driver/openstack/openstackendpoint.py @@ -385,6 +385,7 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint): def __init__(self, driver, connection, region): name = f'{connection.connection_name}-{region}' super().__init__(driver, connection, name) + self.log = logging.getLogger(f"zuul.openstack.{self.name}") self.region = region # Wrap these instance methods with a per-instance LRU cache so @@ -399,9 +400,15 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint): self._listAZs = functools.lru_cache(maxsize=None)( self._listAZs) - self.log = logging.getLogger(f"zuul.openstack.{self.name}") - self._running = True + self.rate_limiter = RateLimiter(self.name, connection.rate) + self._last_image_check_failure = time.time() + self._last_port_cleanup = None + self._client = self._getClient() + + def startEndpoint(self): + self.log.debug("Starting OpenStack endpoint") + self._running = True # The default http connection pool size is 10; match it for # efficiency. workers = 10 @@ -424,14 +431,8 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint): CACHE_TTL, self.api_executor)( self._listFloatingIps) - self.rate_limiter = RateLimiter(self.name, - connection.rate) - - self._last_image_check_failure = time.time() - self._last_port_cleanup = None - self._client = self._getClient() - - def stop(self): + def stopEndpoint(self): + self.log.debug("Stopping OpenStack endpoint") self.api_executor.shutdown() self._running = False diff --git a/zuul/launcher/server.py b/zuul/launcher/server.py index 80c3d582dd..9cd4691784 100644 --- a/zuul/launcher/server.py +++ b/zuul/launcher/server.py @@ -881,6 +881,7 @@ class Launcher: self.connections.stop() self.upload_executor.shutdown() self.endpoint_upload_executor.shutdown() + # Endpoints are stopped by drivers self.log.debug("Stopped launcher") def join(self): @@ -928,6 +929,7 @@ class Launcher: continue endpoint = provider.getEndpoint() endpoints[endpoint.canonical_name] = endpoint + endpoint.start() self.endpoints = endpoints return updated diff --git a/zuul/provider/__init__.py b/zuul/provider/__init__.py index 8641bbfb35..1fd8c91937 100644 --- a/zuul/provider/__init__.py +++ b/zuul/provider/__init__.py @@ -15,6 +15,7 @@ import abc import json import math +import threading import urllib.parse from zuul.lib.voluputil import Required, Optional, Nullable, assemble @@ -121,6 +122,23 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta): self.driver = driver self.connection = connection self.name = name + self.start_lock = threading.Lock() + self.started = False + self.stopped = False + + def start(self): + with self.start_lock: + if not self.stopped and not self.started: + self.startEndpoint() + self.started = True + + def stop(self): + with self.start_lock: + if self.started: + self.stopEndpoint() + # Set the stopped flag regardless of whether we started so + # that we won't start after stopping. + self.stopped = True @property def canonical_name(self): @@ -129,6 +147,21 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta): urllib.parse.quote_plus(self.name), ]) + def handleStart(self): + """Start the endpoint + + This method may start any threads necessary for the endpoint. + + """ + raise NotImplementedError() + + def handleStop(self): + """Stop the endpoint + + This method must stop all endpoint threads. + """ + raise NotImplementedError() + class BaseProviderSchema(metaclass=abc.ABCMeta): def getLabelSchema(self): @@ -526,3 +559,25 @@ class BaseProvider(zkobject.PolymorphicZKObjectMixin, :param ProviderNode node: The node of the server """ pass + + +class EndpointCacheMixin: + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.endpoints = {} + self.endpoints_lock = threading.Lock() + + def getEndpointById(self, endpoint_id, create_args): + with self.endpoints_lock: + try: + return self.endpoints[endpoint_id] + except KeyError: + pass + endpoint = self._endpoint_class(*create_args) + self.endpoints[endpoint_id] = endpoint + return endpoint + + def stopEndpoints(self): + with self.endpoints_lock: + for endpoint in self.endpoints.values(): + endpoint.stop()