diff --git a/tests/fakegithub.py b/tests/fakegithub.py index f278e49401..ef0478f75f 100644 --- a/tests/fakegithub.py +++ b/tests/fakegithub.py @@ -283,6 +283,10 @@ class FakeIssue(object): def pull_request(self): return FakePull(self._fake_pull_request) + @property + def number(self): + return self._fake_pull_request.number + class FakeFile(object): def __init__(self, filename): @@ -449,6 +453,14 @@ class FakeGithubClient(object): def tokenize(s): return re.findall(r'[\w]+', s) + def query_is_sha(s): + return re.match(r'[a-z0-9]{40}', s) + + if query_is_sha(query): + return (FakeIssueSearchResult(FakeIssue(pr)) + for pr in self._data.pull_requests.values() + if pr.head_sha == query) + parts = tokenize(query) terms = set() results = [] @@ -471,4 +483,4 @@ class FakeGithubClient(object): issue = FakeIssue(pr) results.append(FakeIssueSearchResult(issue)) - return results + return iter(results) diff --git a/zuul/driver/github/githubconnection.py b/zuul/driver/github/githubconnection.py index 69a902c2e5..f201110620 100644 --- a/zuul/driver/github/githubconnection.py +++ b/zuul/driver/github/githubconnection.py @@ -1280,7 +1280,7 @@ class GithubConnection(BaseConnection): def getPull(self, project_name, number, event=None): log = get_annotated_logger(self.log, event) - github = self.getGithubClient(project_name) + github = self.getGithubClient(project_name, zuul_event_id=event) owner, proj = project_name.split('/') for retry in range(5): try: @@ -1344,6 +1344,8 @@ class GithubConnection(BaseConnection): def getPullBySha(self, sha, project_name, event): log = get_annotated_logger(self.log, event) + + # Serve from the cache if existing cached_pr_numbers = self._sha_pr_cache.get(project_name, sha) if len(cached_pr_numbers) > 1: raise Exception('Multiple pulls found with head sha %s' % sha) @@ -1352,31 +1354,20 @@ class GithubConnection(BaseConnection): pr_body, pr_obj = self.getPull(project_name, pr, event) return pr_body - pulls = [] - github = self.getGithubClient(project_name) - owner, repository = project_name.split('/') - repo = github.repository(owner, repository) - for pr in repo.pull_requests(state='open', - # We sort by updated from oldest to newest - # as that will prefer more recently - # PRs in our LRU cache. - sort='updated', - direction='asc'): - pr_dict = pr.as_dict() - self._sha_pr_cache.update(project_name, pr_dict) - if pr.head.sha != sha: - continue - if pr_dict in pulls: - continue - pulls.append(pr_dict) + github = self.getGithubClient(project_name, zuul_event_id=event) + issues = list(github.search_issues(sha)) log.debug('Got PR on project %s for sha %s', project_name, sha) - if len(pulls) > 1: + if len(issues) > 1: raise Exception('Multiple pulls found with head sha %s' % sha) - if len(pulls) == 0: + if len(issues) == 0: return None - return pulls.pop() + + pr_body, pr_obj = self.getPull( + project_name, issues.pop().issue.number, event) + self._sha_pr_cache.update(project_name, pr_body) + return pr_body def getPullReviews(self, pr_obj, project, number, event): log = get_annotated_logger(self.log, event)