Add a github graphql query for branch protection

The only way to find out if a branch is locked is by using a graphql
query for branch protection rules.  This change adds a such a query.
Because neither our graphql client nor our fake graphql server fully
implemented pagination, and we expect to routinely have more than 100
branches, this change also handles pagination.

And since this in graphql, it handles nested pagination.

The word "Fake" is removed from the fake graphql server classes because
graphene introspects them to determine their object type names.  Using
their real names makes the object types match which is now important
since one of our new queries specifies an object type.

The github connection class is updated to use this query for protected
branches.  Locking support will come in a future change.

Change-Id: I2633332396b79280984f0ebfa64a955d24fb7bae
This commit is contained in:
James E. Blair 2024-02-29 14:25:25 -08:00
parent 2298392446
commit 69b09211f6
7 changed files with 260 additions and 68 deletions

View File

@ -12,42 +12,58 @@
# License for the specific language governing permissions and limitations
# under the License.
from graphene import Boolean, Field, Int, List, ObjectType, String
from graphene import Boolean, Field, Int, List, ObjectType, String, Schema
class FakePageInfo(ObjectType):
class ID(String):
"""Github global object ids are strings"""
pass
class PageInfo(ObjectType):
end_cursor = String()
has_previous_page = Boolean()
has_next_page = Boolean()
def resolve_end_cursor(parent, info):
return 'testcursor'
return str(parent['after'] + parent['first'])
def resolve_has_previous_page(parent, info):
return parent['after'] > 0
def resolve_has_next_page(parent, info):
return False
return parent['after'] + parent['first'] < parent['length']
class FakeMatchingRef(ObjectType):
class MatchingRef(ObjectType):
name = String()
def resolve_name(parent, info):
return parent
class FakeMatchingRefs(ObjectType):
nodes = List(FakeMatchingRef)
class MatchingRefs(ObjectType):
pageInfo = Field(PageInfo)
nodes = List(MatchingRef)
def resolve_nodes(parent, info):
# To simplify tests just return the pattern and a bogus ref that should
# not disturb zuul.
return [parent.pattern, 'bogus-ref']
return parent['nodes']
def resolve_pageInfo(parent, info):
return parent
class FakeBranchProtectionRule(ObjectType):
class BranchProtectionRule(ObjectType):
id = ID()
pattern = String()
requiredStatusCheckContexts = List(String)
requiresApprovingReviews = Boolean()
requiresCodeOwnerReviews = Boolean()
matchingRefs = Field(FakeMatchingRefs, first=Int())
matchingRefs = Field(MatchingRefs, first=Int(), after=String())
lockBranch = Boolean()
def resolve_id(parent, info):
return parent.id
def resolve_pattern(parent, info):
return parent.pattern
@ -61,25 +77,41 @@ class FakeBranchProtectionRule(ObjectType):
def resolve_requiresCodeOwnerReviews(parent, info):
return parent.require_codeowners_review
def resolve_matchingRefs(parent, info, first=None):
def resolve_lockBranch(parent, info):
return parent.lock_branch
def resolve_matchingRefs(parent, info, first, after=None):
if after is None:
after = '0'
after = int(after)
values = parent.matching_refs
return dict(
length=len(values),
nodes=values[after:after + first],
first=first,
after=after,
)
class BranchProtectionRules(ObjectType):
pageInfo = Field(PageInfo)
nodes = List(BranchProtectionRule)
def resolve_nodes(parent, info):
return parent['nodes']
def resolve_pageInfo(parent, info):
return parent
class FakeBranchProtectionRules(ObjectType):
nodes = List(FakeBranchProtectionRule)
def resolve_nodes(parent, info):
return parent.values()
class FakeActor(ObjectType):
class Actor(ObjectType):
login = String()
class FakeStatusContext(ObjectType):
class StatusContext(ObjectType):
state = String()
context = String()
creator = Field(FakeActor)
creator = Field(Actor)
def resolve_state(parent, info):
state = parent.state.upper()
@ -92,14 +124,14 @@ class FakeStatusContext(ObjectType):
return parent.creator
class FakeStatus(ObjectType):
contexts = List(FakeStatusContext)
class Status(ObjectType):
contexts = List(StatusContext)
def resolve_contexts(parent, info):
return parent
class FakeCheckRun(ObjectType):
class CheckRun(ObjectType):
name = String()
conclusion = String()
@ -112,21 +144,21 @@ class FakeCheckRun(ObjectType):
return None
class FakeCheckRuns(ObjectType):
nodes = List(FakeCheckRun)
class CheckRuns(ObjectType):
nodes = List(CheckRun)
def resolve_nodes(parent, info):
return parent
class FakeApp(ObjectType):
class App(ObjectType):
slug = String()
name = String()
class FakeCheckSuite(ObjectType):
app = Field(FakeApp)
checkRuns = Field(FakeCheckRuns, first=Int())
class CheckSuite(ObjectType):
app = Field(App)
checkRuns = Field(CheckRuns, first=Int())
def resolve_app(parent, info):
if not parent:
@ -143,9 +175,9 @@ class FakeCheckSuite(ObjectType):
return check_runs_by_name.values()
class FakeCheckSuites(ObjectType):
class CheckSuites(ObjectType):
nodes = List(FakeCheckSuite)
nodes = List(CheckSuite)
def resolve_nodes(parent, info):
# Note: we only use a single check suite in the tests so return a
@ -153,15 +185,15 @@ class FakeCheckSuites(ObjectType):
return [parent]
class FakeCommit(ObjectType):
class Commit(ObjectType):
class Meta:
# Graphql object type that defaults to the class name, but we require
# 'Commit'.
name = 'Commit'
status = Field(FakeStatus)
checkSuites = Field(FakeCheckSuites, first=Int())
status = Field(Status)
checkSuites = Field(CheckSuites, first=Int())
def resolve_status(parent, info):
seen = set()
@ -178,7 +210,7 @@ class FakeCommit(ObjectType):
return parent._check_runs
class FakePullRequest(ObjectType):
class PullRequest(ObjectType):
isDraft = Boolean()
reviewDecision = String()
mergeable = String()
@ -210,18 +242,28 @@ class FakePullRequest(ObjectType):
return 'REVIEW_REQUIRED'
class FakeRepository(ObjectType):
class Repository(ObjectType):
name = String()
branchProtectionRules = Field(FakeBranchProtectionRules, first=Int())
pullRequest = Field(FakePullRequest, number=Int(required=True))
object = Field(FakeCommit, expression=String(required=True))
branchProtectionRules = Field(BranchProtectionRules,
first=Int(), after=String())
pullRequest = Field(PullRequest, number=Int(required=True))
object = Field(Commit, expression=String(required=True))
def resolve_name(parent, info):
org, name = parent.name.split('/')
return name
def resolve_branchProtectionRules(parent, info, first):
return parent._branch_protection_rules
def resolve_branchProtectionRules(parent, info, first, after=None):
if after is None:
after = '0'
after = int(after)
values = list(parent._branch_protection_rules.values())
return dict(
length=len(values),
nodes=values[after:after + first],
first=first,
after=after,
)
def resolve_pullRequest(parent, info, number):
return parent.data.pull_requests.get(number)
@ -231,8 +273,19 @@ class FakeRepository(ObjectType):
class FakeGithubQuery(ObjectType):
repository = Field(FakeRepository, owner=String(required=True),
repository = Field(Repository, owner=String(required=True),
name=String(required=True))
node = Field(BranchProtectionRule, id=ID(required=True))
def resolve_repository(root, info, owner, name):
return info.context._data.repos.get((owner, name))
def resolve_node(root, info, id):
for repo in info.context._data.repos.values():
for rule in repo._branch_protection_rules.values():
if rule.id == id:
return rule
def getGrapheneSchema():
return Schema(query=FakeGithubQuery, types=[ID])

View File

@ -27,14 +27,13 @@ import uuid
import string
import random
from tests.fake_graphql import FakeGithubQuery
from tests.fake_graphql import getGrapheneSchema
import zuul.driver.github.githubconnection as githubconnection
from zuul.driver.github.githubconnection import utc, GithubClientManager
from tests.util import random_sha1
import git
import github3.exceptions
import graphene
import requests
from requests.structures import CaseInsensitiveDict
import requests_mock
@ -684,16 +683,21 @@ class FakeRepository(object):
return self._branches
def _set_branch_protection(self, branch_name, protected=True,
contexts=None, require_review=False):
contexts=None, require_review=False,
locked=False):
if not protected:
if branch_name in self._branch_protection_rules:
del self._branch_protection_rules[branch_name]
return
rule = self._branch_protection_rules[branch_name]
rule.id = str(uuid.uuid4())
rule.pattern = branch_name
rule.required_contexts = contexts or []
rule.require_reviews = require_review
rule.matching_refs = [branch_name]
rule.lock_branch = locked
return rule
def _set_permission(self, key, value):
# NOTE (felix): Currently, this is only used to mock a repo with
@ -1029,7 +1033,7 @@ class FakeGithubSession(object):
self.client = client
self.headers = CaseInsensitiveDict()
self._base_url = None
self.schema = graphene.Schema(query=FakeGithubQuery)
self.schema = getGrapheneSchema()
# Imitate hooks dict. This will be unused and ignored in the tests.
self.hooks = {

View File

@ -2829,3 +2829,30 @@ class TestGithubSchemaWarnings(ZuulTestCase):
self.assertIn(
"Use 'rerequested' instead",
str(tenant.layout.loading_errors[8].error))
class TestGithubGraphQL(ZuulTestCase):
config_file = 'zuul-github-driver.conf'
scheduler_count = 1
@simple_layout('layouts/basic-github.yaml', driver='github')
def test_graphql_query_branch_protection(self):
project = self.fake_github.getProject('org/project')
github = self.fake_github.getGithubClient(project.name)
repo = github.repo_from_project('org/project')
# Ensure that both parts of the query hit pagination by having
# more than 100 results for each.
num_rules = 110
num_extra_branches = 104
for branch_no in range(num_rules):
rule = repo._set_branch_protection(f'branch{branch_no}',
protected=True,
locked=True)
# Add more fake matching refs to the rule here
for suffix in range(num_extra_branches):
rule.matching_refs.append(f'branch{branch_no}_{suffix}')
branches = list(
self.fake_github.graphql_client.fetch_branch_protection(
github, project))
self.assertEqual(num_rules * (1 + num_extra_branches), len(branches))

View File

@ -1859,8 +1859,9 @@ class GithubConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
branch_infos = {}
if BranchFlag.PROTECTED in required_flags:
valid_flags.add(BranchFlag.PROTECTED)
for branch_name in self._fetchProjectBranchesREST(
github, project, protected=True):
for branch_name, locked in \
self.graphql_client.fetch_branch_protection(
github, project).items():
bi = branch_infos.setdefault(
branch_name, BranchInfo(branch_name))
bi.protected = True
@ -1883,7 +1884,7 @@ class GithubConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
if protected:
params['protected'] = 1
branches = []
branches = set()
while url:
resp = github.session.get(
url, headers=headers, params=params)
@ -1902,8 +1903,8 @@ class GithubConnection(ZKChangeCacheMixin, ZKBranchCacheMixin, BaseConnection):
raise Exception("Got status code 404 when listing branches "
"of project %s" % project.name)
branches.extend([x['name'] for x in resp.json()])
for x in resp.json():
branches.add(x['name'])
return branches
def _fetchProjectMergeModes(self, project):

View File

@ -1,4 +1,5 @@
# Copyright 2020 BMW Group
# Copyright 2024 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@ -38,6 +39,8 @@ class GraphQLClient:
query_names = [
'canmerge',
'canmerge-legacy',
'branch-protection',
'branch-protection-inner',
]
for query_name in query_names:
f = importlib.resources.files('zuul').joinpath(
@ -52,31 +55,37 @@ class GraphQLClient:
}
return data
def _fetch_canmerge(self, github, owner, repo, pull, sha):
variables = {
'zuul_query': 'canmerge', # used for logging
'owner': owner,
'repo': repo,
'pull': pull,
'head_sha': sha,
}
def _run_query(self, log, github, query_name, **args):
args['zuul_query'] = query_name # used for logging
query = self.queries[query_name]
query = self._prepare_query(query, args)
response = github.session.post(self.url, json=query)
response = response.json()
if 'data' not in response:
log.error("Error running query %s: %s",
query_name, response)
return response
def _fetch_canmerge(self, log, github, owner, repo, pull, sha):
if github.version and github.version[:2] < (2, 21):
# Github Enterprise prior to 2.21 doesn't offer the review decision
# so don't request it as this will result in an error.
query = self.queries['canmerge-legacy']
query = 'canmerge-legacy'
else:
# Since GitHub Enterprise 2.21 and on github.com we can request the
# review decision state of the pull request.
query = self.queries['canmerge']
query = self._prepare_query(query, variables)
response = github.session.post(self.url, json=query)
return response.json()
query = 'canmerge'
return self._run_query(log, github, query,
owner=owner,
repo=repo,
pull=pull,
head_sha=sha)
def fetch_canmerge(self, github, change, zuul_event_id=None):
log = get_annotated_logger(self.log, zuul_event_id)
owner, repo = change.project.name.split('/')
data = self._fetch_canmerge(github, owner, repo, change.number,
data = self._fetch_canmerge(log, github, owner, repo, change.number,
change.patchset)
result = {}
@ -146,3 +155,53 @@ class GraphQLClient:
}
return result
def _fetch_branch_protection(self, log, github, project,
zuul_event_id=None):
owner, repo = project.name.split('/')
branches = {}
branch_subqueries = []
cursor = None
while True:
data = self._run_query(
log, github, 'branch-protection',
owner=owner,
repo=repo,
cursor=cursor)['data']
for rule in data['repository']['branchProtectionRules']['nodes']:
for branch in rule['matchingRefs']['nodes']:
branches[branch['name']] = rule['lockBranch']
refs_pageinfo = rule['matchingRefs']['pageInfo']
if refs_pageinfo['hasNextPage']:
branch_subqueries.append(dict(
rule_node_id=rule['id'],
cursor=refs_pageinfo['endCursor']))
rules_pageinfo = data['repository']['branchProtectionRules'
]['pageInfo']
if not rules_pageinfo['hasNextPage']:
break
cursor = rules_pageinfo['endCursor']
for subquery in branch_subqueries:
cursor = subquery['cursor']
while True:
data = self._run_query(
log, github, 'branch-protection-inner',
rule_node_id=subquery['rule_node_id'],
cursor=cursor)['data']
for branch in data['node']['matchingRefs']['nodes']:
branches[branch['name']] = rule['lockBranch']
refs_pageinfo = data['node']['matchingRefs']['pageInfo']
if not refs_pageinfo['hasNextPage']:
break
cursor = refs_pageinfo['endCursor']
return branches
def fetch_branch_protection(self, github, project, zuul_event_id=None):
"""Return a dictionary of branches and whether they are locked"""
log = get_annotated_logger(self.log, zuul_event_id)
return self._fetch_branch_protection(log, github, project,
zuul_event_id)

View File

@ -0,0 +1,19 @@
query ruleProtectedBranches(
$rule_node_id: ID!
$cursor: String
) {
node(id: $rule_node_id) {
... on BranchProtectionRule {
matchingRefs(first: 100, after: $cursor) {
pageInfo {
endCursor
hasNextPage
hasPreviousPage
}
nodes {
name
}
}
}
}
}

View File

@ -0,0 +1,29 @@
query repoProtectedBranches(
$owner: String!
$repo: String!
$cursor: String
) {
repository(owner: $owner, name: $repo) {
branchProtectionRules(first: 100, after: $cursor) {
pageInfo {
endCursor
hasNextPage
hasPreviousPage
}
nodes {
id
lockBranch
matchingRefs(first: 100) {
pageInfo {
endCursor
hasNextPage
hasPreviousPage
}
nodes {
name
}
}
}
}
}
}