diff --git a/webhook_server/libs/github_api.py b/webhook_server/libs/github_api.py index 4610a24e8..7bd17d8c3 100644 --- a/webhook_server/libs/github_api.py +++ b/webhook_server/libs/github_api.py @@ -56,6 +56,8 @@ run_command, ) +_SHA_PATTERN = re.compile(r"^[0-9a-f]{40}$") + class CountingRequester: """ @@ -115,6 +117,8 @@ def __init__(self, hook_data: dict[Any, Any], headers: Headers, logger: logging. self.repository_full_name: str = hook_data["repository"]["full_name"] self._bg_tasks: set[Task[Any]] = set() self.parent_committer: str = "" + self.pr_base_sha: str = "" + self.pr_head_sha: str = "" self.x_github_delivery: str = headers.get("X-GitHub-Delivery", "") self.github_event: str = headers["X-GitHub-Event"] self.config = Config(repository=self.repository_name, logger=self.logger) @@ -386,6 +390,43 @@ def redact_output(value: str) -> str: redacted_err = redact_output(err) self.logger.error(f"{self.log_prefix} Failed to fetch PR {pr_number} ref: {redacted_err}") raise RuntimeError(f"Failed to fetch PR {pr_number} ref: {redacted_err}") + + # Fetch payload SHAs explicitly to handle force-push race condition + # The webhook payload SHAs may differ from the current PR ref if the PR + # was force-pushed between webhook delivery and processing + # Validate SHA format first — reset invalid SHAs so the fetch is skipped + for sha_attr in ("pr_base_sha", "pr_head_sha"): + sha = getattr(self, sha_attr) + if not isinstance(sha, str) or (sha and not _SHA_PATTERN.match(sha)): + self.logger.warning( + f"{self.log_prefix} Invalid {sha_attr} format: {str(sha)[:20]}, will use API fallback" + ) + setattr(self, sha_attr, "") + + if self.pr_base_sha and self.pr_head_sha: + for sha in (self.pr_base_sha, self.pr_head_sha): + # Check if SHA exists in clone + rc_check, _, _ = await run_command( + command=f"{git_cmd} cat-file -e {sha}^{{commit}}", + log_prefix=self.log_prefix, + verify_stderr=False, + mask_sensitive=self.mask_sensitive, + ) + if not rc_check: + self.logger.debug( + f"{self.log_prefix} Payload SHA {sha[:7]} not in clone, fetching explicitly" + ) + rc_fetch, _, _ = await run_command( + command=f"{git_cmd} fetch origin {sha}", + log_prefix=self.log_prefix, + redact_secrets=[github_token], + mask_sensitive=self.mask_sensitive, + ) + if not rc_fetch: + self.logger.warning( + f"{self.log_prefix} Failed to fetch payload SHA {sha[:7]} — " + f"git diff may fail if this SHA is unreachable" + ) else: # For push events (tags only - branch pushes skip cloning) # checkout_ref guaranteed to be non-None by validation at function start @@ -449,7 +490,7 @@ async def _recheck_merge_eligibility(self, pull_request: PullRequest) -> None: """ await self._clone_repository(pull_request=pull_request) owners_file_handler = OwnersFileHandler(github_webhook=self) - owners_file_handler = await owners_file_handler.initialize(pull_request=pull_request) + owners_file_handler = await owners_file_handler.initialize() await PullRequestHandler(github_webhook=self, owners_file_handler=owners_file_handler).check_if_can_be_merged( pull_request=pull_request ) @@ -596,6 +637,18 @@ async def process(self) -> Any: self.parent_committer = pull_request.user.login self.last_committer = getattr(self.last_commit.committer, "login", self.parent_committer) + # Store PR SHAs: prefer webhook payload (avoids race condition with live API) + # For pull_request events, base.sha and head.sha are guaranteed by GitHub webhook spec. + # For other events (issue_comment, check_run), fall back to PullRequest API object. + if self.github_event == "pull_request": + self.pr_base_sha = self.hook_data["pull_request"]["base"]["sha"] + self.pr_head_sha = self.hook_data["pull_request"]["head"]["sha"] + else: + self.pr_base_sha, self.pr_head_sha = await asyncio.gather( + github_api_call(lambda: pull_request.base.sha, logger=self.logger, log_prefix=self.log_prefix), + github_api_call(lambda: pull_request.head.sha, logger=self.logger, log_prefix=self.log_prefix), + ) + # Clone repository for local file processing (OWNERS, changed files) # For check_run, status, and pull_request_review_thread events, # cloning happens later only when needed (inside their respective handlers) @@ -604,7 +657,7 @@ async def process(self) -> Any: if self.github_event == "issue_comment": owners_file_handler = OwnersFileHandler(github_webhook=self) - owners_file_handler = await owners_file_handler.initialize(pull_request=pull_request) + owners_file_handler = await owners_file_handler.initialize() await IssueCommentHandler( github_webhook=self, owners_file_handler=owners_file_handler @@ -618,7 +671,7 @@ async def process(self) -> Any: elif self.github_event == "pull_request": owners_file_handler = OwnersFileHandler(github_webhook=self) - owners_file_handler = await owners_file_handler.initialize(pull_request=pull_request) + owners_file_handler = await owners_file_handler.initialize() await PullRequestHandler( github_webhook=self, owners_file_handler=owners_file_handler @@ -632,7 +685,7 @@ async def process(self) -> Any: elif self.github_event == "pull_request_review": owners_file_handler = OwnersFileHandler(github_webhook=self) - owners_file_handler = await owners_file_handler.initialize(pull_request=pull_request) + owners_file_handler = await owners_file_handler.initialize() await PullRequestReviewHandler( github_webhook=self, owners_file_handler=owners_file_handler @@ -678,7 +731,7 @@ async def process(self) -> Any: await self._clone_repository(pull_request=pull_request) owners_file_handler = OwnersFileHandler(github_webhook=self) - owners_file_handler = await owners_file_handler.initialize(pull_request=pull_request) + owners_file_handler = await owners_file_handler.initialize() handled = await CheckRunHandler( github_webhook=self, owners_file_handler=owners_file_handler ).process_pull_request_check_run_webhook_data(pull_request=pull_request) diff --git a/webhook_server/libs/handlers/owners_files_handler.py b/webhook_server/libs/handlers/owners_files_handler.py index 25b188960..d9d27b20a 100644 --- a/webhook_server/libs/handlers/owners_files_handler.py +++ b/webhook_server/libs/handlers/owners_files_handler.py @@ -28,7 +28,7 @@ def __init__(self, github_webhook: "GithubWebhook") -> None: self.log_prefix: str = self.github_webhook.log_prefix self.repository: Repository = self.github_webhook.repository - async def initialize(self, pull_request: PullRequest) -> "OwnersFileHandler": + async def initialize(self) -> "OwnersFileHandler": """Initialize handler with PR data (optimized with parallel operations). Phase 1: Fetch independent data in parallel (changed files + OWNERS data) @@ -37,7 +37,7 @@ async def initialize(self, pull_request: PullRequest) -> "OwnersFileHandler": # Phase 1: Parallel data fetching - independent GitHub API operations self.changed_files, self.all_repository_approvers_and_reviewers = await asyncio.gather( - self.list_changed_files(pull_request=pull_request), + self.list_changed_files(), self.get_all_repository_approvers_and_reviewers(), ) @@ -84,14 +84,15 @@ def allowed_users(self) -> list[str]: self.logger.debug(f"{self.log_prefix} ROOT allowed users: {_allowed_users}") return _allowed_users - async def list_changed_files(self, pull_request: PullRequest) -> list[str]: + async def list_changed_files(self) -> list[str]: """List changed files in the PR using git diff on cloned repository. Uses local git diff command instead of GitHub API to reduce API calls. The repository is already cloned to self.github_webhook.clone_repo_dir. - - Args: - pull_request: PyGithub PullRequest object + SHAs are read from the webhook payload to avoid race conditions with + live API calls (base branch may receive new commits between clone and API call). + If the PR is force-pushed between webhook delivery and processing, + _clone_repository() explicitly fetches payload SHAs to ensure they exist. Returns: List of changed file paths relative to repository root @@ -100,11 +101,11 @@ async def list_changed_files(self, pull_request: PullRequest) -> list[str]: RuntimeError: If git diff command fails asyncio.CancelledError: Propagates cancellation (never caught) """ - # Get base and head SHAs (wrap property accesses in github_api_call for retry support) - base_sha, head_sha = await asyncio.gather( - github_api_call(lambda: pull_request.base.sha, logger=self.logger, log_prefix=self.log_prefix), - github_api_call(lambda: pull_request.head.sha, logger=self.logger, log_prefix=self.log_prefix), - ) + # SHAs are stored on the GithubWebhook instance during process(): + # - From webhook payload for pull_request events (avoids race condition with live API) + # - From PullRequest object for other event types (issue_comment, check_run, etc.) + base_sha = self.github_webhook.pr_base_sha + head_sha = self.github_webhook.pr_head_sha # Run git diff command on cloned repository # Quote clone_repo_dir to handle paths with spaces or special characters diff --git a/webhook_server/tests/test_github_api.py b/webhook_server/tests/test_github_api.py index f3ba9d186..c0a1bf4a4 100644 --- a/webhook_server/tests/test_github_api.py +++ b/webhook_server/tests/test_github_api.py @@ -39,7 +39,7 @@ def pull_request_payload(self) -> dict[str, Any]: "number": 123, "title": "Test PR", "user": {"login": "testuser"}, - "base": {"ref": "main"}, + "base": {"ref": "main", "sha": "base123"}, "head": {"sha": "abc123"}, }, } @@ -419,6 +419,113 @@ async def test_process_issue_comment_event( await webhook.process() mock_process_comment.assert_called_once() + @pytest.mark.asyncio + async def test_pr_sha_storage_from_webhook_payload( + self, pull_request_payload: dict[str, Any], webhook_headers: Headers + ) -> None: + """Test that pr_base_sha/pr_head_sha are stored from webhook payload for pull_request events.""" + # Add base SHA to the payload (head SHA already present) + pull_request_payload["pull_request"]["base"]["sha"] = "base-sha-from-payload" + pull_request_payload["pull_request"]["head"]["sha"] = "head-sha-from-payload" + + with ( + patch.dict(os.environ, {"WEBHOOK_SERVER_DATA_DIR": "webhook_server/tests/manifests"}), + patch("webhook_server.libs.github_api.get_api_with_highest_rate_limit") as mock_api_rate_limit, + patch("webhook_server.libs.github_api.get_repository_github_app_api") as mock_repo_api, + patch("webhook_server.utils.helpers.get_apis_and_tokes_from_config") as mock_get_apis, + patch("webhook_server.libs.config.Config.repository_local_data") as mock_repo_local_data, + patch("webhook_server.libs.github_api.GithubWebhook.add_api_users_to_auto_verified_and_merged_users"), + ): + mock_api = Mock() + mock_api.rate_limiting = [100, 5000] + mock_user = Mock() + mock_user.login = "test-user" + mock_api.get_user.return_value = mock_user + mock_api_rate_limit.return_value = (mock_api, "TOKEN", "USER") + mock_repo_api.return_value = Mock() + mock_get_apis.return_value = [] + mock_repo_local_data.return_value = {} + + webhook = GithubWebhook(hook_data=pull_request_payload, headers=webhook_headers, logger=Mock()) + + mock_pr = Mock() + mock_pr.draft = False + mock_pr.user.login = "testuser" + mock_pr.base.ref = "main" + # These API SHAs should NOT be used (payload takes priority) + mock_pr.base.sha = "api-base-sha-should-not-be-used" + mock_pr.head.sha = "api-head-sha-should-not-be-used" + mock_commit = Mock() + mock_pr.get_commits.return_value = [mock_commit] + + with ( + patch.object(webhook, "get_pull_request", return_value=mock_pr), + patch.object(webhook, "_clone_repository", new=AsyncMock(return_value=None)), + patch.object(OwnersFileHandler, "initialize", new=AsyncMock(return_value=None)), + patch( + "webhook_server.libs.handlers.pull_request_handler.PullRequestHandler.process_pull_request_webhook_data", + new=AsyncMock(return_value=None), + ), + ): + await webhook.process() + + # Verify SHAs came from webhook payload, not live API + assert webhook.pr_base_sha == "base-sha-from-payload" + assert webhook.pr_head_sha == "head-sha-from-payload" + + @pytest.mark.asyncio + async def test_pr_sha_storage_fallback_for_non_pr_events(self, issue_comment_payload: dict[str, Any]) -> None: + """Test that pr_base_sha/pr_head_sha fall back to API for non-pull_request events. + + issue_comment payloads have no top-level 'pull_request' dict with SHAs, + so the code must fall back to the PullRequest object's base.sha/head.sha. + """ + with ( + patch.dict(os.environ, {"WEBHOOK_SERVER_DATA_DIR": "webhook_server/tests/manifests"}), + patch("webhook_server.libs.github_api.get_api_with_highest_rate_limit") as mock_api_rate_limit, + patch("webhook_server.libs.github_api.get_repository_github_app_api") as mock_repo_api, + patch("webhook_server.utils.helpers.get_apis_and_tokes_from_config") as mock_get_apis, + patch("webhook_server.libs.config.Config.repository_local_data") as mock_repo_local_data, + patch("webhook_server.libs.github_api.GithubWebhook.add_api_users_to_auto_verified_and_merged_users"), + ): + mock_api = Mock() + mock_api.rate_limiting = [100, 5000] + mock_user = Mock() + mock_user.login = "test-user" + mock_api.get_user.return_value = mock_user + mock_api_rate_limit.return_value = (mock_api, "TOKEN", "USER") + mock_repo_api.return_value = Mock() + mock_get_apis.return_value = [] + mock_repo_local_data.return_value = {} + + headers = Headers({"X-GitHub-Event": "issue_comment"}) + webhook = GithubWebhook(hook_data=issue_comment_payload, headers=headers, logger=Mock()) + + mock_pr = Mock() + mock_pr.draft = False + mock_pr.user.login = "testuser" + mock_pr.base.ref = "main" + # These API SHAs SHOULD be used (no payload SHAs for issue_comment) + mock_pr.base.sha = "api-base-sha-fallback" + mock_pr.head.sha = "api-head-sha-fallback" + mock_commit = Mock() + mock_pr.get_commits.return_value = [mock_commit] + + with ( + patch.object(webhook, "get_pull_request", return_value=mock_pr), + patch.object(webhook, "_clone_repository", new=AsyncMock(return_value=None)), + patch.object(OwnersFileHandler, "initialize", new=AsyncMock(return_value=None)), + patch( + "webhook_server.libs.handlers.issue_comment_handler.IssueCommentHandler.process_comment_webhook_data", + new=AsyncMock(return_value=None), + ), + ): + await webhook.process() + + # Verify SHAs came from API fallback (no payload SHAs for issue_comment) + assert webhook.pr_base_sha == "api-base-sha-fallback" + assert webhook.pr_head_sha == "api-head-sha-fallback" + @patch.dict(os.environ, {"WEBHOOK_SERVER_DATA_DIR": "webhook_server/tests/manifests"}) @patch("webhook_server.libs.github_api.get_repository_github_app_api") @patch("webhook_server.libs.github_api.get_api_with_highest_rate_limit") diff --git a/webhook_server/tests/test_owners_files_handler.py b/webhook_server/tests/test_owners_files_handler.py index 561b6d2f7..e8d7f5551 100644 --- a/webhook_server/tests/test_owners_files_handler.py +++ b/webhook_server/tests/test_owners_files_handler.py @@ -35,7 +35,7 @@ def owners_file_handler(self, mock_github_webhook: Mock) -> OwnersFileHandler: return OwnersFileHandler(mock_github_webhook) @pytest.mark.asyncio - async def test_initialize(self, owners_file_handler: OwnersFileHandler, mock_pull_request: Mock) -> None: + async def test_initialize(self, owners_file_handler: OwnersFileHandler) -> None: """Test the initialize method.""" with patch.object(owners_file_handler, "list_changed_files", new=AsyncMock()) as mock_list_files: with patch.object( @@ -60,7 +60,7 @@ async def test_initialize(self, owners_file_handler: OwnersFileHandler, mock_pul mock_get_pr_approvers.return_value = ["user1"] mock_get_pr_reviewers.return_value = ["user2"] - result = await owners_file_handler.initialize(mock_pull_request) + result = await owners_file_handler.initialize() assert result == owners_file_handler assert owners_file_handler.changed_files == ["file1.py", "file2.py"] @@ -87,13 +87,11 @@ async def test_ensure_initialized_initialized(self, owners_file_handler: OwnersF owners_file_handler._ensure_initialized() # Should not raise @pytest.mark.asyncio - async def test_list_changed_files( - self, owners_file_handler: OwnersFileHandler, mock_pull_request: Mock, tmp_path: Path - ) -> None: - """Test list_changed_files method using git diff.""" - # Set up mock PR SHAs - mock_pull_request.base.sha = "base123abc" - mock_pull_request.head.sha = "head456def" + async def test_list_changed_files(self, owners_file_handler: OwnersFileHandler, tmp_path: Path) -> None: + """Test list_changed_files reads SHAs from GithubWebhook instance.""" + # SHAs are stored on the GithubWebhook instance during process() + owners_file_handler.github_webhook.pr_base_sha = "base123abc" + owners_file_handler.github_webhook.pr_head_sha = "head456def" # Set up handler properties owners_file_handler.github_webhook.clone_repo_dir = str(tmp_path) @@ -105,7 +103,7 @@ async def test_list_changed_files( ) as mock_run_command: mock_run_command.return_value = (True, "file1.py\nfile2.py\n", "") - result = await owners_file_handler.list_changed_files(mock_pull_request) + result = await owners_file_handler.list_changed_files() # Verify result assert result == ["file1.py", "file2.py"]