Import nodescan framework from Nodepool
This changes adds the nodescan machinerie as it is in Nodepool today, with a few small adaptions to make the tests pass. The nodescan framework will be plugged into the launcher logic and adapted to the launcher in a later change. Change-Id: I1820dfb9bd05d7d2ec66ddbc60451a2f0132c988
This commit is contained in:
@@ -0,0 +1,276 @@
|
||||
# Copyright (C) 2023 Acme Gating, LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import testtools
|
||||
|
||||
from zuul import exceptions
|
||||
from zuul.model import ProviderNode
|
||||
from zuul.launcher.server import NodescanWorker, NodescanRequest
|
||||
from zuul.zk import zkobject
|
||||
|
||||
from tests.base import (
|
||||
BaseTestCase,
|
||||
iterate_timeout,
|
||||
okay_tracebacks,
|
||||
)
|
||||
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self):
|
||||
self.blocking = True
|
||||
self.fd = 1
|
||||
|
||||
def setblocking(self, b):
|
||||
self.blocking = b
|
||||
|
||||
def getsockopt(self, level, optname):
|
||||
return None
|
||||
|
||||
def connect(self, addr):
|
||||
if not self.blocking:
|
||||
raise BlockingIOError()
|
||||
raise Exception("blocking connect attempted")
|
||||
|
||||
def fileno(self):
|
||||
return self.fd
|
||||
|
||||
|
||||
class FakePoll:
|
||||
def __init__(self, _fail=False):
|
||||
self.fds = []
|
||||
self._fail = _fail
|
||||
|
||||
def register(self, fd, bitmap):
|
||||
self.fds.append(fd)
|
||||
|
||||
def unregister(self, fd):
|
||||
if fd in self.fds:
|
||||
self.fds.remove(fd)
|
||||
|
||||
def poll(self, timeout=None):
|
||||
if self._fail:
|
||||
return []
|
||||
fds = self.fds[:]
|
||||
self.fds = [f for f in fds if not isinstance(f, FakeSocket)]
|
||||
fds = [f.fileno() if hasattr(f, 'fileno') else f for f in fds]
|
||||
return [(f, 0) for f in fds]
|
||||
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
|
||||
class FakeKey:
|
||||
def get_name(self):
|
||||
return 'fake key'
|
||||
|
||||
def get_base64(self):
|
||||
return 'fake base64'
|
||||
|
||||
|
||||
class FakeTransport:
|
||||
def __init__(self, _fail=False, active=True):
|
||||
self.active = active
|
||||
self._fail = _fail
|
||||
|
||||
def start_client(self, event=None, timeout=None):
|
||||
if not self._fail:
|
||||
event.set()
|
||||
|
||||
def get_security_options(self):
|
||||
ret = Dummy()
|
||||
ret.key_types = ['rsa']
|
||||
return ret
|
||||
|
||||
def get_remote_server_key(self):
|
||||
return FakeKey()
|
||||
|
||||
def get_exception(self):
|
||||
return Exception("Fake ssh error")
|
||||
|
||||
|
||||
class DummyProviderNode(ProviderNode, subclass_id="dummy-nodescan"):
|
||||
pass
|
||||
|
||||
|
||||
class TestNodescanWorker(BaseTestCase):
|
||||
|
||||
def createZKContext(self, lock=None):
|
||||
return zkobject.ZKContext(self.zk_client, lock,
|
||||
None, self.log)
|
||||
|
||||
@patch('paramiko.transport.Transport')
|
||||
@patch('socket.socket')
|
||||
@patch('select.epoll')
|
||||
def test_nodescan(self, mock_epoll, mock_socket, mock_transport):
|
||||
# Test the nodescan worker
|
||||
fake_socket = FakeSocket()
|
||||
mock_socket.return_value = fake_socket
|
||||
mock_epoll.return_value = FakePoll()
|
||||
mock_transport.return_value = FakeTransport()
|
||||
worker = NodescanWorker()
|
||||
node = DummyProviderNode()
|
||||
node._set(
|
||||
interface_ip='198.51.100.1',
|
||||
connection_port=22,
|
||||
connection_type='ssh',
|
||||
)
|
||||
worker.start()
|
||||
request = NodescanRequest(node, True, 300, self.log)
|
||||
worker.addRequest(request)
|
||||
for _ in iterate_timeout(30, 'waiting for nodescan'):
|
||||
if request.complete:
|
||||
break
|
||||
result = request.result()
|
||||
self.assertEqual(result, ['fake key fake base64'])
|
||||
worker.stop()
|
||||
worker.join()
|
||||
|
||||
@patch('paramiko.transport.Transport')
|
||||
@patch('socket.socket')
|
||||
@patch('select.epoll')
|
||||
def test_nodescan_connection_timeout(
|
||||
self, mock_epoll, mock_socket, mock_transport):
|
||||
# Test a timeout during socket connection
|
||||
fake_socket = FakeSocket()
|
||||
mock_socket.return_value = fake_socket
|
||||
mock_epoll.return_value = FakePoll(_fail=True)
|
||||
mock_transport.return_value = FakeTransport()
|
||||
worker = NodescanWorker()
|
||||
node = DummyProviderNode()
|
||||
node._set(
|
||||
interface_ip='198.51.100.1',
|
||||
connection_port=22,
|
||||
connection_type='ssh',
|
||||
)
|
||||
worker.start()
|
||||
request = NodescanRequest(node, True, 1, self.log)
|
||||
worker.addRequest(request)
|
||||
for _ in iterate_timeout(30, 'waiting for nodescan'):
|
||||
if request.complete:
|
||||
break
|
||||
with testtools.ExpectedException(
|
||||
exceptions.ConnectionTimeoutException):
|
||||
request.result()
|
||||
worker.stop()
|
||||
worker.join()
|
||||
|
||||
@patch('paramiko.transport.Transport')
|
||||
@patch('socket.socket')
|
||||
@patch('select.epoll')
|
||||
def test_nodescan_ssh_timeout(
|
||||
self, mock_epoll, mock_socket, mock_transport):
|
||||
# Test a timeout during ssh connection
|
||||
fake_socket = FakeSocket()
|
||||
mock_socket.return_value = fake_socket
|
||||
mock_epoll.return_value = FakePoll()
|
||||
mock_transport.return_value = FakeTransport(_fail=True)
|
||||
worker = NodescanWorker()
|
||||
node = DummyProviderNode()
|
||||
node._set(
|
||||
interface_ip='198.51.100.1',
|
||||
connection_port=22,
|
||||
connection_type='ssh',
|
||||
)
|
||||
worker.start()
|
||||
request = NodescanRequest(node, True, 1, self.log)
|
||||
worker.addRequest(request)
|
||||
for _ in iterate_timeout(30, 'waiting for nodescan'):
|
||||
if request.complete:
|
||||
break
|
||||
with testtools.ExpectedException(
|
||||
exceptions.ConnectionTimeoutException):
|
||||
request.result()
|
||||
worker.stop()
|
||||
worker.join()
|
||||
|
||||
@patch('paramiko.transport.Transport')
|
||||
@patch('socket.socket')
|
||||
@patch('select.epoll')
|
||||
@okay_tracebacks('Fake ssh error')
|
||||
def test_nodescan_ssh_error(
|
||||
self, mock_epoll, mock_socket, mock_transport):
|
||||
# Test an ssh error
|
||||
fake_socket = FakeSocket()
|
||||
mock_socket.return_value = fake_socket
|
||||
mock_epoll.return_value = FakePoll()
|
||||
mock_transport.return_value = FakeTransport(active=False)
|
||||
worker = NodescanWorker()
|
||||
node = DummyProviderNode()
|
||||
node._set(
|
||||
interface_ip='198.51.100.1',
|
||||
connection_port=22,
|
||||
connection_type='ssh',
|
||||
)
|
||||
worker.start()
|
||||
request = NodescanRequest(node, True, 1, self.log)
|
||||
worker.addRequest(request)
|
||||
for _ in iterate_timeout(30, 'waiting for nodescan'):
|
||||
if request.complete:
|
||||
break
|
||||
with testtools.ExpectedException(
|
||||
exceptions.ConnectionTimeoutException):
|
||||
request.result()
|
||||
worker.stop()
|
||||
worker.join()
|
||||
|
||||
@patch('paramiko.transport.Transport')
|
||||
@patch('socket.socket')
|
||||
@patch('select.epoll')
|
||||
def test_nodescan_queue(self, mock_epoll, mock_socket, mock_transport):
|
||||
# Test the max_requests queing function
|
||||
fake_socket1 = FakeSocket()
|
||||
fake_socket2 = FakeSocket()
|
||||
fake_socket2.fd = 2
|
||||
# We get two sockets for each host
|
||||
sockets = [fake_socket1, fake_socket1, fake_socket2, fake_socket2]
|
||||
|
||||
def getsocket(*args, **kw):
|
||||
return sockets.pop(0)
|
||||
|
||||
mock_socket.side_effect = getsocket
|
||||
mock_epoll.return_value = FakePoll()
|
||||
mock_transport.return_value = FakeTransport()
|
||||
worker = NodescanWorker()
|
||||
worker.MAX_REQUESTS = 1
|
||||
node1 = DummyProviderNode()
|
||||
node1._set(
|
||||
interface_ip='198.51.100.1',
|
||||
connection_port=22,
|
||||
connection_type='ssh',
|
||||
)
|
||||
node2 = DummyProviderNode()
|
||||
node2._set(
|
||||
interface_ip='198.51.100.2',
|
||||
connection_port=22,
|
||||
connection_type='ssh',
|
||||
)
|
||||
|
||||
request1 = NodescanRequest(node1, True, 300, self.log)
|
||||
request2 = NodescanRequest(node2, True, 300, self.log)
|
||||
worker.addRequest(request1)
|
||||
worker.addRequest(request2)
|
||||
worker.start()
|
||||
for _ in iterate_timeout(5, 'waiting for nodescan'):
|
||||
if request1.complete and request2.complete:
|
||||
break
|
||||
result1 = request1.result()
|
||||
result2 = request2.result()
|
||||
self.assertEqual(result1, ['fake key fake base64'])
|
||||
self.assertEqual(result2, ['fake key fake base64'])
|
||||
worker.stop()
|
||||
worker.join()
|
||||
@@ -79,6 +79,14 @@ class CapacityException(Exception):
|
||||
statsd_key = 'error.capacity'
|
||||
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionTimeoutException(TimeoutException):
|
||||
statsd_key = 'error.ssh'
|
||||
|
||||
|
||||
class RuntimeConfigurationException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
+398
-1
@@ -16,10 +16,13 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import concurrent.futures
|
||||
import collections
|
||||
import errno
|
||||
import fcntl
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import select
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
@@ -27,9 +30,10 @@ import time
|
||||
import uuid
|
||||
|
||||
import mmh3
|
||||
import paramiko
|
||||
import requests
|
||||
|
||||
from zuul import model
|
||||
from zuul import exceptions, model
|
||||
from zuul.lib import commandsocket, tracing
|
||||
from zuul.lib.collections import DefaultKeyDict
|
||||
from zuul.lib.config import get_default
|
||||
@@ -246,6 +250,399 @@ class EndpointUploadJob:
|
||||
self.launcher.addImageValidateEvent(self.upload)
|
||||
|
||||
|
||||
class NodescanRequest:
|
||||
"""A state machine for a nodescan request.
|
||||
|
||||
When complete, use the result() method to obtain the keys or raise
|
||||
an exception if an errer was encountered during processing.
|
||||
|
||||
"""
|
||||
|
||||
START = 'start'
|
||||
CONNECTING_INIT = 'connecting'
|
||||
NEGOTIATING_INIT = 'negotiating'
|
||||
CONNECTING_KEY = 'connecting key'
|
||||
NEGOTIATING_KEY = 'negotiating key'
|
||||
COMPLETE = 'complete'
|
||||
# For unit testing
|
||||
FAKE = False
|
||||
|
||||
def __init__(self, node, host_key_checking, timeout, log):
|
||||
self.state = self.START
|
||||
self.node = node
|
||||
self.host_key_checking = host_key_checking
|
||||
self.timeout = timeout
|
||||
self.log = log
|
||||
self.complete = False
|
||||
self.keys = []
|
||||
if (node.connection_type == 'ssh' or
|
||||
node.connection_type == 'network_cli'):
|
||||
self.gather_hostkeys = True
|
||||
else:
|
||||
self.gather_hostkeys = False
|
||||
self.ip = node.interface_ip
|
||||
self.port = node.connection_port
|
||||
if 'fake' not in self.ip and not self.FAKE:
|
||||
addrinfo = socket.getaddrinfo(self.ip, self.port)[0]
|
||||
self.family = addrinfo[0]
|
||||
self.sockaddr = addrinfo[4]
|
||||
self.sock = None
|
||||
self.transport = None
|
||||
self.event = None
|
||||
self.key_types = None
|
||||
self.key_index = None
|
||||
self.key_type = None
|
||||
self.start_time = time.monotonic()
|
||||
self.worker = None
|
||||
self.exception = None
|
||||
self.connect_start_time = None
|
||||
# Stats
|
||||
self.init_connection_attempts = 0
|
||||
self.key_connection_failures = 0
|
||||
self.key_negotiation_failures = 0
|
||||
|
||||
def setWorker(self, worker):
|
||||
"""Store a reference to the worker thread so we register and unregister
|
||||
the socket file descriptor from the polling object"""
|
||||
self.worker = worker
|
||||
|
||||
def fail(self, exception):
|
||||
"""Declare this request a failure and store the related exception"""
|
||||
self.exception = exception
|
||||
self.cleanup()
|
||||
self.complete = True
|
||||
self.state = self.COMPLETE
|
||||
|
||||
def cleanup(self):
|
||||
"""Try to close everything and unregister from the worker"""
|
||||
self._close()
|
||||
if self.exception:
|
||||
status = 'failed'
|
||||
else:
|
||||
status = 'complete'
|
||||
dt = int(time.monotonic() - self.start_time)
|
||||
self.log.debug("Nodescan request %s with %s keys, "
|
||||
"%s initial connection attempts, "
|
||||
"%s key connection failures, "
|
||||
"%s key negotiation failures in %s seconds",
|
||||
status, len(self.keys),
|
||||
self.init_connection_attempts,
|
||||
self.key_connection_failures,
|
||||
self.key_negotiation_failures,
|
||||
dt)
|
||||
|
||||
def result(self):
|
||||
"""Return the resulting keys, or raise an exception"""
|
||||
if self.exception:
|
||||
raise self.exception
|
||||
return self.keys
|
||||
|
||||
def _close(self):
|
||||
if self.transport:
|
||||
try:
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.transport = None
|
||||
self.event = None
|
||||
if self.sock:
|
||||
self.worker.unRegisterDescriptor(self.sock)
|
||||
try:
|
||||
self.sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.sock = None
|
||||
|
||||
def _checkTimeout(self):
|
||||
now = time.monotonic()
|
||||
if now - self.start_time > self.timeout:
|
||||
raise exceptions.ConnectionTimeoutException(
|
||||
f"Timeout connecting to {self.ip} on port {self.port}")
|
||||
|
||||
def _checkTransport(self):
|
||||
# This stanza is from
|
||||
# https://github.com/paramiko/paramiko/blob/main/paramiko/transport.py
|
||||
if not self.transport.active:
|
||||
e = self.transport.get_exception()
|
||||
if e is not None:
|
||||
raise e
|
||||
raise paramiko.exceptions.SSHException("Negotiation failed.")
|
||||
|
||||
def _connect(self):
|
||||
if self.sock:
|
||||
self.worker.unRegisterDescriptor(self.sock)
|
||||
self.sock = socket.socket(self.family, socket.SOCK_STREAM)
|
||||
# Set nonblocking so we can poll for connection completion
|
||||
self.sock.setblocking(False)
|
||||
try:
|
||||
self.sock.connect(self.sockaddr)
|
||||
except BlockingIOError:
|
||||
self.state = self.CONNECTING_INIT
|
||||
self.connect_start_time = time.monotonic()
|
||||
self.worker.registerDescriptor(self.sock)
|
||||
|
||||
def _start(self):
|
||||
# Use our Event subclass that will wake the worker when the
|
||||
# event is set.
|
||||
self.event = NodescanEvent(self.worker)
|
||||
# Return the socket to blocking mode as we hand it off to paramiko.
|
||||
self.sock.setblocking(True)
|
||||
self.transport = paramiko.transport.Transport(self.sock)
|
||||
if self.key_type is not None:
|
||||
opts = self.transport.get_security_options()
|
||||
opts.key_types = [self.key_type]
|
||||
# This starts a thread.
|
||||
self.transport.start_client(
|
||||
event=self.event, timeout=self.timeout)
|
||||
|
||||
def _nextKey(self):
|
||||
self._close()
|
||||
self.key_index += 1
|
||||
if self.key_index >= len(self.key_types):
|
||||
self.state = self.COMPLETE
|
||||
return True
|
||||
self.key_type = self.key_types[self.key_index]
|
||||
self._connect()
|
||||
self.state = self.CONNECTING_KEY
|
||||
|
||||
def advance(self, socket_ready):
|
||||
if self.state == self.START:
|
||||
if self.worker is None:
|
||||
raise Exception("Request not registered with worker")
|
||||
if not self.host_key_checking:
|
||||
self.state = self.COMPLETE
|
||||
else:
|
||||
if 'fake' in self.ip or self.FAKE:
|
||||
if self.gather_hostkeys:
|
||||
self.keys = ['ssh-rsa FAKEKEY']
|
||||
self.state = self.COMPLETE
|
||||
else:
|
||||
self.init_connection_attempts += 1
|
||||
self._connect()
|
||||
|
||||
if self.state == self.CONNECTING_INIT:
|
||||
if not socket_ready:
|
||||
# Check the overall timeout
|
||||
self._checkTimeout()
|
||||
# If we're still here, then don't let any individual
|
||||
# connection attempt last more than 10 seconds:
|
||||
if time.monotonic() - self.connect_start_time >= 10:
|
||||
self._close()
|
||||
self.state = self.START
|
||||
return
|
||||
eno = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
||||
if eno:
|
||||
if eno not in [errno.ECONNREFUSED, errno.EHOSTUNREACH]:
|
||||
self.log.exception(
|
||||
f"Error {eno} connecting to {self.ip} "
|
||||
f"on port {self.port}")
|
||||
# Try again. Don't immediately start to reconnect
|
||||
# since econnrefused can happen very quickly, so we
|
||||
# could end up busy-waiting.
|
||||
self._close()
|
||||
self.state = self.START
|
||||
self._checkTimeout()
|
||||
return
|
||||
if self.gather_hostkeys:
|
||||
self._start()
|
||||
self.state = self.NEGOTIATING_INIT
|
||||
else:
|
||||
self.state = self.COMPLETE
|
||||
|
||||
if self.state == self.NEGOTIATING_INIT:
|
||||
if not self.event.is_set():
|
||||
self._checkTimeout()
|
||||
return
|
||||
# This will raise an exception on ssh errors
|
||||
try:
|
||||
self._checkTransport()
|
||||
except Exception:
|
||||
self.log.exception(
|
||||
f"SSH error connecting to {self.ip} on port {self.port}")
|
||||
# Try again
|
||||
self._close()
|
||||
self.key_negotiation_failures += 1
|
||||
self.state = self.START
|
||||
self._checkTimeout()
|
||||
self._connect()
|
||||
return
|
||||
# This is our first successful connection. Now that
|
||||
# we've done it, start again specifying the first key
|
||||
# type.
|
||||
opts = self.transport.get_security_options()
|
||||
self.key_types = opts.key_types
|
||||
self.key_index = -1
|
||||
self._nextKey()
|
||||
|
||||
if self.state == self.CONNECTING_KEY:
|
||||
if not socket_ready:
|
||||
self._checkTimeout()
|
||||
return
|
||||
eno = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
||||
if eno:
|
||||
self.log.error(
|
||||
f"Error {eno} connecting to {self.ip} on port {self.port}")
|
||||
self.key_connection_failures += 1
|
||||
self._nextKey()
|
||||
return
|
||||
self._start()
|
||||
self.state = self.NEGOTIATING_KEY
|
||||
|
||||
if self.state == self.NEGOTIATING_KEY:
|
||||
if not self.event.is_set():
|
||||
self._checkTimeout()
|
||||
return
|
||||
# This will raise an exception on ssh errors
|
||||
try:
|
||||
self._checkTransport()
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
if 'no acceptable host key' not in msg:
|
||||
# We expect some host keys to not be valid
|
||||
# when scanning only log if the error isn't
|
||||
# due to mismatched host key types.
|
||||
self.log.exception(
|
||||
f"SSH error connecting to {self.ip} "
|
||||
f"on port {self.port}")
|
||||
self.key_negotiation_failures += 1
|
||||
self._nextKey()
|
||||
|
||||
# Check if we're still in the same state
|
||||
if self.state == self.NEGOTIATING_KEY:
|
||||
key = self.transport.get_remote_server_key()
|
||||
if key:
|
||||
self.keys.append("%s %s" % (key.get_name(), key.get_base64()))
|
||||
self.log.debug('Added ssh host key: %s', key.get_name())
|
||||
self._nextKey()
|
||||
|
||||
if self.state == self.COMPLETE:
|
||||
self._close()
|
||||
self.complete = True
|
||||
|
||||
|
||||
class NodescanEvent(threading.Event):
|
||||
"""A subclass of event that will wake the NodescanWorker poll"""
|
||||
def __init__(self, worker, *args, **kw):
|
||||
super().__init__(*args, **kw)
|
||||
self._zuul_worker = worker
|
||||
|
||||
def set(self):
|
||||
super().set()
|
||||
try:
|
||||
os.write(self._zuul_worker.wake_write, b'\n')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class NodescanWorker:
|
||||
"""Handles requests for nodescans.
|
||||
|
||||
This class has a single thread that drives nodescan requests
|
||||
submitted by the launcher.
|
||||
|
||||
"""
|
||||
# This process is highly scalable, except for paramiko which
|
||||
# spawns a thread for each ssh connection. To avoid thread
|
||||
# overload, we set a max value for concurrent requests.
|
||||
# Simultaneous requests higher than this value will be queued.
|
||||
MAX_REQUESTS = 100
|
||||
|
||||
def __init__(self):
|
||||
self.wake_read, self.wake_write = os.pipe()
|
||||
fcntl.fcntl(self.wake_read, fcntl.F_SETFL, os.O_NONBLOCK)
|
||||
self._running = False
|
||||
self._active_requests = []
|
||||
self._pending_requests = []
|
||||
self.poll = select.epoll()
|
||||
self.poll.register(self.wake_read, select.EPOLLIN)
|
||||
|
||||
def start(self):
|
||||
self._running = True
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
os.write(self.wake_write, b'\n')
|
||||
|
||||
def join(self):
|
||||
self.thread.join()
|
||||
|
||||
def addRequest(self, request):
|
||||
"""Submit a nodescan request"""
|
||||
request.setWorker(self)
|
||||
if len(self._active_requests) >= self.MAX_REQUESTS:
|
||||
self._pending_requests.append(request)
|
||||
else:
|
||||
self._active_requests.append(request)
|
||||
# If the poll is sleeping, wake it up for immediate action
|
||||
os.write(self.wake_write, b'\n')
|
||||
|
||||
def removeRequest(self, request):
|
||||
"""Remove the request and cleanup"""
|
||||
if request is None:
|
||||
return
|
||||
request.cleanup()
|
||||
try:
|
||||
self._active_requests.remove(request)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
self._pending_requests.remove(request)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def registerDescriptor(self, fd):
|
||||
"""Register the fd with the poll object"""
|
||||
# Oneshot means that once it triggers, it will automatically
|
||||
# be removed. That's great for us since we only use this for
|
||||
# detecting when the initial connection is complete and have
|
||||
# no further use.
|
||||
self.poll.register(
|
||||
fd, select.EPOLLOUT | select.EPOLLERR |
|
||||
select.EPOLLHUP | select.EPOLLONESHOT)
|
||||
|
||||
def unRegisterDescriptor(self, fd):
|
||||
"""Unregister the fd with the poll object"""
|
||||
try:
|
||||
self.poll.unregister(fd)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
while self._running:
|
||||
# Set the poll timeout to 1 second so that we check all
|
||||
# requests for timeouts every second. This could be
|
||||
# increased to a few seconds without significant impact.
|
||||
timeout = 1
|
||||
while (self._pending_requests and
|
||||
len(self._active_requests) < self.MAX_REQUESTS):
|
||||
# If we have room for more requests, add them and set
|
||||
# the timeout to 0 so that we immediately start
|
||||
# advancing them.
|
||||
request = self._pending_requests.pop(0)
|
||||
self._active_requests.append(request)
|
||||
timeout = 0
|
||||
ready = self.poll.poll(timeout=timeout)
|
||||
ready = [x[0] for x in ready]
|
||||
if self.wake_read in ready:
|
||||
# Empty the wake pipe
|
||||
while True:
|
||||
try:
|
||||
os.read(self.wake_read, 1024)
|
||||
except BlockingIOError:
|
||||
break
|
||||
for request in self._active_requests:
|
||||
try:
|
||||
socket_ready = (request.sock and
|
||||
request.sock.fileno() in ready)
|
||||
request.advance(socket_ready)
|
||||
except Exception as e:
|
||||
request.fail(e)
|
||||
if request.complete:
|
||||
self.removeRequest(request)
|
||||
|
||||
|
||||
class Launcher:
|
||||
log = logging.getLogger("zuul.Launcher")
|
||||
# Max. time to wait for a cache to sync
|
||||
|
||||
Reference in New Issue
Block a user