feat: Add git branch functionality and unit tests

This commit introduces the `git branch` tool to the MCP Git server, allowing users to list branches with various filtering options.

Changes include:
- Implemented `git_branch` function in `src/git/src/mcp_server_git/server.py` to support listing local, remote, and all branches, as well as filtering by `contains` and `not_contains` SHA values.
- Added comprehensive unit tests for the `git branch` functionality in `src/git/tests/test_server.py`, covering different branch types and commit filtering scenarios.
- Updated `src/git/README.md`.
This commit is contained in:
JavieHush
2025-05-28 16:18:42 +08:00
parent 6f920efa8a
commit e4d856214b
3 changed files with 117 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Sequence
from typing import Sequence, Optional
from mcp.server import Server
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
@@ -13,7 +13,7 @@ from mcp.types import (
)
from enum import Enum
import git
from pydantic import BaseModel
from pydantic import BaseModel, Field
class GitStatus(BaseModel):
repo_path: str
@@ -59,6 +59,24 @@ class GitShow(BaseModel):
class GitInit(BaseModel):
repo_path: str
class GitBranch(BaseModel):
repo_path: str = Field(
...,
description="The path to the Git repository.",
)
branch_type: str = Field(
...,
description="Whether to list local branches ('local'), remote branches ('remote') or all branches('all').",
)
contains: Optional[str] = Field(
None,
description="The commit sha that branch should contain. Do not pass anything to this param if no commit sha is specified",
)
not_contains: Optional[str] = Field(
None,
description="The commit sha that branch should NOT contain. Do not pass anything to this param if no commit sha is specified",
)
class GitTools(str, Enum):
STATUS = "git_status"
DIFF_UNSTAGED = "git_diff_unstaged"
@@ -72,6 +90,7 @@ class GitTools(str, Enum):
CHECKOUT = "git_checkout"
SHOW = "git_show"
INIT = "git_init"
BRANCH = "git_branch"
def git_status(repo: git.Repo) -> str:
return repo.git.status()
@@ -147,6 +166,34 @@ def git_show(repo: git.Repo, revision: str) -> str:
output.append(d.diff.decode('utf-8'))
return "".join(output)
def git_branch(repo: git.Repo, branch_type: str, contains: str | None = None, not_contains: str | None = None) -> str:
match contains:
case None:
contains_sha = (None,)
case _:
contains_sha = ("--contains", contains)
match not_contains:
case None:
not_contains_sha = (None,)
case _:
not_contains_sha = ("--no-contains", not_contains)
match branch_type:
case 'local':
b_type = None
case 'remote':
b_type = "-r"
case 'all':
b_type = "-a"
case _:
return f"Invalid branch type: {branch_type}"
# None value will be auto deleted by GitPython
branch_info = repo.git.branch(b_type, *contains_sha, *not_contains_sha)
return branch_info
async def serve(repository: Path | None) -> None:
logger = logging.getLogger(__name__)
@@ -222,6 +269,11 @@ async def serve(repository: Path | None) -> None:
name=GitTools.INIT,
description="Initialize a new Git repository",
inputSchema=GitInit.schema(),
),
Tool(
name=GitTools.BRANCH,
description="List Git branches",
inputSchema=GitBranch.schema(),
)
]
@@ -351,6 +403,18 @@ async def serve(repository: Path | None) -> None:
text=result
)]
case GitTools.BRANCH:
result = git_branch(
repo,
arguments.get("branch_type", 'local'),
arguments.get("contains", None),
arguments.get("not_contains", None),
)
return [TextContent(
type="text",
text=result
)]
case _:
raise ValueError(f"Unknown tool: {name}")