diff --git a/tests/base.py b/tests/base.py index 4a4e536a33..40ffed90e5 100644 --- a/tests/base.py +++ b/tests/base.py @@ -63,6 +63,8 @@ import testtools.content_type from git.exc import NoSuchPathError import yaml import paramiko + +from zuul.driver.github.githubconnection import GithubClientManager from zuul.lib.connections import ConnectionRegistry from psutil import Popen @@ -2122,26 +2124,13 @@ class FakeGithubPullRequest(object): repo.heads[self.branch].commit = repo.commit(self.head_sha) -class FakeGithubConnection(githubconnection.GithubConnection): - log = logging.getLogger("zuul.test.FakeGithubConnection") +class FakeGithubClientManager(GithubClientManager): - def __init__(self, driver, connection_name, connection_config, rpcclient, - changes_db=None, upstream_root=None, git_url_with_auth=False): - super(FakeGithubConnection, self).__init__(driver, connection_name, - connection_config) - self.connection_name = connection_name - self.pr_number = 0 - self.pull_requests = changes_db - self.statuses = {} - self.upstream_root = upstream_root - self.merge_failure = False - self.merge_not_allowed_count = 0 - self.reports = [] - self.github_data = tests.fakegithub.FakeGithubData(changes_db) - self.recorded_clients = [] - self.git_url_with_auth = git_url_with_auth - self.rpcclient = rpcclient + def __init__(self, connection_config): + super().__init__(connection_config) self.record_clients = False + self.recorded_clients = [] + self.github_data = None def getGithubClient(self, project=None, @@ -2174,6 +2163,30 @@ class FakeGithubConnection(githubconnection.GithubConnection): orgs[repo[0]] = inst_id self.installation_map['/'.join(repo)] = inst_id + +class FakeGithubConnection(githubconnection.GithubConnection): + log = logging.getLogger("zuul.test.FakeGithubConnection") + client_manager_class = FakeGithubClientManager + + def __init__(self, driver, connection_name, connection_config, rpcclient, + changes_db=None, upstream_root=None, git_url_with_auth=False): + super(FakeGithubConnection, self).__init__(driver, connection_name, + connection_config) + self.connection_name = connection_name + self.pr_number = 0 + self.pull_requests = changes_db + self.statuses = {} + self.upstream_root = upstream_root + self.merge_failure = False + self.merge_not_allowed_count = 0 + self.reports = [] + + self.github_data = tests.fakegithub.FakeGithubData(changes_db) + self._github_client_manager.github_data = self.github_data + + self.git_url_with_auth = git_url_with_auth + self.rpcclient = rpcclient + def setZuulWebPort(self, port): self.zuul_web_port = port diff --git a/tests/unit/test_zuultrigger.py b/tests/unit/test_zuultrigger.py index aeda124332..00ee2efdef 100644 --- a/tests/unit/test_zuultrigger.py +++ b/tests/unit/test_zuultrigger.py @@ -173,8 +173,9 @@ class TestZuulTriggerParentChangeEnqueuedGithub(ZuulGithubAppTestCase): # After starting recording installation containing org2/project # should not be contacted - inst_id_to_check = self.fake_github.installation_map['org2/project'] - inst_clients = [x for x in self.fake_github.recorded_clients + gh_manager = self.fake_github._github_client_manager + inst_id_to_check = gh_manager.installation_map['org2/project'] + inst_clients = [x for x in gh_manager.recorded_clients if x._inst_id == inst_id_to_check] self.assertEqual(len(inst_clients), 0) diff --git a/zuul/driver/github/githubconnection.py b/zuul/driver/github/githubconnection.py index 19497008d3..2e84e77319 100644 --- a/zuul/driver/github/githubconnection.py +++ b/zuul/driver/github/githubconnection.py @@ -763,38 +763,12 @@ class GithubUser(collections.Mapping): } -class GithubConnection(BaseConnection): - driver_name = 'github' - log = logging.getLogger("zuul.GithubConnection") - payload_path = 'payload' +class GithubClientManager: + log = logging.getLogger('zuul.GithubConnection.GithubClientManager') - def __init__(self, driver, connection_name, connection_config): - super(GithubConnection, self).__init__( - driver, connection_name, connection_config) - self._change_cache = {} - self._change_update_lock = {} - self._installation_map_lock = threading.Lock() - self._project_branch_cache_include_unprotected = {} - self._project_branch_cache_exclude_unprotected = {} - self.projects = {} - self.git_ssh_key = self.connection_config.get('sshkey') + def __init__(self, connection_config): + self.connection_config = connection_config self.server = self.connection_config.get('server', 'github.com') - self.canonical_hostname = self.connection_config.get( - 'canonical_hostname', self.server) - self.source = driver.getSource(self) - self.event_queue = queue.Queue() - self._sha_pr_cache = GithubShaCache() - - self._request_locks = {} - self.max_threads_per_installation = int(self.connection_config.get( - 'max_threads_per_installation', 1)) - - # Logging of rate limit is optional as this does additional requests - rate_limit_logging = self.connection_config.get( - 'rate_limit_logging', 'true') - self._log_rate_limit = True - if rate_limit_logging.lower() == 'false': - self._log_rate_limit = False if self.server == 'github.com': self.api_base_url = GITHUB_BASE_URL @@ -809,13 +783,6 @@ class GithubConnection(BaseConnection): if verify_ssl.lower() == 'false': self.verify_ssl = False - self.app_id = None - self.app_key = None - self.sched = None - - self.installation_map = {} - self.installation_token_cache = {} - # NOTE(jamielennox): Better here would be to cache to memcache or file # or something external - but zuul already sucks at restarting so in # memory probably doesn't make this much worse. @@ -840,49 +807,61 @@ class GithubConnection(BaseConnection): cache_etags=True, heuristic=NoAgeHeuristic()) - # The regex is based on the connection host. We do not yet support - # cross-connection dependency gathering - self.depends_on_re = re.compile( - r"^Depends-On: https://%s/.+/.+/pull/[0-9]+$" % self.server, - re.MULTILINE | re.IGNORECASE) + # Logging of rate limit is optional as this does additional requests + rate_limit_logging = self.connection_config.get( + 'rate_limit_logging', 'true') + self._log_rate_limit = True + if rate_limit_logging.lower() == 'false': + self._log_rate_limit = False - self.graphql_client = GraphQLClient('%s/graphql' % self.api_base_url) + self.app_id = None + self.app_key = None + self._initialized = False - def toDict(self): - d = super().toDict() - d.update({ - "baseurl": self.base_url, - "canonical_hostname": self.canonical_hostname, - "server": self.server, - }) - return d + self._installation_map_lock = threading.Lock() + self.installation_map = {} + self.installation_token_cache = {} - def onLoad(self): - self.log.info('Starting GitHub connection: %s' % self.connection_name) - self.gearman_worker = GithubGearmanWorker(self) + def initialize(self): self.log.info('Authing to GitHub') self._authenticateGithubAPI() self._prime_installation_map() - self.log.info('Starting event connector') - self._start_event_connector() - self.log.info('Starting GearmanWorker') - self.gearman_worker.start() + self._initialized = True - def onStop(self): - # TODO(jeblair): remove this check which is here only so that - # zuul-web can call connections.stop to shut down the sql - # connection. - if hasattr(self, 'gearman_worker'): - self.gearman_worker.stop() - self._stop_event_connector() + @property + def initialized(self): + return self._initialized - def _start_event_connector(self): - self.github_event_connector = GithubEventConnector(self) - self.github_event_connector.start() + @property + def usesAppAuthentication(self): + return True if self.app_id else False - def _stop_event_connector(self): - if self.github_event_connector: - self.github_event_connector.stop() + def _authenticateGithubAPI(self): + config = self.connection_config + + app_id = config.get('app_id') + app_key = None + app_key_file = config.get('app_key') + + if app_key_file: + try: + with open(app_key_file, 'r') as f: + app_key = f.read() + except IOError: + m = "Failed to open app key file for reading: %s" + self.log.error(m, app_key_file) + + if (app_id or app_key) and \ + not (app_id and app_key): + self.log.warning("You must provide an app_id and " + "app_key to use installation based " + "authentication") + return + + if app_id: + self.app_id = int(app_id) + if app_key: + self.app_key = app_key def _createGithubClient(self, zuul_event_id=None): session = github3.session.GitHubSession(default_read_timeout=300) @@ -921,43 +900,6 @@ class GithubConnection(BaseConnection): github._zuul_user_id = None return github - def _authenticateGithubAPI(self): - config = self.connection_config - - app_id = config.get('app_id') - app_key = None - app_key_file = config.get('app_key') - - if app_key_file: - try: - with open(app_key_file, 'r') as f: - app_key = f.read() - except IOError: - m = "Failed to open app key file for reading: %s" - self.log.error(m, app_key_file) - - if (app_id or app_key) and \ - not (app_id and app_key): - self.log.warning("You must provide an app_id and " - "app_key to use installation based " - "authentication") - - return - - if app_id: - self.app_id = int(app_id) - if app_key: - self.app_key = app_key - - @staticmethod - def _append_accept_header(github, value): - old_header = github.session.headers.get('Accept', None) - if old_header: - new_value = '%s,%s' % (old_header, value) - else: - new_value = value - github.session.headers['Accept'] = new_value - def _get_app_auth_headers(self): now = datetime.datetime.now(utc) expiry = now + datetime.timedelta(minutes=5) @@ -972,8 +914,8 @@ class GithubConnection(BaseConnection): return headers - def _get_installation_key(self, project, inst_id=None, - reprime=False): + def get_installation_key(self, project, inst_id=None, + reprime=False): installation_id = inst_id if project is not None: installation_id = self.installation_map.get(project) @@ -982,9 +924,9 @@ class GithubConnection(BaseConnection): if reprime: # prime installation map and try again without refreshing self._prime_installation_map() - return self._get_installation_key(project, - inst_id=inst_id, - reprime=False) + return self.get_installation_key(project, + inst_id=inst_id, + reprime=False) self.log.error("No installation ID available for project %s", project) @@ -1061,7 +1003,7 @@ class GithubConnection(BaseConnection): for install in installations: inst_id = install.get('id') token_by_inst[inst_id] = executor.submit( - self._get_installation_key, project=None, + self.get_installation_key, project=None, inst_id=inst_id) for inst_id, result in token_by_inst.items(): @@ -1091,22 +1033,20 @@ class GithubConnection(BaseConnection): with self._installation_map_lock: self.log.debug('Finished waiting for fetching installations') - def get_request_lock(self, installation_id): - return self._request_locks.setdefault( - installation_id, threading.Semaphore( - value=self.max_threads_per_installation)) + def getGithubClientsForProjects(self, projects): + # Get a list of projects with unique installation ids + installation_ids = set() + installation_projects = set() - def addEvent(self, data, event=None, delivery=None): - return self.event_queue.put((time.time(), data, event, delivery)) + for project in projects: + installation_id = self.installation_map.get(project.name) + if installation_id not in installation_ids: + installation_ids.add(installation_id) + installation_projects.add(project.name) - def getEvent(self): - return self.event_queue.get() - - def getEventQueueSize(self): - return self.event_queue.qsize() - - def eventDone(self): - self.event_queue.task_done() + clients = [self.getGithubClient(project_name) + for project_name in installation_projects] + return clients def getGithubClient(self, project=None, @@ -1118,7 +1058,7 @@ class GithubConnection(BaseConnection): if project and self.app_id: # Call get_installation_key to ensure the token gets refresehd in # case it's expired. - token = self._get_installation_key(project) + token = self.get_installation_key(project) # Only set the auth header if we have a token. If not, just don't # set any auth header so we will be treated as anonymous. That's @@ -1157,6 +1097,113 @@ class GithubConnection(BaseConnection): return github + +class GithubConnection(BaseConnection): + driver_name = 'github' + log = logging.getLogger("zuul.GithubConnection") + payload_path = 'payload' + client_manager_class = GithubClientManager + + def __init__(self, driver, connection_name, connection_config): + super(GithubConnection, self).__init__( + driver, connection_name, connection_config) + self._change_cache = {} + self._change_update_lock = {} + self._project_branch_cache_include_unprotected = {} + self._project_branch_cache_exclude_unprotected = {} + self.projects = {} + self.git_ssh_key = self.connection_config.get('sshkey') + self.server = self.connection_config.get('server', 'github.com') + self.canonical_hostname = self.connection_config.get( + 'canonical_hostname', self.server) + self.source = driver.getSource(self) + self.event_queue = queue.Queue() + self._sha_pr_cache = GithubShaCache() + + self._request_locks = {} + self.max_threads_per_installation = int(self.connection_config.get( + 'max_threads_per_installation', 1)) + + self._github_client_manager = self.client_manager_class( + self.connection_config) + + self.sched = None + + # The regex is based on the connection host. We do not yet support + # cross-connection dependency gathering + self.depends_on_re = re.compile( + r"^Depends-On: https://%s/.+/.+/pull/[0-9]+$" % self.server, + re.MULTILINE | re.IGNORECASE) + + self.graphql_client = GraphQLClient( + '%s/graphql' % self._github_client_manager.api_base_url) + + def toDict(self): + d = super().toDict() + d.update({ + "baseurl": self._github_client_manager.base_url, + "canonical_hostname": self.canonical_hostname, + "server": self.server, + }) + return d + + def onLoad(self): + self.log.info('Starting GitHub connection: %s' % self.connection_name) + self.gearman_worker = GithubGearmanWorker(self) + self._github_client_manager.initialize() + self.log.info('Starting event connector') + self._start_event_connector() + self.log.info('Starting GearmanWorker') + self.gearman_worker.start() + + def onStop(self): + # TODO(jeblair): remove this check which is here only so that + # zuul-web can call connections.stop to shut down the sql + # connection. + if hasattr(self, 'gearman_worker'): + self.gearman_worker.stop() + self._stop_event_connector() + + def _start_event_connector(self): + self.github_event_connector = GithubEventConnector(self) + self.github_event_connector.start() + + def _stop_event_connector(self): + if self.github_event_connector: + self.github_event_connector.stop() + + @staticmethod + def _append_accept_header(github, value): + old_header = github.session.headers.get('Accept', None) + if old_header: + new_value = '%s,%s' % (old_header, value) + else: + new_value = value + github.session.headers['Accept'] = new_value + + def get_request_lock(self, installation_id): + return self._request_locks.setdefault( + installation_id, threading.Semaphore( + value=self.max_threads_per_installation)) + + def addEvent(self, data, event=None, delivery=None): + return self.event_queue.put((time.time(), data, event, delivery)) + + def getEvent(self): + return self.event_queue.get() + + def getEventQueueSize(self): + return self.event_queue.qsize() + + def eventDone(self): + self.event_queue.task_done() + + def getGithubClient(self, + project=None, + zuul_event_id=None): + return self._github_client_manager.getGithubClient( + project=project, zuul_event_id=zuul_event_id) + def maintainCache(self, relevant): remove = set() for key, change in self._change_cache.items(): @@ -1245,37 +1292,17 @@ class GithubConnection(BaseConnection): if not change.uris: return changes - # Get a list of projects with unique installation ids - installation_ids = set() - installation_projects = set() - - if projects: - # We only need to find changes in projects in the supplied - # ChangeQueue. Find all of the github installations for - # all of those projects, and search using each of them, so - # that if we get the right results based on the - # permissions granted to each of the installations. The - # common case for this is likely to be just one - # installation -- change queues aren't likely to span more - # than one installation. - for project in projects: - installation_id = self.installation_map.get(project.name) - if installation_id not in installation_ids: - installation_ids.add(installation_id) - installation_projects.add(project.name) - else: + if not projects: # We aren't in the context of a change queue and we just # need to query all installations of this tenant. This currently # only happens if certain features of the zuul trigger are # used; generally it should be avoided. - for project_name, installation_id in self.installation_map.items(): - trusted, project = tenant.getProject(project_name) - # ignore projects from different tenants - if not project: - continue - if installation_id not in installation_ids: - installation_ids.add(installation_id) - installation_projects.add(project_name) + projects = [p for p in tenant.all_projects + if p.connection_name == self.connection_name] + # Otherwise we use the input projects list and look for changes in the + # supplied projects. + clients = self._github_client_manager.getGithubClientsForProjects( + projects) keys = set() # TODO: Max of 5 OR operators can be used per query and @@ -1284,9 +1311,8 @@ class GithubConnection(BaseConnection): # tests/fakegithub.py pattern = ' OR '.join(['"Depends-On: %s"' % x for x in change.uris]) query = '%s type:pr is:open in:body' % pattern - # Repeat the search for each installation id (project) - for installation_project in installation_projects: - github = self.getGithubClient(installation_project) + # Repeat the search for each client (project) + for github in clients: for issue in github.search_issues(query=query): pr = issue.issue.pull_request().as_dict() if not pr.get('url'): @@ -1401,17 +1427,17 @@ class GithubConnection(BaseConnection): # if app_id is configured but self.app_id is empty we are not # authenticated yet against github as app - if not self.app_id and self.connection_config.get('app_id', None): - self._authenticateGithubAPI() - self._prime_installation_map() + if not self._github_client_manager.initialized: + self._github_client_manager.initialize() - if self.app_id: + if self._github_client_manager.usesAppAuthentication: # We may be in the context of a merger or executor here. The # mergers and executors don't receive webhook events so they miss # new repository installations. In order to cope with this we need # to reprime the installation map if we don't find the repo there. - installation_key = self._get_installation_key(project.name, - reprime=True) + installation_key = \ + self._github_client_manager.get_installation_key( + project.name, reprime=True) return 'https://x-access-token:%s@%s/%s' % (installation_key, self.server, project.name) @@ -1799,7 +1825,7 @@ class GithubConnection(BaseConnection): def getCommitChecks(self, project_name, sha, zuul_event_id=None): log = get_annotated_logger(self.log, zuul_event_id) - if not self.app_id: + if not self._github_client_manager.usesAppAuthentication: log.debug( "Not authenticated as Github app. Unable to retrieve commit " "checks for sha %s on %s", @@ -1855,7 +1881,7 @@ class GithubConnection(BaseConnection): # Track a list of failed check run operations to report back to Github errors = [] - if not self.app_id: + if not self._github_client_manager.usesAppAuthentication: # We don't try to update check runs, if we aren't authenticated as # Github app at all. If we are, we still have to ensure that we # don't crash on missing permissions. diff --git a/zuul/model.py b/zuul/model.py index a5dd8827db..5442df3a27 100644 --- a/zuul/model.py +++ b/zuul/model.py @@ -4362,6 +4362,15 @@ class Tenant(object): self.authorization_rules = [] + @property + def all_projects(self): + """ + Return a generator for all projects of the tenant. + """ + for hostname_dict in self.projects.values(): + for project in hostname_dict.values(): + yield project + def _addProject(self, tpc): """Add a project to the project index