feat(git): add git_diff tool for branch comparison

Added new git_diff tool to allow comparison between branches or commits.
This adds the ability to compare branches directly through the MCP interface.
This commit is contained in:
monkeydaichan
2024-12-07 01:25:33 +09:00
parent 2ecb382a02
commit ba301c4a66
2 changed files with 38 additions and 7 deletions

View File

@@ -24,6 +24,10 @@ class GitDiffUnstaged(BaseModel):
class GitDiffStaged(BaseModel):
repo_path: str
class GitDiff(BaseModel):
repo_path: str
target: str
class GitCommit(BaseModel):
repo_path: str
message: str
@@ -48,6 +52,7 @@ class GitTools(str, Enum):
STATUS = "git_status"
DIFF_UNSTAGED = "git_diff_unstaged"
DIFF_STAGED = "git_diff_staged"
DIFF = "git_diff"
COMMIT = "git_commit"
ADD = "git_add"
RESET = "git_reset"
@@ -63,6 +68,9 @@ def git_diff_unstaged(repo: git.Repo) -> str:
def git_diff_staged(repo: git.Repo) -> str:
return repo.git.diff("--cached")
def git_diff(repo: git.Repo, target: str) -> str:
return repo.git.diff(target)
def git_commit(repo: git.Repo, message: str) -> str:
commit = repo.index.commit(message)
return f"Changes committed successfully with hash {commit.hexsha}"
@@ -127,6 +135,11 @@ async def serve(repository: Path | None) -> None:
description="Shows changes that are staged for commit",
inputSchema=GitDiffStaged.schema(),
),
Tool(
name=GitTools.DIFF,
description="Shows differences between branches or commits",
inputSchema=GitDiff.schema(),
),
Tool(
name=GitTools.COMMIT,
description="Records changes to the repository",
@@ -210,6 +223,13 @@ async def serve(repository: Path | None) -> None:
text=f"Staged changes:\n{diff}"
)]
case GitTools.DIFF:
diff = git_diff(repo, arguments["target"])
return [TextContent(
type="text",
text=f"Diff with {arguments['target']}:\n{diff}"
)]
case GitTools.COMMIT:
result = git_commit(repo, arguments["message"])
return [TextContent(
@@ -254,4 +274,4 @@ async def serve(repository: Path | None) -> None:
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)
await server.run(read_stream, write_stream, options, raise_exceptions=True)