Handle early AWS spot instance reclamations

If an AWS spot instance is used as a metastatic backing node, an
unexpected series of events can occur:

* aws driver creates backing node instance
* aws driver scans ssh keys and stores them on backing node
* aws reclaims spot instance
* aws re-uses IP from backing node
* metastatic driver creates node
* metastatic driver scans ssh keys and stores them on node

Zuul would then use the wrong node (whether that succeeds depends
on what else has happened to the node in the interim).

To avoid this situation, we implement this change:
* After scanning the metastatic node ssh keys, we compare them to
  the backing node ssh keys and if they differ, trigger an error
  in the metastatic node and mark the backing node as failed.

In case the node is reclaimed one step early in the above sequence,
we implement this change:
* After completing the nodescan, the aws driver will double check
  that the instance is still running; if not, it will trigger an
  error.

The above is still subject to a small race if the nodescan time
takes less than the cache interval of the instance list, and the
node is reclaimed after the nodescan and within the cache interval
(currently 10 seconds).  In the unlikely event that does happen,
then the metastatic key check should still catch the issue as long
as the replacement node also does not boot within those 10 seconds.
(Technically possible, but the combination of all of these things
should be very unlikely in practice.)

Change-Id: I9ce1f6df04e9c49deceda99c8e4024dd98ea88f9
This commit is contained in:
James E. Blair 2024-11-05 13:07:43 -08:00
parent e84465e2d4
commit 242f9cc3e6
6 changed files with 174 additions and 15 deletions

View File

@ -877,6 +877,23 @@ class AwsAdapter(statemachine.Adapter):
iops=label.iops)) iops=label.iops))
return quota return quota
def notifyNodescanResult(self, label, external_id, success, keys):
instance_id = external_id.get('instance')
host_id = external_id.get('host')
# Verify that the instance or host have not been deleted in
# the interim (i.e., due to spot instance reclamation).
if instance_id:
instance = {'InstanceId': instance_id, 'state': 'unknown'}
instance = self._refresh(instance)
if instance['State']['Name'].lower() != 'running':
raise Exception(f"Instance {instance_id} is not running")
if host_id:
host = {'HostId': host_id, 'state': 'unknown'}
host = self._refresh(host)
if host['State'].lower() != 'available':
raise Exception(f"Host {host_id} is not available")
def uploadImage(self, provider_image, image_name, filename, def uploadImage(self, provider_image, image_name, filename,
image_format, metadata, md5, sha256): image_format, metadata, md5, sha256):
self.log.debug(f"Uploading image {image_name}") self.log.debug(f"Uploading image {image_name}")

View File

@ -352,20 +352,31 @@ class MetastaticAdapter(statemachine.Adapter):
def getQuotaForLabel(self, label): def getQuotaForLabel(self, label):
return QuotaInformation(instances=1) return QuotaInformation(instances=1)
def notifyNodescanFailure(self, label, external_id): def notifyNodescanResult(self, label, external_id, success, keys):
exc = None
with self.allocation_lock: with self.allocation_lock:
backing_node_records = self.backing_node_records.get( backing_node_records = self.backing_node_records.get(
label.name, []) label.name, [])
for bnr in backing_node_records: for bnr in backing_node_records:
if bnr.backsNode(external_id): if bnr.backsNode(external_id):
self.log.info( break
"Nodescan failure of %s on %s, failing backing node", else:
external_id, bnr.node_id) raise Exception(
bnr.failed = True f"Unable to find backing node for {external_id}")
backing_node = self._getNode(bnr.node_id) backing_node = self._getNode(bnr.node_id)
backing_node.user_data = self._makeBackingNodeUserData(bnr) if success and backing_node.host_keys and keys:
self.zk.storeNode(backing_node) if sorted(keys) != sorted(backing_node.host_keys):
return exc = Exception(f"Key mismatch on {external_id}")
success = False
if not success:
self.log.info(
"Nodescan failure of %s on %s, failing backing node",
external_id, bnr.node_id)
bnr.failed = True
backing_node.user_data = self._makeBackingNodeUserData(bnr)
self.zk.storeNode(backing_node)
if exc:
raise exc
# Local implementation below # Local implementation below

View File

@ -242,6 +242,7 @@ class StateMachineNodeLauncher(stats.StatsReporter):
instance = None instance = None
node = self.node node = self.node
statsd_key = 'ready' statsd_key = 'ready'
label = self.handler.pool.labels[node.type[0]]
try: try:
if self.state_machine is None: if self.state_machine is None:
@ -261,8 +262,16 @@ class StateMachineNodeLauncher(stats.StatsReporter):
self.log.warning("Error scanning keys: %s", str(e)) self.log.warning("Error scanning keys: %s", str(e))
else: else:
self.log.exception("Exception scanning keys:") self.log.exception("Exception scanning keys:")
try:
self.manager.adapter.notifyNodescanResult(
label, node.external_id, False, None)
except Exception:
self.log.exception(
"Exception processing failed nodescan result:")
raise exceptions.LaunchKeyscanException( raise exceptions.LaunchKeyscanException(
"Can't scan instance %s key" % node.id) "Can't scan instance %s key" % node.id)
self.manager.adapter.notifyNodescanResult(
label, node.external_id, True, keys)
if keys: if keys:
node.host_keys = keys node.host_keys = keys
self.log.debug(f"Node {node.id} is ready") self.log.debug(f"Node {node.id} is ready")
@ -297,7 +306,6 @@ class StateMachineNodeLauncher(stats.StatsReporter):
self.updateNodeFromInstance(instance) self.updateNodeFromInstance(instance)
self.log.debug("Submitting nodescan request for %s", self.log.debug("Submitting nodescan request for %s",
node.interface_ip) node.interface_ip)
label = self.handler.pool.labels[self.node.type[0]]
self.nodescan_request = NodescanRequest( self.nodescan_request = NodescanRequest(
node, node,
label.host_key_checking, label.host_key_checking,
@ -357,9 +365,6 @@ class StateMachineNodeLauncher(stats.StatsReporter):
if isinstance(e, exceptions.LaunchKeyscanException): if isinstance(e, exceptions.LaunchKeyscanException):
try: try:
label = self.handler.pool.labels[node.type[0]]
self.manager.adapter.notifyNodescanFailure(
label, node.external_id)
console = self.manager.adapter.getConsoleLog( console = self.manager.adapter.getConsoleLog(
label, node.external_id) label, node.external_id)
if console: if console:
@ -1701,10 +1706,12 @@ class Adapter:
""" """
raise NotImplementedError() raise NotImplementedError()
def notifyNodescanFailure(self, label, external_id): def notifyNodescanResult(self, label, external_id, success, keys):
"""Notify the adapter of a nodescan failure """Notify the adapter of a nodescan resurt
:param label ConfigLabel: The label config for the node :param label ConfigLabel: The label config for the node
:param external_id str or dict: The external id of the server :param external_id str or dict: The external id of the server
:param success bool: Whether the nodescan succeeded
:param keys str or None: The retrieved keys
""" """
pass pass

View File

@ -19,6 +19,7 @@ providers:
- name: ec2-us-west-2 - name: ec2-us-west-2
driver: aws driver: aws
region-name: us-west-2 region-name: us-west-2
launch-retries: 1
cloud-images: cloud-images:
- name: ubuntu1404 - name: ubuntu1404
image-id: ami-1e749f67 image-id: ami-1e749f67

View File

@ -1326,6 +1326,44 @@ class TestDriverAws(tests.DBTestCase):
self.assertTrue(node.node_properties['spot']) self.assertTrue(node.node_properties['spot'])
def test_aws_provisioning_spot_early_reclaim(self):
# Test that if AWS reclaims the spot instance before we're
# done with the nodescan we fail the request.
orig_notify = nodepool.driver.aws.adapter.AwsAdapter.\
notifyNodescanResult
def notify(*args, **kw):
adapter = self.pool.getProviderManager('ec2-us-west-2').adapter
for i in adapter._listInstances():
self.ec2_client.terminate_instances(
InstanceIds=[i['InstanceId']])
for _ in iterate_timeout(60, Exception,
"Instance list cache to update",
interval=1):
for i in adapter._listInstances():
if i['State']['Name'].lower() == 'running':
break
else:
break
return orig_notify(*args, **kw)
self.useFixture(fixtures.MonkeyPatch(
'nodepool.driver.aws.adapter.AwsAdapter.notifyNodescanResult',
notify))
configfile = self.setup_config('aws/aws-spot.yaml')
self.pool = self.useNodepool(configfile, watermark_sleep=1)
self.startPool(self.pool)
req = zk.NodeRequest()
req.state = zk.REQUESTED
req.tenant_name = 'tenant-1'
req.node_types.append('ubuntu1404-spot')
self.zk.storeNodeRequest(req)
self.log.debug("Waiting for request %s", req.id)
self.waitForNodeRequest(req, states=(zk.FAILED,))
def test_aws_dedicated_host(self): def test_aws_dedicated_host(self):
req = self.requestNode('aws/aws-dedicated-host.yaml', 'ubuntu') req = self.requestNode('aws/aws-dedicated-host.yaml', 'ubuntu')
for _ in iterate_timeout(60, Exception, for _ in iterate_timeout(60, Exception,

View File

@ -309,6 +309,91 @@ class TestDriverMetastatic(tests.DBTestCase):
nodes = self._getNodes() nodes = self._getNodes()
self.assertEqual(nodes, []) self.assertEqual(nodes, [])
def test_metastatic_nodescan_key_mismatch(self):
# Test that a nodescan key mismatch takes a backing node out of service
# This tests a scenario where a keyscan on a metastatic node
# does not match the key from the backing node. This could
# happen if the cloud reclaimed the backing node and reused
# the IP.
counter = -1
# bn1, node1, node2
keys = [
['ssh-rsa bnkey'],
['ssh-rsa otherkey'],
['ssh-rsa bnkey'],
]
orig_advance = nodepool.driver.statemachine.NodescanRequest.advance
def handler(obj, *args, **kw):
nonlocal counter, keys
if counter >= len(keys):
return orig_advance(obj, *args, **kw)
ret = orig_advance(obj, *args, **kw)
obj.keys = keys[counter]
counter += 1
return ret
self.useFixture(fixtures.MonkeyPatch(
'nodepool.driver.statemachine.NodescanRequest.advance',
handler))
configfile = self.setup_config('metastatic.yaml')
pool = self.useNodepool(configfile, watermark_sleep=1)
self.startPool(pool)
manager = pool.getProviderManager('fake-provider')
manager.adapter._client.create_image(name="fake-image")
# Launch one metastatic node on a backing node
node1 = self._requestNode()
nodes = self._getNodes()
self.assertEqual(len(nodes), 2)
bn1 = nodes[1]
self.assertEqual(bn1.provider, 'fake-provider')
self.assertEqual(bn1.id, node1.driver_data['backing_node'])
# Launch a second one with a failed nodescan; should have a
# second backing node
node2 = self._requestNode()
nodes = self._getNodes()
bn2 = nodes[3]
# Reload bn1 since the userdata failed
self.assertEqual(bn1.id, nodes[1].id)
bn1 = nodes[1]
self.assertNotEqual(bn1.id, bn2.id)
self.assertEqual(nodes, [node1, bn1, node2, bn2])
self.assertEqual(bn2.id, node2.driver_data['backing_node'])
# Allocate a third node, should use the second backing node
node3 = self._requestNode()
nodes = self._getNodes()
self.assertEqual(nodes, [node1, bn1, node2, bn2, node3])
self.assertEqual(bn2.id, node3.driver_data['backing_node'])
# Delete node3, verify that both backing nodes exist
node3.state = zk.DELETING
self.zk.storeNode(node3)
self.waitForNodeDeletion(node3)
nodes = self._getNodes()
self.assertEqual(nodes, [node1, bn1, node2, bn2])
# Delete node2, verify that only the first backing node exists
node2.state = zk.DELETING
self.zk.storeNode(node2)
self.waitForNodeDeletion(node2)
self.waitForNodeDeletion(bn2)
nodes = self._getNodes()
self.assertEqual(nodes, [node1, bn1])
# Delete node1, verify that no nodes exist
node1.state = zk.DELETING
self.zk.storeNode(node1)
self.waitForNodeDeletion(node1)
self.waitForNodeDeletion(bn1)
nodes = self._getNodes()
self.assertEqual(nodes, [])
def test_metastatic_min_retention(self): def test_metastatic_min_retention(self):
# Test that the metastatic driver honors min-retention # Test that the metastatic driver honors min-retention
configfile = self.setup_config('metastatic.yaml') configfile = self.setup_config('metastatic.yaml')