diff --git a/src/git/src/mcp_server_git/server.py b/src/git/src/mcp_server_git/server.py index fe1e3f59..85f48e0e 100644 --- a/src/git/src/mcp_server_git/server.py +++ b/src/git/src/mcp_server_git/server.py @@ -39,6 +39,11 @@ class GitLog(BaseModel): repo_path: str max_count: int = 10 +class GitCreateBranch(BaseModel): + repo_path: str + branch_name: str + base_branch: str | None = None + class GitTools(str, Enum): STATUS = "git_status" DIFF_UNSTAGED = "git_diff_unstaged" @@ -47,6 +52,7 @@ class GitTools(str, Enum): ADD = "git_add" RESET = "git_reset" LOG = "git_log" + CREATE_BRANCH = "git_create_branch" def git_status(repo: git.Repo) -> str: return repo.git.status() @@ -81,6 +87,15 @@ def git_log(repo: git.Repo, max_count: int = 10) -> list[str]: ) return log +def git_create_branch(repo: git.Repo, branch_name: str, base_branch: str | None = None) -> str: + if base_branch: + base = repo.refs[base_branch] + else: + base = repo.active_branch + + repo.create_head(branch_name, base) + return f"Created branch '{branch_name}' from {base.name}" + async def serve(repository: Path | None) -> None: logger = logging.getLogger(__name__) @@ -132,6 +147,11 @@ async def serve(repository: Path | None) -> None: description="Shows the commit logs", inputSchema=GitLog.schema(), ), + Tool( + name=GitTools.CREATE_BRANCH, + description="Creates a new branch from an optional base branch", + inputSchema=GitCreateBranch.schema(), + ), ] async def list_repos() -> Sequence[str]: @@ -218,6 +238,17 @@ async def serve(repository: Path | None) -> None: text="Commit history:\n" + "\n".join(log) )] + case GitTools.CREATE_BRANCH: + result = git_create_branch( + repo, + arguments["branch_name"], + arguments.get("base_branch") + ) + return [TextContent( + type="text", + text=result + )] + case _: raise ValueError(f"Unknown tool: {name}")