mirror of
https://github.com/ollama/ollama.git
synced 2026-04-27 03:05:43 +02:00
Compare commits
6 Commits
launch-cop
...
v0.21.0-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9cb535407 | ||
|
|
031baef094 | ||
|
|
7d271e6dc9 | ||
|
|
c88dae2d6b | ||
|
|
9e3618d663 | ||
|
|
e585ecd11f |
@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
ollama
|
ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
|
||||||
|
|
||||||
### Coding
|
### Coding
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ To launch a specific integration:
|
|||||||
ollama launch claude
|
ollama launch claude
|
||||||
```
|
```
|
||||||
|
|
||||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||||
|
|
||||||
### AI assistant
|
### AI assistant
|
||||||
|
|
||||||
|
|||||||
76
cmd/launch/copilot.go
Normal file
76
cmd/launch/copilot.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Copilot implements Runner for GitHub Copilot CLI integration.
|
||||||
|
type Copilot struct{}
|
||||||
|
|
||||||
|
func (c *Copilot) String() string { return "Copilot CLI" }
|
||||||
|
|
||||||
|
func (c *Copilot) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Copilot) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("copilot"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".local", "bin", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Copilot) Run(model string, args []string) error {
|
||||||
|
copilotPath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(copilotPath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
cmd.Env = append(os.Environ(), c.envVars(model)...)
|
||||||
|
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// envVars returns the environment variables that configure Copilot CLI
|
||||||
|
// to use Ollama as its model provider.
|
||||||
|
func (c *Copilot) envVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
|
||||||
|
"COPILOT_PROVIDER_API_KEY=",
|
||||||
|
"COPILOT_PROVIDER_WIRE_API=responses",
|
||||||
|
}
|
||||||
|
|
||||||
|
if model != "" {
|
||||||
|
env = append(env, "COPILOT_MODEL="+model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
161
cmd/launch/copilot_test.go
Normal file
161
cmd/launch/copilot_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCopilotIntegration(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Copilot CLI" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotFindPath(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
t.Run("finds copilot in PATH", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fakeBin := filepath.Join(tmpDir, name)
|
||||||
|
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fakeBin {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when not in PATH", func(t *testing.T) {
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(tmpDir, ".local", "bin", name)
|
||||||
|
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||||
|
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fallback {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotArgs(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
args []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||||
|
{"empty model", "", nil, nil},
|
||||||
|
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||||
|
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model, tt.args)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotEnvVars(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
envMap := func(envs []string) map[string]string {
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, e := range envs {
|
||||||
|
k, v, _ := strings.Cut(e, "=")
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("sets required provider env vars with model", func(t *testing.T) {
|
||||||
|
got := envMap(c.envVars("llama3.2"))
|
||||||
|
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
|
||||||
|
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
|
||||||
|
}
|
||||||
|
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
|
||||||
|
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
|
||||||
|
}
|
||||||
|
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
|
||||||
|
}
|
||||||
|
if got["COPILOT_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
|
||||||
|
got := envMap(c.envVars(""))
|
||||||
|
if _, ok := got["COPILOT_MODEL"]; ok {
|
||||||
|
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||||
|
got := envMap(c.envVars("test"))
|
||||||
|
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -206,6 +206,7 @@ Supported integrations:
|
|||||||
claude Claude Code
|
claude Claude Code
|
||||||
cline Cline
|
cline Cline
|
||||||
codex Codex
|
codex Codex
|
||||||
|
copilot Copilot CLI (aliases: copilot-cli)
|
||||||
droid Droid
|
droid Droid
|
||||||
hermes Hermes Agent
|
hermes Hermes Agent
|
||||||
opencode OpenCode
|
opencode OpenCode
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type IntegrationInfo struct {
|
|||||||
Description string
|
Description string
|
||||||
}
|
}
|
||||||
|
|
||||||
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
|
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
|
||||||
|
|
||||||
var integrationSpecs = []*IntegrationSpec{
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
{
|
{
|
||||||
@@ -74,6 +74,19 @@ var integrationSpecs = []*IntegrationSpec{
|
|||||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "copilot",
|
||||||
|
Runner: &Copilot{},
|
||||||
|
Aliases: []string{"copilot-cli"},
|
||||||
|
Description: "GitHub's AI coding agent for the terminal",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Copilot{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://github.com/features/copilot/cli/",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "droid",
|
Name: "droid",
|
||||||
Runner: &Droid{},
|
Runner: &Droid{},
|
||||||
|
|||||||
@@ -120,6 +120,7 @@
|
|||||||
"pages": [
|
"pages": [
|
||||||
"/integrations/claude-code",
|
"/integrations/claude-code",
|
||||||
"/integrations/codex",
|
"/integrations/codex",
|
||||||
|
"/integrations/copilot-cli",
|
||||||
"/integrations/opencode",
|
"/integrations/opencode",
|
||||||
"/integrations/droid",
|
"/integrations/droid",
|
||||||
"/integrations/goose",
|
"/integrations/goose",
|
||||||
|
|||||||
93
docs/integrations/copilot-cli.mdx
Normal file
93
docs/integrations/copilot-cli.mdx
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
---
|
||||||
|
title: Copilot CLI
|
||||||
|
---
|
||||||
|
|
||||||
|
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
|
||||||
|
|
||||||
|
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Install [Copilot CLI](https://github.com/features/copilot/cli/):
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
|
||||||
|
```shell macOS / Linux (Homebrew)
|
||||||
|
brew install copilot-cli
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell npm (all platforms)
|
||||||
|
npm install -g @github/copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell macOS / Linux (script)
|
||||||
|
curl -fsSL https://gh.io/copilot-install | bash
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell Windows (WinGet)
|
||||||
|
winget install GitHub.Copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Usage with Ollama
|
||||||
|
|
||||||
|
### Quick setup
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run directly with a model
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot --model kimi-k2.5:cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud`
|
||||||
|
- `glm-5:cloud`
|
||||||
|
- `minimax-m2.7:cloud`
|
||||||
|
- `qwen3.5:cloud`
|
||||||
|
- `glm-4.7-flash`
|
||||||
|
- `qwen3.5`
|
||||||
|
|
||||||
|
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
|
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
|
||||||
|
|
||||||
|
1. Set the environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
|
||||||
|
export COPILOT_PROVIDER_API_KEY=
|
||||||
|
export COPILOT_PROVIDER_WIRE_API=responses
|
||||||
|
export COPILOT_MODEL=qwen3.5
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Run Copilot CLI:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
Or run with environment variables inline:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||||
@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
|||||||
|
|
||||||
- [Claude Code](/integrations/claude-code)
|
- [Claude Code](/integrations/claude-code)
|
||||||
- [Codex](/integrations/codex)
|
- [Codex](/integrations/codex)
|
||||||
|
- [Copilot CLI](/integrations/copilot-cli)
|
||||||
- [OpenCode](/integrations/opencode)
|
- [OpenCode](/integrations/opencode)
|
||||||
- [Droid](/integrations/droid)
|
- [Droid](/integrations/droid)
|
||||||
- [Goose](/integrations/goose)
|
- [Goose](/integrations/goose)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
// <|tool_call>/<|tool_response> tags for function calling.
|
// <|tool_call>/<|tool_response> tags for function calling.
|
||||||
type Gemma4Renderer struct {
|
type Gemma4Renderer struct {
|
||||||
useImgTags bool
|
useImgTags bool
|
||||||
|
emptyBlockOnNothink bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||||||
// Generation prompt.
|
// Generation prompt.
|
||||||
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
||||||
sb.WriteString("<|turn>model\n")
|
sb.WriteString("<|turn>model\n")
|
||||||
|
if r.emptyBlockOnNothink && !hasThink {
|
||||||
|
sb.WriteString("<|channel>thought\n<channel|>")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package renderers
|
|||||||
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
||||||
// Gemma 4 reference template.
|
// Gemma 4 reference template.
|
||||||
//
|
//
|
||||||
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
|
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
|
||||||
// reference intentionally uses the shared baseline without an empty generation-time
|
// reference cases below use the small (e2b/e4b-style) baseline, with large
|
||||||
// thought channel until renderer selection is split by size.
|
// (26b/31b-style) checks covered separately in this file.
|
||||||
//
|
//
|
||||||
// To regenerate expected values, save the E2B template to
|
// To regenerate expected values, save the E2B template to
|
||||||
// gemma4_e2b_chat_template.jinja2 and run:
|
// gemma4_e2b_chat_template.jinja2 and run:
|
||||||
@@ -1474,6 +1474,47 @@ Hi<turn|>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
|
||||||
|
messages := []api.Message{{Role: "user", Content: "Hello"}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rendererName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "legacy_alias",
|
||||||
|
rendererName: "gemma4",
|
||||||
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small",
|
||||||
|
rendererName: "gemma4-small",
|
||||||
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large",
|
||||||
|
rendererName: "gemma4-large",
|
||||||
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
|
||||||
|
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
|
||||||
|
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
||||||
if os.Getenv("VERIFY_JINJA2") == "" {
|
if os.Getenv("VERIFY_JINJA2") == "" {
|
||||||
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
|
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
|
||||||
@@ -1616,17 +1657,37 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
variants := []struct {
|
||||||
|
name string
|
||||||
|
renderer *Gemma4Renderer
|
||||||
|
templateRel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "small",
|
||||||
|
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
|
||||||
|
templateRel: gemma4E2BTemplate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large",
|
||||||
|
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
|
||||||
|
templateRel: gemma431BTemplate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, variant := range variants {
|
||||||
|
t.Run(variant.name, func(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
|
||||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
|
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
|
||||||
assert.Equal(t, jinja2Output, got,
|
assert.Equal(t, jinja2Output, got,
|
||||||
"renderer output doesn't match Jinja2 template output")
|
"renderer output doesn't match Jinja2 template output")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
|
func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
|
||||||
|
|||||||
@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
|
|||||||
return renderer
|
return renderer
|
||||||
case "nemotron-3-nano":
|
case "nemotron-3-nano":
|
||||||
return &Nemotron3NanoRenderer{}
|
return &Nemotron3NanoRenderer{}
|
||||||
case "gemma4":
|
case "gemma4", "gemma4-small":
|
||||||
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||||
|
case "gemma4-large":
|
||||||
|
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
|
||||||
case "functiongemma":
|
case "functiongemma":
|
||||||
return &FunctionGemmaRenderer{}
|
return &FunctionGemmaRenderer{}
|
||||||
case "glm-4.7":
|
case "glm-4.7":
|
||||||
|
|||||||
@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
|||||||
arch := layer.GGML.KV().Architecture()
|
arch := layer.GGML.KV().Architecture()
|
||||||
switch arch {
|
switch arch {
|
||||||
case "gemma4":
|
case "gemma4":
|
||||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
|
||||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||||
if _, ok := r.Parameters["stop"]; !ok {
|
if _, ok := r.Parameters["stop"]; !ok {
|
||||||
if r.Parameters == nil {
|
if r.Parameters == nil {
|
||||||
|
|||||||
78
server/gemma4_test.go
Normal file
78
server/gemma4_test.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestResolveGemma4Renderer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model *Model
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil model falls back to legacy alias",
|
||||||
|
model: nil,
|
||||||
|
want: gemma4RendererLegacy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit small passes through",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererSmall),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit large passes through",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLarge),
|
||||||
|
},
|
||||||
|
want: gemma4RendererLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy e4b tag resolves small",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gemma4:e4b",
|
||||||
|
ShortName: "gemma4:e4b",
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy 31b tag resolves large",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gemma4:31b-cloud",
|
||||||
|
ShortName: "gemma4:31b-cloud",
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: gemma4RendererLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy model type resolves small",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy model type resolves large",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||||
|
},
|
||||||
|
want: gemma4RendererLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy unknown defaults small",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveGemma4Renderer(tt.model); got != tt.want {
|
||||||
|
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -156,7 +156,7 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
|
|
||||||
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
||||||
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
||||||
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
|
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
|
||||||
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
||||||
return c == model.CapabilityVision || c == "audio"
|
return c == model.CapabilityVision || c == "audio"
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -118,6 +118,39 @@ func TestModelCapabilities(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gemma4 small safetensors suppresses vision and audio",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
ModelFormat: "safetensors",
|
||||||
|
Renderer: gemma4RendererSmall,
|
||||||
|
Capabilities: []string{"vision", "audio"},
|
||||||
|
},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemma4 large safetensors suppresses vision and audio",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
ModelFormat: "safetensors",
|
||||||
|
Renderer: gemma4RendererLarge,
|
||||||
|
Capabilities: []string{"vision", "audio"},
|
||||||
|
},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy gemma4 safetensors suppresses vision and audio",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
ModelFormat: "safetensors",
|
||||||
|
Renderer: gemma4RendererLegacy,
|
||||||
|
Capabilities: []string{"vision", "audio"},
|
||||||
|
},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare two slices of model.Capability regardless of order
|
// compare two slices of model.Capability regardless of order
|
||||||
|
|||||||
@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
|
|
||||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
if m.Config.Renderer != "" {
|
if m.Config.Renderer != "" {
|
||||||
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
rendererName := resolveRendererName(m)
|
||||||
|
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,14 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func testConfigWithRenderer(renderer string) model.ConfigV2 {
|
||||||
|
return model.ConfigV2{Renderer: renderer}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
|
||||||
|
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
type expect struct {
|
type expect struct {
|
||||||
prompt string
|
prompt string
|
||||||
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
|||||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||||
|
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model Model
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "small from name",
|
||||||
|
model: Model{
|
||||||
|
Name: "gemma4:e4b",
|
||||||
|
ShortName: "gemma4:e4b",
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large from model type",
|
||||||
|
model: Model{
|
||||||
|
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||||
|
},
|
||||||
|
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := renderPrompt(&tt.model, msgs, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
110
server/renderer_resolution.go
Normal file
110
server/renderer_resolution.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
gemma4RendererLegacy = "gemma4"
|
||||||
|
gemma4RendererSmall = "gemma4-small"
|
||||||
|
gemma4RendererLarge = "gemma4-large"
|
||||||
|
|
||||||
|
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
|
||||||
|
// large template. Default to the small prompt unless the model is clearly in
|
||||||
|
// the large range.
|
||||||
|
gemma4LargeMinParameterCount = 16_000_000_000
|
||||||
|
)
|
||||||
|
|
||||||
|
func resolveRendererName(m *Model) string {
|
||||||
|
if m == nil || m.Config.Renderer == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Config.Renderer {
|
||||||
|
case gemma4RendererLegacy:
|
||||||
|
return resolveGemma4Renderer(m)
|
||||||
|
default:
|
||||||
|
return m.Config.Renderer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveGemma4Renderer(m *Model) string {
|
||||||
|
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
|
||||||
|
if m == nil {
|
||||||
|
return gemma4RendererLegacy
|
||||||
|
}
|
||||||
|
return m.Config.Renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
|
||||||
|
return renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
if renderer, ok := gemma4RendererFromName(m.Name); ok {
|
||||||
|
return renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
|
||||||
|
return gemma4RendererForParameterCount(parameterCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return gemma4RendererSmall
|
||||||
|
}
|
||||||
|
|
||||||
|
func gemma4RendererForParameterCount(parameterCount uint64) string {
|
||||||
|
if parameterCount >= gemma4LargeMinParameterCount {
|
||||||
|
return gemma4RendererLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
return gemma4RendererSmall
|
||||||
|
}
|
||||||
|
|
||||||
|
func gemma4RendererFromName(name string) (string, bool) {
|
||||||
|
lower := strings.ToLower(name)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
|
||||||
|
return gemma4RendererSmall, true
|
||||||
|
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
|
||||||
|
return gemma4RendererLarge, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHumanParameterCount(s string) (uint64, bool) {
|
||||||
|
if s == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
unit := strings.ToUpper(s[len(s)-1:])
|
||||||
|
var multiplier float64
|
||||||
|
switch unit {
|
||||||
|
case "B":
|
||||||
|
multiplier = float64(format.Billion)
|
||||||
|
case "M":
|
||||||
|
multiplier = float64(format.Million)
|
||||||
|
case "K":
|
||||||
|
multiplier = float64(format.Thousand)
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint64(value * multiplier), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGemma4Renderer(renderer string) bool {
|
||||||
|
switch renderer {
|
||||||
|
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
_, digest := createBinFile(t, ggml.KV{
|
||||||
|
"general.architecture": "gemma4",
|
||||||
|
"general.parameter_count": uint64(25_200_000_000),
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Files: map[string]string{"test.gguf": digest},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse manifest: %v", err)
|
||||||
|
}
|
||||||
|
if mf.Config.Digest == "" {
|
||||||
|
t.Fatalf("unexpected empty config digest for manifest")
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("config blob path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgFile, err := os.Open(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open config blob: %v", err)
|
||||||
|
}
|
||||||
|
defer cfgFile.Close()
|
||||||
|
|
||||||
|
var cfg model.ConfigV2
|
||||||
|
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||||
|
t.Fatalf("decode config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Renderer != gemma4RendererLegacy {
|
||||||
|
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
|
||||||
|
}
|
||||||
|
if cfg.Parser != "gemma4" {
|
||||||
|
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDetectModelTypeFromFiles(t *testing.T) {
|
func TestDetectModelTypeFromFiles(t *testing.T) {
|
||||||
t.Run("gguf file", func(t *testing.T) {
|
t.Run("gguf file", func(t *testing.T) {
|
||||||
_, digest := createBinFile(t, nil, nil)
|
_, digest := createBinFile(t, nil, nil)
|
||||||
|
|||||||
@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
|||||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
|
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
|
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
|
||||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
|
||||||
libraryPaths := []string{ml.LibOllamaPath}
|
|
||||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
|
||||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append existing LD_LIBRARY_PATH if set
|
|
||||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
|
||||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
|
||||||
|
|
||||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
|
||||||
found := false
|
|
||||||
for i := range cmd.Env {
|
|
||||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
|
||||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
|
||||||
}
|
|
||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cmd = cmd
|
s.cmd = cmd
|
||||||
|
|
||||||
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mlxLibraryPathEnv() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
return "PATH"
|
||||||
|
case "darwin":
|
||||||
|
return "DYLD_LIBRARY_PATH"
|
||||||
|
default:
|
||||||
|
return "LD_LIBRARY_PATH"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
|
||||||
|
if len(libraryPaths) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search order for the imagegen runner is:
|
||||||
|
// 1. bundled lib/ollama root
|
||||||
|
// 2. backend-specific library dirs selected during GPU discovery
|
||||||
|
// 3. any existing caller-provided library path values
|
||||||
|
pathEnv := mlxLibraryPathEnv()
|
||||||
|
pathEnvPaths := append([]string{}, libraryPaths...)
|
||||||
|
if existingPath, ok := os.LookupEnv(pathEnv); ok {
|
||||||
|
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
|
||||||
|
}
|
||||||
|
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||||
|
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||||
|
|
||||||
|
ollamaLibraryPaths := append([]string{}, libraryPaths...)
|
||||||
|
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||||
|
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
|
||||||
|
}
|
||||||
|
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||||
|
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
|
||||||
|
for i := range cmd.Env {
|
||||||
|
name, _, ok := strings.Cut(cmd.Env[i], "=")
|
||||||
|
if ok && strings.EqualFold(name, key) {
|
||||||
|
cmd.Env[i] = key + "=" + value
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmd.Env = append(cmd.Env, key+"="+value)
|
||||||
|
}
|
||||||
|
|
||||||
// getLastErr returns the last stderr line.
|
// getLastErr returns the last stderr line.
|
||||||
func (s *Server) getLastErr() string {
|
func (s *Server) getLastErr() string {
|
||||||
s.lastErrLock.Lock()
|
s.lastErrLock.Lock()
|
||||||
|
|||||||
@@ -1061,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
var donorKV *sharedKVEntry
|
||||||
|
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||||
|
|
||||||
// If this layer is a donor, store its cached KV for later shared layers.
|
// If this layer is a donor, store its cached KV for later shared layers.
|
||||||
if layer.IsDonor && c != nil {
|
if layer.IsDonor && donorKV != nil {
|
||||||
state := c.State()
|
sharedKV[layer.LayerIdx] = *donorKV
|
||||||
if len(state) >= 2 && state[0] != nil && state[1] != nil {
|
|
||||||
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1190,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
|
|||||||
return mlx.Squeeze(sliced, 2)
|
return mlx.Squeeze(sliced, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||||
h := mlx.Add(x, attnOut)
|
h := mlx.Add(x, attnOut)
|
||||||
|
|
||||||
@@ -1237,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||||||
h = mlx.Mul(h, l.LayerScalar)
|
h = mlx.Mul(h, l.LayerScalar)
|
||||||
}
|
}
|
||||||
|
|
||||||
return h
|
return h, donorKV
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||||
// Determine head dim and scale based on layer type.
|
// Determine head dim and scale based on layer type.
|
||||||
headDim := cfg.HeadDim
|
headDim := cfg.HeadDim
|
||||||
scale := cfg.SlidingScale
|
scale := cfg.SlidingScale
|
||||||
@@ -1274,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
||||||
|
|
||||||
var k, v *mlx.Array
|
var k, v *mlx.Array
|
||||||
|
var donorKV *sharedKVEntry
|
||||||
|
|
||||||
if donorEntry != nil {
|
if donorEntry != nil {
|
||||||
// Shared layer: use donor's cached K/V.
|
// Shared layer: use donor's cached K/V.
|
||||||
@@ -1312,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
// Update cache.
|
// Update cache.
|
||||||
if c != nil {
|
if c != nil {
|
||||||
k, v = c.Update(k, v)
|
k, v = c.Update(k, v)
|
||||||
|
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1365,7 +1365,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
// strided views differently. Metal handles them natively.
|
// strided views differently. Metal handles them natively.
|
||||||
out = mlx.Contiguous(out, false)
|
out = mlx.Contiguous(out, false)
|
||||||
}
|
}
|
||||||
return a.OProj.Forward(out)
|
return a.OProj.Forward(out), donorKV
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
|||||||
Reference in New Issue
Block a user