Merge "Refactor github auth handling into its own class"

This commit is contained in:
Zuul 2020-07-16 19:02:32 +00:00 committed by Gerrit Code Review
commit 0234d6b015
4 changed files with 236 additions and 187 deletions

View File

@ -63,6 +63,8 @@ import testtools.content_type
from git.exc import NoSuchPathError
import yaml
import paramiko
from zuul.driver.github.githubconnection import GithubClientManager
from zuul.lib.connections import ConnectionRegistry
from psutil import Popen
@ -2144,26 +2146,13 @@ class FakeGithubPullRequest(object):
repo.heads[self.branch].commit = repo.commit(self.head_sha)
class FakeGithubConnection(githubconnection.GithubConnection):
log = logging.getLogger("zuul.test.FakeGithubConnection")
class FakeGithubClientManager(GithubClientManager):
def __init__(self, driver, connection_name, connection_config, rpcclient,
changes_db=None, upstream_root=None, git_url_with_auth=False):
super(FakeGithubConnection, self).__init__(driver, connection_name,
connection_config)
self.connection_name = connection_name
self.pr_number = 0
self.pull_requests = changes_db
self.statuses = {}
self.upstream_root = upstream_root
self.merge_failure = False
self.merge_not_allowed_count = 0
self.reports = []
self.github_data = tests.fakegithub.FakeGithubData(changes_db)
self.recorded_clients = []
self.git_url_with_auth = git_url_with_auth
self.rpcclient = rpcclient
def __init__(self, connection_config):
super().__init__(connection_config)
self.record_clients = False
self.recorded_clients = []
self.github_data = None
def getGithubClient(self,
project=None,
@ -2196,6 +2185,30 @@ class FakeGithubConnection(githubconnection.GithubConnection):
orgs[repo[0]] = inst_id
self.installation_map['/'.join(repo)] = inst_id
class FakeGithubConnection(githubconnection.GithubConnection):
log = logging.getLogger("zuul.test.FakeGithubConnection")
client_manager_class = FakeGithubClientManager
def __init__(self, driver, connection_name, connection_config, rpcclient,
changes_db=None, upstream_root=None, git_url_with_auth=False):
super(FakeGithubConnection, self).__init__(driver, connection_name,
connection_config)
self.connection_name = connection_name
self.pr_number = 0
self.pull_requests = changes_db
self.statuses = {}
self.upstream_root = upstream_root
self.merge_failure = False
self.merge_not_allowed_count = 0
self.reports = []
self.github_data = tests.fakegithub.FakeGithubData(changes_db)
self._github_client_manager.github_data = self.github_data
self.git_url_with_auth = git_url_with_auth
self.rpcclient = rpcclient
def setZuulWebPort(self, port):
self.zuul_web_port = port

View File

@ -173,8 +173,9 @@ class TestZuulTriggerParentChangeEnqueuedGithub(ZuulGithubAppTestCase):
# After starting recording installation containing org2/project
# should not be contacted
inst_id_to_check = self.fake_github.installation_map['org2/project']
inst_clients = [x for x in self.fake_github.recorded_clients
gh_manager = self.fake_github._github_client_manager
inst_id_to_check = gh_manager.installation_map['org2/project']
inst_clients = [x for x in gh_manager.recorded_clients
if x._inst_id == inst_id_to_check]
self.assertEqual(len(inst_clients), 0)

View File

@ -763,38 +763,12 @@ class GithubUser(collections.Mapping):
}
class GithubConnection(BaseConnection):
driver_name = 'github'
log = logging.getLogger("zuul.GithubConnection")
payload_path = 'payload'
class GithubClientManager:
log = logging.getLogger('zuul.GithubConnection.GithubClientManager')
def __init__(self, driver, connection_name, connection_config):
super(GithubConnection, self).__init__(
driver, connection_name, connection_config)
self._change_cache = {}
self._change_update_lock = {}
self._installation_map_lock = threading.Lock()
self._project_branch_cache_include_unprotected = {}
self._project_branch_cache_exclude_unprotected = {}
self.projects = {}
self.git_ssh_key = self.connection_config.get('sshkey')
def __init__(self, connection_config):
self.connection_config = connection_config
self.server = self.connection_config.get('server', 'github.com')
self.canonical_hostname = self.connection_config.get(
'canonical_hostname', self.server)
self.source = driver.getSource(self)
self.event_queue = queue.Queue()
self._sha_pr_cache = GithubShaCache()
self._request_locks = {}
self.max_threads_per_installation = int(self.connection_config.get(
'max_threads_per_installation', 1))
# Logging of rate limit is optional as this does additional requests
rate_limit_logging = self.connection_config.get(
'rate_limit_logging', 'true')
self._log_rate_limit = True
if rate_limit_logging.lower() == 'false':
self._log_rate_limit = False
if self.server == 'github.com':
self.api_base_url = GITHUB_BASE_URL
@ -809,13 +783,6 @@ class GithubConnection(BaseConnection):
if verify_ssl.lower() == 'false':
self.verify_ssl = False
self.app_id = None
self.app_key = None
self.sched = None
self.installation_map = {}
self.installation_token_cache = {}
# NOTE(jamielennox): Better here would be to cache to memcache or file
# or something external - but zuul already sucks at restarting so in
# memory probably doesn't make this much worse.
@ -840,49 +807,61 @@ class GithubConnection(BaseConnection):
cache_etags=True,
heuristic=NoAgeHeuristic())
# The regex is based on the connection host. We do not yet support
# cross-connection dependency gathering
self.depends_on_re = re.compile(
r"^Depends-On: https://%s/.+/.+/pull/[0-9]+$" % self.server,
re.MULTILINE | re.IGNORECASE)
# Logging of rate limit is optional as this does additional requests
rate_limit_logging = self.connection_config.get(
'rate_limit_logging', 'true')
self._log_rate_limit = True
if rate_limit_logging.lower() == 'false':
self._log_rate_limit = False
self.graphql_client = GraphQLClient('%s/graphql' % self.api_base_url)
self.app_id = None
self.app_key = None
self._initialized = False
def toDict(self):
d = super().toDict()
d.update({
"baseurl": self.base_url,
"canonical_hostname": self.canonical_hostname,
"server": self.server,
})
return d
self._installation_map_lock = threading.Lock()
self.installation_map = {}
self.installation_token_cache = {}
def onLoad(self):
self.log.info('Starting GitHub connection: %s' % self.connection_name)
self.gearman_worker = GithubGearmanWorker(self)
def initialize(self):
self.log.info('Authing to GitHub')
self._authenticateGithubAPI()
self._prime_installation_map()
self.log.info('Starting event connector')
self._start_event_connector()
self.log.info('Starting GearmanWorker')
self.gearman_worker.start()
self._initialized = True
def onStop(self):
# TODO(jeblair): remove this check which is here only so that
# zuul-web can call connections.stop to shut down the sql
# connection.
if hasattr(self, 'gearman_worker'):
self.gearman_worker.stop()
self._stop_event_connector()
@property
def initialized(self):
return self._initialized
def _start_event_connector(self):
self.github_event_connector = GithubEventConnector(self)
self.github_event_connector.start()
@property
def usesAppAuthentication(self):
return True if self.app_id else False
def _stop_event_connector(self):
if self.github_event_connector:
self.github_event_connector.stop()
def _authenticateGithubAPI(self):
config = self.connection_config
app_id = config.get('app_id')
app_key = None
app_key_file = config.get('app_key')
if app_key_file:
try:
with open(app_key_file, 'r') as f:
app_key = f.read()
except IOError:
m = "Failed to open app key file for reading: %s"
self.log.error(m, app_key_file)
if (app_id or app_key) and \
not (app_id and app_key):
self.log.warning("You must provide an app_id and "
"app_key to use installation based "
"authentication")
return
if app_id:
self.app_id = int(app_id)
if app_key:
self.app_key = app_key
def _createGithubClient(self, zuul_event_id=None):
session = github3.session.GitHubSession(default_read_timeout=300)
@ -921,43 +900,6 @@ class GithubConnection(BaseConnection):
github._zuul_user_id = None
return github
def _authenticateGithubAPI(self):
config = self.connection_config
app_id = config.get('app_id')
app_key = None
app_key_file = config.get('app_key')
if app_key_file:
try:
with open(app_key_file, 'r') as f:
app_key = f.read()
except IOError:
m = "Failed to open app key file for reading: %s"
self.log.error(m, app_key_file)
if (app_id or app_key) and \
not (app_id and app_key):
self.log.warning("You must provide an app_id and "
"app_key to use installation based "
"authentication")
return
if app_id:
self.app_id = int(app_id)
if app_key:
self.app_key = app_key
@staticmethod
def _append_accept_header(github, value):
old_header = github.session.headers.get('Accept', None)
if old_header:
new_value = '%s,%s' % (old_header, value)
else:
new_value = value
github.session.headers['Accept'] = new_value
def _get_app_auth_headers(self):
now = datetime.datetime.now(utc)
expiry = now + datetime.timedelta(minutes=5)
@ -972,8 +914,8 @@ class GithubConnection(BaseConnection):
return headers
def _get_installation_key(self, project, inst_id=None,
reprime=False):
def get_installation_key(self, project, inst_id=None,
reprime=False):
installation_id = inst_id
if project is not None:
installation_id = self.installation_map.get(project)
@ -982,9 +924,9 @@ class GithubConnection(BaseConnection):
if reprime:
# prime installation map and try again without refreshing
self._prime_installation_map()
return self._get_installation_key(project,
inst_id=inst_id,
reprime=False)
return self.get_installation_key(project,
inst_id=inst_id,
reprime=False)
self.log.error("No installation ID available for project %s",
project)
@ -1061,7 +1003,7 @@ class GithubConnection(BaseConnection):
for install in installations:
inst_id = install.get('id')
token_by_inst[inst_id] = executor.submit(
self._get_installation_key, project=None,
self.get_installation_key, project=None,
inst_id=inst_id)
for inst_id, result in token_by_inst.items():
@ -1091,22 +1033,20 @@ class GithubConnection(BaseConnection):
with self._installation_map_lock:
self.log.debug('Finished waiting for fetching installations')
def get_request_lock(self, installation_id):
return self._request_locks.setdefault(
installation_id, threading.Semaphore(
value=self.max_threads_per_installation))
def getGithubClientsForProjects(self, projects):
# Get a list of projects with unique installation ids
installation_ids = set()
installation_projects = set()
def addEvent(self, data, event=None, delivery=None):
return self.event_queue.put((time.time(), data, event, delivery))
for project in projects:
installation_id = self.installation_map.get(project.name)
if installation_id not in installation_ids:
installation_ids.add(installation_id)
installation_projects.add(project.name)
def getEvent(self):
return self.event_queue.get()
def getEventQueueSize(self):
return self.event_queue.qsize()
def eventDone(self):
self.event_queue.task_done()
clients = [self.getGithubClient(project_name)
for project_name in installation_projects]
return clients
def getGithubClient(self,
project=None,
@ -1118,7 +1058,7 @@ class GithubConnection(BaseConnection):
if project and self.app_id:
# Call get_installation_key to ensure the token gets refresehd in
# case it's expired.
token = self._get_installation_key(project)
token = self.get_installation_key(project)
# Only set the auth header if we have a token. If not, just don't
# set any auth header so we will be treated as anonymous. That's
@ -1157,6 +1097,113 @@ class GithubConnection(BaseConnection):
return github
class GithubConnection(BaseConnection):
driver_name = 'github'
log = logging.getLogger("zuul.GithubConnection")
payload_path = 'payload'
client_manager_class = GithubClientManager
def __init__(self, driver, connection_name, connection_config):
super(GithubConnection, self).__init__(
driver, connection_name, connection_config)
self._change_cache = {}
self._change_update_lock = {}
self._project_branch_cache_include_unprotected = {}
self._project_branch_cache_exclude_unprotected = {}
self.projects = {}
self.git_ssh_key = self.connection_config.get('sshkey')
self.server = self.connection_config.get('server', 'github.com')
self.canonical_hostname = self.connection_config.get(
'canonical_hostname', self.server)
self.source = driver.getSource(self)
self.event_queue = queue.Queue()
self._sha_pr_cache = GithubShaCache()
self._request_locks = {}
self.max_threads_per_installation = int(self.connection_config.get(
'max_threads_per_installation', 1))
self._github_client_manager = self.client_manager_class(
self.connection_config)
self.sched = None
# The regex is based on the connection host. We do not yet support
# cross-connection dependency gathering
self.depends_on_re = re.compile(
r"^Depends-On: https://%s/.+/.+/pull/[0-9]+$" % self.server,
re.MULTILINE | re.IGNORECASE)
self.graphql_client = GraphQLClient(
'%s/graphql' % self._github_client_manager.api_base_url)
def toDict(self):
d = super().toDict()
d.update({
"baseurl": self._github_client_manager.base_url,
"canonical_hostname": self.canonical_hostname,
"server": self.server,
})
return d
def onLoad(self):
self.log.info('Starting GitHub connection: %s' % self.connection_name)
self.gearman_worker = GithubGearmanWorker(self)
self._github_client_manager.initialize()
self.log.info('Starting event connector')
self._start_event_connector()
self.log.info('Starting GearmanWorker')
self.gearman_worker.start()
def onStop(self):
# TODO(jeblair): remove this check which is here only so that
# zuul-web can call connections.stop to shut down the sql
# connection.
if hasattr(self, 'gearman_worker'):
self.gearman_worker.stop()
self._stop_event_connector()
def _start_event_connector(self):
self.github_event_connector = GithubEventConnector(self)
self.github_event_connector.start()
def _stop_event_connector(self):
if self.github_event_connector:
self.github_event_connector.stop()
@staticmethod
def _append_accept_header(github, value):
old_header = github.session.headers.get('Accept', None)
if old_header:
new_value = '%s,%s' % (old_header, value)
else:
new_value = value
github.session.headers['Accept'] = new_value
def get_request_lock(self, installation_id):
return self._request_locks.setdefault(
installation_id, threading.Semaphore(
value=self.max_threads_per_installation))
def addEvent(self, data, event=None, delivery=None):
return self.event_queue.put((time.time(), data, event, delivery))
def getEvent(self):
return self.event_queue.get()
def getEventQueueSize(self):
return self.event_queue.qsize()
def eventDone(self):
self.event_queue.task_done()
def getGithubClient(self,
project=None,
zuul_event_id=None):
return self._github_client_manager.getGithubClient(
project=project, zuul_event_id=zuul_event_id)
def maintainCache(self, relevant):
remove = set()
for key, change in self._change_cache.items():
@ -1249,37 +1296,17 @@ class GithubConnection(BaseConnection):
if not change.uris:
return changes
# Get a list of projects with unique installation ids
installation_ids = set()
installation_projects = set()
if projects:
# We only need to find changes in projects in the supplied
# ChangeQueue. Find all of the github installations for
# all of those projects, and search using each of them, so
# that if we get the right results based on the
# permissions granted to each of the installations. The
# common case for this is likely to be just one
# installation -- change queues aren't likely to span more
# than one installation.
for project in projects:
installation_id = self.installation_map.get(project.name)
if installation_id not in installation_ids:
installation_ids.add(installation_id)
installation_projects.add(project.name)
else:
if not projects:
# We aren't in the context of a change queue and we just
# need to query all installations of this tenant. This currently
# only happens if certain features of the zuul trigger are
# used; generally it should be avoided.
for project_name, installation_id in self.installation_map.items():
trusted, project = tenant.getProject(project_name)
# ignore projects from different tenants
if not project:
continue
if installation_id not in installation_ids:
installation_ids.add(installation_id)
installation_projects.add(project_name)
projects = [p for p in tenant.all_projects
if p.connection_name == self.connection_name]
# Otherwise we use the input projects list and look for changes in the
# supplied projects.
clients = self._github_client_manager.getGithubClientsForProjects(
projects)
keys = set()
# TODO: Max of 5 OR operators can be used per query and
@ -1288,9 +1315,8 @@ class GithubConnection(BaseConnection):
# tests/fakegithub.py
pattern = ' OR '.join(['"Depends-On: %s"' % x for x in change.uris])
query = '%s type:pr is:open in:body' % pattern
# Repeat the search for each installation id (project)
for installation_project in installation_projects:
github = self.getGithubClient(installation_project)
# Repeat the search for each client (project)
for github in clients:
for issue in github.search_issues(query=query):
pr = issue.issue.pull_request().as_dict()
if not pr.get('url'):
@ -1405,17 +1431,17 @@ class GithubConnection(BaseConnection):
# if app_id is configured but self.app_id is empty we are not
# authenticated yet against github as app
if not self.app_id and self.connection_config.get('app_id', None):
self._authenticateGithubAPI()
self._prime_installation_map()
if not self._github_client_manager.initialized:
self._github_client_manager.initialize()
if self.app_id:
if self._github_client_manager.usesAppAuthentication:
# We may be in the context of a merger or executor here. The
# mergers and executors don't receive webhook events so they miss
# new repository installations. In order to cope with this we need
# to reprime the installation map if we don't find the repo there.
installation_key = self._get_installation_key(project.name,
reprime=True)
installation_key = \
self._github_client_manager.get_installation_key(
project.name, reprime=True)
return 'https://x-access-token:%s@%s/%s' % (installation_key,
self.server,
project.name)
@ -1803,7 +1829,7 @@ class GithubConnection(BaseConnection):
def getCommitChecks(self, project_name, sha, zuul_event_id=None):
log = get_annotated_logger(self.log, zuul_event_id)
if not self.app_id:
if not self._github_client_manager.usesAppAuthentication:
log.debug(
"Not authenticated as Github app. Unable to retrieve commit "
"checks for sha %s on %s",
@ -1859,7 +1885,7 @@ class GithubConnection(BaseConnection):
# Track a list of failed check run operations to report back to Github
errors = []
if not self.app_id:
if not self._github_client_manager.usesAppAuthentication:
# We don't try to update check runs, if we aren't authenticated as
# Github app at all. If we are, we still have to ensure that we
# don't crash on missing permissions.

View File

@ -4393,6 +4393,15 @@ class Tenant(object):
self.authorization_rules = []
@property
def all_projects(self):
"""
Return a generator for all projects of the tenant.
"""
for hostname_dict in self.projects.values():
for project in hostname_dict.values():
yield project
def _addProject(self, tpc):
"""Add a project to the project index