Refactor github auth handling into its own class

The github client and authentication handling is spread over many
methods in the GithubConnection class. As a preparation for further
improvements separate that into its own class so it's easier to
maintain.

Change-Id: Ia935776ae8ccfa8d044a134e1507e4d2e582fbe0
This commit is contained in:
Tobias Henkel 2020-02-26 14:30:50 +01:00
parent 087ab1d587
commit d94f313687
No known key found for this signature in database
GPG Key ID: 03750DEC158E5FA2
4 changed files with 236 additions and 187 deletions

View File

@ -63,6 +63,8 @@ import testtools.content_type
from git.exc import NoSuchPathError from git.exc import NoSuchPathError
import yaml import yaml
import paramiko import paramiko
from zuul.driver.github.githubconnection import GithubClientManager
from zuul.lib.connections import ConnectionRegistry from zuul.lib.connections import ConnectionRegistry
from psutil import Popen from psutil import Popen
@ -2122,26 +2124,13 @@ class FakeGithubPullRequest(object):
repo.heads[self.branch].commit = repo.commit(self.head_sha) repo.heads[self.branch].commit = repo.commit(self.head_sha)
class FakeGithubConnection(githubconnection.GithubConnection): class FakeGithubClientManager(GithubClientManager):
log = logging.getLogger("zuul.test.FakeGithubConnection")
def __init__(self, driver, connection_name, connection_config, rpcclient, def __init__(self, connection_config):
changes_db=None, upstream_root=None, git_url_with_auth=False): super().__init__(connection_config)
super(FakeGithubConnection, self).__init__(driver, connection_name,
connection_config)
self.connection_name = connection_name
self.pr_number = 0
self.pull_requests = changes_db
self.statuses = {}
self.upstream_root = upstream_root
self.merge_failure = False
self.merge_not_allowed_count = 0
self.reports = []
self.github_data = tests.fakegithub.FakeGithubData(changes_db)
self.recorded_clients = []
self.git_url_with_auth = git_url_with_auth
self.rpcclient = rpcclient
self.record_clients = False self.record_clients = False
self.recorded_clients = []
self.github_data = None
def getGithubClient(self, def getGithubClient(self,
project=None, project=None,
@ -2174,6 +2163,30 @@ class FakeGithubConnection(githubconnection.GithubConnection):
orgs[repo[0]] = inst_id orgs[repo[0]] = inst_id
self.installation_map['/'.join(repo)] = inst_id self.installation_map['/'.join(repo)] = inst_id
class FakeGithubConnection(githubconnection.GithubConnection):
log = logging.getLogger("zuul.test.FakeGithubConnection")
client_manager_class = FakeGithubClientManager
def __init__(self, driver, connection_name, connection_config, rpcclient,
changes_db=None, upstream_root=None, git_url_with_auth=False):
super(FakeGithubConnection, self).__init__(driver, connection_name,
connection_config)
self.connection_name = connection_name
self.pr_number = 0
self.pull_requests = changes_db
self.statuses = {}
self.upstream_root = upstream_root
self.merge_failure = False
self.merge_not_allowed_count = 0
self.reports = []
self.github_data = tests.fakegithub.FakeGithubData(changes_db)
self._github_client_manager.github_data = self.github_data
self.git_url_with_auth = git_url_with_auth
self.rpcclient = rpcclient
def setZuulWebPort(self, port): def setZuulWebPort(self, port):
self.zuul_web_port = port self.zuul_web_port = port

View File

@ -173,8 +173,9 @@ class TestZuulTriggerParentChangeEnqueuedGithub(ZuulGithubAppTestCase):
# After starting recording installation containing org2/project # After starting recording installation containing org2/project
# should not be contacted # should not be contacted
inst_id_to_check = self.fake_github.installation_map['org2/project'] gh_manager = self.fake_github._github_client_manager
inst_clients = [x for x in self.fake_github.recorded_clients inst_id_to_check = gh_manager.installation_map['org2/project']
inst_clients = [x for x in gh_manager.recorded_clients
if x._inst_id == inst_id_to_check] if x._inst_id == inst_id_to_check]
self.assertEqual(len(inst_clients), 0) self.assertEqual(len(inst_clients), 0)

View File

@ -763,38 +763,12 @@ class GithubUser(collections.Mapping):
} }
class GithubConnection(BaseConnection): class GithubClientManager:
driver_name = 'github' log = logging.getLogger('zuul.GithubConnection.GithubClientManager')
log = logging.getLogger("zuul.GithubConnection")
payload_path = 'payload'
def __init__(self, driver, connection_name, connection_config): def __init__(self, connection_config):
super(GithubConnection, self).__init__( self.connection_config = connection_config
driver, connection_name, connection_config)
self._change_cache = {}
self._change_update_lock = {}
self._installation_map_lock = threading.Lock()
self._project_branch_cache_include_unprotected = {}
self._project_branch_cache_exclude_unprotected = {}
self.projects = {}
self.git_ssh_key = self.connection_config.get('sshkey')
self.server = self.connection_config.get('server', 'github.com') self.server = self.connection_config.get('server', 'github.com')
self.canonical_hostname = self.connection_config.get(
'canonical_hostname', self.server)
self.source = driver.getSource(self)
self.event_queue = queue.Queue()
self._sha_pr_cache = GithubShaCache()
self._request_locks = {}
self.max_threads_per_installation = int(self.connection_config.get(
'max_threads_per_installation', 1))
# Logging of rate limit is optional as this does additional requests
rate_limit_logging = self.connection_config.get(
'rate_limit_logging', 'true')
self._log_rate_limit = True
if rate_limit_logging.lower() == 'false':
self._log_rate_limit = False
if self.server == 'github.com': if self.server == 'github.com':
self.api_base_url = GITHUB_BASE_URL self.api_base_url = GITHUB_BASE_URL
@ -809,13 +783,6 @@ class GithubConnection(BaseConnection):
if verify_ssl.lower() == 'false': if verify_ssl.lower() == 'false':
self.verify_ssl = False self.verify_ssl = False
self.app_id = None
self.app_key = None
self.sched = None
self.installation_map = {}
self.installation_token_cache = {}
# NOTE(jamielennox): Better here would be to cache to memcache or file # NOTE(jamielennox): Better here would be to cache to memcache or file
# or something external - but zuul already sucks at restarting so in # or something external - but zuul already sucks at restarting so in
# memory probably doesn't make this much worse. # memory probably doesn't make this much worse.
@ -840,49 +807,61 @@ class GithubConnection(BaseConnection):
cache_etags=True, cache_etags=True,
heuristic=NoAgeHeuristic()) heuristic=NoAgeHeuristic())
# The regex is based on the connection host. We do not yet support # Logging of rate limit is optional as this does additional requests
# cross-connection dependency gathering rate_limit_logging = self.connection_config.get(
self.depends_on_re = re.compile( 'rate_limit_logging', 'true')
r"^Depends-On: https://%s/.+/.+/pull/[0-9]+$" % self.server, self._log_rate_limit = True
re.MULTILINE | re.IGNORECASE) if rate_limit_logging.lower() == 'false':
self._log_rate_limit = False
self.graphql_client = GraphQLClient('%s/graphql' % self.api_base_url) self.app_id = None
self.app_key = None
self._initialized = False
def toDict(self): self._installation_map_lock = threading.Lock()
d = super().toDict() self.installation_map = {}
d.update({ self.installation_token_cache = {}
"baseurl": self.base_url,
"canonical_hostname": self.canonical_hostname,
"server": self.server,
})
return d
def onLoad(self): def initialize(self):
self.log.info('Starting GitHub connection: %s' % self.connection_name)
self.gearman_worker = GithubGearmanWorker(self)
self.log.info('Authing to GitHub') self.log.info('Authing to GitHub')
self._authenticateGithubAPI() self._authenticateGithubAPI()
self._prime_installation_map() self._prime_installation_map()
self.log.info('Starting event connector') self._initialized = True
self._start_event_connector()
self.log.info('Starting GearmanWorker')
self.gearman_worker.start()
def onStop(self): @property
# TODO(jeblair): remove this check which is here only so that def initialized(self):
# zuul-web can call connections.stop to shut down the sql return self._initialized
# connection.
if hasattr(self, 'gearman_worker'):
self.gearman_worker.stop()
self._stop_event_connector()
def _start_event_connector(self): @property
self.github_event_connector = GithubEventConnector(self) def usesAppAuthentication(self):
self.github_event_connector.start() return True if self.app_id else False
def _stop_event_connector(self): def _authenticateGithubAPI(self):
if self.github_event_connector: config = self.connection_config
self.github_event_connector.stop()
app_id = config.get('app_id')
app_key = None
app_key_file = config.get('app_key')
if app_key_file:
try:
with open(app_key_file, 'r') as f:
app_key = f.read()
except IOError:
m = "Failed to open app key file for reading: %s"
self.log.error(m, app_key_file)
if (app_id or app_key) and \
not (app_id and app_key):
self.log.warning("You must provide an app_id and "
"app_key to use installation based "
"authentication")
return
if app_id:
self.app_id = int(app_id)
if app_key:
self.app_key = app_key
def _createGithubClient(self, zuul_event_id=None): def _createGithubClient(self, zuul_event_id=None):
session = github3.session.GitHubSession(default_read_timeout=300) session = github3.session.GitHubSession(default_read_timeout=300)
@ -921,43 +900,6 @@ class GithubConnection(BaseConnection):
github._zuul_user_id = None github._zuul_user_id = None
return github return github
def _authenticateGithubAPI(self):
config = self.connection_config
app_id = config.get('app_id')
app_key = None
app_key_file = config.get('app_key')
if app_key_file:
try:
with open(app_key_file, 'r') as f:
app_key = f.read()
except IOError:
m = "Failed to open app key file for reading: %s"
self.log.error(m, app_key_file)
if (app_id or app_key) and \
not (app_id and app_key):
self.log.warning("You must provide an app_id and "
"app_key to use installation based "
"authentication")
return
if app_id:
self.app_id = int(app_id)
if app_key:
self.app_key = app_key
@staticmethod
def _append_accept_header(github, value):
old_header = github.session.headers.get('Accept', None)
if old_header:
new_value = '%s,%s' % (old_header, value)
else:
new_value = value
github.session.headers['Accept'] = new_value
def _get_app_auth_headers(self): def _get_app_auth_headers(self):
now = datetime.datetime.now(utc) now = datetime.datetime.now(utc)
expiry = now + datetime.timedelta(minutes=5) expiry = now + datetime.timedelta(minutes=5)
@ -972,8 +914,8 @@ class GithubConnection(BaseConnection):
return headers return headers
def _get_installation_key(self, project, inst_id=None, def get_installation_key(self, project, inst_id=None,
reprime=False): reprime=False):
installation_id = inst_id installation_id = inst_id
if project is not None: if project is not None:
installation_id = self.installation_map.get(project) installation_id = self.installation_map.get(project)
@ -982,9 +924,9 @@ class GithubConnection(BaseConnection):
if reprime: if reprime:
# prime installation map and try again without refreshing # prime installation map and try again without refreshing
self._prime_installation_map() self._prime_installation_map()
return self._get_installation_key(project, return self.get_installation_key(project,
inst_id=inst_id, inst_id=inst_id,
reprime=False) reprime=False)
self.log.error("No installation ID available for project %s", self.log.error("No installation ID available for project %s",
project) project)
@ -1061,7 +1003,7 @@ class GithubConnection(BaseConnection):
for install in installations: for install in installations:
inst_id = install.get('id') inst_id = install.get('id')
token_by_inst[inst_id] = executor.submit( token_by_inst[inst_id] = executor.submit(
self._get_installation_key, project=None, self.get_installation_key, project=None,
inst_id=inst_id) inst_id=inst_id)
for inst_id, result in token_by_inst.items(): for inst_id, result in token_by_inst.items():
@ -1091,22 +1033,20 @@ class GithubConnection(BaseConnection):
with self._installation_map_lock: with self._installation_map_lock:
self.log.debug('Finished waiting for fetching installations') self.log.debug('Finished waiting for fetching installations')
def get_request_lock(self, installation_id): def getGithubClientsForProjects(self, projects):
return self._request_locks.setdefault( # Get a list of projects with unique installation ids
installation_id, threading.Semaphore( installation_ids = set()
value=self.max_threads_per_installation)) installation_projects = set()
def addEvent(self, data, event=None, delivery=None): for project in projects:
return self.event_queue.put((time.time(), data, event, delivery)) installation_id = self.installation_map.get(project.name)
if installation_id not in installation_ids:
installation_ids.add(installation_id)
installation_projects.add(project.name)
def getEvent(self): clients = [self.getGithubClient(project_name)
return self.event_queue.get() for project_name in installation_projects]
return clients
def getEventQueueSize(self):
return self.event_queue.qsize()
def eventDone(self):
self.event_queue.task_done()
def getGithubClient(self, def getGithubClient(self,
project=None, project=None,
@ -1118,7 +1058,7 @@ class GithubConnection(BaseConnection):
if project and self.app_id: if project and self.app_id:
# Call get_installation_key to ensure the token gets refresehd in # Call get_installation_key to ensure the token gets refresehd in
# case it's expired. # case it's expired.
token = self._get_installation_key(project) token = self.get_installation_key(project)
# Only set the auth header if we have a token. If not, just don't # Only set the auth header if we have a token. If not, just don't
# set any auth header so we will be treated as anonymous. That's # set any auth header so we will be treated as anonymous. That's
@ -1157,6 +1097,113 @@ class GithubConnection(BaseConnection):
return github return github
class GithubConnection(BaseConnection):
driver_name = 'github'
log = logging.getLogger("zuul.GithubConnection")
payload_path = 'payload'
client_manager_class = GithubClientManager
def __init__(self, driver, connection_name, connection_config):
super(GithubConnection, self).__init__(
driver, connection_name, connection_config)
self._change_cache = {}
self._change_update_lock = {}
self._project_branch_cache_include_unprotected = {}
self._project_branch_cache_exclude_unprotected = {}
self.projects = {}
self.git_ssh_key = self.connection_config.get('sshkey')
self.server = self.connection_config.get('server', 'github.com')
self.canonical_hostname = self.connection_config.get(
'canonical_hostname', self.server)
self.source = driver.getSource(self)
self.event_queue = queue.Queue()
self._sha_pr_cache = GithubShaCache()
self._request_locks = {}
self.max_threads_per_installation = int(self.connection_config.get(
'max_threads_per_installation', 1))
self._github_client_manager = self.client_manager_class(
self.connection_config)
self.sched = None
# The regex is based on the connection host. We do not yet support
# cross-connection dependency gathering
self.depends_on_re = re.compile(
r"^Depends-On: https://%s/.+/.+/pull/[0-9]+$" % self.server,
re.MULTILINE | re.IGNORECASE)
self.graphql_client = GraphQLClient(
'%s/graphql' % self._github_client_manager.api_base_url)
def toDict(self):
d = super().toDict()
d.update({
"baseurl": self._github_client_manager.base_url,
"canonical_hostname": self.canonical_hostname,
"server": self.server,
})
return d
def onLoad(self):
self.log.info('Starting GitHub connection: %s' % self.connection_name)
self.gearman_worker = GithubGearmanWorker(self)
self._github_client_manager.initialize()
self.log.info('Starting event connector')
self._start_event_connector()
self.log.info('Starting GearmanWorker')
self.gearman_worker.start()
def onStop(self):
# TODO(jeblair): remove this check which is here only so that
# zuul-web can call connections.stop to shut down the sql
# connection.
if hasattr(self, 'gearman_worker'):
self.gearman_worker.stop()
self._stop_event_connector()
def _start_event_connector(self):
self.github_event_connector = GithubEventConnector(self)
self.github_event_connector.start()
def _stop_event_connector(self):
if self.github_event_connector:
self.github_event_connector.stop()
@staticmethod
def _append_accept_header(github, value):
old_header = github.session.headers.get('Accept', None)
if old_header:
new_value = '%s,%s' % (old_header, value)
else:
new_value = value
github.session.headers['Accept'] = new_value
def get_request_lock(self, installation_id):
return self._request_locks.setdefault(
installation_id, threading.Semaphore(
value=self.max_threads_per_installation))
def addEvent(self, data, event=None, delivery=None):
return self.event_queue.put((time.time(), data, event, delivery))
def getEvent(self):
return self.event_queue.get()
def getEventQueueSize(self):
return self.event_queue.qsize()
def eventDone(self):
self.event_queue.task_done()
def getGithubClient(self,
project=None,
zuul_event_id=None):
return self._github_client_manager.getGithubClient(
project=project, zuul_event_id=zuul_event_id)
def maintainCache(self, relevant): def maintainCache(self, relevant):
remove = set() remove = set()
for key, change in self._change_cache.items(): for key, change in self._change_cache.items():
@ -1245,37 +1292,17 @@ class GithubConnection(BaseConnection):
if not change.uris: if not change.uris:
return changes return changes
# Get a list of projects with unique installation ids if not projects:
installation_ids = set()
installation_projects = set()
if projects:
# We only need to find changes in projects in the supplied
# ChangeQueue. Find all of the github installations for
# all of those projects, and search using each of them, so
# that if we get the right results based on the
# permissions granted to each of the installations. The
# common case for this is likely to be just one
# installation -- change queues aren't likely to span more
# than one installation.
for project in projects:
installation_id = self.installation_map.get(project.name)
if installation_id not in installation_ids:
installation_ids.add(installation_id)
installation_projects.add(project.name)
else:
# We aren't in the context of a change queue and we just # We aren't in the context of a change queue and we just
# need to query all installations of this tenant. This currently # need to query all installations of this tenant. This currently
# only happens if certain features of the zuul trigger are # only happens if certain features of the zuul trigger are
# used; generally it should be avoided. # used; generally it should be avoided.
for project_name, installation_id in self.installation_map.items(): projects = [p for p in tenant.all_projects
trusted, project = tenant.getProject(project_name) if p.connection_name == self.connection_name]
# ignore projects from different tenants # Otherwise we use the input projects list and look for changes in the
if not project: # supplied projects.
continue clients = self._github_client_manager.getGithubClientsForProjects(
if installation_id not in installation_ids: projects)
installation_ids.add(installation_id)
installation_projects.add(project_name)
keys = set() keys = set()
# TODO: Max of 5 OR operators can be used per query and # TODO: Max of 5 OR operators can be used per query and
@ -1284,9 +1311,8 @@ class GithubConnection(BaseConnection):
# tests/fakegithub.py # tests/fakegithub.py
pattern = ' OR '.join(['"Depends-On: %s"' % x for x in change.uris]) pattern = ' OR '.join(['"Depends-On: %s"' % x for x in change.uris])
query = '%s type:pr is:open in:body' % pattern query = '%s type:pr is:open in:body' % pattern
# Repeat the search for each installation id (project) # Repeat the search for each client (project)
for installation_project in installation_projects: for github in clients:
github = self.getGithubClient(installation_project)
for issue in github.search_issues(query=query): for issue in github.search_issues(query=query):
pr = issue.issue.pull_request().as_dict() pr = issue.issue.pull_request().as_dict()
if not pr.get('url'): if not pr.get('url'):
@ -1401,17 +1427,17 @@ class GithubConnection(BaseConnection):
# if app_id is configured but self.app_id is empty we are not # if app_id is configured but self.app_id is empty we are not
# authenticated yet against github as app # authenticated yet against github as app
if not self.app_id and self.connection_config.get('app_id', None): if not self._github_client_manager.initialized:
self._authenticateGithubAPI() self._github_client_manager.initialize()
self._prime_installation_map()
if self.app_id: if self._github_client_manager.usesAppAuthentication:
# We may be in the context of a merger or executor here. The # We may be in the context of a merger or executor here. The
# mergers and executors don't receive webhook events so they miss # mergers and executors don't receive webhook events so they miss
# new repository installations. In order to cope with this we need # new repository installations. In order to cope with this we need
# to reprime the installation map if we don't find the repo there. # to reprime the installation map if we don't find the repo there.
installation_key = self._get_installation_key(project.name, installation_key = \
reprime=True) self._github_client_manager.get_installation_key(
project.name, reprime=True)
return 'https://x-access-token:%s@%s/%s' % (installation_key, return 'https://x-access-token:%s@%s/%s' % (installation_key,
self.server, self.server,
project.name) project.name)
@ -1799,7 +1825,7 @@ class GithubConnection(BaseConnection):
def getCommitChecks(self, project_name, sha, zuul_event_id=None): def getCommitChecks(self, project_name, sha, zuul_event_id=None):
log = get_annotated_logger(self.log, zuul_event_id) log = get_annotated_logger(self.log, zuul_event_id)
if not self.app_id: if not self._github_client_manager.usesAppAuthentication:
log.debug( log.debug(
"Not authenticated as Github app. Unable to retrieve commit " "Not authenticated as Github app. Unable to retrieve commit "
"checks for sha %s on %s", "checks for sha %s on %s",
@ -1855,7 +1881,7 @@ class GithubConnection(BaseConnection):
# Track a list of failed check run operations to report back to Github # Track a list of failed check run operations to report back to Github
errors = [] errors = []
if not self.app_id: if not self._github_client_manager.usesAppAuthentication:
# We don't try to update check runs, if we aren't authenticated as # We don't try to update check runs, if we aren't authenticated as
# Github app at all. If we are, we still have to ensure that we # Github app at all. If we are, we still have to ensure that we
# don't crash on missing permissions. # don't crash on missing permissions.

View File

@ -4362,6 +4362,15 @@ class Tenant(object):
self.authorization_rules = [] self.authorization_rules = []
@property
def all_projects(self):
"""
Return a generator for all projects of the tenant.
"""
for hostname_dict in self.projects.values():
for project in hostname_dict.values():
yield project
def _addProject(self, tpc): def _addProject(self, tpc):
"""Add a project to the project index """Add a project to the project index