Refactor branch cache to support more queries

The current branch cache is hyper-optimized to support exactly two
types of branch queries: all branches for a project, or unprotected
branches for a project.  GitHub provides another axis: "locked"
branches.  Cconceivably, other code review systems could as well, and
there may be even more axes in the future.

In order to support locked branches in a future change, we must first
refactor the branch cache to support more than two queries.  This
change implements that with the following scheme:

The branch cache will be a dictionary of project_name -> ProjectInfo,
and ProjectInfo will hold general information about the project such
as supported merge modes and default branch, as well as another
dictionary branch_name -> BranchInfo.  The BranchInfo record will hold
boolean flags indicating whether the branch is protected or locked.

Additionally, the project_info record will hold a set of which queries
for these flags have been performed and whether they were successful
or not.  This allows us to determine whether the branch_info flags
are valid or not.  For example, if we have only performed a query to
get all the branches, and a caller requests a list of protected
branches, we know that the BranchInfo.protected bool is not valid,
so we return a LookupEerror to the caller which will trigger another
query to get the protected branch list which will then be used to
update the branch cache, setting the protected bool to true where
appropriate on existing BranchInfo objects and setting the protected
query flag on ProjectInfo.  The result matches the current behavior,
but is extensible to support more flags.

In order to minimize the size of the branch cache in ZooKeeper, the
BranchInfo object is serialized as a simple integer with a bitmap of
the associated booleans.  Likewise, the several queries are stored in
the serialized ProjectInfo as two bitmaps (one for success, and one
for failed).

This change stubs out the "locked" flag and query in some places, just
to demonstrate sufficiency for future use, but it does not implement
support for locked branches yet.  A future change will do that.
As long as we don't actually add any locked branches, we can still
serialize to the old branch cache data structure, so this change does
so to enable rolling upgrades.  Tests of the upgrade path and
continued operation on only the old data path are included.

Change-Id: I8841e675295f15e5d6dd004f9e34836b8bbbdb63
This commit is contained in:
James E. Blair 2024-02-29 15:56:38 -08:00
parent 02133ca0ff
commit 7741fc923c
11 changed files with 715 additions and 243 deletions

View File

@ -198,3 +198,9 @@ Version 26
:Prior Zuul version: 9.5.0
:Description: Refactor circular dependencies.
Affects schedulers and executors.
Version 27
----------
:Prior Zuul version: 10.0.0
:Description: Refactor branch cache.
Affects schedulers and web.

View File

@ -1,4 +1,4 @@
# Copyright 2022 Acme Gating, LLC
# Copyright 2022, 2024 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -67,3 +67,84 @@ class TestModelUpgrade(ZuulTestCase):
for _ in iterate_timeout(30, "model api to update"):
if component_registry.model_api == 1:
break
class TestGithubModelUpgrade(ZuulTestCase):
config_file = "zuul-gerrit-github.conf"
scheduler_count = 1
@model_version(26)
@simple_layout('layouts/gate-github.yaml', driver='github')
def test_model_26(self):
# This excercises the backwards-compat branch cache
# serialization code; no uprade happens in this test.
first = self.scheds.first
second = self.createScheduler()
second.start()
self.assertEqual(len(self.scheds), 2)
for _ in iterate_timeout(10, "until priming is complete"):
state_one = first.sched.local_layout_state.get("tenant-one")
if state_one:
break
for _ in iterate_timeout(
10, "all schedulers to have the same layout state"):
if (second.sched.local_layout_state.get(
"tenant-one") == state_one):
break
conn = first.connections.connections['github']
with self.createZKContext() as ctx:
# There's a lot of exception catching in the branch cache,
# so exercise a serialize/deserialize cycle.
old = conn._branch_cache.cache.serialize(ctx)
data = json.loads(old)
self.assertEqual(['master'],
data['remainder']['org/common-config'])
new = conn._branch_cache.cache.deserialize(old, ctx)
self.assertTrue(new['projects'][
'org/common-config'].branches['master'].present)
with first.sched.layout_update_lock, first.sched.run_handler_lock:
A = self.fake_github.openFakePullRequest(
'org/project', 'master', 'A')
self.fake_github.emitEvent(A.getPullRequestOpenedEvent())
self.waitUntilSettled(matcher=[second])
self.waitUntilSettled()
self.assertHistory([
dict(name='project-test1', result='SUCCESS'),
dict(name='project-test2', result='SUCCESS'),
], ordered=False)
@model_version(26)
@simple_layout('layouts/gate-github.yaml', driver='github')
def test_model_26_27(self):
# This excercises the branch cache upgrade.
first = self.scheds.first
self.model_test_component_info.model_api = 27
second = self.createScheduler()
second.start()
self.assertEqual(len(self.scheds), 2)
for _ in iterate_timeout(10, "until priming is complete"):
state_one = first.sched.local_layout_state.get("tenant-one")
if state_one:
break
for _ in iterate_timeout(
10, "all schedulers to have the same layout state"):
if (second.sched.local_layout_state.get(
"tenant-one") == state_one):
break
with first.sched.layout_update_lock, first.sched.run_handler_lock:
A = self.fake_github.openFakePullRequest(
'org/project', 'master', 'A')
self.fake_github.emitEvent(A.getPullRequestOpenedEvent())
self.waitUntilSettled(matcher=[second])
self.waitUntilSettled()
self.assertHistory([
dict(name='project-test1', result='SUCCESS'),
dict(name='project-test2', result='SUCCESS'),
], ordered=False)

View File

@ -1,5 +1,5 @@
# Copyright 2019 Red Hat, Inc.
# Copyright 2022 Acme Gating, LLC
# Copyright 2022, 2024 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -27,7 +27,7 @@ from zuul.lib import yamlutil as yaml
from zuul.model import BuildRequest, HoldRequest, MergeRequest
from zuul.zk import ZooKeeperClient
from zuul.zk.blob_store import BlobStore
from zuul.zk.branch_cache import BranchCache
from zuul.zk.branch_cache import BranchCache, BranchFlag, BranchInfo
from zuul.zk.change_cache import (
AbstractChangeCache,
ChangeKey,
@ -1830,132 +1830,193 @@ class TestBranchCache(ZooKeeperBaseTestCase):
conn = DummyConnection()
cache = BranchCache(self.zk_client, conn, self.component_registry)
protected_flags = {BranchFlag.PROTECTED}
all_flags = {BranchFlag.PRESENT}
test_data = {
'project1': {
'all': ['protected1', 'protected2',
'unprotected1', 'unprotected2'],
'protected': ['protected1', 'protected2'],
'all': [
BranchInfo('protected1', present=True),
BranchInfo('protected2', present=True),
BranchInfo('unprotected1', present=True),
BranchInfo('unprotected2', present=True),
],
'protected': [
BranchInfo('protected1', protected=True),
BranchInfo('protected2', protected=True),
],
},
}
# Test a protected-only query followed by all
cache.setProjectBranches('project1', True,
cache.setProjectBranches('project1', protected_flags,
test_data['project1']['protected'])
self.assertEqual(
sorted(cache.getProjectBranches('project1', True)),
test_data['project1']['protected']
sorted([bi.name for bi in
cache.getProjectBranches('project1', protected_flags)
if bi.protected is True]),
[bi.name for bi in test_data['project1']['protected']]
)
self.assertRaises(
LookupError,
lambda: cache.getProjectBranches('project1', False)
lambda: cache.getProjectBranches('project1', all_flags),
)
cache.setProjectBranches('project1', False,
cache.setProjectBranches('project1', all_flags,
test_data['project1']['all'])
self.assertEqual(
sorted(cache.getProjectBranches('project1', True)),
test_data['project1']['protected']
sorted([bi.name for bi in
cache.getProjectBranches('project1', protected_flags)
if bi.protected is True]),
[bi.name for bi in test_data['project1']['protected']]
)
self.assertEqual(
sorted(cache.getProjectBranches('project1', False)),
test_data['project1']['all']
sorted([bi.name for bi in
cache.getProjectBranches('project1', all_flags)]),
[bi.name for bi in test_data['project1']['all']]
)
# There's a lot of exception catching in the branch cache,
# so exercise a serialize/deserialize cycle.
ctx = ZKContext(self.zk_client, None, None, self.log)
data = cache.cache.serialize(ctx)
cache.cache.deserialize(data, ctx)
def test_branch_cache_all_then_protected(self):
conn = DummyConnection()
cache = BranchCache(self.zk_client, conn, self.component_registry)
protected_flags = {BranchFlag.PROTECTED}
all_flags = {BranchFlag.PRESENT}
test_data = {
'project1': {
'all': ['protected1', 'protected2',
'unprotected1', 'unprotected2'],
'protected': ['protected1', 'protected2'],
'all': [
BranchInfo('protected1', present=True),
BranchInfo('protected2', present=True),
BranchInfo('unprotected1', present=True),
BranchInfo('unprotected2', present=True),
],
'protected': [
BranchInfo('protected1', protected=True),
BranchInfo('protected2', protected=True),
],
},
}
self.assertRaises(
LookupError,
lambda: cache.getProjectBranches('project1', True)
lambda: cache.getProjectBranches('project1', protected_flags)
)
self.assertRaises(
LookupError,
lambda: cache.getProjectBranches('project1', False)
lambda: cache.getProjectBranches('project1', all_flags)
)
# Test the other order; all followed by protected-only
cache.setProjectBranches('project1', False,
cache.setProjectBranches('project1', all_flags,
test_data['project1']['all'])
self.assertRaises(
LookupError,
lambda: cache.getProjectBranches('project1', True)
lambda: cache.getProjectBranches('project1', protected_flags)
)
self.assertEqual(
sorted(cache.getProjectBranches('project1', False)),
test_data['project1']['all']
sorted([bi.name for bi in
cache.getProjectBranches('project1', all_flags)]),
[bi.name for bi in test_data['project1']['all']]
)
cache.setProjectBranches('project1', True,
cache.setProjectBranches('project1', protected_flags,
test_data['project1']['protected'])
self.assertEqual(
sorted(cache.getProjectBranches('project1', True)),
test_data['project1']['protected']
sorted([bi.name for bi in
cache.getProjectBranches('project1', protected_flags)
if bi.protected is True]),
[bi.name for bi in test_data['project1']['protected']]
)
self.assertEqual(
sorted(cache.getProjectBranches('project1', False)),
test_data['project1']['all']
sorted([bi.name for bi in
cache.getProjectBranches('project1', all_flags)]),
[bi.name for bi in test_data['project1']['all']]
)
# There's a lot of exception catching in the branch cache,
# so exercise a serialize/deserialize cycle.
ctx = ZKContext(self.zk_client, None, None, self.log)
data = cache.cache.serialize(ctx)
cache.cache.deserialize(data, ctx)
def test_branch_cache_change_protected(self):
conn = DummyConnection()
cache = BranchCache(self.zk_client, conn, self.component_registry)
protected_flags = {BranchFlag.PROTECTED}
all_flags = {BranchFlag.PRESENT}
data1 = {
'project1': {
'all': ['newbranch', 'protected'],
'protected': ['protected'],
'all': [
BranchInfo('newbranch', present=True),
BranchInfo('protected', present=True),
],
'protected': [
BranchInfo('protected', protected=True),
],
},
}
data2 = {
'project1': {
'all': ['newbranch', 'protected'],
'protected': ['newbranch', 'protected'],
'all': [
BranchInfo('newbranch', present=True),
BranchInfo('protected', present=True),
],
'protected': [
BranchInfo('newbranch', present=True, protected=True),
BranchInfo('protected', protected=True),
],
},
}
# Create a new unprotected branch
cache.setProjectBranches('project1', False,
cache.setProjectBranches('project1', all_flags,
data1['project1']['all'])
cache.setProjectBranches('project1', True,
cache.setProjectBranches('project1', protected_flags,
data1['project1']['protected'])
self.assertEqual(
cache.getProjectBranches('project1', True),
data1['project1']['protected']
sorted([bi.name for bi in
cache.getProjectBranches('project1', protected_flags)
if bi.protected is True]),
[bi.name for bi in data1['project1']['protected']]
)
self.assertEqual(
sorted(cache.getProjectBranches('project1', False)),
data1['project1']['all']
sorted([bi.name for bi in
cache.getProjectBranches('project1', all_flags)]),
[bi.name for bi in data1['project1']['all']]
)
# Change it to protected
cache.setProtected('project1', 'newbranch', True)
self.assertEqual(
sorted(cache.getProjectBranches('project1', True)),
data2['project1']['protected']
sorted([bi.name for bi in
cache.getProjectBranches('project1', protected_flags)
if bi.protected is True]),
[bi.name for bi in data2['project1']['protected']]
)
self.assertEqual(
sorted(cache.getProjectBranches('project1', False)),
data2['project1']['all']
sorted([bi.name for bi in
cache.getProjectBranches('project1', all_flags)]),
[bi.name for bi in data2['project1']['all']]
)
# Change it back
cache.setProtected('project1', 'newbranch', False)
self.assertEqual(
sorted(cache.getProjectBranches('project1', True)),
data1['project1']['protected']
sorted([bi.name for bi in
cache.getProjectBranches('project1', protected_flags)
if bi.protected is True]),
[bi.name for bi in data1['project1']['protected']]
)
self.assertEqual(
sorted(cache.getProjectBranches('project1', False)),
data1['project1']['all']
sorted([bi.name for bi in
cache.getProjectBranches('project1', all_flags)]),
[bi.name for bi in data1['project1']['all']]
)
def test_branch_cache_lookup_error(self):

View File

@ -155,18 +155,17 @@ class ZKBranchCacheMixin:
pass
@abc.abstractmethod
def _fetchProjectBranches(self, project, exclude_unprotected):
def _fetchProjectBranches(self, project, required_flags):
"""Perform a remote query to determine the project's branches.
Connection subclasses should implement this method.
:param model.Project project:
The project.
:param bool exclude_unprotected:
Whether the query should exclude unprotected branches from
the response.
:param set(BranchFlag) required_flags:
Which flags need to be valid in the result set.
:returns: A list of branch names.
:returns: A list of BranchInfo objects
"""
def _fetchProjectMergeModes(self, project):
@ -247,21 +246,15 @@ class ZKBranchCacheMixin:
The project for which the branches are returned.
"""
# Figure out which queries we have a cache for
protected_branches = self._branch_cache.getProjectBranches(
project.name, True, default=None)
all_branches = self._branch_cache.getProjectBranches(
project.name, False, default=None)
# Update them if we have them
if protected_branches is not None:
protected_branches = self._fetchProjectBranches(project, True)
required_flags = self._branch_cache._getProjectCompletedFlags(
project.name)
if required_flags:
# Update them if we have them
valid_flags, branch_infos = self._fetchProjectBranches(
project, required_flags)
self._branch_cache.setProjectBranches(
project.name, True, protected_branches)
if all_branches is not None:
all_branches = self._fetchProjectBranches(project, False)
self._branch_cache.setProjectBranches(
project.name, False, all_branches)
project.name, valid_flags, branch_infos)
merge_modes = self._fetchProjectMergeModes(project)
self._branch_cache.setProjectMergeModes(
@ -285,12 +278,18 @@ class ZKBranchCacheMixin:
:returns: The list of branch names.
"""
exclude_unprotected = tenant.getExcludeUnprotectedBranches(project)
exclude_locked = False
branches = None
required_flags = self._fetchProjectBranchesRequiredFlags(
exclude_unprotected, exclude_locked)
if self._branch_cache:
try:
branches = self._branch_cache.getProjectBranches(
project.name, exclude_unprotected, min_ltime)
project.name, required_flags, min_ltime)
if branches is not None:
branches = [b.name for b in self._filterProjectBranches(
branches, exclude_unprotected, exclude_locked)]
except LookupError:
if self.read_only:
# A scheduler hasn't attempted to fetch them yet
@ -308,9 +307,17 @@ class ZKBranchCacheMixin:
raise RuntimeError(
"Will not fetch project branches as read-only is set")
# We need to perform a query
# Above we calculated a set of flags needed to answer the
# query. If the fetch below fails, we will mark that set of
# flags as failed in the ProjectInfo structure. However, if
# the fetch below succeeds, it can supply its own set of valid
# flags that we will record as successful. This lets the
# driver indicate that the returned results include more data
# than strictly necessary (ie, protected+locked and not just
# protected).
try:
branches = self._fetchProjectBranches(project, exclude_unprotected)
valid_flags, branch_infos = self._fetchProjectBranches(
project, required_flags)
except Exception:
# We weren't able to get the branches. We need to tell
# future schedulers to try again but tell zuul-web that we
@ -319,15 +326,15 @@ class ZKBranchCacheMixin:
# time we encounter None in the cache, we will try again.
if self._branch_cache:
self._branch_cache.setProjectBranches(
project.name, exclude_unprotected, None)
project.name, required_flags, None)
raise
self.log.info("Got branches for %s" % project.name)
if self._branch_cache:
self._branch_cache.setProjectBranches(
project.name, exclude_unprotected, branches)
project.name, valid_flags, branch_infos)
return sorted(branches)
return sorted([bi.name for bi in branch_infos])
def getProjectMergeModes(self, project, tenant, min_ltime=-1):
"""Get the merge modes for the given project.
@ -442,9 +449,8 @@ class ZKBranchCacheMixin:
return default_branch
def checkBranchCache(self, project_name: str, event,
protected: bool = None) -> None:
"""Clear the cache for a project when a branch event is processed
def checkBranchCache(self, project_name, event, protected=None):
"""Update the cache for a project when a branch event is processed
This method must be called when a branch event is processed: if the
event references a branch and the unprotected branches are excluded,
@ -462,22 +468,17 @@ class ZKBranchCacheMixin:
protected = self.isBranchProtected(project_name, event.branch,
zuul_event_id=event)
if protected is not None:
# If the branch appears in the exclude_unprotected cache but
# is unprotected, clear the exclude cache.
# If the branch does not appear in the exclude_unprotected
# cache but is protected, clear the exclude cache.
# All branches should always appear in the include_unprotected
# cache, so we never clear it.
required_flags = self._fetchProjectBranchesRequiredFlags(
exclude_unprotected=True, exclude_locked=False)
branches = self._branch_cache.getProjectBranches(
project_name, True, default=None)
project_name, required_flags, default=None)
if not branches:
branches = []
branches = [b.name for b in branches]
update = False
if (event.branch in branches) and (not protected):
update = True

View File

@ -55,7 +55,7 @@ from zuul.driver.git.gitwatcher import GitWatcher
from zuul.lib import tracing
from zuul.lib.logutil import get_annotated_logger
from zuul.model import Ref, Tag, Branch, Project
from zuul.zk.branch_cache import BranchCache
from zuul.zk.branch_cache import BranchCache, BranchFlag, BranchInfo
from zuul.zk.change_cache import (
AbstractChangeCache,
ChangeKey,
@ -1127,12 +1127,21 @@ class GerritConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
not any(part.startswith('.') or part.endswith('.lock')
for part in parts))
def _fetchProjectBranches(self, project, exclude_unprotected):
def _fetchProjectBranchesRequiredFlags(
self, exclude_unprotected, exclude_locked):
return {BranchFlag.PRESENT}
def _filterProjectBranches(
self, branch_infos, exclude_unprotected, exclude_locked):
return branch_infos
def _fetchProjectBranches(self, project, required_flags):
refs = self.getInfoRefs(project)
heads = [str(k[len('refs/heads/'):]) for k in refs
if k.startswith('refs/heads/') and
GerritConnection._checkRefFormat(k)]
return heads
branch_infos = [BranchInfo(h, present=True) for h in heads]
return {BranchFlag.PRESENT}, branch_infos
def _fetchProjectDefaultBranch(self, project):
if not self.session:

View File

@ -27,7 +27,7 @@ from collections import OrderedDict, defaultdict
from collections.abc import Mapping
from itertools import chain
from json.decoder import JSONDecodeError
from typing import List, Optional
from typing import Optional
import cherrypy
import cachecontrol
@ -55,7 +55,7 @@ from zuul.model import Ref, Branch, Tag, Project
from zuul.exceptions import MergeFailure
from zuul.driver.github.githubmodel import PullRequest, GithubTriggerEvent
from zuul.model import DequeueEvent
from zuul.zk.branch_cache import BranchCache
from zuul.zk.branch_cache import (BranchCache, BranchFlag, BranchInfo)
from zuul.zk.change_cache import (
AbstractChangeCache,
ChangeKey,
@ -666,10 +666,13 @@ class GithubEventProcessor(object):
# Save all protected branches
cached_branches = self.connection._branch_cache.getProjectBranches(
project_name, True, default=None)
project_name, {BranchFlag.PROTECTED}, default=None)
if cached_branches is None:
raise RuntimeError(f"No branches for project {project_name}")
else:
cached_branches = [b.name for b in cached_branches
if b.protected is True]
old_protected_branches = set(cached_branches)
# Update the project banches
@ -681,9 +684,11 @@ class GithubEventProcessor(object):
self.connection.updateProjectBranches(project)
# Get all protected branches
new_protected_branches = set(
new_protected_branches =\
self.connection._branch_cache.getProjectBranches(
project_name, True))
project_name, {BranchFlag.PROTECTED})
new_protected_branches = set(
[b.name for b in new_protected_branches if b.protected is True])
newly_protected = new_protected_branches - old_protected_branches
newly_unprotected = old_protected_branches - new_protected_branches
@ -1825,15 +1830,57 @@ class GithubConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
def addProject(self, project):
self.projects[project.name] = project
def _fetchProjectBranches(self, project: Project,
exclude_unprotected: bool) -> List[str]:
def _fetchProjectBranchesRequiredFlags(
self, exclude_unprotected, exclude_locked):
required_flags = set()
if exclude_unprotected:
required_flags.add(BranchFlag.PROTECTED)
if exclude_locked:
required_flags.add(BranchFlag.LOCKED)
if not exclude_unprotected:
# We also need all the branches:
required_flags.add(BranchFlag.PRESENT)
if not required_flags:
required_flags = {BranchFlag.PRESENT}
return required_flags
def _filterProjectBranches(
self, branch_infos, exclude_unprotected, exclude_locked):
if exclude_unprotected:
branch_infos = [b for b in branch_infos if b.protected is True]
if exclude_locked:
branch_infos = [b for b in branch_infos if b.locked is not True]
return branch_infos
def _fetchProjectBranches(self, project, required_flags):
github = self.getGithubClient(project.name)
valid_flags = set()
branch_infos = {}
if BranchFlag.PROTECTED in required_flags:
valid_flags.add(BranchFlag.PROTECTED)
for branch_name in self._fetchProjectBranchesREST(
github, project, protected=True):
bi = branch_infos.setdefault(
branch_name, BranchInfo(branch_name))
bi.protected = True
if BranchFlag.PRESENT in required_flags:
valid_flags.add(BranchFlag.PRESENT)
for branch_name in self._fetchProjectBranchesREST(
github, project, protected=False):
bi = branch_infos.setdefault(
branch_name, BranchInfo(branch_name))
bi.present = True
return valid_flags, list(branch_infos.values())
def _fetchProjectBranchesREST(self, github, project, protected):
# Fetch the project branches from the rest api
url = github.session.build_url('repos', project.name,
'branches')
headers = {'Accept': 'application/vnd.github.loki-preview+json'}
params = {'per_page': 100}
if exclude_unprotected:
if protected:
params['protected'] = 1
branches = []

View File

@ -24,7 +24,6 @@ from zuul.model import Change, TriggerEvent, EventFilter, RefFilter
from zuul.model import FalseWithReason
from zuul.driver.util import time_to_seconds, to_list
EMPTY_GIT_REF = '0' * 40 # git sha of all zeros, used during creates/deletes

View File

@ -27,7 +27,6 @@ import urllib3
import dateutil.parser
from urllib.parse import quote_plus
from typing import List, Optional
from opentelemetry import trace
@ -40,9 +39,9 @@ from zuul.lib.http import ZuulHTTPAdapter
from zuul.lib.logutil import get_annotated_logger
from zuul.lib.config import any_to_bool
from zuul.exceptions import MergeFailure
from zuul.model import Branch, Project, Ref, Tag
from zuul.model import Branch, Ref, Tag
from zuul.driver.gitlab.gitlabmodel import GitlabTriggerEvent, MergeRequest
from zuul.zk.branch_cache import BranchCache
from zuul.zk.branch_cache import BranchCache, BranchFlag, BranchInfo
from zuul.zk.change_cache import (
AbstractChangeCache,
ConcurrentUpdateError,
@ -597,14 +596,42 @@ class GitlabConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
def addProject(self, project):
self.projects[project.name] = project
def _fetchProjectBranches(self, project: Project,
exclude_unprotected: bool) -> List[str]:
branches = self.gl_client.get_project_branches(project.name,
exclude_unprotected)
return branches
def _fetchProjectBranchesRequiredFlags(
self, exclude_unprotected, exclude_locked):
required_flags = set()
if exclude_unprotected:
required_flags.add(BranchFlag.PROTECTED)
if not required_flags:
required_flags = {BranchFlag.PRESENT}
return required_flags
def isBranchProtected(self, project_name: str, branch_name: str,
zuul_event_id=None) -> Optional[bool]:
def _filterProjectBranches(
self, branch_infos, exclude_unprotected, exclude_locked):
if exclude_unprotected:
branch_infos = [b for b in branch_infos if b.protected is True]
return branch_infos
def _fetchProjectBranches(self, project, required_flags):
valid_flags = set()
branch_infos = {}
if BranchFlag.PROTECTED in required_flags:
valid_flags.add(BranchFlag.PROTECTED)
for branch_name in self.gl_client.get_project_branches(
project.name, True):
bi = branch_infos.setdefault(
branch_name, BranchInfo(branch_name))
bi.protected = True
if BranchFlag.PRESENT in required_flags:
valid_flags.add(BranchFlag.PRESENT)
for branch_name in self.gl_client.get_project_branches(
project.name, False):
bi = branch_infos.setdefault(
branch_name, BranchInfo(branch_name))
bi.present = True
return valid_flags, list(branch_infos.values())
def isBranchProtected(self, project_name, branch_name,
zuul_event_id=None):
branch = self.gl_client.get_project_branch(project_name, branch_name,
zuul_event_id)
return branch.get('protected')

View File

@ -33,7 +33,7 @@ from zuul.web.handler import BaseWebController
from zuul.model import Ref, Branch, Tag
from zuul.lib import tracing
from zuul.lib import dependson
from zuul.zk.branch_cache import BranchCache
from zuul.zk.branch_cache import BranchCache, BranchFlag, BranchInfo
from zuul.zk.change_cache import (
AbstractChangeCache,
ConcurrentUpdateError,
@ -592,12 +592,22 @@ class PagureConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
url += '/c/%s' % sha
return url
def _fetchProjectBranches(self, project, exclude_unprotected):
def _fetchProjectBranchesRequiredFlags(
self, exclude_unprotected, exclude_locked):
return {BranchFlag.PRESENT}
def _filterProjectBranches(
self, branch_infos, exclude_unprotected, exclude_locked):
return branch_infos
def _fetchProjectBranches(self, project, required_flags):
pagure = self.get_project_api_client(project.name)
branches = pagure.get_project_branches()
self.log.info("Got branches for %s" % project.name)
return branches
branch_infos = [BranchInfo(name, present=True)
for name in branches]
return {BranchFlag.PRESENT}, branch_infos
def isBranchProtected(self, project_name, branch_name,
zuul_event_id=None):

View File

@ -14,4 +14,4 @@
# When making ZK schema changes, increment this and add a record to
# doc/source/developer/model-changelog.rst
MODEL_API = 26
MODEL_API = 27

View File

@ -1,6 +1,6 @@
# Copyright 2014 Rackspace Australia
# Copyright 2021 BMW Group
# Copyright 2021 Acme Gating, LLC
# Copyright 2021, 2024 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -15,12 +15,20 @@
# under the License.
import collections
from enum import Enum
import logging
import json
from functools import reduce
from operator import ior
from zuul.zk.zkobject import ZKContext, ShardedZKObject
from zuul.zk.locks import SessionAwareReadLock, SessionAwareWriteLock, locked
from zuul.zk.locks import (
SessionAwareReadLock,
SessionAwareWriteLock,
locked as zk_locked
)
from zuul import model
from zuul.zk.components import COMPONENT_REGISTRY
from kazoo.exceptions import NoNodeError
@ -28,32 +36,147 @@ from kazoo.exceptions import NoNodeError
RAISE_EXCEPTION = object()
# These flags should be the purview of the drivers, but we need to
# know about them in order to support backwards compatability to
# MODEL_API < 27. In the future, we should be able to make these
# driver-specific and have driver-specific subclasses of BranchInfo,
# etc.
class BranchFlag(Enum):
PRESENT = 0x1
PROTECTED = 0x2
LOCKED = 0x4
# A helper method for the branch cache below.
def return_default(default, project_name):
if default is RAISE_EXCEPTION:
raise LookupError(
f"No branches for project {project_name}")
return default
class BranchInfo:
def __init__(self, name, present=None, protected=None, locked=None):
self.name = name
# These are tri-state: None means indeterminate, true or false
# are definitive.
self.present = present
self.protected = protected
self.locked = locked
def update(self, other):
if other.present is not None:
self.present = other.present
if other.protected is not None:
self.protected = other.protected
if other.locked is not None:
self.locked = other.locked
def toDict(self):
# This doesn't really return a dict, but like other toDict
# methods, it returns the object that will be encoded into
# JSON. It just happens we don't need a full dict for this.
flags = 0
valid_flags = 0
for f in self.flags:
flags |= f.value
for f in self.valid_flags:
valid_flags |= f.value
return [flags, valid_flags]
@property
def flags(self):
flags = set()
if self.present:
flags.add(BranchFlag.PRESENT)
if self.protected:
flags.add(BranchFlag.PROTECTED)
if self.locked:
flags.add(BranchFlag.LOCKED)
return flags
@property
def valid_flags(self):
# If a flag is None, then we don't know it for this branch so
# we consider it invalid.
valid_flags = set()
if self.present is not None:
valid_flags.add(BranchFlag.PRESENT)
if self.protected is not None:
valid_flags.add(BranchFlag.PROTECTED)
if self.locked is not None:
valid_flags.add(BranchFlag.LOCKED)
return valid_flags
@classmethod
def fromDict(cls, name, data):
o = cls(name)
flags, valid_flags = data
if valid_flags & BranchFlag.PRESENT.value:
o.present = bool(flags & BranchFlag.PRESENT.value)
if valid_flags & BranchFlag.PROTECTED.value:
o.protected = bool(flags & BranchFlag.PROTECTED.value)
if valid_flags & BranchFlag.LOCKED.value:
o.locked = bool(flags & BranchFlag.LOCKED.value)
return o
class ProjectInfo:
"""Store branch cache project information in ZK
If a project is absent from the cache, it needs to be queried from
the source.
"""
def __init__(self, name, merge_modes=None, default_branch=None):
self.name = name
self.merge_modes = merge_modes
self.default_branch = default_branch
self.branches = {}
# The set of flags we have performed queries for:
self.completed_flags = set()
# If there was an error fetching the branches for a given set
# of flags, the failure will be recorded here:
self.failed_flags = set()
def toDict(self):
return {
'merge_modes': self.merge_modes,
'default_branch': self.default_branch,
'branches': {b.name: b.toDict() for b in self.branches.values()},
'flags': [
reduce(ior, [x.value for x in self.completed_flags], 0),
reduce(ior, [x.value for x in self.failed_flags], 0),
],
}
@classmethod
def fromDict(cls, name, data):
o = cls(name)
o.merge_modes = data['merge_modes']
o.default_branch = data['default_branch']
o.branches = {
name: BranchInfo.fromDict(name, bdata)
for name, bdata in data['branches'].items()
}
completed_flags = data['flags'][0]
failed_flags = data['flags'][1]
for flag in BranchFlag:
if flag.value & completed_flags:
o.completed_flags.add(flag)
if flag.value & failed_flags:
o.failed_flags.add(flag)
return o
class BranchCacheZKObject(ShardedZKObject):
"""Store the branch cache in ZK
There are two projects dictionaries, protected and remainder.
Each is project_name:str -> branches:list.
The protected dictionary contains only the protected branches.
The remainder dictionary contains any other branches.
If there has never been a query that included unprotected
branches, the projects key will not be present in the remaider
dictionary. If there has never been a query that excluded
unprotected branches, then the protected dictionary will not have
the project's key.
If a project is absent from the dict, it needs to be queried from
the source.
If there was an error fetching the branches, None will be stored
as a sentinel value.
When performing an exclude_unprotected query, remove any duplicate
branches from remaider to save space. When determining the full
list of branches, combine both lists.
"""
# We can always recreate data if necessary, so go ahead and
@ -65,31 +188,117 @@ class BranchCacheZKObject(ShardedZKObject):
def __init__(self):
super().__init__()
self._set(protected={},
remainder={},
merge_modes={},
default_branch={})
self._set(
projects={},
)
def serialize(self, context):
data = {
"protected": self.protected,
"remainder": self.remainder,
"merge_modes": self.merge_modes,
"default_branch": self.default_branch,
}
if COMPONENT_REGISTRY.model_api < 27:
data = self.serialize_old()
else:
data = self.serialize_new()
return json.dumps(data, sort_keys=True).encode("utf8")
def serialize_new(self):
return {
"projects": {p.name: p.toDict() for p in self.projects.values()},
}
def serialize_old(self):
protected = {}
remainder = {}
merge_modes = {}
default_branch = {}
for pi in self.projects.values():
merge_modes[pi.name] = pi.merge_modes
default_branch[pi.name] = pi.default_branch
if BranchFlag.PROTECTED in pi.completed_flags:
pl = protected[pi.name] = []
elif BranchFlag.PROTECTED in pi.failed_flags:
pl = protected[pi.name] = None
else:
pl = None
if BranchFlag.PRESENT in pi.completed_flags:
rl = remainder[pi.name] = []
elif BranchFlag.PRESENT in pi.failed_flags:
rl = remainder[pi.name] = None
else:
rl = None
for bi in pi.branches.values():
if bi.protected:
if pl is not None:
pl.append(bi.name)
elif rl is not None:
rl.append(bi.name)
elif rl is not None:
rl.append(bi.name)
return {
"protected": protected,
"remainder": remainder,
"merge_modes": merge_modes,
"default_branch": default_branch,
}
def deserialize(self, raw, context):
data = super().deserialize(raw, context)
# MODEL_API < 11
if "protected" in data:
# MODEL_API < 27
self.deserialize_old(data)
else:
self.deserialize_new(data)
return data
def deserialize_new(self, data):
projects = {}
for project_name, project_data in data['projects'].items():
projects[project_name] = ProjectInfo.fromDict(
project_name, project_data)
data['projects'] = projects
def deserialize_old(self, data):
if "merge_modes" not in data:
# MODEL_API < 11
data["merge_modes"] = collections.defaultdict(
lambda: model.ALL_MERGE_MODES)
# MODEL_API < 16
if "default_branch" not in data:
# MODEL_API < 16
data["default_branch"] = collections.defaultdict(
lambda: 'master')
return data
projects = {}
for project_name, branches in data['protected'].items():
project_info = ProjectInfo(
project_name,
data['merge_modes'].get(project_name, model.ALL_MERGE_MODES),
data['default_branch'].get(project_name, 'master'))
projects[project_name] = project_info
if branches is None:
project_info.failed_flags.add(BranchFlag.PROTECTED)
elif branches:
project_info.completed_flags.add(BranchFlag.PROTECTED)
for branch_name in branches:
project_info.branches[branch_name] = BranchInfo(
branch_name, protected=True)
for project_name, branches in data['remainder'].items():
project_info = projects.get(project_name)
if project_info is None:
project_info = ProjectInfo(
project_name,
data['merge_modes'].get(project_name,
model.ALL_MERGE_MODES),
data['default_branch'].get(project_name, 'master'))
projects[project_name] = project_info
if branches is None:
project_info.failed_flags.add(BranchFlag.PRESENT)
elif branches:
project_info.completed_flags.add(BranchFlag.PRESENT)
for branch_name in branches:
# Create a branchinfo object
project_info.branches[branch_name] = BranchInfo(
branch_name, present=True)
data.clear()
data['projects'] = projects
def _save(self, context, data, create=False):
super()._save(context, data, create)
@ -122,7 +331,7 @@ class BranchCache:
self.zk_context = ZKContext(zk_client, self.wlock, None, self.log)
with (self.zk_context as ctx,
locked(self.wlock)):
zk_locked(self.wlock)):
try:
self.cache = BranchCacheZKObject.fromZK(
ctx, data_path, _path=data_path)
@ -132,22 +341,33 @@ class BranchCache:
def clear(self, projects=None):
"""Clear the cache"""
with (locked(self.wlock),
with (zk_locked(self.wlock),
self.zk_context as ctx,
self.cache.activeContext(ctx)):
if projects is None:
self.cache.protected.clear()
self.cache.remainder.clear()
self.cache.merge_modes.clear()
self.cache.default_branch.clear()
self.cache.projects.clear()
else:
for p in projects:
self.cache.protected.pop(p, None)
self.cache.remainder.pop(p, None)
self.cache.merge_modes.pop(p, None)
self.cache.default_branch.pop(p, None)
self.cache.projects.pop(p, None)
def getProjectBranches(self, project_name, exclude_unprotected,
def _getRequiredFlags(self, exclude_unprotected, exclude_locked):
required_flags = set()
if exclude_unprotected:
required_flags.add(BranchFlag.PROTECTED)
if exclude_locked:
required_flags.add(BranchFlag.LOCKED)
if not required_flags:
required_flags = {BranchFlag.PRESENT}
return required_flags
def _getProjectCompletedFlags(self, project_name):
try:
project_info = self.cache.projects[project_name]
except KeyError:
return set()
return project_info.completed_flags
def getProjectBranches(self, project_name, required_flags,
min_ltime=-1, default=RAISE_EXCEPTION):
"""Get the branch names for the given project.
@ -168,8 +388,9 @@ class BranchCache:
:param str project_name:
The project for which the branches are returned.
:param bool exclude_unprotected:
Whether to return all or only protected branches.
:param bool required_flags:
The branch flags we must have completed queries for in order
for the cache to be considered valid.
:param int min_ltime:
The minimum cache ltime to consider the cache valid.
:param any default:
@ -179,37 +400,33 @@ class BranchCache:
an error when fetching the branches.
"""
if self.ltime < min_ltime:
with (locked(self.rlock),
with (zk_locked(self.rlock),
self.zk_context as ctx):
self.cache.refresh(ctx)
protected_branches = None
project_info = None
try:
protected_branches = self.cache.protected[project_name]
project_info = self.cache.projects[project_name]
except KeyError:
if exclude_unprotected:
if default is RAISE_EXCEPTION:
raise LookupError(
f"No branches for project {project_name}")
else:
return default
return return_default(default, project_name)
if not exclude_unprotected:
try:
remainder_branches = self.cache.remainder[project_name]
except KeyError:
if default is RAISE_EXCEPTION:
raise LookupError(
f"No branches for project {project_name}")
else:
return default
# We've definitely stored a failure, so return that.
if project_info is None:
return None
if remainder_branches is not None:
return (protected_branches or []) + remainder_branches
# Determine if we have enough info to answer the question
if not (required_flags.issubset(project_info.completed_flags)):
# We don't have the data, either because we haven't
# queried it or the query failed. Figure out which.
if (required_flags & project_info.failed_flags):
return None
return return_default(default, project_name)
return protected_branches
# We have the necessary info for this filtering.
return list(project_info.branches.values())
def setProjectBranches(self, project_name, exclude_unprotected, branches):
def setProjectBranches(self, project_name,
valid_flags, branch_infos):
"""Set the branch names for the given project.
Use None as a sentinel value for the branches to indicate that
@ -217,30 +434,53 @@ class BranchCache:
:param str project_name:
The project for the branches.
:param bool exclude_unprotected:
Whether this is a list of all or only protected branches.
:param set(int) queries:
The queries this list of branches is able to satisfy.
:param list[str] branches:
The list of branches or None to indicate a fetch error.
"""
with (locked(self.wlock),
with (zk_locked(self.wlock),
self.zk_context as ctx,
self.cache.activeContext(ctx)):
if exclude_unprotected:
self.cache.protected[project_name] = branches
remainder_branches = self.cache.remainder.get(project_name)
if remainder_branches and branches:
remainder = list(set(remainder_branches) -
set(branches))
self.cache.remainder[project_name] = remainder
else:
protected_branches = self.cache.protected.get(project_name)
if protected_branches and branches:
remainder = list(set(branches) -
set(protected_branches))
project_info = self.cache.projects.get(project_name)
if project_info is None:
project_info = ProjectInfo(project_name)
self.cache.projects[project_name] = project_info
if branch_infos is None:
# We're storing an error, set the bits accordingly
for flag in valid_flags:
project_info.failed_flags.add(flag)
project_info.completed_flags.discard(flag)
return
# Set the bits indicating a good query.
for flag in valid_flags:
project_info.failed_flags.discard(flag)
project_info.completed_flags.add(flag)
# Add or update branch info
for branch_info in branch_infos:
existing = project_info.branches.get(branch_info.name)
if existing:
existing.update(branch_info)
else:
remainder = branches
self.cache.remainder[project_name] = remainder
project_info.branches[branch_info.name] = branch_info
# Delete any existing branches which we would expect to be
# in the results but aren't. At the time of writing, this
# isn't strictly necessary beacuse we clear the branch
# cache on branch deletion, but this may enable us to
# change that in the future.
valid_branches = set([bi.name for bi in branch_infos])
for branch_name in list(project_info.branches.keys()):
if branch_name in valid_branches:
continue
branch_info = project_info.branches[branch_name]
if branch_info.valid_flags.issubset(valid_flags):
del project_info.branches[branch_name]
def setProtected(self, project_name, branch, protected):
"""Correct the protection state of a branch.
@ -249,34 +489,21 @@ class BranchCache:
receiving an explicit event.
"""
with (locked(self.wlock),
with (zk_locked(self.wlock),
self.zk_context as ctx,
self.cache.activeContext(ctx)):
protected_branches = self.cache.protected.get(project_name)
remainder_branches = self.cache.remainder.get(project_name)
if protected:
if protected_branches is None:
# We've never run a protected query, so we
# should ignore this branch.
return
else:
# We have run a protected query; if we have
# also run an unprotected query, we need to
# move the branch from remainder to protected.
if remainder_branches and branch in remainder_branches:
remainder_branches.remove(branch)
if branch not in protected_branches:
protected_branches.append(branch)
else:
if protected_branches and branch in protected_branches:
protected_branches.remove(branch)
if remainder_branches is None:
# We've never run an unprotected query, so we
# should ignore this branch.
return
else:
if branch not in remainder_branches:
remainder_branches.append(branch)
project_info = self.cache.projects.get(project_name)
if project_info is None:
project_info = ProjectInfo(project_name)
self.cache.projects[project_name] = project_info
branch_info = project_info.branches.get(branch)
if branch_info is None:
branch_info = BranchInfo(branch)
project_info.branches[branch] = branch_info
branch_info.protected = protected
def getProjectMergeModes(self, project_name,
min_ltime=-1, default=RAISE_EXCEPTION):
@ -308,20 +535,19 @@ class BranchCache:
an error when fetching the merge modes.
"""
if self.ltime < min_ltime:
with locked(self.rlock):
with zk_locked(self.rlock):
self.cache.refresh(self.zk_context)
merge_modes = None
project_info = None
try:
merge_modes = self.cache.merge_modes[project_name]
project_info = self.cache.projects[project_name]
except KeyError:
if default is RAISE_EXCEPTION:
raise LookupError(
f"No merge modes for project {project_name}")
else:
return default
return return_default(default, project_name)
return merge_modes
if project_info is None:
return None
return project_info.merge_modes
def setProjectMergeModes(self, project_name, merge_modes):
"""Set the supported merge modes for the given project.
@ -336,9 +562,12 @@ class BranchCache:
"""
with locked(self.wlock):
with zk_locked(self.wlock):
with self.cache.activeContext(self.zk_context):
self.cache.merge_modes[project_name] = merge_modes
project_info = self.cache.projects.get(project_name)
if project_info is None:
project_info = ProjectInfo(project_name)
project_info.merge_modes = merge_modes
def getProjectDefaultBranch(self, project_name,
min_ltime=-1, default=RAISE_EXCEPTION):
@ -371,20 +600,19 @@ class BranchCache:
"""
if self.ltime < min_ltime:
with locked(self.rlock):
with zk_locked(self.rlock):
self.cache.refresh(self.zk_context)
default_branch = None
project_info = None
try:
default_branch = self.cache.default_branch[project_name]
project_info = self.cache.projects[project_name]
except KeyError:
if default is RAISE_EXCEPTION:
raise LookupError(
f"No default branch for project {project_name}")
else:
return default
return return_default(default, project_name)
return default_branch
if project_info is None:
return None
return project_info.default_branch
def setProjectDefaultBranch(self, project_name, default_branch):
"""Set the upstream default branch for the given project.
@ -399,9 +627,12 @@ class BranchCache:
"""
with locked(self.wlock):
with zk_locked(self.wlock):
with self.cache.activeContext(self.zk_context):
self.cache.default_branch[project_name] = default_branch
project_info = self.cache.projects.get(project_name)
if project_info is None:
project_info = ProjectInfo(project_name)
project_info.default_branch = default_branch
@property
def ltime(self):