Use ZuulTreeCache for OIDC signing keys

Change-Id: I696d631af2c3f75a9bd035b2537099dec95675a1
This commit is contained in:
Dong Zhang
2025-02-14 10:30:42 +01:00
parent 302341cc2a
commit a5b2abc388
8 changed files with 98 additions and 16 deletions

View File

@@ -2800,7 +2800,7 @@ class ZuulTestCase(BaseTestCase):
password = self.config.get("keystore", "password")
keystore = zuul.lib.keystorage.KeyStorage(
self.zk_client, password=password)
self.zk_client, password=password, start_cache=False)
import_keys = {}
import_data = {'keys': import_keys}
@@ -2814,6 +2814,7 @@ class ZuulTestCase(BaseTestCase):
import_keys[path] = json.load(i)
keystore.importKeys(import_data, False)
keystore.stop()
def copyDirToRepo(self, project, source_path):
self.init_repo(project)
@@ -3004,6 +3005,14 @@ class ZuulTestCase(BaseTestCase):
log_str += "".join(traceback.format_stack(stack_frame))
self.log.debug(log_str)
raise Exception("More than one thread is running: %s" % threads)
self.cleanupTestServers()
def cleanupTestServers(self):
del self.executor_server
del self.scheds
del self.launcher
del self.fake_nodepool
del self.zk_client
def assertCleanShutdown(self):
pass

View File

@@ -7470,11 +7470,15 @@ class TestExecutor(ZuulTestCase):
# so skip these checks.
pass
def cleanupTestServers(self):
pass
def assertCleanShutdown(self):
self.log.debug("Assert clean shutdown")
# After shutdown, make sure no jobs are running
self.assertEqual({}, self.executor_server.job_workers)
super().cleanupTestServers()
# Make sure that git.Repo objects have been garbage collected.
gc.disable()

View File

@@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import copy
import io
import json
import logging
@@ -5594,8 +5595,12 @@ class TestOIDCSigningKeys(ZuulTestCase):
# a new key should be appended
time.sleep(rotation_interval + 1)
keystore.rotateOidcSigningKeys(algorithm, rotation_interval, max_ttl)
test_keys3 = keystore.getOidcSigningKeyData(algorithm)
self.assertEqual(len(test_keys3.keys), 2)
for _ in iterate_timeout(10, 'cache to sync'):
test_keys3 = keystore.getOidcSigningKeyData(algorithm)
if len(test_keys3.keys) == 2:
# avoid test_keys3 being modified in place by cache update
test_keys3 = copy.deepcopy(test_keys3)
break
private_key3, _, version3 = keystore.getLatestOidcSigningKeys(
algorithm)
self.assertEqual(
@@ -5617,8 +5622,10 @@ class TestOIDCSigningKeys(ZuulTestCase):
# the old key should be removed
time.sleep(max_ttl + 1)
keystore.rotateOidcSigningKeys(algorithm, rotation_interval, max_ttl)
test_keys4 = keystore.getOidcSigningKeyData(algorithm)
self.assertEqual(len(test_keys4.keys), 1)
for _ in iterate_timeout(10, 'cache to sync'):
test_keys4 = keystore.getOidcSigningKeyData(algorithm)
if len(test_keys4.keys) == 1:
break
private_key4, _, version4 = keystore.getLatestOidcSigningKeys(
algorithm)
self.assertEqual(

View File

@@ -3979,6 +3979,7 @@ class ExecutorServer(BaseMergeServer):
def stop(self):
self.log.debug("Stopping executor")
self.component_info.state = self.component_info.STOPPED
self.keystore.stop()
self.connections.stop()
self.disk_accountant.stop()
# The governor can change function registration, so make sure

View File

@@ -25,11 +25,35 @@ import paramiko
from zuul.exceptions import AlgorithmNotSupportedException
from zuul.lib import encryption, strings
from zuul.zk import ZooKeeperBase
from zuul.zk.cache import ZuulTreeCache
from zuul.zk.zkobject import ZKContext, ZKObject
RSA_KEY_SIZE = 2048
class OIDCSigningKeysCache(ZuulTreeCache):
def __init__(self, client):
super().__init__(
client, OIDCSigningKeys.OIDC_ROOT_PATH, async_worker=False)
def objectFromRaw(self, key, data, zstat):
return OIDCSigningKeys._fromRaw(data, zstat, None)
def updateFromRaw(self, obj, key, data, zstat):
obj._updateFromRaw(data, zstat, None)
def parsePath(self, path):
return self._formatKey(path.split('/')[-1])
def getSigningKeys(self, algorithm):
return self._cached_objects.get(self._formatKey(algorithm))
def _formatKey(self, algorithm):
# key format: ("oidc_keys", algorithm)
return ('oidc_keys', algorithm)
class OIDCSigningKeys(ZKObject):
OIDC_ROOT_PATH = "/keystorage-oidc"
@@ -175,14 +199,29 @@ class KeyStorage(ZooKeeperBase):
SECRETS_PATH = PROJECT_PATH + "/secrets"
SSH_PATH = PROJECT_PATH + "/ssh"
def __init__(self, zookeeper_client, password, backup=None):
def __init__(self, zookeeper_client, password, backup=None,
start_cache=True):
super().__init__(zookeeper_client)
self.password = password
self.password_bytes = password.encode("utf-8")
if start_cache:
self.oidc_signing_keys_cache = OIDCSigningKeysCache(self.client)
else:
self.oidc_signing_keys_cache = None
self.getProjectSSHKeys = cachetools.func.lru_cache(maxsize=None)(
self.getProjectSSHKeys)
self.getProjectSecretsKeys = cachetools.func.lru_cache(maxsize=None)(
self.getProjectSecretsKeys)
def createZKContext(self):
return ZKContext(self.client, None, None, self.log)
def stop(self):
if self.oidc_signing_keys_cache:
self.oidc_signing_keys_cache.stop()
super().stop()
def _walk(self, root):
ret = []
children = self.kazoo_client.get_children(root)
@@ -227,7 +266,6 @@ class KeyStorage(ZooKeeperBase):
key_path = self.SSH_PATH.format(connection_name, prefix, name)
return key_path
@cachetools.cached(cache={})
def getProjectSSHKeys(self, connection_name, project_name):
"""Return the public and private keys"""
key = self._getSSHKey(connection_name, project_name)
@@ -306,7 +344,6 @@ class KeyStorage(ZooKeeperBase):
key_path = self.SECRETS_PATH.format(connection_name, prefix, name)
return key_path
@cachetools.cached(cache={})
def getProjectSecretsKeys(self, connection_name, project_name):
"""Return the public and private keys"""
pem_private_key = self._getSecretsKey(connection_name, project_name)
@@ -415,20 +452,35 @@ class KeyStorage(ZooKeeperBase):
rotation_interval, max_ttl)
def getOidcSigningKeyData(self, algorithm):
"""Return the key data of an algorithm of OIDC singing keys"""
with self.createZKContext() as context:
oidc_signing_keys = OIDCSigningKeys.loadKeys(
context, algorithm)
if not oidc_signing_keys:
OIDCSigningKeys.createAndStoreKeys(
context, algorithm, self.password_bytes)
"""
Return the key data of an algorithm of OIDC singing keys
The data returned is from ZuulTreeCache, could be not in sync with
the actual data in Zookeeper.
"""
oidc_signing_keys = self.oidc_signing_keys_cache.getSigningKeys(
algorithm)
# If it is not found in cache, it could be the cache hasn't
# been synced or the key has not been created yet. We need to
# check both.
if not oidc_signing_keys:
with self.createZKContext() as context:
oidc_signing_keys = OIDCSigningKeys.loadKeys(
context, algorithm)
if not oidc_signing_keys:
OIDCSigningKeys.createAndStoreKeys(
context, algorithm, self.password_bytes)
oidc_signing_keys = OIDCSigningKeys.loadKeys(
context, algorithm)
return oidc_signing_keys
def getLatestOidcSigningKeys(self, algorithm):
"""Return the latest key pair of an algorithm of OIDC singing keys"""
"""
Return the latest key pair of an algorithm of OIDC singing keys
The data rerunted is from ZuulTreeCache, could be not in sync with
the actual data in Zookeeper.
"""
signing_key_data = self.getOidcSigningKeyData(algorithm=algorithm)
latest_key = signing_key_data.keys[-1]
pem_private_key = latest_key["private_key"].encode("utf-8")

View File

@@ -400,6 +400,7 @@ class Scheduler(threading.Thread):
def stop(self):
self.log.debug("Stopping scheduler")
self.keystore.stop()
self._stopped = True
self.wake_event.set()
# Main thread, connections and layout update may be waiting

View File

@@ -3143,6 +3143,7 @@ class ZuulWeb(object):
def stop(self):
self.log.info("ZuulWeb stopping")
self.keystore.stop()
self._running = False
self.component_info.state = self.component_info.STOPPED
cherrypy.engine.exit()

View File

@@ -279,6 +279,13 @@ class ZooKeeperBase(ZooKeeperSimpleBase):
self.client.on_reconnect_listeners.append(self._onReconnect)
self.client.on_suspended_listeners.append(self._onSuspended)
def stop(self):
if self.client:
self.client.on_connect_listeners.remove(self._onConnect)
self.client.on_disconnect_listeners.remove(self._onDisconnect)
self.client.on_reconnect_listeners.remove(self._onReconnect)
self.client.on_suspended_listeners.remove(self._onSuspended)
def _onConnect(self):
pass