diff --git a/zuul/driver/github/githubconnection.py b/zuul/driver/github/githubconnection.py index 218c82c93f..d34dfbf5be 100644 --- a/zuul/driver/github/githubconnection.py +++ b/zuul/driver/github/githubconnection.py @@ -71,6 +71,26 @@ class UTC(datetime.tzinfo): utc = UTC() +class GithubShaCache(object): + def __init__(self): + self.projects = {} + + def update(self, project_name, pr): + project_cache = self.projects.setdefault(project_name, {}) + sha = pr['head']['sha'] + number = pr['number'] + cached_prs = project_cache.setdefault(sha, set()) + if pr['state'] == 'open': + cached_prs.add(number) + else: + cached_prs.discard(number) + + def get(self, project_name, sha): + project_cache = self.projects.get(project_name, {}) + cached_prs = project_cache.get(sha, set()) + return cached_prs + + class GithubGearmanWorker(object): """A thread that answers gearman requests""" log = logging.getLogger("zuul.GithubGearmanWorker") @@ -514,6 +534,7 @@ class GithubConnection(BaseConnection): 'canonical_hostname', self.server) self.source = driver.getSource(self) self.event_queue = queue.Queue() + self._sha_pr_cache = GithubShaCache() # Logging of rate limit is optional as this does additional requests rate_limit_logging = self.connection_config.get( @@ -1125,6 +1146,9 @@ class GithubConnection(BaseConnection): pr['files'] = [] pr['labels'] = [l.name for l in issueobj.labels()] + + self._sha_pr_cache.update(project_name, pr) + log.debug('Got PR %s#%s', project_name, number) self.log_rate_limit(self.log, github) return pr @@ -1162,16 +1186,26 @@ class GithubConnection(BaseConnection): return True def getPullBySha(self, sha, project, log): + cached_pr_numbers = self._sha_pr_cache.get(project, sha) + if len(cached_pr_numbers) > 1: + raise Exception('Multiple pulls found with head sha %s' % sha) + if len(cached_pr_numbers) == 1: + for pr in cached_pr_numbers: + return self.getPull(project, pr, log) + pulls = [] + project_name = project owner, project = project.split('/') github = self.getGithubClient("%s/%s" % (owner, project)) repo = github.repository(owner, project) for pr in repo.pull_requests(state='open'): + pr_dict = pr.as_dict() + self._sha_pr_cache.update(project_name, pr_dict) if pr.head.sha != sha: continue - if pr.as_dict() in pulls: + if pr_dict in pulls: continue - pulls.append(pr.as_dict()) + pulls.append(pr_dict) log.debug('Got PR on project %s for sha %s', project, sha) self.log_rate_limit(self.log, github)