Fix endpoint races
This change fixes two endpoint-related races: * We now shut down the aws driver delete thread rather than setting it to daemon. This provides more predictable behavior in tests. Some tests started launcher endpoints out-of-band (ie, they didn't really use the launcher, but it still started and stopped outside of the main action of the test). In order to ensure that every endpoint that we start threads for is also correctly stopped, standard stop/start methods are added to endpoints and locks are used to ensure that the start/stop sequence is predictable. This also lets us use endpoints in the scheduler/web without starting background threads for them. * It was possible to create multiple endpoints for a given region if they were created around the same time since we did not use a lock around the critical section of the cache. To correct that, surrond the cache lookup with a lock and move the behavior to a mixin class for reusability. (A cachetools decorator around getEndpoint was considered but the provider object in the method signature makes that difficult). Change-Id: I5e8ca06e76fced1bb342250a953ecda576168874
This commit is contained in:
parent
407fce6494
commit
a14971778a
@ -216,6 +216,8 @@ class TestAwsDriver(BaseCloudDriverTest):
|
||||
|
||||
layout = self.scheds.first.sched.abide.tenants.get('tenant-one').layout
|
||||
provider = layout.providers['aws-us-east-1-main']
|
||||
# Start the endpoint since we're going to use the scheduler's endpoint.
|
||||
provider.getEndpoint().start()
|
||||
|
||||
with self.createZKContext(None) as ctx:
|
||||
node = AwsProviderNode.new(ctx, label=label)
|
||||
|
@ -62,12 +62,9 @@ class LauncherBaseTestCase(ZuulTestCase):
|
||||
self.s3.create_bucket(
|
||||
Bucket='zuul',
|
||||
CreateBucketConfiguration={'LocationConstraint': 'us-west-2'})
|
||||
self.addCleanup(self.mock_aws.stop)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
self.mock_aws.stop()
|
||||
super().tearDown()
|
||||
|
||||
def _nodes_by_label(self):
|
||||
nodes = self.launcher.api.nodes_cache.getItems()
|
||||
nodes_by_label = defaultdict(list)
|
||||
|
@ -16,16 +16,14 @@ import urllib
|
||||
|
||||
from zuul.driver import Driver, ConnectionInterface, ProviderInterface
|
||||
from zuul.driver.aws import awsconnection, awsmodel, awsprovider, awsendpoint
|
||||
from zuul.provider import EndpointCacheMixin
|
||||
|
||||
|
||||
class AwsDriver(Driver, ConnectionInterface, ProviderInterface):
|
||||
class AwsDriver(Driver, EndpointCacheMixin,
|
||||
ConnectionInterface, ProviderInterface):
|
||||
name = 'aws'
|
||||
_endpoint_class = awsendpoint.AwsProviderEndpoint
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
super().__init__(*args, **kw)
|
||||
self.endpoints = {}
|
||||
|
||||
def getConnection(self, name, config):
|
||||
return awsconnection.AwsConnection(self, name, config)
|
||||
|
||||
@ -52,15 +50,9 @@ class AwsDriver(Driver, ConnectionInterface, ProviderInterface):
|
||||
urllib.parse.quote_plus(provider.connection.connection_name),
|
||||
urllib.parse.quote_plus(provider.region),
|
||||
])
|
||||
try:
|
||||
return self.endpoints[endpoint_id]
|
||||
except KeyError:
|
||||
pass
|
||||
endpoint = self._endpoint_class(
|
||||
self, provider.connection, provider.region)
|
||||
self.endpoints[endpoint_id] = endpoint
|
||||
return endpoint
|
||||
return self.getEndpointById(
|
||||
endpoint_id,
|
||||
create_args=(self, provider.connection, provider.region))
|
||||
|
||||
def stop(self):
|
||||
for endpoint in self.endpoints.values():
|
||||
endpoint.stop()
|
||||
self.stopEndpoints()
|
||||
|
@ -356,50 +356,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
|
||||
def __init__(self, driver, connection, region):
|
||||
name = f'{connection.connection_name}-{region}'
|
||||
super().__init__(driver, connection, name)
|
||||
self.region = region
|
||||
|
||||
# Wrap these instance methods with a per-instance LRU cache so
|
||||
# that we don't leak memory over time when the adapter is
|
||||
# occasionally replaced.
|
||||
self._getInstanceType = functools.lru_cache(maxsize=None)(
|
||||
self._getInstanceType)
|
||||
self._getImage = functools.lru_cache(maxsize=None)(
|
||||
self._getImage)
|
||||
|
||||
self.log = logging.getLogger(f"zuul.aws.{self.name}")
|
||||
self._running = True
|
||||
|
||||
# AWS has a default rate limit for creating instances that
|
||||
# works out to a sustained 2 instances/sec, but the actual
|
||||
# create instance API call takes 1 second or more. If we want
|
||||
# to achieve faster than 1 instance/second throughput, we need
|
||||
# to parallelize create instance calls, so we set up a
|
||||
# threadworker to do that.
|
||||
|
||||
# A little bit of a heuristic here to set the worker count.
|
||||
# It appears that AWS typically takes 1-1.5 seconds to execute
|
||||
# a create API call. Figure out how many we have to do in
|
||||
# parallel in order to run at the rate limit, then quadruple
|
||||
# that for headroom. Max out at 8 so we don't end up with too
|
||||
# many threads. In practice, this will be 8 with the default
|
||||
# values, and only less if users slow down the rate.
|
||||
workers = max(min(int(connection.rate * 4), 8), 1)
|
||||
self.log.info("Create executor with max workers=%s", workers)
|
||||
self.create_executor = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
# We can batch delete instances using the AWS API, so to do
|
||||
# that, create a queue for deletes, and a thread to process
|
||||
# the queue. It will be greedy and collect as many pending
|
||||
# instance deletes as possible to delete together. Typically
|
||||
# under load, that will mean a single instance delete followed
|
||||
# by larger batches. That strikes a balance between
|
||||
# responsiveness and efficiency. Reducing the overall number
|
||||
# of requests leaves more time for create instance calls.
|
||||
self.delete_host_queue = queue.Queue()
|
||||
self.delete_instance_queue = queue.Queue()
|
||||
self.delete_thread = threading.Thread(target=self._deleteThread)
|
||||
self.delete_thread.daemon = True
|
||||
self.delete_thread.start()
|
||||
self.region = region
|
||||
|
||||
self.rate_limiter = RateLimiter(self.name,
|
||||
connection.rate)
|
||||
@ -416,6 +374,16 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
|
||||
# minutes.
|
||||
self.quota_service_rate_limiter = RateLimiter(self.name,
|
||||
connection.rate)
|
||||
|
||||
# Wrap these instance methods with a per-instance LRU cache so
|
||||
# that we don't leak memory over time when the adapter is
|
||||
# occasionally replaced.
|
||||
# TODO: This may be able to be a different kind of cache now
|
||||
self._getInstanceType = functools.lru_cache(maxsize=None)(
|
||||
self._getInstanceType)
|
||||
self._getImage = functools.lru_cache(maxsize=None)(
|
||||
self._getImage)
|
||||
|
||||
self.image_id_by_filter_cache = cachetools.TTLCache(
|
||||
maxsize=8192, ttl=(5 * 60))
|
||||
|
||||
@ -431,6 +399,45 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
|
||||
self.aws_quotas = self.aws.client("service-quotas")
|
||||
self.ebs_client = self.aws.client('ebs')
|
||||
|
||||
def startEndpoint(self):
|
||||
self._running = True
|
||||
self.log.debug("Starting AWS endpoint")
|
||||
|
||||
# AWS has a default rate limit for creating instances that
|
||||
# works out to a sustained 2 instances/sec, but the actual
|
||||
# create instance API call takes 1 second or more. If we want
|
||||
# to achieve faster than 1 instance/second throughput, we need
|
||||
# to parallelize create instance calls, so we set up a
|
||||
# threadworker to do that.
|
||||
|
||||
# A little bit of a heuristic here to set the worker count.
|
||||
# It appears that AWS typically takes 1-1.5 seconds to execute
|
||||
# a create API call. Figure out how many we have to do in
|
||||
# parallel in order to run at the rate limit, then quadruple
|
||||
# that for headroom. Max out at 8 so we don't end up with too
|
||||
# many threads. In practice, this will be 8 with the default
|
||||
# values, and only less if users slow down the rate.
|
||||
workers = max(min(int(self.connection.rate * 4), 8), 1)
|
||||
self.log.info("Create executor with max workers=%s", workers)
|
||||
self.create_executor = ThreadPoolExecutor(
|
||||
thread_name_prefix=f'aws-create-{self.name}',
|
||||
max_workers=workers)
|
||||
|
||||
# We can batch delete instances using the AWS API, so to do
|
||||
# that, create a queue for deletes, and a thread to process
|
||||
# the queue. It will be greedy and collect as many pending
|
||||
# instance deletes as possible to delete together. Typically
|
||||
# under load, that will mean a single instance delete followed
|
||||
# by larger batches. That strikes a balance between
|
||||
# responsiveness and efficiency. Reducing the overall number
|
||||
# of requests leaves more time for create instance calls.
|
||||
self.delete_host_queue = queue.Queue()
|
||||
self.delete_instance_queue = queue.Queue()
|
||||
self.delete_thread = threading.Thread(
|
||||
name=f'aws-delete-{self.name}',
|
||||
target=self._deleteThread)
|
||||
self.delete_thread.start()
|
||||
|
||||
workers = 10
|
||||
self.log.info("Create executor with max workers=%s", workers)
|
||||
self.api_executor = ThreadPoolExecutor(
|
||||
@ -466,10 +473,14 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
|
||||
SERVICE_QUOTA_CACHE_TTL, self.api_executor)(
|
||||
self._listEBSQuotas)
|
||||
|
||||
def stop(self):
|
||||
def stopEndpoint(self):
|
||||
self.log.debug("Stopping AWS endpoint")
|
||||
self.create_executor.shutdown()
|
||||
self.api_executor.shutdown()
|
||||
self._running = False
|
||||
self.delete_host_queue.put(None)
|
||||
self.delete_instance_queue.put(None)
|
||||
self.delete_thread.join()
|
||||
|
||||
def listResources(self, bucket_name):
|
||||
for host in self._listHosts():
|
||||
@ -1502,12 +1513,18 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
|
||||
def _getBatch(the_queue):
|
||||
records = []
|
||||
try:
|
||||
records.append(the_queue.get(block=True, timeout=10))
|
||||
record = the_queue.get(block=True, timeout=10)
|
||||
if record is None:
|
||||
return records
|
||||
records.append(record)
|
||||
except queue.Empty:
|
||||
return []
|
||||
while True:
|
||||
try:
|
||||
records.append(the_queue.get(block=False))
|
||||
record = the_queue.get(block=False)
|
||||
if record is None:
|
||||
return records
|
||||
records.append(record)
|
||||
except queue.Empty:
|
||||
break
|
||||
# The terminate call has a limit of 1k, but AWS recommends
|
||||
@ -1527,6 +1544,8 @@ class AwsProviderEndpoint(BaseProviderEndpoint):
|
||||
with self.rate_limiter(log.debug, f"Deleted {count} instances"):
|
||||
self.ec2_client.terminate_instances(InstanceIds=ids)
|
||||
|
||||
if not self._running:
|
||||
return
|
||||
records = self._getBatch(self.delete_host_queue)
|
||||
if records:
|
||||
ids = []
|
||||
|
@ -21,16 +21,14 @@ from zuul.driver.openstack import (
|
||||
openstackprovider,
|
||||
openstackendpoint,
|
||||
)
|
||||
from zuul.provider import EndpointCacheMixin
|
||||
|
||||
|
||||
class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface):
|
||||
class OpenstackDriver(Driver, EndpointCacheMixin,
|
||||
ConnectionInterface, ProviderInterface):
|
||||
name = 'openstack'
|
||||
_endpoint_class = openstackendpoint.OpenstackProviderEndpoint
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
super().__init__(*args, **kw)
|
||||
self.endpoints = {}
|
||||
|
||||
def getConnection(self, name, config):
|
||||
return openstackconnection.OpenstackConnection(self, name, config)
|
||||
|
||||
@ -54,17 +52,11 @@ class OpenstackDriver(Driver, ConnectionInterface, ProviderInterface):
|
||||
urllib.parse.quote_plus(connection.connection_name),
|
||||
urllib.parse.quote_plus(region_str),
|
||||
])
|
||||
try:
|
||||
return self.endpoints[endpoint_id]
|
||||
except KeyError:
|
||||
pass
|
||||
endpoint = self._endpoint_class(self, connection, region)
|
||||
self.endpoints[endpoint_id] = endpoint
|
||||
return endpoint
|
||||
return self.getEndpointById(endpoint_id,
|
||||
create_args=(self, connection, region))
|
||||
|
||||
def getEndpoint(self, provider):
|
||||
return self._getEndpoint(provider.connection, provider.region)
|
||||
|
||||
def stop(self):
|
||||
for endpoint in self.endpoints.values():
|
||||
endpoint.stop()
|
||||
self.stopEndpoints()
|
||||
|
@ -385,6 +385,7 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint):
|
||||
def __init__(self, driver, connection, region):
|
||||
name = f'{connection.connection_name}-{region}'
|
||||
super().__init__(driver, connection, name)
|
||||
self.log = logging.getLogger(f"zuul.openstack.{self.name}")
|
||||
self.region = region
|
||||
|
||||
# Wrap these instance methods with a per-instance LRU cache so
|
||||
@ -399,9 +400,15 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint):
|
||||
self._listAZs = functools.lru_cache(maxsize=None)(
|
||||
self._listAZs)
|
||||
|
||||
self.log = logging.getLogger(f"zuul.openstack.{self.name}")
|
||||
self._running = True
|
||||
self.rate_limiter = RateLimiter(self.name, connection.rate)
|
||||
|
||||
self._last_image_check_failure = time.time()
|
||||
self._last_port_cleanup = None
|
||||
self._client = self._getClient()
|
||||
|
||||
def startEndpoint(self):
|
||||
self.log.debug("Starting OpenStack endpoint")
|
||||
self._running = True
|
||||
# The default http connection pool size is 10; match it for
|
||||
# efficiency.
|
||||
workers = 10
|
||||
@ -424,14 +431,8 @@ class OpenstackProviderEndpoint(BaseProviderEndpoint):
|
||||
CACHE_TTL, self.api_executor)(
|
||||
self._listFloatingIps)
|
||||
|
||||
self.rate_limiter = RateLimiter(self.name,
|
||||
connection.rate)
|
||||
|
||||
self._last_image_check_failure = time.time()
|
||||
self._last_port_cleanup = None
|
||||
self._client = self._getClient()
|
||||
|
||||
def stop(self):
|
||||
def stopEndpoint(self):
|
||||
self.log.debug("Stopping OpenStack endpoint")
|
||||
self.api_executor.shutdown()
|
||||
self._running = False
|
||||
|
||||
|
@ -881,6 +881,7 @@ class Launcher:
|
||||
self.connections.stop()
|
||||
self.upload_executor.shutdown()
|
||||
self.endpoint_upload_executor.shutdown()
|
||||
# Endpoints are stopped by drivers
|
||||
self.log.debug("Stopped launcher")
|
||||
|
||||
def join(self):
|
||||
@ -928,6 +929,7 @@ class Launcher:
|
||||
continue
|
||||
endpoint = provider.getEndpoint()
|
||||
endpoints[endpoint.canonical_name] = endpoint
|
||||
endpoint.start()
|
||||
self.endpoints = endpoints
|
||||
return updated
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
import abc
|
||||
import json
|
||||
import math
|
||||
import threading
|
||||
import urllib.parse
|
||||
|
||||
from zuul.lib.voluputil import Required, Optional, Nullable, assemble
|
||||
@ -121,6 +122,23 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta):
|
||||
self.driver = driver
|
||||
self.connection = connection
|
||||
self.name = name
|
||||
self.start_lock = threading.Lock()
|
||||
self.started = False
|
||||
self.stopped = False
|
||||
|
||||
def start(self):
|
||||
with self.start_lock:
|
||||
if not self.stopped and not self.started:
|
||||
self.startEndpoint()
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
with self.start_lock:
|
||||
if self.started:
|
||||
self.stopEndpoint()
|
||||
# Set the stopped flag regardless of whether we started so
|
||||
# that we won't start after stopping.
|
||||
self.stopped = True
|
||||
|
||||
@property
|
||||
def canonical_name(self):
|
||||
@ -129,6 +147,21 @@ class BaseProviderEndpoint(metaclass=abc.ABCMeta):
|
||||
urllib.parse.quote_plus(self.name),
|
||||
])
|
||||
|
||||
def handleStart(self):
|
||||
"""Start the endpoint
|
||||
|
||||
This method may start any threads necessary for the endpoint.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def handleStop(self):
|
||||
"""Stop the endpoint
|
||||
|
||||
This method must stop all endpoint threads.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BaseProviderSchema(metaclass=abc.ABCMeta):
|
||||
def getLabelSchema(self):
|
||||
@ -526,3 +559,25 @@ class BaseProvider(zkobject.PolymorphicZKObjectMixin,
|
||||
:param ProviderNode node: The node of the server
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EndpointCacheMixin:
|
||||
def __init__(self, *args, **kw):
|
||||
super().__init__(*args, **kw)
|
||||
self.endpoints = {}
|
||||
self.endpoints_lock = threading.Lock()
|
||||
|
||||
def getEndpointById(self, endpoint_id, create_args):
|
||||
with self.endpoints_lock:
|
||||
try:
|
||||
return self.endpoints[endpoint_id]
|
||||
except KeyError:
|
||||
pass
|
||||
endpoint = self._endpoint_class(*create_args)
|
||||
self.endpoints[endpoint_id] = endpoint
|
||||
return endpoint
|
||||
|
||||
def stopEndpoints(self):
|
||||
with self.endpoints_lock:
|
||||
for endpoint in self.endpoints.values():
|
||||
endpoint.stop()
|
||||
|
Loading…
x
Reference in New Issue
Block a user