diff --git a/src/git/src/mcp_server_git/__init__.py b/src/git/src/mcp_server_git/__init__.py index 2270018733..5502481bcd 100644 --- a/src/git/src/mcp_server_git/__init__.py +++ b/src/git/src/mcp_server_git/__init__.py @@ -4,6 +4,7 @@ import sys from .server import serve + @click.command() @click.option("--repository", "-r", type=Path, help="Git repository path") @click.option("-v", "--verbose", count=True) @@ -20,5 +21,6 @@ def main(repository: Path | None, verbose: bool) -> None: logging.basicConfig(level=logging_level, stream=sys.stderr) asyncio.run(serve(repository)) + if __name__ == "__main__": main() diff --git a/src/git/src/mcp_server_git/server.py b/src/git/src/mcp_server_git/server.py index 5ce953e545..4ea4bab4c1 100644 --- a/src/git/src/mcp_server_git/server.py +++ b/src/git/src/mcp_server_git/server.py @@ -20,60 +20,70 @@ # Default number of context lines to show in diff output DEFAULT_CONTEXT_LINES = 3 + class GitStatus(BaseModel): repo_path: str + class GitDiffUnstaged(BaseModel): repo_path: str context_lines: int = DEFAULT_CONTEXT_LINES + class GitDiffStaged(BaseModel): repo_path: str context_lines: int = DEFAULT_CONTEXT_LINES + class GitDiff(BaseModel): repo_path: str target: str context_lines: int = DEFAULT_CONTEXT_LINES + class GitCommit(BaseModel): repo_path: str message: str + class GitAdd(BaseModel): repo_path: str files: list[str] + class GitReset(BaseModel): repo_path: str + class GitLog(BaseModel): repo_path: str max_count: int = 10 start_timestamp: Optional[str] = Field( None, - description="Start timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')" + description="Start timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')", ) end_timestamp: Optional[str] = Field( None, - description="End timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')" + description="End timestamp for filtering commits. Accepts: ISO 8601 format (e.g., '2024-01-15T14:30:25'), relative dates (e.g., '2 weeks ago', 'yesterday'), or absolute dates (e.g., '2024-01-15', 'Jan 15 2024')", ) + class GitCreateBranch(BaseModel): repo_path: str branch_name: str base_branch: str | None = None + class GitCheckout(BaseModel): repo_path: str branch_name: str + class GitShow(BaseModel): repo_path: str revision: str - class GitBranch(BaseModel): repo_path: str = Field( ..., @@ -108,16 +118,24 @@ class GitTools(str, Enum): BRANCH = "git_branch" + def git_status(repo: git.Repo) -> str: return repo.git.status() -def git_diff_unstaged(repo: git.Repo, context_lines: int = DEFAULT_CONTEXT_LINES) -> str: + +def git_diff_unstaged( + repo: git.Repo, context_lines: int = DEFAULT_CONTEXT_LINES +) -> str: return repo.git.diff(f"--unified={context_lines}") + def git_diff_staged(repo: git.Repo, context_lines: int = DEFAULT_CONTEXT_LINES) -> str: return repo.git.diff(f"--unified={context_lines}", "--cached") -def git_diff(repo: git.Repo, target: str, context_lines: int = DEFAULT_CONTEXT_LINES) -> str: + +def git_diff( + repo: git.Repo, target: str, context_lines: int = DEFAULT_CONTEXT_LINES +) -> str: # Defense in depth: reject targets starting with '-' to prevent flag injection, # even if a malicious ref with that name exists (e.g. via filesystem manipulation) if target.startswith("-"): @@ -125,10 +143,12 @@ def git_diff(repo: git.Repo, target: str, context_lines: int = DEFAULT_CONTEXT_L repo.rev_parse(target) # Validates target is a real git ref, throws BadName if not return repo.git.diff(f"--unified={context_lines}", target) + def git_commit(repo: git.Repo, message: str) -> str: commit = repo.index.commit(message) return f"Changes committed successfully with hash {commit.hexsha}" + def git_add(repo: git.Repo, files: list[str]) -> str: if files == ["."]: repo.git.add(".") @@ -137,26 +157,37 @@ def git_add(repo: git.Repo, files: list[str]) -> str: repo.git.add("--", *files) return "Files staged successfully" + def git_reset(repo: git.Repo) -> str: repo.index.reset() return "All staged changes reset" -def git_log(repo: git.Repo, max_count: int = 10, start_timestamp: Optional[str] = None, end_timestamp: Optional[str] = None) -> list[str]: + +def git_log( + repo: git.Repo, + max_count: int = 10, + start_timestamp: Optional[str] = None, + end_timestamp: Optional[str] = None, +) -> list[str]: if start_timestamp or end_timestamp: # Defense in depth: reject timestamps starting with '-' to prevent flag injection if start_timestamp and start_timestamp.startswith("-"): - raise ValueError(f"Invalid start_timestamp: '{start_timestamp}' - cannot start with '-'") + raise ValueError( + f"Invalid start_timestamp: '{start_timestamp}' - cannot start with '-'" + ) if end_timestamp and end_timestamp.startswith("-"): - raise ValueError(f"Invalid end_timestamp: '{end_timestamp}' - cannot start with '-'") + raise ValueError( + f"Invalid end_timestamp: '{end_timestamp}' - cannot start with '-'" + ) # Use git log command with date filtering args = [] if start_timestamp: - args.extend(['--since', start_timestamp]) + args.extend(["--since", start_timestamp]) if end_timestamp: - args.extend(['--until', end_timestamp]) - args.extend(['--format=%H%n%an%n%ad%n%s%n']) + args.extend(["--until", end_timestamp]) + args.extend(["--format=%H%n%an%n%ad%n%s%n"]) - log_output = repo.git.log(*args).split('\n') + log_output = repo.git.log(*args).split("\n") log = [] # Process commits in groups of 4 (hash, author, date, message) @@ -182,7 +213,10 @@ def git_log(repo: git.Repo, max_count: int = 10, start_timestamp: Optional[str] ) return log -def git_create_branch(repo: git.Repo, branch_name: str, base_branch: str | None = None) -> str: + +def git_create_branch( + repo: git.Repo, branch_name: str, base_branch: str | None = None +) -> str: # Defense in depth: reject names starting with '-' to prevent flag injection if branch_name.startswith("-"): raise BadName(f"Invalid branch name: '{branch_name}' - cannot start with '-'") @@ -196,17 +230,19 @@ def git_create_branch(repo: git.Repo, branch_name: str, base_branch: str | None repo.create_head(branch_name, base) return f"Created branch '{branch_name}' from '{base.name}'" + def git_checkout(repo: git.Repo, branch_name: str) -> str: # Defense in depth: reject branch names starting with '-' to prevent flag injection, # even if a malicious ref with that name exists (e.g. via filesystem manipulation) if branch_name.startswith("-"): raise BadName(f"Invalid branch name: '{branch_name}' - cannot start with '-'") - repo.rev_parse(branch_name) # Validates branch_name is a real git ref, throws BadName if not + repo.rev_parse( + branch_name + ) # Validates branch_name is a real git ref, throws BadName if not repo.git.checkout(branch_name) return f"Switched to branch '{branch_name}'" - def git_show(repo: git.Repo, revision: str) -> str: # Defense in depth: reject revisions starting with '-' to prevent flag injection, # even if a malicious ref with that name exists (e.g. via filesystem manipulation) @@ -229,11 +265,12 @@ def git_show(repo: git.Repo, revision: str) -> str: if d.diff is None: continue if isinstance(d.diff, bytes): - output.append(d.diff.decode('utf-8')) + output.append(d.diff.decode("utf-8")) else: output.append(d.diff) return "".join(output) + def validate_repo_path(repo_path: Path, allowed_repository: Path | None) -> None: """Validate that repo_path is within the allowed repository path.""" if allowed_repository is None: @@ -255,12 +292,19 @@ def validate_repo_path(repo_path: Path, allowed_repository: Path | None) -> None ) -def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, not_contains: str | None = None) -> str: +def git_branch( + repo: git.Repo, + branch_type: str, + contains: str | None = None, + not_contains: str | None = None, +) -> str: # Defense in depth: reject values starting with '-' to prevent flag injection if contains and contains.startswith("-"): raise BadName(f"Invalid contains value: '{contains}' - cannot start with '-'") if not_contains and not_contains.startswith("-"): - raise BadName(f"Invalid not_contains value: '{not_contains}' - cannot start with '-'") + raise BadName( + f"Invalid not_contains value: '{not_contains}' - cannot start with '-'" + ) match contains: case None: @@ -275,11 +319,11 @@ def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, no not_contains_sha = ("--no-contains", not_contains) match branch_type: - case 'local': + case "local": b_type = None - case 'remote': + case "remote": b_type = "-r" - case 'all': + case "all": b_type = "-a" case _: return f"Invalid branch type: {branch_type}" @@ -437,20 +481,24 @@ async def list_tools() -> list[Tool]: idempotentHint=True, openWorldHint=False, ), - ) + ), ] async def list_repos() -> Sequence[str]: async def by_roots() -> Sequence[str]: if not isinstance(server.request_context.session, ServerSession): - raise TypeError("server.request_context.session must be a ServerSession") + raise TypeError( + "server.request_context.session must be a ServerSession" + ) if not server.request_context.session.check_client_capability( ClientCapabilities(roots=RootsCapability()) ): return [] - roots_result: ListRootsResult = await server.request_context.session.list_roots() + roots_result: ListRootsResult = ( + await server.request_context.session.list_roots() + ) logger.debug(f"Roots result: {roots_result}") repo_paths = [] for root in roots_result.roots: @@ -482,52 +530,43 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: match name: case GitTools.STATUS: status = git_status(repo) - return [TextContent( - type="text", - text=f"Repository status:\n{status}" - )] + return [TextContent(type="text", text=f"Repository status:\n{status}")] case GitTools.DIFF_UNSTAGED: - diff = git_diff_unstaged(repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES)) - return [TextContent( - type="text", - text=f"Unstaged changes:\n{diff}" - )] + diff = git_diff_unstaged( + repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES) + ) + return [TextContent(type="text", text=f"Unstaged changes:\n{diff}")] case GitTools.DIFF_STAGED: - diff = git_diff_staged(repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES)) - return [TextContent( - type="text", - text=f"Staged changes:\n{diff}" - )] + diff = git_diff_staged( + repo, arguments.get("context_lines", DEFAULT_CONTEXT_LINES) + ) + return [TextContent(type="text", text=f"Staged changes:\n{diff}")] case GitTools.DIFF: - diff = git_diff(repo, arguments["target"], arguments.get("context_lines", DEFAULT_CONTEXT_LINES)) - return [TextContent( - type="text", - text=f"Diff with {arguments['target']}:\n{diff}" - )] + diff = git_diff( + repo, + arguments["target"], + arguments.get("context_lines", DEFAULT_CONTEXT_LINES), + ) + return [ + TextContent( + type="text", text=f"Diff with {arguments['target']}:\n{diff}" + ) + ] case GitTools.COMMIT: result = git_commit(repo, arguments["message"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.ADD: result = git_add(repo, arguments["files"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.RESET: result = git_reset(repo) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] # Update the LOG case: case GitTools.LOG: @@ -535,49 +574,34 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: repo, arguments.get("max_count", 10), arguments.get("start_timestamp"), - arguments.get("end_timestamp") + arguments.get("end_timestamp"), ) - return [TextContent( - type="text", - text="Commit history:\n" + "\n".join(log) - )] + return [ + TextContent(type="text", text="Commit history:\n" + "\n".join(log)) + ] case GitTools.CREATE_BRANCH: result = git_create_branch( - repo, - arguments["branch_name"], - arguments.get("base_branch") + repo, arguments["branch_name"], arguments.get("base_branch") ) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.CHECKOUT: result = git_checkout(repo, arguments["branch_name"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.SHOW: result = git_show(repo, arguments["revision"]) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case GitTools.BRANCH: result = git_branch( repo, - arguments.get("branch_type", 'local'), + arguments.get("branch_type", "local"), arguments.get("contains", None), arguments.get("not_contains", None), ) - return [TextContent( - type="text", - text=result - )] + return [TextContent(type="text", text=result)] case _: raise ValueError(f"Unknown tool: {name}") diff --git a/src/git/tests/test_server.py b/src/git/tests/test_server.py index a5492adc85..b294141267 100644 --- a/src/git/tests/test_server.py +++ b/src/git/tests/test_server.py @@ -19,6 +19,7 @@ ) import shutil + @pytest.fixture def test_repository(tmp_path: Path): repo_path = tmp_path / "temp_test_repo" @@ -30,7 +31,17 @@ def test_repository(tmp_path: Path): yield test_repo - shutil.rmtree(repo_path) + test_repo.close() + + def remove_readonly(func, path, excinfo): + import stat + import os + + os.chmod(path, stat.S_IWRITE) + func(path) + + shutil.rmtree(repo_path, onerror=remove_readonly) + def test_git_checkout_existing_branch(test_repository): test_repository.git.branch("test-branch") @@ -39,31 +50,37 @@ def test_git_checkout_existing_branch(test_repository): assert "Switched to branch 'test-branch'" in result assert test_repository.active_branch.name == "test-branch" -def test_git_checkout_nonexistent_branch(test_repository): +def test_git_checkout_nonexistent_branch(test_repository): with pytest.raises(BadName): git_checkout(test_repository, "nonexistent-branch") + def test_git_branch_local(test_repository): test_repository.git.branch("new-branch-local") result = git_branch(test_repository, "local") assert "new-branch-local" in result + def test_git_branch_remote(test_repository): result = git_branch(test_repository, "remote") assert "" == result.strip() # Should be empty if no remote branches + def test_git_branch_all(test_repository): test_repository.git.branch("new-branch-all") result = git_branch(test_repository, "all") assert "new-branch-all" in result + def test_git_branch_contains(test_repository): # Get the default branch name (could be "main" or "master") default_branch = test_repository.active_branch.name # Create a new branch and commit to it test_repository.git.checkout("-b", "feature-branch") - Path(test_repository.working_dir / Path("feature.txt")).write_text("feature content") + Path(test_repository.working_dir / Path("feature.txt")).write_text( + "feature content" + ) test_repository.index.add(["feature.txt"]) commit = test_repository.index.commit("feature commit") test_repository.git.checkout(default_branch) @@ -72,12 +89,15 @@ def test_git_branch_contains(test_repository): assert "feature-branch" in result assert default_branch not in result + def test_git_branch_not_contains(test_repository): # Get the default branch name (could be "main" or "master") default_branch = test_repository.active_branch.name # Create a new branch and commit to it test_repository.git.checkout("-b", "another-feature-branch") - Path(test_repository.working_dir / Path("another_feature.txt")).write_text("another feature content") + Path(test_repository.working_dir / Path("another_feature.txt")).write_text( + "another feature content" + ) test_repository.index.add(["another_feature.txt"]) commit = test_repository.index.commit("another feature commit") test_repository.git.checkout(default_branch) @@ -86,6 +106,7 @@ def test_git_branch_not_contains(test_repository): assert "another-feature-branch" not in result assert default_branch in result + def test_git_add_all_files(test_repository): file_path = Path(test_repository.working_dir) / "all_file.txt" file_path.write_text("adding all") @@ -96,6 +117,7 @@ def test_git_add_all_files(test_repository): assert "all_file.txt" in staged_files assert result == "Files staged successfully" + def test_git_add_specific_files(test_repository): file1 = Path(test_repository.working_dir) / "file1.txt" file2 = Path(test_repository.working_dir) / "file2.txt" @@ -109,12 +131,14 @@ def test_git_add_specific_files(test_repository): assert "file2.txt" not in staged_files assert result == "Files staged successfully" + def test_git_status(test_repository): result = git_status(test_repository) assert result is not None assert "On branch" in result or "branch" in result.lower() + def test_git_diff_unstaged(test_repository): file_path = Path(test_repository.working_dir) / "test.txt" file_path.write_text("modified content") @@ -124,11 +148,13 @@ def test_git_diff_unstaged(test_repository): assert "test.txt" in result assert "modified content" in result + def test_git_diff_unstaged_empty(test_repository): result = git_diff_unstaged(test_repository) assert result == "" + def test_git_diff_staged(test_repository): file_path = Path(test_repository.working_dir) / "staged_file.txt" file_path.write_text("staged content") @@ -139,11 +165,13 @@ def test_git_diff_staged(test_repository): assert "staged_file.txt" in result assert "staged content" in result + def test_git_diff_staged_empty(test_repository): result = git_diff_staged(test_repository) assert result == "" + def test_git_diff(test_repository): # Get the default branch name (could be "main" or "master") default_branch = test_repository.active_branch.name @@ -158,6 +186,7 @@ def test_git_diff(test_repository): assert "test.txt" in result assert "feature changes" in result + def test_git_commit(test_repository): file_path = Path(test_repository.working_dir) / "commit_test.txt" file_path.write_text("content to commit") @@ -170,6 +199,7 @@ def test_git_commit(test_repository): latest_commit = test_repository.head.commit assert latest_commit.message.strip() == "test commit message" + def test_git_reset(test_repository): file_path = Path(test_repository.working_dir) / "reset_test.txt" file_path.write_text("content to reset") @@ -185,6 +215,7 @@ def test_git_reset(test_repository): staged_after = [item.a_path for item in test_repository.index.diff("HEAD")] assert "reset_test.txt" not in staged_after + def test_git_log(test_repository): for i in range(3): file_path = Path(test_repository.working_dir) / f"log_test_{i}.txt" @@ -201,6 +232,7 @@ def test_git_log(test_repository): assert "Date:" in result[0] assert "Message:" in result[0] + def test_git_log_default(test_repository): result = git_log(test_repository) @@ -208,6 +240,7 @@ def test_git_log_default(test_repository): assert len(result) >= 1 assert "initial commit" in result[0] + def test_git_create_branch(test_repository): result = git_create_branch(test_repository, "new-feature-branch") @@ -216,6 +249,7 @@ def test_git_create_branch(test_repository): branches = [ref.name for ref in test_repository.references] assert "new-feature-branch" in branches + def test_git_create_branch_from_base(test_repository): test_repository.git.checkout("-b", "base-branch") file_path = Path(test_repository.working_dir) / "base.txt" @@ -227,6 +261,7 @@ def test_git_create_branch_from_base(test_repository): assert "Created branch 'derived-branch' from 'base-branch'" in result + def test_git_show(test_repository): file_path = Path(test_repository.working_dir) / "show_test.txt" file_path.write_text("show content") @@ -242,6 +277,7 @@ def test_git_show(test_repository): assert "show test commit" in result assert "show_test.txt" in result + def test_git_show_initial_commit(test_repository): initial_commit = list(test_repository.iter_commits())[-1] @@ -254,6 +290,7 @@ def test_git_show_initial_commit(test_repository): # Tests for validate_repo_path (repository scoping security fix) + def test_validate_repo_path_no_restriction(): """When no repository restriction is configured, any path should be allowed.""" validate_repo_path(Path("/any/path"), None) # Should not raise @@ -313,8 +350,11 @@ def test_validate_repo_path_symlink_escape(tmp_path: Path): with pytest.raises(ValueError) as exc_info: validate_repo_path(symlink, allowed) assert "outside the allowed repository" in str(exc_info.value) + + # Tests for argument injection protection + def test_git_diff_rejects_flag_injection(test_repository): """git_diff should reject flags that could be used for argument injection.""" with pytest.raises(BadName): @@ -429,6 +469,7 @@ def test_git_checkout_rejects_malicious_refs(test_repository): # git_log, and git_branch — matching the existing guards on git_diff and # git_checkout. + def test_git_show_rejects_flag_injection(test_repository): """git_show should reject revisions starting with '-'.""" with pytest.raises(BadName):