Import nodescan framework from Nodepool

This changes adds the nodescan machinerie as it is in Nodepool today,
with a few small adaptions to make the tests pass.

The nodescan framework will be plugged into the launcher logic and
adapted to the launcher in a later change.

Change-Id: I1820dfb9bd05d7d2ec66ddbc60451a2f0132c988
This commit is contained in:
Simon Westphahl
2025-01-21 16:20:16 +01:00
parent 70d2788cb1
commit 2a49ce23db
3 changed files with 682 additions and 1 deletions
+276
View File
@@ -0,0 +1,276 @@
# Copyright (C) 2023 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import testtools
from zuul import exceptions
from zuul.model import ProviderNode
from zuul.launcher.server import NodescanWorker, NodescanRequest
from zuul.zk import zkobject
from tests.base import (
BaseTestCase,
iterate_timeout,
okay_tracebacks,
)
class FakeSocket:
def __init__(self):
self.blocking = True
self.fd = 1
def setblocking(self, b):
self.blocking = b
def getsockopt(self, level, optname):
return None
def connect(self, addr):
if not self.blocking:
raise BlockingIOError()
raise Exception("blocking connect attempted")
def fileno(self):
return self.fd
class FakePoll:
def __init__(self, _fail=False):
self.fds = []
self._fail = _fail
def register(self, fd, bitmap):
self.fds.append(fd)
def unregister(self, fd):
if fd in self.fds:
self.fds.remove(fd)
def poll(self, timeout=None):
if self._fail:
return []
fds = self.fds[:]
self.fds = [f for f in fds if not isinstance(f, FakeSocket)]
fds = [f.fileno() if hasattr(f, 'fileno') else f for f in fds]
return [(f, 0) for f in fds]
class Dummy:
pass
class FakeKey:
def get_name(self):
return 'fake key'
def get_base64(self):
return 'fake base64'
class FakeTransport:
def __init__(self, _fail=False, active=True):
self.active = active
self._fail = _fail
def start_client(self, event=None, timeout=None):
if not self._fail:
event.set()
def get_security_options(self):
ret = Dummy()
ret.key_types = ['rsa']
return ret
def get_remote_server_key(self):
return FakeKey()
def get_exception(self):
return Exception("Fake ssh error")
class DummyProviderNode(ProviderNode, subclass_id="dummy-nodescan"):
pass
class TestNodescanWorker(BaseTestCase):
def createZKContext(self, lock=None):
return zkobject.ZKContext(self.zk_client, lock,
None, self.log)
@patch('paramiko.transport.Transport')
@patch('socket.socket')
@patch('select.epoll')
def test_nodescan(self, mock_epoll, mock_socket, mock_transport):
# Test the nodescan worker
fake_socket = FakeSocket()
mock_socket.return_value = fake_socket
mock_epoll.return_value = FakePoll()
mock_transport.return_value = FakeTransport()
worker = NodescanWorker()
node = DummyProviderNode()
node._set(
interface_ip='198.51.100.1',
connection_port=22,
connection_type='ssh',
)
worker.start()
request = NodescanRequest(node, True, 300, self.log)
worker.addRequest(request)
for _ in iterate_timeout(30, 'waiting for nodescan'):
if request.complete:
break
result = request.result()
self.assertEqual(result, ['fake key fake base64'])
worker.stop()
worker.join()
@patch('paramiko.transport.Transport')
@patch('socket.socket')
@patch('select.epoll')
def test_nodescan_connection_timeout(
self, mock_epoll, mock_socket, mock_transport):
# Test a timeout during socket connection
fake_socket = FakeSocket()
mock_socket.return_value = fake_socket
mock_epoll.return_value = FakePoll(_fail=True)
mock_transport.return_value = FakeTransport()
worker = NodescanWorker()
node = DummyProviderNode()
node._set(
interface_ip='198.51.100.1',
connection_port=22,
connection_type='ssh',
)
worker.start()
request = NodescanRequest(node, True, 1, self.log)
worker.addRequest(request)
for _ in iterate_timeout(30, 'waiting for nodescan'):
if request.complete:
break
with testtools.ExpectedException(
exceptions.ConnectionTimeoutException):
request.result()
worker.stop()
worker.join()
@patch('paramiko.transport.Transport')
@patch('socket.socket')
@patch('select.epoll')
def test_nodescan_ssh_timeout(
self, mock_epoll, mock_socket, mock_transport):
# Test a timeout during ssh connection
fake_socket = FakeSocket()
mock_socket.return_value = fake_socket
mock_epoll.return_value = FakePoll()
mock_transport.return_value = FakeTransport(_fail=True)
worker = NodescanWorker()
node = DummyProviderNode()
node._set(
interface_ip='198.51.100.1',
connection_port=22,
connection_type='ssh',
)
worker.start()
request = NodescanRequest(node, True, 1, self.log)
worker.addRequest(request)
for _ in iterate_timeout(30, 'waiting for nodescan'):
if request.complete:
break
with testtools.ExpectedException(
exceptions.ConnectionTimeoutException):
request.result()
worker.stop()
worker.join()
@patch('paramiko.transport.Transport')
@patch('socket.socket')
@patch('select.epoll')
@okay_tracebacks('Fake ssh error')
def test_nodescan_ssh_error(
self, mock_epoll, mock_socket, mock_transport):
# Test an ssh error
fake_socket = FakeSocket()
mock_socket.return_value = fake_socket
mock_epoll.return_value = FakePoll()
mock_transport.return_value = FakeTransport(active=False)
worker = NodescanWorker()
node = DummyProviderNode()
node._set(
interface_ip='198.51.100.1',
connection_port=22,
connection_type='ssh',
)
worker.start()
request = NodescanRequest(node, True, 1, self.log)
worker.addRequest(request)
for _ in iterate_timeout(30, 'waiting for nodescan'):
if request.complete:
break
with testtools.ExpectedException(
exceptions.ConnectionTimeoutException):
request.result()
worker.stop()
worker.join()
@patch('paramiko.transport.Transport')
@patch('socket.socket')
@patch('select.epoll')
def test_nodescan_queue(self, mock_epoll, mock_socket, mock_transport):
# Test the max_requests queing function
fake_socket1 = FakeSocket()
fake_socket2 = FakeSocket()
fake_socket2.fd = 2
# We get two sockets for each host
sockets = [fake_socket1, fake_socket1, fake_socket2, fake_socket2]
def getsocket(*args, **kw):
return sockets.pop(0)
mock_socket.side_effect = getsocket
mock_epoll.return_value = FakePoll()
mock_transport.return_value = FakeTransport()
worker = NodescanWorker()
worker.MAX_REQUESTS = 1
node1 = DummyProviderNode()
node1._set(
interface_ip='198.51.100.1',
connection_port=22,
connection_type='ssh',
)
node2 = DummyProviderNode()
node2._set(
interface_ip='198.51.100.2',
connection_port=22,
connection_type='ssh',
)
request1 = NodescanRequest(node1, True, 300, self.log)
request2 = NodescanRequest(node2, True, 300, self.log)
worker.addRequest(request1)
worker.addRequest(request2)
worker.start()
for _ in iterate_timeout(5, 'waiting for nodescan'):
if request1.complete and request2.complete:
break
result1 = request1.result()
result2 = request2.result()
self.assertEqual(result1, ['fake key fake base64'])
self.assertEqual(result2, ['fake key fake base64'])
worker.stop()
worker.join()
+8
View File
@@ -79,6 +79,14 @@ class CapacityException(Exception):
statsd_key = 'error.capacity'
class TimeoutException(Exception):
pass
class ConnectionTimeoutException(TimeoutException):
statsd_key = 'error.ssh'
class RuntimeConfigurationException(Exception):
pass
+398 -1
View File
@@ -16,10 +16,13 @@
from concurrent.futures import ThreadPoolExecutor
import concurrent.futures
import collections
import errno
import fcntl
import itertools
import logging
import os
import random
import select
import socket
import subprocess
import threading
@@ -27,9 +30,10 @@ import time
import uuid
import mmh3
import paramiko
import requests
from zuul import model
from zuul import exceptions, model
from zuul.lib import commandsocket, tracing
from zuul.lib.collections import DefaultKeyDict
from zuul.lib.config import get_default
@@ -246,6 +250,399 @@ class EndpointUploadJob:
self.launcher.addImageValidateEvent(self.upload)
class NodescanRequest:
"""A state machine for a nodescan request.
When complete, use the result() method to obtain the keys or raise
an exception if an errer was encountered during processing.
"""
START = 'start'
CONNECTING_INIT = 'connecting'
NEGOTIATING_INIT = 'negotiating'
CONNECTING_KEY = 'connecting key'
NEGOTIATING_KEY = 'negotiating key'
COMPLETE = 'complete'
# For unit testing
FAKE = False
def __init__(self, node, host_key_checking, timeout, log):
self.state = self.START
self.node = node
self.host_key_checking = host_key_checking
self.timeout = timeout
self.log = log
self.complete = False
self.keys = []
if (node.connection_type == 'ssh' or
node.connection_type == 'network_cli'):
self.gather_hostkeys = True
else:
self.gather_hostkeys = False
self.ip = node.interface_ip
self.port = node.connection_port
if 'fake' not in self.ip and not self.FAKE:
addrinfo = socket.getaddrinfo(self.ip, self.port)[0]
self.family = addrinfo[0]
self.sockaddr = addrinfo[4]
self.sock = None
self.transport = None
self.event = None
self.key_types = None
self.key_index = None
self.key_type = None
self.start_time = time.monotonic()
self.worker = None
self.exception = None
self.connect_start_time = None
# Stats
self.init_connection_attempts = 0
self.key_connection_failures = 0
self.key_negotiation_failures = 0
def setWorker(self, worker):
"""Store a reference to the worker thread so we register and unregister
the socket file descriptor from the polling object"""
self.worker = worker
def fail(self, exception):
"""Declare this request a failure and store the related exception"""
self.exception = exception
self.cleanup()
self.complete = True
self.state = self.COMPLETE
def cleanup(self):
"""Try to close everything and unregister from the worker"""
self._close()
if self.exception:
status = 'failed'
else:
status = 'complete'
dt = int(time.monotonic() - self.start_time)
self.log.debug("Nodescan request %s with %s keys, "
"%s initial connection attempts, "
"%s key connection failures, "
"%s key negotiation failures in %s seconds",
status, len(self.keys),
self.init_connection_attempts,
self.key_connection_failures,
self.key_negotiation_failures,
dt)
def result(self):
"""Return the resulting keys, or raise an exception"""
if self.exception:
raise self.exception
return self.keys
def _close(self):
if self.transport:
try:
self.transport.close()
except Exception:
pass
self.transport = None
self.event = None
if self.sock:
self.worker.unRegisterDescriptor(self.sock)
try:
self.sock.close()
except Exception:
pass
self.sock = None
def _checkTimeout(self):
now = time.monotonic()
if now - self.start_time > self.timeout:
raise exceptions.ConnectionTimeoutException(
f"Timeout connecting to {self.ip} on port {self.port}")
def _checkTransport(self):
# This stanza is from
# https://github.com/paramiko/paramiko/blob/main/paramiko/transport.py
if not self.transport.active:
e = self.transport.get_exception()
if e is not None:
raise e
raise paramiko.exceptions.SSHException("Negotiation failed.")
def _connect(self):
if self.sock:
self.worker.unRegisterDescriptor(self.sock)
self.sock = socket.socket(self.family, socket.SOCK_STREAM)
# Set nonblocking so we can poll for connection completion
self.sock.setblocking(False)
try:
self.sock.connect(self.sockaddr)
except BlockingIOError:
self.state = self.CONNECTING_INIT
self.connect_start_time = time.monotonic()
self.worker.registerDescriptor(self.sock)
def _start(self):
# Use our Event subclass that will wake the worker when the
# event is set.
self.event = NodescanEvent(self.worker)
# Return the socket to blocking mode as we hand it off to paramiko.
self.sock.setblocking(True)
self.transport = paramiko.transport.Transport(self.sock)
if self.key_type is not None:
opts = self.transport.get_security_options()
opts.key_types = [self.key_type]
# This starts a thread.
self.transport.start_client(
event=self.event, timeout=self.timeout)
def _nextKey(self):
self._close()
self.key_index += 1
if self.key_index >= len(self.key_types):
self.state = self.COMPLETE
return True
self.key_type = self.key_types[self.key_index]
self._connect()
self.state = self.CONNECTING_KEY
def advance(self, socket_ready):
if self.state == self.START:
if self.worker is None:
raise Exception("Request not registered with worker")
if not self.host_key_checking:
self.state = self.COMPLETE
else:
if 'fake' in self.ip or self.FAKE:
if self.gather_hostkeys:
self.keys = ['ssh-rsa FAKEKEY']
self.state = self.COMPLETE
else:
self.init_connection_attempts += 1
self._connect()
if self.state == self.CONNECTING_INIT:
if not socket_ready:
# Check the overall timeout
self._checkTimeout()
# If we're still here, then don't let any individual
# connection attempt last more than 10 seconds:
if time.monotonic() - self.connect_start_time >= 10:
self._close()
self.state = self.START
return
eno = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if eno:
if eno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH]:
self.log.exception(
f"Error {eno} connecting to {self.ip} "
f"on port {self.port}")
# Try again. Don't immediately start to reconnect
# since econnrefused can happen very quickly, so we
# could end up busy-waiting.
self._close()
self.state = self.START
self._checkTimeout()
return
if self.gather_hostkeys:
self._start()
self.state = self.NEGOTIATING_INIT
else:
self.state = self.COMPLETE
if self.state == self.NEGOTIATING_INIT:
if not self.event.is_set():
self._checkTimeout()
return
# This will raise an exception on ssh errors
try:
self._checkTransport()
except Exception:
self.log.exception(
f"SSH error connecting to {self.ip} on port {self.port}")
# Try again
self._close()
self.key_negotiation_failures += 1
self.state = self.START
self._checkTimeout()
self._connect()
return
# This is our first successful connection. Now that
# we've done it, start again specifying the first key
# type.
opts = self.transport.get_security_options()
self.key_types = opts.key_types
self.key_index = -1
self._nextKey()
if self.state == self.CONNECTING_KEY:
if not socket_ready:
self._checkTimeout()
return
eno = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if eno:
self.log.error(
f"Error {eno} connecting to {self.ip} on port {self.port}")
self.key_connection_failures += 1
self._nextKey()
return
self._start()
self.state = self.NEGOTIATING_KEY
if self.state == self.NEGOTIATING_KEY:
if not self.event.is_set():
self._checkTimeout()
return
# This will raise an exception on ssh errors
try:
self._checkTransport()
except Exception as e:
msg = str(e)
if 'no acceptable host key' not in msg:
# We expect some host keys to not be valid
# when scanning only log if the error isn't
# due to mismatched host key types.
self.log.exception(
f"SSH error connecting to {self.ip} "
f"on port {self.port}")
self.key_negotiation_failures += 1
self._nextKey()
# Check if we're still in the same state
if self.state == self.NEGOTIATING_KEY:
key = self.transport.get_remote_server_key()
if key:
self.keys.append("%s %s" % (key.get_name(), key.get_base64()))
self.log.debug('Added ssh host key: %s', key.get_name())
self._nextKey()
if self.state == self.COMPLETE:
self._close()
self.complete = True
class NodescanEvent(threading.Event):
"""A subclass of event that will wake the NodescanWorker poll"""
def __init__(self, worker, *args, **kw):
super().__init__(*args, **kw)
self._zuul_worker = worker
def set(self):
super().set()
try:
os.write(self._zuul_worker.wake_write, b'\n')
except Exception:
pass
class NodescanWorker:
"""Handles requests for nodescans.
This class has a single thread that drives nodescan requests
submitted by the launcher.
"""
# This process is highly scalable, except for paramiko which
# spawns a thread for each ssh connection. To avoid thread
# overload, we set a max value for concurrent requests.
# Simultaneous requests higher than this value will be queued.
MAX_REQUESTS = 100
def __init__(self):
self.wake_read, self.wake_write = os.pipe()
fcntl.fcntl(self.wake_read, fcntl.F_SETFL, os.O_NONBLOCK)
self._running = False
self._active_requests = []
self._pending_requests = []
self.poll = select.epoll()
self.poll.register(self.wake_read, select.EPOLLIN)
def start(self):
self._running = True
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()
def stop(self):
self._running = False
os.write(self.wake_write, b'\n')
def join(self):
self.thread.join()
def addRequest(self, request):
"""Submit a nodescan request"""
request.setWorker(self)
if len(self._active_requests) >= self.MAX_REQUESTS:
self._pending_requests.append(request)
else:
self._active_requests.append(request)
# If the poll is sleeping, wake it up for immediate action
os.write(self.wake_write, b'\n')
def removeRequest(self, request):
"""Remove the request and cleanup"""
if request is None:
return
request.cleanup()
try:
self._active_requests.remove(request)
except ValueError:
pass
try:
self._pending_requests.remove(request)
except ValueError:
pass
def registerDescriptor(self, fd):
"""Register the fd with the poll object"""
# Oneshot means that once it triggers, it will automatically
# be removed. That's great for us since we only use this for
# detecting when the initial connection is complete and have
# no further use.
self.poll.register(
fd, select.EPOLLOUT | select.EPOLLERR |
select.EPOLLHUP | select.EPOLLONESHOT)
def unRegisterDescriptor(self, fd):
"""Unregister the fd with the poll object"""
try:
self.poll.unregister(fd)
except Exception:
pass
def run(self):
while self._running:
# Set the poll timeout to 1 second so that we check all
# requests for timeouts every second. This could be
# increased to a few seconds without significant impact.
timeout = 1
while (self._pending_requests and
len(self._active_requests) < self.MAX_REQUESTS):
# If we have room for more requests, add them and set
# the timeout to 0 so that we immediately start
# advancing them.
request = self._pending_requests.pop(0)
self._active_requests.append(request)
timeout = 0
ready = self.poll.poll(timeout=timeout)
ready = [x[0] for x in ready]
if self.wake_read in ready:
# Empty the wake pipe
while True:
try:
os.read(self.wake_read, 1024)
except BlockingIOError:
break
for request in self._active_requests:
try:
socket_ready = (request.sock and
request.sock.fileno() in ready)
request.advance(socket_ready)
except Exception as e:
request.fail(e)
if request.complete:
self.removeRequest(request)
class Launcher:
log = logging.getLogger("zuul.Launcher")
# Max. time to wait for a cache to sync