Compare commits

...

16 Commits

Author SHA1 Message Date
Patrick Devine
578c32e42e still more linter stuff 2026-03-19 17:29:12 -07:00
Patrick Devine
a10d2625ca linters ftw 2026-03-19 17:20:59 -07:00
Patrick Devine
b960d769ad more linter fixes 2026-03-19 17:11:43 -07:00
Patrick Devine
455a6099d1 gofumpt the linter 2026-03-19 16:52:35 -07:00
Patrick Devine
7e6e8377eb mlx: qwen3.5 vision support 2026-03-19 16:35:08 -07:00
Bruce MacDonald
126d8db7f3 parsers: robust xml tool repair (#14961)
Previous xml repair for glm was a good start, but we need to go further and repair any incorrect open or closing tags

Co-authored-by: Dongluo Chen <dongluo.chen@gmail.com>
2026-03-19 11:24:48 -07:00
Eva H
3f3a24b418 app: fix desktop app stuck loading when OLLAMA_HOST is an unspecified bind address (#14885) 2026-03-19 12:57:57 -04:00
Jesse Gross
96e36c0d90 mlxrunner: share KV cache across conversations with common prefixes
Enable multiple conversations to reuse cached computations when they
share token prefixes (e.g. the same system prompt). A prefix trie
tracks shared regions so switching between conversations only
recomputes tokens that diverge. Inactive conversation state is paged
from active GPU memory to other memory and restored on demand, with LRU
eviction to keep memory usage bounded.
2026-03-18 16:06:33 -07:00
Jesse Gross
6f8ddbb26b mlxrunner: fix Slice(0, 0) returning full dimension instead of empty
Slice used cmp.Or to resolve a zero stop value to the dimension size,
intended to support open-ended slices like a[i:]. This made Slice(0, 0)
indistinguishable from Slice(), so any slice with a zero stop would
silently include the entire dimension instead of being empty.

Replace cmp.Or with an explicit End sentinel and resolve negative
indices against the dimension size, matching Python/PyTorch semantics.
2026-03-18 16:06:33 -07:00
Eva H
b5e7888414 cmd/launch: skip redundant config writes when model unchanged (#14941) 2026-03-18 17:36:52 -04:00
Parth Sareen
eab4d22269 docs: update claude code and openclaw for web search (#14922) 2026-03-18 14:18:49 -07:00
Bruce MacDonald
5759c2d2d2 launch: fix openclaw not picking up newly selected model (#14943)
Sessions with a stale model field were not updated when the primary
changed, so the old model continued to be used.
2026-03-18 13:20:10 -07:00
Bruce MacDonald
42b1c2642b docs: update minimax-m2.5 references to m2.7 (#14942) 2026-03-18 12:59:28 -07:00
Bruce MacDonald
727d69ddf3 tui: fix signin on headless Linux systems (#14627)
Defensively handle environments without a display server to ensure signin remains usable on headless VMs and SSH sessions.

- Skip calling xdg-open when neither DISPLAY nor WAYLAND_DISPLAY is set, preventing silent failures or unexpected browser handlers
- Render the signin URL as plain text instead of wrapping it in OSC 8 hyperlink escape sequences, which can be garbled or hidden by terminals that don't support them
2026-03-18 11:11:17 -07:00
Jesse Gross
f622b0c5fc launch: disable claude attribution header to preserve KV cache
Claude Code sends an x-anthropic-billing-header that changes on every
request. This is embedded in the system prompt and consequently
breaks the KV cache for every request. Given the size of the prompts
that Claude Code usees, this has significant performance impact.
2026-03-17 20:48:03 -07:00
Bruce MacDonald
5d0000634c cmd/launch: check for both npm and git before installing OpenClaw (#14888)
The OpenClaw installer requires git in addition to npm. Update the
dependency check to detect both and provide specific install guidance
for whichever dependencies are missing.
2026-03-17 18:20:05 -07:00
39 changed files with 5215 additions and 479 deletions

View File

@@ -155,7 +155,7 @@ func (s *Server) ollamaProxy() http.Handler {
return
}
target := envconfig.Host()
target := envconfig.ConnectableHost()
s.log().Info("configuring ollama proxy", "target", target.String())
newProxy := httputil.NewSingleHostReverseProxy(target)

View File

@@ -2071,7 +2071,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
},
{
name: "explicit :cloud model without local stub returns not found by default",
model: "minimax-m2.5:cloud",
model: "minimax-m2.7:cloud",
showStatus: http.StatusNotFound,
whoamiStatus: http.StatusOK,
whoamiResp: api.UserResponse{Name: "testuser"},

View File

@@ -59,6 +59,7 @@ func (c *Claude) Run(model string, args []string) error {
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
"ANTHROPIC_API_KEY=",
"ANTHROPIC_AUTH_TOKEN=ollama",
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
)
env = append(env, c.modelEnvVars(model)...)

View File

@@ -310,7 +310,7 @@ func names(items []ModelItem) []string {
func TestBuildModelList_NoExistingModels(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "")
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5"}
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5"}
if diff := cmp.Diff(want, names(items)); diff != "" {
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
}
@@ -338,7 +338,7 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
got := names(items)
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
want := []string{"glm-4.7-flash", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "llama3.2", "qwen2.5"}
want := []string{"glm-4.7-flash", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
}
@@ -354,7 +354,7 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
got := names(items)
// All recs pinned at top (cloud before local in mixed case), then non-recs
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
}
@@ -392,7 +392,7 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3.5:cloud":
case "minimax-m2.7:cloud", "kimi-k2.5:cloud", "qwen3.5:cloud":
if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
}
@@ -412,7 +412,7 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
// glm-4.7-flash and glm-5:cloud are installed so they sort normally;
// kimi-k2.5:cloud, qwen3.5:cloud, and qwen3.5 are not installed so they go to the bottom
// All recs: cloud first in mixed case, then local, in rec order within each
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5"}
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
}
@@ -430,7 +430,7 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
// kimi-k2.5:cloud is installed so it sorts normally;
// the rest of the recommendations are not installed so they go to the bottom
// All recs pinned at top (cloud first in mixed case), then non-recs
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
}
@@ -583,7 +583,7 @@ func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
lastRecIdx := -1
firstNonRecIdx := len(got)
for i, name := range got {
isRec := name == "glm-4.7-flash" || name == "qwen3.5" || name == "minimax-m2.5:cloud" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud" || name == "qwen3.5:cloud"
isRec := name == "glm-4.7-flash" || name == "qwen3.5" || name == "minimax-m2.7:cloud" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud" || name == "qwen3.5:cloud"
if isRec && i > lastRecIdx {
lastRecIdx = i
}

View File

@@ -413,9 +413,6 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
return "", err
}
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
if err := config.SetLastModel(current); err != nil {
return "", err
}
return current, nil
}
@@ -428,9 +425,6 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
return "", err
}
if err := config.SetLastModel(current); err != nil {
return "", err
}
return current, nil
}
}
@@ -439,8 +433,10 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
if err != nil {
return "", err
}
if err := config.SetLastModel(model); err != nil {
return "", err
if model != current {
if err := config.SetLastModel(model); err != nil {
return "", err
}
}
return model, nil
}
@@ -475,8 +471,10 @@ func (c *launcherClient) launchSingleIntegration(ctx context.Context, name strin
return nil
}
if err := config.SaveIntegration(name, []string{target}); err != nil {
return fmt.Errorf("failed to save: %w", err)
if target != current {
if err := config.SaveIntegration(name, []string{target}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
}
return launchAfterConfiguration(name, runner, target, req)

View File

@@ -951,7 +951,7 @@ func TestLaunchIntegration_OpenclawInstallsBeforeConfigSideEffects(t *testing.T)
if err == nil {
t.Fatal("expected launch to fail before configuration when OpenClaw is missing")
}
if !strings.Contains(err.Error(), "npm was not found") {
if !strings.Contains(err.Error(), "required dependencies are missing") {
t.Fatalf("expected install prerequisite error, got %v", err)
}
if selectorCalled {

View File

@@ -24,7 +24,7 @@ var recommendedModels = []ModelItem{
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
}
@@ -43,7 +43,7 @@ type cloudModelLimit struct {
// cloudModelLimits maps cloud model base names to their token limits.
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
var cloudModelLimits = map[string]cloudModelLimit{
"minimax-m2.5": {Context: 204_800, Output: 128_000},
"minimax-m2.7": {Context: 204_800, Output: 128_000},
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
@@ -92,6 +92,10 @@ func OpenBrowser(url string) {
case "darwin":
_ = exec.Command("open", url).Start()
case "linux":
// Skip on headless systems where no display server is available
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
return
}
_ = exec.Command("xdg-open", url).Start()
case "windows":
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()

View File

@@ -429,13 +429,17 @@ func ensureOpenclawInstalled() (string, error) {
return "clawdbot", nil
}
if _, err := exec.LookPath("npm"); err != nil {
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
"Install Node.js first:\n" +
" https://nodejs.org/\n\n" +
"Then rerun:\n" +
" ollama launch\n" +
"and select OpenClaw")
_, npmErr := exec.LookPath("npm")
_, gitErr := exec.LookPath("git")
if npmErr != nil || gitErr != nil {
var missing []string
if npmErr != nil {
missing = append(missing, "npm (Node.js): https://nodejs.org/")
}
if gitErr != nil {
missing = append(missing, "git: https://git-scm.com/")
}
return "", fmt.Errorf("openclaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s", strings.Join(missing, "\n "))
}
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
@@ -605,6 +609,8 @@ func clearSessionModelOverride(primary string) {
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
delete(sess, "modelOverride")
delete(sess, "providerOverride")
}
if model, _ := sess["model"].(string); model != "" && model != primary {
sess["model"] = primary
changed = true
}

View File

@@ -1376,7 +1376,7 @@ func TestOpenclawModelConfig(t *testing.T) {
// report it as a remote/cloud model
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/show" {
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.7"}`)
return
}
w.WriteHeader(http.StatusNotFound)
@@ -1386,7 +1386,7 @@ func TestOpenclawModelConfig(t *testing.T) {
u, _ := url.Parse(srv.URL)
client := api.NewClient(u, srv.Client())
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.7:cloud")
if !isCloud {
t.Error("expected isCloud = true for cloud model")
@@ -1768,3 +1768,124 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
}
})
}
func TestClearSessionModelOverride(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
sessionsDir := filepath.Join(tmpDir, ".openclaw", "agents", "main", "sessions")
sessionsPath := filepath.Join(sessionsDir, "sessions.json")
writeSessionsFile := func(t *testing.T, sessions map[string]map[string]any) {
t.Helper()
if err := os.MkdirAll(sessionsDir, 0o755); err != nil {
t.Fatal(err)
}
data, err := json.Marshal(sessions)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(sessionsPath, data, 0o600); err != nil {
t.Fatal(err)
}
}
readSessionsFile := func(t *testing.T) map[string]map[string]any {
t.Helper()
data, err := os.ReadFile(sessionsPath)
if err != nil {
t.Fatalf("reading sessions file: %v", err)
}
var sessions map[string]map[string]any
if err := json.Unmarshal(data, &sessions); err != nil {
t.Fatalf("parsing sessions file: %v", err)
}
return sessions
}
t.Run("clears modelOverride and updates model", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"model": "ollama/old-model", "modelOverride": "old-model", "providerOverride": "ollama"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
sess := sessions["sess1"]
if _, ok := sess["modelOverride"]; ok {
t.Error("modelOverride should have been deleted")
}
if _, ok := sess["providerOverride"]; ok {
t.Error("providerOverride should have been deleted")
}
if sess["model"] != "new-model" {
t.Errorf("model = %q, want %q", sess["model"], "new-model")
}
})
t.Run("updates model field in sessions without modelOverride", func(t *testing.T) {
// This is the bug case: session has model pointing to old primary,
// but no explicit modelOverride. After changing primary, the session
// model field must also be updated.
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"model": "ollama/old-model"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
if sessions["sess1"]["model"] != "new-model" {
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "new-model")
}
})
t.Run("does not update session already using primary", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"model": "current-model"},
})
clearSessionModelOverride("current-model")
sessions := readSessionsFile(t)
if sessions["sess1"]["model"] != "current-model" {
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "current-model")
}
})
t.Run("does not update session with empty model field", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"sess1": {"other": "data"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
if _, ok := sessions["sess1"]["model"]; ok {
t.Error("model field should not have been added to session with no model")
}
})
t.Run("handles multiple sessions mixed", func(t *testing.T) {
writeSessionsFile(t, map[string]map[string]any{
"with-override": {"model": "old", "modelOverride": "old", "providerOverride": "ollama"},
"without-override": {"model": "old"},
"already-current": {"model": "new-model"},
"no-model": {"other": "data"},
})
clearSessionModelOverride("new-model")
sessions := readSessionsFile(t)
if sessions["with-override"]["model"] != "new-model" {
t.Errorf("with-override model = %q, want %q", sessions["with-override"]["model"], "new-model")
}
if _, ok := sessions["with-override"]["modelOverride"]; ok {
t.Error("with-override: modelOverride should be deleted")
}
if sessions["without-override"]["model"] != "new-model" {
t.Errorf("without-override model = %q, want %q", sessions["without-override"]["model"], "new-model")
}
if sessions["already-current"]["model"] != "new-model" {
t.Errorf("already-current model = %q, want %q", sessions["already-current"]["model"], "new-model")
}
if _, ok := sessions["no-model"]["model"]; ok {
t.Error("no-model: model should not have been added")
}
})
t.Run("no-op when sessions file missing", func(t *testing.T) {
os.RemoveAll(sessionsDir)
clearSessionModelOverride("new-model") // should not panic or error
})
}

View File

@@ -97,11 +97,8 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
// Padding is outside the hyperlink so spaces don't get underlined.
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
s.WriteString("Navigate to:\n")
s.WriteString(urlWrap.Render(link))
s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
s.WriteString("\n\n")
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(

View File

@@ -25,22 +25,6 @@ func TestRenderSignIn_ContainsURL(t *testing.T) {
}
}
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
url := "https://ollama.com/connect?key=abc123"
got := renderSignIn("test:cloud", url, 0, 120)
// Should contain OSC 8 open sequence with the URL
osc8Open := "\033]8;;" + url + "\033\\"
if !strings.Contains(got, osc8Open) {
t.Error("should contain OSC 8 open sequence with URL")
}
// Should contain OSC 8 close sequence
osc8Close := "\033]8;;\033\\"
if !strings.Contains(got, osc8Close) {
t.Error("should contain OSC 8 close sequence")
}
}
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
got := renderSignIn("test:cloud", "https://example.com", 0, 80)

View File

@@ -41,13 +41,27 @@ ollama launch claude --model kimi-k2.5:cloud
- `kimi-k2.5:cloud`
- `glm-5:cloud`
- `minimax-m2.5:cloud`
- `minimax-m2.7:cloud`
- `qwen3.5:cloud`
- `glm-4.7-flash`
- `qwen3.5`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
## Non-interactive (headless) mode
Run Claude Code without interaction for use in Docker, CI/CD, or scripts:
```shell
ollama launch claude --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
```
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Claude Code.
## Web search
Claude Code can search the web through Ollama's web search API. See the [web search documentation](/capabilities/web-search) for setup and usage.
## Scheduled Tasks with `/loop`
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.

View File

@@ -15,13 +15,29 @@ Ollama handles everything automatically:
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
2. **Security** — On the first launch, a security notice explains the risks of tool access
3. **Model** — Pick a model from the selector (local or cloud)
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and installs the web search and fetch plugin
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
## Web search and fetch
OpenClaw ships with a web search and fetch plugin that gives local or cloud models the ability to search the web and extract readable page content.
```bash
ollama launch openclaw
```
Web search and fetch is enabled automatically when launching OpenClaw through Ollama. To install the plugin directly:
```bash
openclaw plugins install @ollama/openclaw-web-search
```
<Note>Web search for local models requires `ollama signin`.</Note>
## Configure without launching
To change the model without starting the gateway and TUI:
@@ -43,7 +59,7 @@ If the gateway is already running, it restarts automatically to pick up the new
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
- `glm-5:cloud` — Reasoning and code generation
**Local models:**
@@ -52,6 +68,16 @@ If the gateway is already running, it restarts automatically to pick up the new
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Non-interactive (headless) mode
Run OpenClaw without interaction for use in Docker, CI/CD, or scripts:
```bash
ollama launch openclaw --model kimi-k2.5:cloud --yes
```
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified.
## Connect messaging apps
```bash

View File

@@ -59,6 +59,29 @@ func Host() *url.URL {
}
}
// ConnectableHost returns Host() with unspecified bind addresses (0.0.0.0, ::)
// replaced by the corresponding loopback address (127.0.0.1, ::1).
// Unspecified addresses are valid for binding a server socket but not for
// connecting as a client, which fails on Windows.
func ConnectableHost() *url.URL {
u := Host()
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
return u
}
if ip := net.ParseIP(host); ip != nil && ip.IsUnspecified() {
if ip.To4() != nil {
host = "127.0.0.1"
} else {
host = "::1"
}
u.Host = net.JoinHostPort(host, port)
}
return u
}
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
func AllowedOrigins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {

View File

@@ -52,6 +52,37 @@ func TestHost(t *testing.T) {
}
}
func TestConnectableHost(t *testing.T) {
cases := map[string]struct {
value string
expect string
}{
"empty": {"", "http://127.0.0.1:11434"},
"localhost": {"127.0.0.1", "http://127.0.0.1:11434"},
"localhost and port": {"127.0.0.1:1234", "http://127.0.0.1:1234"},
"ipv4 unspecified": {"0.0.0.0", "http://127.0.0.1:11434"},
"ipv4 unspecified + port": {"0.0.0.0:1234", "http://127.0.0.1:1234"},
"ipv6 unspecified": {"[::]", "http://[::1]:11434"},
"ipv6 unspecified + port": {"[::]:1234", "http://[::1]:1234"},
"ipv6 localhost": {"[::1]", "http://[::1]:11434"},
"ipv6 localhost + port": {"[::1]:1234", "http://[::1]:1234"},
"specific address": {"192.168.1.5", "http://192.168.1.5:11434"},
"specific address + port": {"192.168.1.5:8080", "http://192.168.1.5:8080"},
"hostname": {"example.com", "http://example.com:11434"},
"hostname and port": {"example.com:1234", "http://example.com:1234"},
"https unspecified + port": {"https://0.0.0.0:4321", "https://127.0.0.1:4321"},
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", tt.value)
if host := ConnectableHost(); host.String() != tt.expect {
t.Errorf("%s: expected %s, got %s", name, tt.expect, host.String())
}
})
}
}
func TestOrigins(t *testing.T) {
cases := []struct {
value string

View File

@@ -345,44 +345,163 @@ func escapeGLM46Content(s string) string {
return result.String()
}
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
// GLM models sometimes omit the closing tag, producing XML like:
// repairPhase represents the expected next tag in the repair cycle.
type repairPhase int
const (
phaseArgKeyOpen repairPhase = iota // expecting <arg_key>
phaseArgKeyClose // expecting </arg_key>
phaseArgValOpen // expecting <arg_value>
phaseArgValClose // expecting </arg_value>
phaseCount // number of phases
)
// repairGLM46XML reconstructs well-formed XML from GLM model output that may
// have missing or mismatched tags. The expected structure is:
//
// <arg_value>value</tool_call>
// func_name
// <arg_key>key</arg_key>
// <arg_value>value</arg_value>
// ...
//
// instead of:
//
// <arg_value>value</arg_value></tool_call>
func repairUnclosedArgValues(s string) string {
// GLM models frequently omit opening or closing tags. This function follows
// the expected tag cycle, scanning forward for each expected tag in sequence.
// When a tag is missing, it inserts the tag and consumes any text in between.
func repairGLM46XML(s string) string {
// tagCycle is the repeating sequence of tags after the function name.
tagCycle := [phaseCount]string{"<arg_key>", "</arg_key>", "<arg_value>", "</arg_value>"}
// findNextTag returns the index and identity of the earliest known tag in s.
findNextTag := func(s string) (int, string) {
bestIdx := -1
bestTag := ""
for _, tag := range tagCycle {
if idx := strings.Index(s, tag); idx != -1 && (bestIdx == -1 || idx < bestIdx) {
bestIdx = idx
bestTag = tag
}
}
return bestIdx, bestTag
}
// tagIndex returns the phase corresponding to the given tag.
tagIndex := func(tag string) repairPhase {
for i, t := range tagCycle {
if t == tag {
return repairPhase(i)
}
}
return -1
}
var result strings.Builder
for {
openIdx := strings.Index(s, "<arg_value>")
if openIdx == -1 {
idx, firstTag := findNextTag(s)
if idx == -1 {
return s
}
prefix := s[:idx]
s = s[idx:]
// If the first tag is not <arg_key>, the text before it may contain both
// the function name and key content (e.g. "weather city</arg_key>").
// Function names cannot contain space, so split at the first space.
phase := phaseArgKeyOpen
if firstTag != "<arg_key>" {
if spIdx := strings.IndexFunc(prefix, unicode.IsSpace); spIdx != -1 {
result.WriteString(prefix[:spIdx])
keyContent := strings.TrimLeftFunc(prefix[spIdx:], unicode.IsSpace)
result.WriteString("<arg_key>")
result.WriteString(keyContent)
phase = phaseArgKeyClose
} else {
result.WriteString(prefix)
}
} else {
result.WriteString(prefix)
}
// Walk through the expected tag cycle. At each step, look for the
// expected tag. If a different tag appears first, emit the missing
// tags to catch up, then continue.
for len(s) > 0 {
idx, found := findNextTag(s)
expected := tagCycle[phase]
isOpen := phase%2 == 0 // even phases are opening tags
if idx == -1 {
// No more tags — emit remaining text with fixups
if isOpen {
// Expecting an opening tag but nothing left — we're done
break
}
// Expecting a closing tag — emit text then close
result.WriteString(s)
result.WriteString(expected)
phase = (phase + 1) % phaseCount
break
}
afterOpen := openIdx + len("<arg_value>")
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
// Check if properly closed before the next <arg_key> (or no next key)
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
end := afterOpen + closeIdx + len("</arg_value>")
result.WriteString(s[:end])
s = s[end:]
if found == expected {
// Found the expected tag — emit any text before it, then the tag
result.WriteString(s[:idx])
result.WriteString(expected)
s = s[idx+len(expected):]
phase = (phase + 1) % phaseCount
continue
}
// Unclosed — insert </arg_value> before the next <arg_key> or at end
if nextKeyIdx != -1 {
insertAt := afterOpen + nextKeyIdx
result.WriteString(s[:insertAt])
result.WriteString("</arg_value>")
s = s[insertAt:]
} else {
result.WriteString(s)
result.WriteString("</arg_value>")
break
// Found a different tag. Insert missing tags to catch up.
foundIdx := tagIndex(found)
if isOpen && idx > 0 {
// Text before the found tag while expecting an opening tag —
// the opening tag was omitted. Emit it before the text.
result.WriteString(expected)
// Advance to the next phase (text content) and then look
// for the closing tag — but the found tag might be that
// closing tag or something further ahead. Emit text up to
// the found tag and insert any missing tags between.
result.WriteString(s[:idx])
phase = (phase + 1) % phaseCount // now expecting closing
s = s[idx:]
// Fall through to re-evaluate with the closing tag expected
continue
}
// Emit missing tags to advance from current phase to the found tag's phase
for phase != foundIdx {
tag := tagCycle[phase]
if phase%2 == 0 {
result.WriteString(tag)
} else {
// Closing tag — emit any text before the found tag first,
// but only if we're one step before the found tag
if (phase+1)%phaseCount == foundIdx && idx > 0 {
result.WriteString(s[:idx])
s = s[idx:]
idx = 0
}
result.WriteString(tag)
}
phase = (phase + 1) % phaseCount
}
// Now phase == foundIdx, re-process without advancing s
}
// If we stopped mid-pair (after an opening tag), close it
switch phase {
case phaseArgKeyClose: // after <arg_key>, expecting text/</arg_key>
result.WriteString("</arg_key>")
result.WriteString("<arg_value>")
result.WriteString("</arg_value>")
case phaseArgValOpen: // after </arg_key>, expecting <arg_value>
result.WriteString("<arg_value>")
result.WriteString("</arg_value>")
case phaseArgValClose: // after <arg_value>, expecting text/</arg_value>
result.WriteString("</arg_value>")
}
return result.String()
}
@@ -398,7 +517,7 @@ func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCa
var parsed GLMToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
parsed = GLMToolCallXML{}
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
repaired := "<tool_call>" + repairGLM46XML(escaped) + "</tool_call>"
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}

View File

@@ -887,6 +887,28 @@ line3</arg_value>`,
},
},
},
{
name: "unopened arg_value after arg_key",
tools: []api.Tool{},
rawToolCall: "get-weather\n<arg_key>city</arg_key>\nNew York</arg_value>\n<arg_key>unit</arg_key>\ncelsius</arg_value>",
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "New York", "unit": "celsius"}`),
},
},
},
{
name: "mixed unopened and valid arg_values",
tools: []api.Tool{},
rawToolCall: "get-weather\n<arg_key>city</arg_key>\n<arg_value>Paris</arg_value>\n<arg_key>unit</arg_key>\ncelsius</arg_value>",
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get-weather",
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
},
},
},
}
for i, tc := range cases {
@@ -902,7 +924,7 @@ line3</arg_value>`,
}
}
func TestRepairUnclosedArgValues(t *testing.T) {
func TestRepairGLM46XML(t *testing.T) {
cases := []struct {
name string
input string
@@ -910,33 +932,63 @@ func TestRepairUnclosedArgValues(t *testing.T) {
}{
{
name: "already valid",
input: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
input: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "unclosed at end",
input: `<arg_key>k</arg_key><arg_value>v`,
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
name: "missing </arg_value> at end",
input: `func<arg_key>k</arg_key><arg_value>v`,
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "unclosed before next arg_key",
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
name: "missing </arg_value> before next arg_key",
input: `func<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
want: `func<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
},
{
name: "no arg_value tags",
name: "no tags at all",
input: `just plain text`,
want: `just plain text`,
},
{
name: "multiple unclosed",
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2`,
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
name: "missing <arg_value> open tag",
input: `func<arg_key>k</arg_key>v</arg_value>`,
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "missing </arg_key> close tag",
input: `func<arg_key>k<arg_value>v</arg_value>`,
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "missing <arg_key> open tag",
input: `func k</arg_key><arg_value>v</arg_value>`,
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "all closing tags missing",
input: `func<arg_key>k<arg_value>v`,
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
},
{
name: "all opening tags missing",
input: "func k</arg_key>v</arg_value>",
want: "func<arg_key>k</arg_key><arg_value>v</arg_value>",
},
{
name: "multiple pairs with mixed missing tags",
input: `func<arg_key>a</arg_key>1</arg_value><arg_key>b<arg_value>2</arg_value>`,
want: `func<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
},
{
name: "newlines preserved",
input: "func\n<arg_key>city</arg_key>\nNew York</arg_value>",
want: "func\n<arg_key>city</arg_key><arg_value>\nNew York</arg_value>",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := repairUnclosedArgValues(tc.input)
got := repairGLM46XML(tc.input)
if got != tc.want {
t.Errorf("got %q, want %q", got, tc.want)
}

View File

@@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
spinnerKey = "create"
capabilities = []string{"completion"}
// Check if model supports thinking based on architecture
if supportsThinking(opts.ModelDir) {
configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json"))
mcfg := parseModelConfig(configData)
if mcfg.supportsThinking() {
capabilities = append(capabilities, "thinking")
}
if mcfg.supportsVision() {
capabilities = append(capabilities, "vision")
}
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
parserName = mcfg.parserName()
rendererName = mcfg.rendererName()
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil
}
// supportsThinking checks if the model supports thinking mode based on its architecture.
// This reads the config.json from the model directory and checks the architectures field.
func supportsThinking(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
}
// modelConfig holds the fields from config.json needed during model creation.
type visionConfig struct {
Depth int32 `json:"depth"`
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
type modelConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
VisionConfig *visionConfig `json:"vision_config"`
ImageTokenID *int32 `json:"image_token_id"`
VisionStartTokenID *int32 `json:"vision_start_token_id"`
VisionEndTokenID *int32 `json:"vision_end_token_id"`
}
// Check architectures that support thinking
thinkingArchitectures := []string{
"glm4moe", // GLM-4 MoE models
"deepseek", // DeepSeek models
"qwen3", // Qwen3 models
}
func parseModelConfig(data []byte) modelConfig {
var cfg modelConfig
_ = json.Unmarshal(data, &cfg)
return cfg
}
// Check the architecture list
for _, arch := range cfg.Architectures {
// archOrTypeContains returns true if any architecture or the model_type
// contains one of the given substrings (case-insensitive).
func (c *modelConfig) archOrTypeContains(substrs ...string) bool {
for _, arch := range c.Architectures {
archLower := strings.ToLower(arch)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(archLower, thinkArch) {
for _, s := range substrs {
if strings.Contains(archLower, s) {
return true
}
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(typeLower, thinkArch) {
if c.ModelType != "" {
typeLower := strings.ToLower(c.ModelType)
for _, s := range substrs {
if strings.Contains(typeLower, s) {
return true
}
}
}
return false
}
// getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
func (c *modelConfig) supportsThinking() bool {
return c.archOrTypeContains("glm4moe", "deepseek", "qwen3")
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
func (c *modelConfig) supportsVision() bool {
return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil
}
// Check architectures for known parsers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3"
}
func (c *modelConfig) parserName() string {
switch {
case c.archOrTypeContains("glm4", "glm-4"):
return "glm-4.7"
case c.archOrTypeContains("deepseek"):
return "deepseek3"
case c.archOrTypeContains("qwen3"):
return "qwen3"
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3"
}
}
return ""
}
// getRendererName returns the renderer name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate renderer.
func getRendererName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
func (c *modelConfig) rendererName() string {
switch {
case c.archOrTypeContains("glm4", "glm-4"):
return "glm-4.7"
case c.archOrTypeContains("deepseek"):
return "deepseek3"
case c.archOrTypeContains("qwen3"):
return "qwen3-coder"
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}

View File

@@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
}
}
func TestSupportsVision(t *testing.T) {
t.Run("vision_config present", func(t *testing.T) {
cfg := parseModelConfig([]byte(`{
"vision_config": {"depth": 2},
"image_token_id": 151655
}`))
if !cfg.supportsVision() {
t.Fatal("supportsVision() = false, want true")
}
})
t.Run("token ids alone imply vision", func(t *testing.T) {
cfg := parseModelConfig([]byte(`{
"vision_start_token_id": 10,
"vision_end_token_id": 11
}`))
if !cfg.supportsVision() {
t.Fatal("supportsVision() = false, want true")
}
})
t.Run("plain text model", func(t *testing.T) {
cfg := parseModelConfig([]byte(`{
"architectures": ["Qwen3_5ForCausalLM"]
}`))
if cfg.supportsVision() {
t.Fatal("supportsVision() = true, want false")
}
})
}

View File

@@ -1,8 +1,26 @@
// cache.go manages a shared KV cache across conversations using a compressed
// prefix trie. Each trie node stores a token sequence (edge) and optional
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
//
// Key properties:
// - Only one path through the trie is "active" (backed by live MLX arrays)
// at a time. Switching paths pages out the frontier node and pages in the
// new path.
// - Snapshots are only captured at the frontier (end) of the active path.
// Intermediate node snapshots come from split prefill.
// - All cache layers must stay at the same token offset.
// - Sibling edges must not share a common token prefix (compressed trie
// invariant).
// - begin() always re-evaluates at least one token so the pipeline can seed
// generation, even on a full prefix match.
package mlxrunner
import (
"cmp"
"fmt"
"log/slog"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
@@ -10,10 +28,13 @@ import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
type kvCache struct {
// For now we only support a single entry, so this is just one sequence
tokens []int32
caches []cache.Cache
root *trieNode // root of the prefix trie
activePath []*trieNode // current root→leaf path with live MLX arrays
caches []cache.Cache
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
}
// cacheSession manages caches for a single pipeline run.
@@ -26,176 +47,555 @@ type cacheSession struct {
caches []cache.Cache
remaining []int32
// snapshotOffset, if > 0, is a trie node boundary where we need to
// capture a snapshot during prefill. This enables future requests
// branching at this point to restore non-rewindable caches (e.g.
// RecurrentCache) instead of re-evaluating from scratch.
snapshotOffset int
}
func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array {
if c == nil {
return dst
func (c *kvCache) ensureCaches(m base.Model) {
if len(c.caches) != 0 {
return
}
keys, values := c.State()
if keys != nil && keys.Valid() {
dst = append(dst, keys)
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
return
}
if values != nil && values.Valid() {
dst = append(dst, values)
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
}
return dst
}
func (c *kvCache) free() {
for i, kv := range c.caches {
if kv == nil {
continue
func (c *kvCache) ensureRoot() {
if c.root == nil {
c.root = &trieNode{
lastUsed: time.Now(),
}
kv.Free()
c.caches[i] = nil
}
c.caches = nil
c.tokens = nil
}
func (c *kvCache) cachesCanTrim() bool {
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.CanTrim() {
return false
}
}
return true
}
func (c *kvCache) trimToPrefix(prefix int) {
for _, kv := range c.caches {
if kv == nil || !kv.CanTrim() {
continue
}
if trim := kv.Offset() - prefix; trim > 0 {
kv.Trim(trim)
}
}
if prefix < len(c.tokens) {
c.tokens = c.tokens[:prefix]
c.activePath = []*trieNode{c.root}
}
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
ensureCaches := func() {
if len(c.caches) != 0 {
return
}
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
c.caches = cacheFactory.NewCaches()
return
}
c.caches = make([]cache.Cache, m.NumLayers())
for i := range c.caches {
c.caches[i] = cache.NewKVCache()
c.ensureCaches(m)
c.ensureRoot()
matchPath, matched := findBestMatch(c.root, inputs)
originalMatched := matched
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if matched == len(inputs) && matched > 0 {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
}
// Check for partial match within a node's edge — truncate path
// to the parent boundary. snapshot() will split the node and
// create the branch point during prefill when caches are ready.
partialMatch := false
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath = matchPath[:len(matchPath)-1]
partialMatch = true
}
}
ensureCaches()
remaining := c.findRemaining(inputs)
ensureCaches()
// Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath)
// switchToPath aligns caches to a common offset
prefix := c.minCacheOffset()
remaining := inputs[prefix:]
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
var snapshotAt int
if partialMatch || (prefix == 0 && matched > 0) {
snapshotAt = matched
}
args := []any{"total", len(inputs), "matched", originalMatched}
args = append(args, "cached", prefix, "left", len(remaining))
if snapshotAt > 0 {
args = append(args, "pending_snapshot", snapshotAt)
}
if prefix == 0 {
slog.Info("cache miss", args...)
} else {
slog.Info("cache hit", args...)
}
return &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
cache: c,
inputs: inputs,
snapshotOffset: snapshotAt,
caches: c.caches,
remaining: remaining,
}
}
// switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode) {
defer c.enforceEvictionPolicy()
// Find common ancestor index.
commonLen := 0
for commonLen < len(c.activePath) && commonLen < len(newPath) {
if c.activePath[commonLen] != newPath[commonLen] {
break
}
commonLen++
}
ancestorOffset := 0
if commonLen > 0 {
ancestorOffset = c.activePath[commonLen-1].endOffset
}
var pageOutCount, pageInCount int
// Page out the leaf of the old path. Only the leaf's live cache
// state is correct — intermediate nodes already have snapshots
// captured during their creation (splitNode + prefill). Snapshotting
// non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary.
if leaf := len(c.activePath) - 1; leaf >= commonLen {
node := c.activePath[leaf]
if !node.hasAllSnapshots() {
fromOffset := node.startOffset()
snaps := make([]cache.Snapshot, len(c.caches))
for j, kv := range c.caches {
if kv == nil {
continue
}
snaps[j] = kv.Snapshot(fromOffset)
}
node.setSnapshots(snaps, &c.pagedOutBytes)
pageOutCount++
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
}
}
// Rewind each cache to the ancestor offset or free it. Freed
// caches (e.g. RecurrentCache that can't rewind) will be restored
// from snapshots during page-in.
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.Restore(nil, ancestorOffset) {
kv.Free()
}
}
// Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check.
for _, node := range newPath {
if len(node.snapshots) == 0 {
continue
}
for j, kv := range c.caches {
if kv == nil {
continue
}
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if kv.Offset() >= node.endOffset {
continue
}
if !kv.Restore(node.snapshots[j], node.endOffset) {
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
c.freeAll()
c.activePath = []*trieNode{c.root}
return
}
}
if node.endOffset > ancestorOffset {
pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
}
}
// Align all caches to the minimum offset.
c.activePath = newPath
minOff := c.minCacheOffset()
for _, kv := range c.caches {
if kv != nil && kv.Offset() != minOff {
if !kv.Restore(nil, minOff) {
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
c.freeAll()
break
}
}
}
for i := len(c.activePath) - 1; i >= 0; i-- {
if c.activePath[i].endOffset <= minOff {
c.activePath = c.activePath[:i+1]
break
}
}
if pageOutCount > 0 || pageInCount > 0 {
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
}
}
// snapshot creates a snapshot at the current cache position. During prefill,
// it is called at branch points (user=false) to create restore points for
// future diverging requests and with user=true to mark an explicit reusable
// restore point.
func (s *cacheSession) snapshot(user bool) {
c := s.cache
cacheOffset := c.minCacheOffset()
if cacheOffset <= 0 {
return
}
// Clear pending intermediate snapshot if we've reached or passed it.
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
s.snapshotOffset = 0
}
// The last node in activePath is the frontier where caches are advancing.
// cacheOffset is always >= its endOffset: begin() restores caches to this
// boundary and prefill advances monotonically forward.
frontier := c.activePath[len(c.activePath)-1]
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
if frontier.endOffset == cacheOffset {
if user {
frontier.user = true
}
if !frontier.hasAllSnapshots() {
s.attachSnapshots(frontier, cacheOffset)
}
return
}
if frontier.endOffset > cacheOffset {
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
return
}
// Advance the trie to cacheOffset — find or create a node there.
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
// Attach fresh snapshots from the live caches. Always use fresh
// snapshots even if the node already has some (e.g. from splitNode's
// Cache.Split which may be incomplete for non-splittable caches
// like RecurrentCache).
if user {
frontier.user = true
}
s.attachSnapshots(frontier, cacheOffset)
}
// advancePath advances the active path from the current frontier by matching
// tokens against existing trie children, splitting partial matches, and
// appending any remaining tokens as new nodes. Returns the new frontier.
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
// Check if existing children already cover some or all of tokens.
// tokens may span multiple trie nodes when extending a previous run's
// leaf and this snapshot now overlaps that same range.
matchPath, matched := findBestMatch(frontier, tokens)
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
remaining := tokens[matched:]
// Check for a partial match within the last node's edge — if so, split it.
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
}
}
// Append traversed nodes (excluding frontier) to the active path.
c.activePath = append(c.activePath, matchPath[1:]...)
dest := matchPath[len(matchPath)-1]
if len(remaining) > 0 {
// Drop non-user snapshots so appendTokens can extend in-place
// rather than creating a new child node.
if len(dest.children) == 0 && !dest.user {
dest.setSnapshots(nil, &c.pagedOutBytes)
}
newDest := dest.appendTokens(c.root, remaining, endOffset)
if newDest != dest {
c.activePath = append(c.activePath, newDest)
}
dest = newDest
}
return dest
}
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
// The node must be on the active path (and thus protected from eviction;
// lastUsed is updated in close()). All non-nil caches must be at the same
// offset (cacheOffset); a mismatch indicates a bug in the caller.
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
c := s.cache
if c.activePath[len(c.activePath)-1] != node {
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
return
}
snaps := make([]cache.Snapshot, len(c.caches))
for i, kv := range c.caches {
if kv != nil {
if kv.Offset() != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
}
snaps[i] = kv.Snapshot(node.startOffset())
}
}
node.setSnapshots(snaps, &c.pagedOutBytes)
slog.Debug("created snapshot", "offset", cacheOffset)
c.enforceEvictionPolicy()
}
// clear releases live caches and drops the trie so future requests cannot
// reuse prompt state keyed only by token IDs.
func (c *kvCache) clear() {
c.freeAll()
walkNodes(c.root, func(n *trieNode) bool {
for _, s := range n.snapshots {
if s != nil {
s.Close()
}
}
n.snapshots = nil
return true
})
c.root = nil
c.activePath = nil
}
// freeAll releases all cache layers.
func (c *kvCache) freeAll() {
for _, kv := range c.caches {
if kv != nil {
kv.Free()
}
}
}
func (c *kvCache) minCacheOffset() int {
offset := 0
found := false
for _, kv := range c.caches {
if kv == nil {
continue
}
if off := kv.Offset(); !found || off < offset {
offset = off
found = true
}
}
return offset
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
if len(s.caches) == 0 {
offset := s.cache.minCacheOffset()
if offset <= 0 {
return
}
offset := -1
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, kv := range s.caches {
if kv == nil {
continue
}
// Mixed cache types (e.g. recurrent + KV) can transiently report different
// offsets, so use the minimum as the safe reusable token prefix.
if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
arrays = appendCacheState(arrays, kv)
}
if offset <= 0 {
return
arrays = append(arrays, kv.State()...)
}
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data.
mlx.AsyncEval(arrays...)
stored := append(s.inputs, s.outputs...)
if offset > len(stored) {
offset = len(stored)
}
s.cache.tokens = stored[:offset]
}
// Advance the trie frontier with any newly generated tokens.
c := s.cache
if len(c.activePath) > 0 {
frontier := c.activePath[len(c.activePath)-1]
stored := append(s.inputs, s.outputs...)
// findRemaining finds the longest common prefix between tokens and the cached
// sequence, trims stale cache entries, and returns the remaining tokens.
func (c *kvCache) findRemaining(tokens []int32) []int32 {
prefix := 0
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
prefix++
}
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if prefix == len(tokens) && prefix > 0 {
prefix--
}
if prefix < len(c.tokens) {
if c.cachesCanTrim() {
c.trimToPrefix(prefix)
} else {
c.free()
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
return tokens
if offset > frontier.endOffset {
newTokens := stored[frontier.endOffset:offset]
c.advancePath(frontier, newTokens, offset)
}
now := time.Now()
for _, node := range c.activePath {
node.lastUsed = now
}
}
if prefix == 0 {
slog.Info("Cache miss", "left", len(tokens))
} else {
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
}
return tokens[prefix:]
}
func (c *kvCache) log() {
if len(c.caches) == 0 {
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
func (c *kvCache) enforceEvictionPolicy() {
if c.pagedOutBytes <= maxPagedOutBytes {
return
}
offset := -1
var totalBytes int
activeSet := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
activeSet[n] = true
}
for c.pagedOutBytes > maxPagedOutBytes {
var best *trieNode
walkNodes(c.root, func(n *trieNode) bool {
if n == c.root || activeSet[n] || !n.hasSnapshots() {
return true
}
// Evict: oldest, then deepest, then largest.
if best == nil || cmp.Or(
n.lastUsed.Compare(best.lastUsed),
cmp.Compare(best.endOffset, n.endOffset),
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
) < 0 {
best = n
}
return true
})
if best == nil {
break
}
c.evictNode(best)
}
}
// evictNode evicts a single node from the trie, freeing its snapshot memory.
func (c *kvCache) evictNode(node *trieNode) {
if len(node.children) == 0 {
// Leaf: remove entirely.
parent := node.parent
kind := "evicting leaf"
if node.user {
kind = "evicting user snapshot"
}
slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
removeNode(node, &c.pagedOutBytes)
// If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge.
if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root {
logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset))
mergeWithChild(parent, c.caches, &c.pagedOutBytes)
}
} else if len(node.children) == 1 {
// Interior snapshot node with one child: merge with child.
slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
mergeWithChild(node, c.caches, &c.pagedOutBytes)
} else {
// Multi-child branch point: drop snapshots but keep the node.
slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
node.setSnapshots(nil, &c.pagedOutBytes)
}
}
func (c *kvCache) dumpTree() {
// Summary stats
var cacheBytes int
for _, kv := range c.caches {
if kv == nil {
continue
}
if off := kv.Offset(); offset < 0 || off < offset {
offset = off
}
for _, a := range appendCacheState(nil, kv) {
totalBytes += a.NumBytes()
for _, a := range kv.State() {
if a != nil {
cacheBytes += a.NumBytes()
}
}
}
if offset < 0 {
return
// Build active path set for marking.
active := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
active[n] = true
}
var nodeCount, snapshotCount int
var pagedBytes int64
var lines []string
var dump func(n *trieNode, prefix string, isLast bool)
dump = func(n *trieNode, prefix string, isLast bool) {
if n == nil {
return
}
nodeCount++
// Build connector
var connector string
if n.parent == nil {
connector = ""
} else if isLast {
connector = prefix + "`-- "
} else {
connector = prefix + "|-- "
}
// Node label
nodeBytes := n.snapshotBytes()
pagedBytes += nodeBytes
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
if nodeBytes > 0 {
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
}
var flags []string
if n.user {
flags = append(flags, "user")
}
if n.hasAllSnapshots() {
snapshotCount++
flags = append(flags, "snap")
}
if active[n] {
flags = append(flags, "active")
}
if len(flags) > 0 {
label += " (" + flags[0]
for _, f := range flags[1:] {
label += ", " + f
}
label += ")"
}
lines = append(lines, connector+label)
// Recurse children
childPrefix := prefix
if n.parent != nil {
if isLast {
childPrefix += " "
} else {
childPrefix += "| "
}
}
for i, child := range n.children {
dump(child, childPrefix, i == len(n.children)-1)
}
}
dump(c.root, "", true)
offset := c.minCacheOffset()
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
for i, l := range lines {
if i == 0 {
logutil.Trace("cache trie: " + l)
} else {
logutil.Trace(" " + l)
}
}
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
}

View File

@@ -8,13 +8,34 @@ import (
type Cache interface {
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
// State returns the cache-owned state roots that should be kept/evaluated.
State() (keys, values *mlx.Array)
CanTrim() bool
Trim(int) int
Clone() Cache
State() []*mlx.Array
Free()
Offset() int
Len() int
// Snapshot copies cache state from fromOffset to current offset into
// pinned VRAM arrays. The active cache is unchanged.
Snapshot(fromOffset int) Snapshot
// Restore brings the cache to target. If snapshot is nil, rewinds
// using the cache's own live state.
Restore(snapshot Snapshot, target int) bool
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
// Takes ownership of both inputs.
Merge(parent, child Snapshot) Snapshot
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
// Takes ownership of the input. Cache types that cannot split
// (e.g. recurrent) return (nil, snapshot).
Split(snapshot Snapshot, at int) (parent, child Snapshot)
}
// Snapshot is paged-out cache state that can be restored later.
type Snapshot interface {
// Size returns the byte size of the paged-out data (in VRAM).
Size() int
// Close unpins the snapshot's arrays so they can be freed by Sweep.
Close()
}
type KVCache struct {
@@ -59,40 +80,148 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
}
}
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
type kvSnapshot struct {
keys, values *mlx.Array
fromOffset, toOffset int
}
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
if c.keys == nil || c.offset <= fromOffset {
return nil
}
from := max(0, fromOffset)
to := c.offset
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
kCopy := mlx.Copy(kSlice)
vCopy := mlx.Copy(vSlice)
mlx.Pin(kCopy, vCopy)
mlx.AsyncEval(kCopy, vCopy)
return &kvSnapshot{
keys: kCopy,
values: vCopy,
fromOffset: from,
toOffset: to,
}
}
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
if snapshot == nil {
// Rewind using live state — just clamp offset.
target = max(0, min(target, c.offset))
c.offset = target
return true
}
snap := snapshot.(*kvSnapshot)
// Check that the cache has data up to the snapshot's starting point.
if c.offset < snap.fromOffset {
return false
}
// Rewind to snapshot start, then feed snapshot data through Update.
c.offset = snap.fromOffset
c.Update(snap.keys, snap.values)
// Clamp to target if needed (target may be less than full snapshot).
if target < c.offset {
c.offset = target
}
return true
}
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
if parent == nil || child == nil {
if parent != nil {
parent.Close()
}
if child != nil {
child.Close()
}
return nil
}
p := parent.(*kvSnapshot)
ch := child.(*kvSnapshot)
mk := p.keys.Concatenate(2, ch.keys)
mv := p.values.Concatenate(2, ch.values)
mlx.Pin(mk, mv)
mlx.AsyncEval(mk, mv)
p.Close()
ch.Close()
return &kvSnapshot{
keys: mk,
values: mv,
fromOffset: p.fromOffset,
toOffset: ch.toOffset,
}
}
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
if snapshot == nil {
return nil, nil
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) CanTrim() bool { return true }
func (c *KVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
return n
}
func (c *KVCache) Clone() Cache {
clone := &KVCache{
keys: c.keys.Clone(),
values: c.values.Clone(),
offset: c.offset,
step: c.step,
snap := snapshot.(*kvSnapshot)
splitIdx := at - snap.fromOffset
seqLen := snap.toOffset - snap.fromOffset
if splitIdx <= 0 {
return nil, snapshot
}
mlx.Pin(clone.keys, clone.values)
return clone
if splitIdx >= seqLen {
return snapshot, nil
}
pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
mlx.Pin(pk, pv, ck, cv)
mlx.AsyncEval(pk, pv, ck, cv)
snap.Close()
p := &kvSnapshot{
keys: pk,
values: pv,
fromOffset: snap.fromOffset,
toOffset: at,
}
ch := &kvSnapshot{
keys: ck,
values: cv,
fromOffset: at,
toOffset: snap.toOffset,
}
return p, ch
}
func (c *KVCache) Free() {
mlx.Unpin(c.keys, c.values)
c.keys, c.values = nil, nil
c.offset = 0
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
@@ -184,29 +313,104 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil, nil
return nil
}
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *RotatingKVCache) CanTrim() bool { return true }
func (c *RotatingKVCache) Trim(n int) int {
n = min(c.offset, n)
c.offset -= n
c.idx -= n
return n
}
func (c *RotatingKVCache) Clone() Cache {
return &RotatingKVCache{
maxSize: c.maxSize,
idx: c.idx,
KVCache: c.KVCache.Clone().(*KVCache),
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
}
}
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
type rotatingSnapshot struct {
kvSnapshot // embedded KV data
idx int // buffer write position at snapshot time
}
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
if c.keys == nil || c.offset <= fromOffset {
return nil
}
state := c.State()
k := state[0].Clone()
v := state[1].Clone()
mlx.Pin(k, v)
return &rotatingSnapshot{
kvSnapshot: kvSnapshot{
keys: k,
values: v,
fromOffset: fromOffset,
toOffset: c.offset,
},
idx: c.idx,
}
}
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if snapshot == nil {
// Live rewind is only safe when the buffer hasn't filled yet
// (offset <= maxSize). Once the window has shifted, rewinding
// leaves fewer than maxSize trailing tokens to attend to —
// a snapshot is required to restore the full window.
if c.offset > c.maxSize {
return false
}
target = max(0, min(target, c.offset))
c.offset = target
c.idx = target
return true
}
snap := snapshot.(*rotatingSnapshot)
// Reject if clamping would leave an incomplete window.
if target < snap.toOffset && snap.toOffset > c.maxSize {
return false
}
// Restore from snapshot: rebuild buffer state.
// Free existing state first.
if c.keys != nil {
mlx.Unpin(c.keys, c.values)
}
c.keys = snap.keys.Clone()
c.values = snap.values.Clone()
mlx.Pin(c.keys, c.values)
c.offset = snap.toOffset
c.idx = snap.idx
// Clamp to target if needed.
if target < c.offset {
target = max(0, target)
c.offset = target
c.idx = target
}
return true
}
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
// For rotating caches, the child snapshot supersedes the parent
// since it contains the full window state.
if parent != nil {
parent.Close()
}
return child
}
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
// Rotating cache snapshots contain the full window state.
// Cannot cleanly split a ring buffer at an arbitrary point.
return nil, snapshot
}
func (c *RotatingKVCache) Free() {
c.KVCache.Free()
c.idx = 0
}

271
x/mlxrunner/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,271 @@
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
// Snapshot [5, 10).
snap := c.Snapshot(5)
// Free the cache completely — offset is now 0.
c.Free()
// Restore should fail because cache doesn't have data up to fromOffset=5.
if c.Restore(snap, 10) {
t.Fatal("expected Restore to fail with no base data")
}
}
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
// is preserved through a snapshot→free→restore cycle.
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
snap := c.Snapshot(0)
if snap == nil {
t.Fatal("Snapshot returned nil")
}
// Free and restore to a fresh cache.
c2 := NewKVCache()
if !c2.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
}
// Verify State() returns arrays with correct sequence dimension.
state := c2.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
// keys shape: [B, H, seqLen, Dk]
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
}
if state[1].Dim(2) != 10 {
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
}
}
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
// that can be sequentially restored to rebuild the original cache state.
func TestKVCacheSplitPreservesData(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
snap := c.Snapshot(0)
parent, child := c.Split(snap, 5)
if parent == nil || child == nil {
t.Fatal("Split returned nil")
}
// Restore parent → offset=5, seq dim=5.
c2 := NewKVCache()
if !c2.Restore(parent, 5) {
t.Fatal("Restore(parent) failed")
}
if c2.Offset() != 5 {
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
}
state := c2.State()
if state[0].Dim(2) != 5 {
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
}
// Restore child on top → offset=10, seq dim=10.
if !c2.Restore(child, 10) {
t.Fatal("Restore(child) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset after child = %d, want 10", c2.Offset())
}
state = c2.State()
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
}
}
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
// produces a snapshot equivalent to the original.
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
snap := c.Snapshot(0)
parent, child := c.Split(snap, 6)
merged := c.Merge(parent, child)
if merged == nil {
t.Fatal("Merge returned nil")
}
c2 := NewKVCache()
if !c2.Restore(merged, 10) {
t.Fatal("Restore(merged) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
}
state := c2.State()
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
}
if state[1].Dim(2) != 10 {
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
}
}
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
// Offset 3 is outside the window.
if c.Restore(nil, 3) {
t.Fatal("Restore(nil, 3) should fail when outside window")
}
}
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
// from a snapshot, the rotating cache has the correct window of data.
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
snap := c.Snapshot(0)
if snap == nil {
t.Fatal("Snapshot returned nil")
}
// Feed 5 more tokens.
for range 5 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
// Restore to offset 10.
if !c.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
}
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
seqDim := state[0].Dim(2)
if seqDim != 4 {
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
}
}
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
// snapshot correctly preserves the write position (idx), so subsequent
// single-token updates land in the right buffer slot.
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
// Fill the window: 6 tokens into a size-4 window.
// After this, idx has wrapped and the buffer has rotated.
for range 6 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
if c.Offset() != 6 {
t.Fatalf("offset = %d, want 6", c.Offset())
}
snap := c.Snapshot(0)
// Mutate the cache further so live state diverges from snapshot.
for range 3 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
}
// Restore to snapshot state.
if !c.Restore(snap, 6) {
t.Fatal("Restore failed")
}
if c.Offset() != 6 {
t.Fatalf("offset after restore = %d, want 6", c.Offset())
}
// Feed one more token. If idx was restored correctly, this should
// produce a valid window of size 4 at offset 7.
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(k, v)
if c.Offset() != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
}
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
seqDim := state[0].Dim(2)
if seqDim != 4 {
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
}
}

View File

@@ -56,16 +56,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
return detached
}
func snapshotPinned(a *mlx.Array) *mlx.Array {
if a == nil || !a.Valid() {
return nil
}
snap := mlx.Copy(a)
mlx.Eval(snap)
mlx.Pin(snap)
return snap
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
@@ -123,30 +113,69 @@ func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array
return keys, values
}
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
return c.convState, c.deltaState
func (c *RecurrentCache) State() []*mlx.Array {
return []*mlx.Array{c.convState, c.deltaState}
}
func (c *RecurrentCache) CanTrim() bool { return false }
func (c *RecurrentCache) Trim(n int) int {
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
_ = n
return 0
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
// does not depend on any parent state.
type recurrentSnapshot struct {
convState, deltaState *mlx.Array
offset int
}
func (c *RecurrentCache) Clone() Cache {
clone := &RecurrentCache{
offset: c.offset,
convTail: c.convTail,
convDim: c.convDim,
numVHeads: c.numVHeads,
headVDim: c.headVDim,
headKDim: c.headKDim,
convState: snapshotPinned(c.convState),
deltaState: snapshotPinned(c.deltaState),
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
// Recurrent state is not position-sliceable — always snapshot the full state.
if c.convState == nil && c.deltaState == nil {
return nil
}
return clone
snap := &recurrentSnapshot{offset: c.offset}
snap.convState = c.convState.Clone()
snap.deltaState = c.deltaState.Clone()
mlx.Pin(snap.convState, snap.deltaState)
return snap
}
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
if snapshot == nil {
// Recurrent state is cumulative and can't rewind. Only succeed
// if we're already at the target (no-op).
return target == c.offset
}
snap := snapshot.(*recurrentSnapshot)
// Recurrent state encodes all tokens up to snap.offset. Restoring
// to a target before that would leave stale state from tokens
// [target, snap.offset) baked in. Only allow restoring forward.
if target < snap.offset {
return false
}
c.convState = c.setStateRaw(c.convState, snap.convState)
c.deltaState = c.setStateRaw(c.deltaState, snap.deltaState)
c.offset = snap.offset
return true
}
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
// Recurrent snapshots are self-contained — child supersedes parent.
if parent != nil {
parent.Close()
}
return child
}
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
// Recurrent state is cumulative and not position-sliceable.
// Cannot recover intermediate state at the split point.
return nil, snapshot
}
func (c *RecurrentCache) Free() {
@@ -156,4 +185,3 @@ func (c *RecurrentCache) Free() {
}
func (c *RecurrentCache) Offset() int { return c.offset }
func (c *RecurrentCache) Len() int { return c.offset }

44
x/mlxrunner/cache/recurrent_test.go vendored Normal file
View File

@@ -0,0 +1,44 @@
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
// allows restoring forward (target >= snapshot offset), never backward.
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
skipIfNoMLX(t)
c := NewRecurrentCache(3, 12, 4, 8, 8)
_ = c.ConvState(1, mlx.DTypeFloat16)
_ = c.DeltaState(1, mlx.DTypeFloat16)
c.Advance(10)
snap := c.Snapshot(0)
c.Advance(5) // now at 15
// Restore backward should fail.
if c.Restore(snap, 5) {
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
}
// Restore to exact snap offset should succeed.
if !c.Restore(snap, 10) {
t.Fatal("Restore(snap, 10) should succeed")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
}
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
snap2 := c.Snapshot(0)
if !c.Restore(snap2, 15) {
t.Fatal("Restore(snap, 15) should succeed")
}
// Recurrent state is at snap.offset (10), not target (15).
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
}
}

859
x/mlxrunner/cache_test.go Normal file
View File

@@ -0,0 +1,859 @@
package mlxrunner
import (
"slices"
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// snapshotTracker records every fakeSnapshot created and every Close() call
// so tests can detect leaked (created but never closed) or double-closed snapshots.
type snapshotTracker struct {
all []*fakeSnapshot
}
func (tr *snapshotTracker) track(s *fakeSnapshot) {
if s == nil {
return
}
s.tracker = tr
tr.all = append(tr.all, s)
}
// Fake caches that store actual token sequences so tests can verify the right
// data was restored, not just the right offset.
// fakeSnapshot stores a copy of the token sub-sequence it covers.
type fakeSnapshot struct {
tokens []int32
from, to int
byteSize int // configurable for eviction tests
tracker *snapshotTracker
closeCount int
}
func (s *fakeSnapshot) Size() int { return s.byteSize }
func (s *fakeSnapshot) Close() {
s.closeCount++
}
// fakeRewindableCache tracks the full token sequence and supports
// arbitrary rewind via Restore(nil, target).
type fakeRewindableCache struct {
tokens []int32
tracker *snapshotTracker
}
func (c *fakeRewindableCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
}
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
func (c *fakeRewindableCache) Free() {
c.tokens = nil
}
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
if fromOffset >= len(c.tokens) {
return nil
}
from := fromOffset
if from < 0 {
from = 0
}
s := &fakeSnapshot{
tokens: slices.Clone(c.tokens[from:]),
from: from,
to: len(c.tokens),
}
c.tracker.track(s)
return s
}
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
if snapshot == nil {
// Rewind live state.
if target < 0 {
target = 0
}
if target > len(c.tokens) {
target = len(c.tokens)
}
c.tokens = c.tokens[:target]
return true
}
s := snapshot.(*fakeSnapshot)
if len(c.tokens) < s.from {
return false // don't have base data up to snapshot start
}
c.tokens = append(c.tokens[:s.from], s.tokens...)
if target < len(c.tokens) {
c.tokens = c.tokens[:target]
}
return true
}
func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
if parent == nil || child == nil {
if parent != nil {
parent.Close()
}
if child != nil {
child.Close()
}
return nil
}
p := parent.(*fakeSnapshot)
ch := child.(*fakeSnapshot)
merged := make([]int32, len(p.tokens)+len(ch.tokens))
copy(merged, p.tokens)
copy(merged[len(p.tokens):], ch.tokens)
s := &fakeSnapshot{
tokens: merged,
from: p.from,
to: ch.to,
byteSize: p.byteSize + ch.byteSize,
}
c.tracker.track(s)
p.Close()
ch.Close()
return s
}
func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
if snapshot == nil {
return nil, nil
}
s := snapshot.(*fakeSnapshot)
relAt := at - s.from
if relAt <= 0 {
return nil, snapshot
}
if relAt >= len(s.tokens) {
return snapshot, nil
}
p := &fakeSnapshot{
tokens: slices.Clone(s.tokens[:relAt]),
from: s.from,
to: at,
byteSize: s.byteSize,
}
ch := &fakeSnapshot{
tokens: slices.Clone(s.tokens[relAt:]),
from: at,
to: s.to,
byteSize: s.byteSize,
}
c.tracker.track(p)
c.tracker.track(ch)
s.Close()
return p, ch
}
// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full
// token sequence but only the trailing maxSize tokens are "live" in the window.
// Once the window fills, live rewind is impossible without a snapshot.
type fakeSlidingWindowCache struct {
tokens []int32
maxSize int
tracker *snapshotTracker
}
func (c *fakeSlidingWindowCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
}
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
func (c *fakeSlidingWindowCache) Free() {
c.tokens = nil
}
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
return nil
}
// Snapshot captures the full window state (like RotatingKVCache.Snapshot).
s := &fakeSnapshot{
tokens: slices.Clone(c.tokens),
from: 0,
to: len(c.tokens),
}
c.tracker.track(s)
return s
}
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
if snapshot == nil {
if target == len(c.tokens) {
return true
}
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
if len(c.tokens) > c.maxSize {
return false
}
c.tokens = c.tokens[:target]
return true
}
s := snapshot.(*fakeSnapshot)
c.tokens = slices.Clone(s.tokens)
if target < len(c.tokens) {
c.tokens = c.tokens[:target]
}
return true
}
func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
// Child supersedes parent for sliding window (full window state).
if parent != nil {
parent.Close()
}
return child
}
func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
// Can't split a ring buffer at an arbitrary point.
return nil, snapshot
}
// fakeRecurrentCache models RecurrentCache semantics: stores tokens
// but cannot rewind without a snapshot.
type fakeRecurrentCache struct {
tokens []int32
tracker *snapshotTracker
}
func (c *fakeRecurrentCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
}
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
func (c *fakeRecurrentCache) Free() {
c.tokens = nil
}
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
// Recurrent state is cumulative; snapshot captures the full state.
if len(c.tokens) == 0 {
return nil
}
s := &fakeSnapshot{
tokens: slices.Clone(c.tokens),
from: 0,
to: len(c.tokens),
}
c.tracker.track(s)
return s
}
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
if snapshot == nil {
return target == len(c.tokens) // can only no-op
}
s := snapshot.(*fakeSnapshot)
if target < s.to {
return false // can't go backward
}
c.tokens = slices.Clone(s.tokens)
return true
}
func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
// Child supersedes parent for cumulative state.
if parent != nil {
parent.Close()
}
return child
}
func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
return nil, snapshot // can't split cumulative state
}
type feedableCache interface {
cache.Cache
feed(tokens []int32)
}
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
type testEnv struct {
kvc *kvCache
caches []cache.Cache // typed references for assertions
tracker *snapshotTracker
}
// newTransformerEnv creates a test environment with a single rewindable cache
// (pure transformer model).
func newTransformerEnv() *testEnv {
tracker := &snapshotTracker{}
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tracker,
}
}
// newSlidingWindowEnv creates a test environment with one rewindable cache and
// one sliding window cache (Mistral-style architecture).
func newSlidingWindowEnv() *testEnv {
tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr}
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
caches := []cache.Cache{rc, sw}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
}
}
// newRecurrentEnv creates a test environment with one rewindable cache and one
// non-rewindable cache (Jamba-style architecture).
func newRecurrentEnv() *testEnv {
tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr}
nrc := &fakeRecurrentCache{tracker: tr}
caches := []cache.Cache{rc, nrc}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
}
}
// assertAllTokens checks that every cache in the environment contains exactly
// the expected token sequence.
func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) {
t.Helper()
for i, c := range e.caches {
assertTokens(t, label, c, expected)
// Verify all caches report the same offset.
if i > 0 && c.Offset() != e.caches[0].Offset() {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
label, i, c.Offset(), e.caches[0].Offset())
}
}
}
// simulateRequest mirrors the production pipeline lifecycle:
// begin -> prefill with snapshot(false) at branch points -> generate -> close
type requestResult struct {
remaining []int32
snapshotOffset int
}
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
// a user snapshot (snapshot(true)) is created at that offset during prefill.
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
t.Helper()
userSnapAt := 0
if len(userSnapshotAt) > 0 {
userSnapAt = userSnapshotAt[0]
}
session := kvc.begin(nil, inputs)
result := requestResult{
remaining: slices.Clone(session.remaining),
snapshotOffset: session.snapshotOffset,
}
assertCacheOffsetAlignment(t, kvc, "after begin")
baseOffset := kvc.minCacheOffset()
remaining := inputs[baseOffset:]
// Collect snapshot points (offset -> user flag) in ascending order.
type snapPoint struct {
offset int
user bool
}
var points []snapPoint
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
points = append(points, snapPoint{session.snapshotOffset, false})
}
if userSnapAt > 0 && userSnapAt > baseOffset {
points = append(points, snapPoint{userSnapAt, true})
}
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
// Prefill: feed tokens, pausing at each snapshot point.
for _, sp := range points {
count := sp.offset - baseOffset
if count > len(remaining) {
break
}
if count > 0 {
feedAll(kvc.caches, remaining[:count])
remaining = remaining[count:]
baseOffset = sp.offset
}
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
session.snapshot(sp.user)
}
// Feed rest of input tokens.
if len(remaining) > 0 {
feedAll(kvc.caches, remaining)
}
assertCacheOffsetAlignment(t, kvc, "after prefill")
// Generate tokens.
if len(generated) > 0 {
session.outputs = generated
feedAll(kvc.caches, generated)
}
assertCacheOffsetAlignment(t, kvc, "before close")
session.close()
return result
}
func feedAll(caches []cache.Cache, tokens []int32) {
for _, c := range caches {
if fc, ok := c.(feedableCache); ok {
fc.feed(tokens)
}
}
}
// assertCacheOffsetAlignment verifies all caches report the same offset.
func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
t.Helper()
if len(kvc.caches) < 2 {
return
}
expected := kvc.caches[0].Offset()
for i := 1; i < len(kvc.caches); i++ {
if got := kvc.caches[i].Offset(); got != expected {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
}
}
}
// assertTokens checks that a feedable cache contains the expected token sequence.
// For sliding window caches, only the trailing maxSize tokens are checked.
func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) {
t.Helper()
switch fc := c.(type) {
case *fakeRewindableCache:
if !slices.Equal(fc.tokens, expected) {
t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected)
}
case *fakeSlidingWindowCache:
// Sliding window stores full history but only trailing maxSize are live.
// Verify the full token sequence matches (the window semantics are
// enforced by Snapshot/Restore, not by the token log).
if !slices.Equal(fc.tokens, expected) {
t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected)
}
case *fakeRecurrentCache:
if !slices.Equal(fc.tokens, expected) {
t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected)
}
default:
t.Fatalf("%s: unknown cache type %T", label, c)
}
}
// checkTrieInvariants walks the trie and checks structural invariants.
func checkTrieInvariants(t *testing.T, root *trieNode) {
t.Helper()
walkNodes(root, func(n *trieNode) bool {
if n.parent != nil {
if n.startOffset() != n.parent.endOffset {
t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d",
n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset)
}
}
if len(n.tokens) != n.endOffset-n.startOffset() {
t.Errorf("node [%d,%d): token count %d != offset span %d",
n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset())
}
for _, c := range n.children {
if c.parent != n {
t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset)
}
}
// No two siblings should start with the same token.
seen := make(map[int32]bool)
for _, c := range n.children {
if len(c.tokens) > 0 {
first := c.tokens[0]
if seen[first] {
t.Errorf("node [%d,%d): duplicate sibling first token %d",
n.startOffset(), n.endOffset, first)
}
seen[first] = true
}
}
return true
})
}
// checkSnapshotLeaks verifies that every tracked snapshot is either still live
// in the trie (closeCount == 0) or has been closed exactly once. It reports
// leaked snapshots (not in trie, never closed) and double-closes.
func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) {
t.Helper()
if tracker == nil {
return
}
// Collect all live snapshots still referenced by trie nodes.
live := make(map[*fakeSnapshot]bool)
walkNodes(root, func(n *trieNode) bool {
for _, s := range n.snapshots {
if s != nil {
if fs, ok := s.(*fakeSnapshot); ok {
live[fs] = true
}
}
}
return true
})
for i, s := range tracker.all {
if live[s] {
if s.closeCount != 0 {
t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)",
i, s.from, s.to, s.closeCount)
}
} else {
if s.closeCount == 0 {
t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie",
i, s.from, s.to)
} else if s.closeCount > 1 {
t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times",
i, s.from, s.to, s.closeCount)
}
}
}
}
// forEachEnv runs fn as subtests for three realistic model configurations:
// pure transformer, transformer + sliding window (Mistral-style), and
// transformer + recurrent (Jamba-style). Leak checking runs automatically
// at the end of each subtest.
func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) {
t.Helper()
run := func(t *testing.T, env *testEnv) {
t.Cleanup(func() {
checkSnapshotLeaks(t, env.tracker, env.kvc.root)
})
fn(t, env)
}
t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) })
t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) })
t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) })
}
// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle:
// two conversations share a prefix and diverge, creating a branch point.
// A third conversation extends the first. Verifies trie structure, cache
// hit lengths, and that semantic caches contain the correct token sequences.
func TestBranchCreationAndReuse(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss.
resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21})
if len(resA.remaining) != 8 {
t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining))
}
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
// Verify trie was populated by close().
_, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
if mA != 10 {
t.Fatalf("A findable: expected 10 matched, got %d", mA)
}
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
// Partial match in A's edge triggers snapshotOffset.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if resB.snapshotOffset != 5 {
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
}
// Cache was rewound to 0 (partial match truncates path to root),
// so all tokens were re-evaluated.
if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
// Both A and B should be findable in the trie.
_, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
if mA2 < 5 {
t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2)
}
_, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
if mB < 5 {
t.Fatalf("B findable: expected >= 5 matched, got %d", mB)
}
// Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix.
// Should get a cache hit for the shared prefix.
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil)
if len(resC.remaining) >= 10 {
t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining))
}
env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41})
checkTrieInvariants(t, kvc.root)
})
}
// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact
// same prompt is requested twice, the cache does not overclaim cached work.
// The last token must be re-evaluated to seed generation.
func TestExactMatchSeedBehavior(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: first time.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
// Request B: identical prompt. Holdback means matched=4, partial in
// the 5-token edge, so path truncates to root and all tokens are
// re-evaluated. snapshotOffset should be set at the holdback point.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
}
if resB.snapshotOffset != 4 {
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
checkTrieInvariants(t, kvc.root)
})
}
// TestConversationResumption tests the most common pattern: user sends a message,
// gets a response, then sends a follow-up. The follow-up should reuse the cached
// prefix (system prompt + first turn + assistant response).
func TestConversationResumption(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Turn 1: system prompt + user message, assistant generates response.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12})
env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12})
// Turn 2: full history + new user message. Should get a cache hit on
// the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21].
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30})
if len(resB.remaining) > 5 {
t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining))
}
env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30})
// Turn 3: even longer history.
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil)
if len(resC.remaining) > 5 {
t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining))
}
env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41})
checkTrieInvariants(t, kvc.root)
})
}
// TestEvictionPreservesActiveConversations creates multiple conversations sharing
// a system prompt, triggers eviction via large snapshot sizes, and verifies the
// active path and shared prefix survive while memory stays bounded.
func TestEvictionPreservesActiveConversations(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
systemPrompt := []int32{1, 2, 3, 4, 5}
// Create 5 conversations with unique suffixes.
for i := range 5 {
suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)}
inputs := append(slices.Clone(systemPrompt), suffix...)
simulateRequest(t, kvc, inputs, []int32{int32(200 + i)})
}
// Inflate snapshot sizes to trigger eviction.
walkNodes(kvc.root, func(n *trieNode) bool {
if !n.hasSnapshots() {
return true
}
snaps := make([]cache.Snapshot, len(n.snapshots))
for i, s := range n.snapshots {
if s != nil {
snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot
}
}
n.setSnapshots(snaps, &kvc.pagedOutBytes)
return true
})
// Run eviction.
kvc.enforceEvictionPolicy()
// Memory should be within limits.
if kvc.pagedOutBytes > maxPagedOutBytes {
t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes)
}
// Active path should be untouched.
if len(kvc.activePath) < 2 {
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
}
// System prompt prefix should still be findable (evicting a
// multi-child branch point only drops snapshots, not the node).
_, matched := findBestMatch(kvc.root, systemPrompt)
if matched < len(systemPrompt) {
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
}
checkTrieInvariants(t, kvc.root)
})
}
// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots
// (snapshot(true)) resist structural changes that would destroy them:
// - A user node forces new tokens into a child instead of extending in-place
// - The snapshot remains restorable after other branches are added
func TestUserSnapshotPreservesRestorePoint(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: user snapshot at offset 5, then generate.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5)
assertUserNodeExists(t, kvc, "after A")
// Request B: extends A's prefix. The user node at offset 5 should
// force tokens into a child rather than extending in-place.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil)
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21})
assertUserNodeExists(t, kvc, "after B")
// Request C: diverge from the user node.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40})
// Request D: switch back to A's branch — user snapshot still restorable.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil)
env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50})
checkTrieInvariants(t, kvc.root)
})
}
// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted,
// a user-marked parent node is not auto-merged with its remaining single child.
func TestUserSnapshotResistsAutoMerge(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: user snapshot at offset 3, then continue to offset 5.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3)
// Request B: diverges at the user node, creating a second child.
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20})
userNode := findUserNode(t, kvc)
if len(userNode.children) != 2 {
t.Fatalf("user node children = %d, want 2", len(userNode.children))
}
// Inflate snapshot sizes and evict. The non-active branch should be
// evicted, leaving the user node with one child.
walkNodes(kvc.root, func(n *trieNode) bool {
if !n.hasSnapshots() {
return true
}
snaps := make([]cache.Snapshot, len(n.snapshots))
for i, s := range n.snapshots {
if s != nil {
snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024}
}
}
n.setSnapshots(snaps, &kvc.pagedOutBytes)
return true
})
kvc.enforceEvictionPolicy()
// The user node should still exist (not auto-merged) even with one child.
assertUserNodeExists(t, kvc, "after eviction")
checkTrieInvariants(t, kvc.root)
})
}
func findUserNode(t *testing.T, kvc *kvCache) *trieNode {
t.Helper()
var found *trieNode
walkNodes(kvc.root, func(n *trieNode) bool {
if n.user {
found = n
}
return true
})
if found == nil {
t.Fatal("no user-marked node found")
}
return found
}
func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) {
t.Helper()
var exists bool
walkNodes(kvc.root, func(n *trieNode) bool {
if n.user {
exists = true
}
return true
})
if !exists {
t.Fatalf("%s: no user-marked node found", label)
}
}
// TestBranchSwitchRestoresCorrectState exercises switching back to an older
// branch after working on a different one, verifying that the restored cache
// state contains the correct token sequence for both rewindable and
// non-rewindable caches.
func TestBranchSwitchRestoresCorrectState(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: [1,2,3,4,5] + generate [10,11]
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11})
// Request B: [1,2,3,6,7] — diverges at token 4
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13})
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13})
// Request C: switch back to A's branch [1,2,3,4,5,10,11,20]
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil)
env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20})
checkTrieInvariants(t, kvc.root)
})
}

296
x/mlxrunner/cache_trie.go Normal file
View File

@@ -0,0 +1,296 @@
package mlxrunner
import (
"fmt"
"slices"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
// trieNode represents a node in the compressed prefix trie for KV cache branching.
// Each node stores a compressed edge (multiple tokens) and optional paged-out
// snapshot data per cache layer.
type trieNode struct {
tokens []int32 // compressed edge — multiple tokens per node
endOffset int // cumulative tokens from root to end of this node
parent *trieNode
children []*trieNode
lastUsed time.Time // for LRU eviction
snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out)
user bool // true = explicit restore point (resist auto-merge)
}
// startOffset returns the cumulative token offset at the start of this node's edge.
func (n *trieNode) startOffset() int {
return n.endOffset - len(n.tokens)
}
// snapshotBytes returns the total bytes of paged-out snapshots on this node.
func (n *trieNode) snapshotBytes() int64 {
var total int64
for _, s := range n.snapshots {
if s != nil {
total += int64(s.Size())
}
}
return total
}
// setSnapshots replaces this node's snapshots with snaps and closes the old ones.
// If counter is non-nil, the net byte delta is applied to it.
func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) {
old := n.swapSnapshots(snaps, counter)
for _, s := range old {
if s != nil {
s.Close()
}
}
}
// swapSnapshots is like setSnapshots but returns the previous snapshots
// without closing them. Use this when the old snapshots will be consumed
// (e.g. by Split/Merge).
func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot {
old := n.snapshots
if counter != nil {
*counter -= n.snapshotBytes()
}
n.snapshots = snaps
if counter != nil {
*counter += n.snapshotBytes()
}
return old
}
// hasSnapshots returns true if any layer has snapshot data.
func (n *trieNode) hasSnapshots() bool {
return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil })
}
// hasAllSnapshots returns true if every layer has snapshot data.
func (n *trieNode) hasAllSnapshots() bool {
return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil)
}
// findBestMatch walks the trie matching input tokens, returning the path of
// nodes traversed and the total number of tokens matched.
func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) {
if root == nil {
return nil, 0
}
path = []*trieNode{root}
pos := 0
node := root
for pos < len(tokens) {
// When multiple children share the same first token (e.g. after
// a split), prefer the child whose full edge matches over one
// that only partially matches. This is just being defensive - it
// shouldn't actually happen.
var best *trieNode
bestMatched := 0
bestFull := false
for _, child := range node.children {
edge := child.tokens
if len(edge) == 0 {
continue
}
if edge[0] != tokens[pos] {
continue
}
// Count matching tokens in this child's edge.
j := 0
for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] {
j++
}
full := j == len(edge)
// Prefer full edge matches; among same type, prefer longer.
if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) {
best = child
bestMatched = j
bestFull = full
}
}
if best == nil {
break
}
pos += bestMatched
path = append(path, best)
if !bestFull {
// Partial match within this edge
break
}
node = best
}
return path, pos
}
// appendTokens either creates a new child node or extends the leaf in place,
// returning the node that now holds the tokens.
func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode {
if n == root || len(n.children) > 0 || n.hasSnapshots() {
child := &trieNode{
tokens: make([]int32, len(tokens)),
endOffset: endOffset,
parent: n,
lastUsed: n.lastUsed,
}
copy(child.tokens, tokens)
n.children = append(n.children, child)
return child
}
n.tokens = append(n.tokens, tokens...)
n.endOffset = endOffset
return n
}
// removeNode removes a leaf node from the trie.
func removeNode(node *trieNode, counter *int64) {
if node.parent == nil {
panic("removeNode called on root")
}
if len(node.children) != 0 {
panic("removeNode called on non-leaf node")
}
p := node.parent
for i, child := range p.children {
if child == node {
p.children = append(p.children[:i], p.children[i+1:]...)
break
}
}
node.parent = nil
node.setSnapshots(nil, counter)
}
// splitNode splits a node at the given token offset within its edge,
// creating a new parent node. Returns the new parent.
// `at` is relative to the node's edge (0-based index into node.tokens).
// If caches are provided, snapshots are split between parent and child
// using Cache.Split; otherwise snapshots are invalidated.
func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode {
if at <= 0 || at >= len(node.tokens) {
panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens)))
}
// Create new parent with the prefix of the edge.
newParent := &trieNode{
tokens: make([]int32, at),
endOffset: node.startOffset() + at,
parent: node.parent,
children: []*trieNode{node},
lastUsed: node.lastUsed,
}
copy(newParent.tokens, node.tokens[:at])
// Update the original node to have only the suffix.
node.tokens = node.tokens[at:]
// endOffset stays the same for the original node.
// Split snapshots between parent and child using Cache.Split.
// Split consumes the old snapshots, so we remove them first (adjusting
// the counter), then assign the split halves (adjusting it back).
if node.hasSnapshots() {
oldSnaps := node.swapSnapshots(nil, counter)
parentSnaps := make([]cache.Snapshot, len(oldSnaps))
childSnaps := make([]cache.Snapshot, len(oldSnaps))
for i, snap := range oldSnaps {
if snap != nil {
parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset)
}
}
newParent.setSnapshots(parentSnaps, counter)
node.setSnapshots(childSnaps, counter)
}
// Reparent: replace node with newParent in the old parent's children.
if node.parent != nil {
for i, child := range node.parent.children {
if child == node {
node.parent.children[i] = newParent
break
}
}
}
node.parent = newParent
return newParent
}
// mergeWithChild merges a node with its single child: concatenates tokens,
// merges snapshot data via Cache.Merge, and removes the child.
func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) {
if len(node.children) != 1 {
panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children)))
}
child := node.children[0]
// Concatenate tokens.
node.tokens = append(node.tokens, child.tokens...)
node.endOffset = child.endOffset
// Merge snapshots per layer. Merge consumes the old snapshots, so we
// remove them first (adjusting the counter), then assign the merged
// result (adjusting it back).
if len(node.snapshots) > 0 || len(child.snapshots) > 0 {
nodeSnaps := node.swapSnapshots(nil, counter)
childSnaps := child.swapSnapshots(nil, counter)
merged := make([]cache.Snapshot, len(caches))
for i := range caches {
var ps, cs cache.Snapshot
if nodeSnaps != nil {
ps = nodeSnaps[i]
}
if childSnaps != nil {
cs = childSnaps[i]
}
merged[i] = caches[i].Merge(ps, cs)
}
node.setSnapshots(merged, counter)
}
// Adopt grandchildren.
node.children = child.children
for _, gc := range node.children {
gc.parent = node
}
// Inherit user flag from child if child was a user-created snapshot node.
node.user = child.user
// Update lastUsed to the more recent of the two.
if child.lastUsed.After(node.lastUsed) {
node.lastUsed = child.lastUsed
}
child.parent = nil
child.children = nil
}
// walkNodes calls fn for every node in the trie (depth-first).
// If fn returns false, the walk stops.
func walkNodes(root *trieNode, fn func(*trieNode) bool) {
if root == nil {
return
}
var walk func(*trieNode) bool
walk = func(n *trieNode) bool {
if !fn(n) {
return false
}
for _, child := range n.children {
if !walk(child) {
return false
}
}
return true
}
walk(root)
}

View File

@@ -0,0 +1,455 @@
package mlxrunner
import (
"slices"
"testing"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
func newTestTrie(tokens []int32) *trieNode {
root := &trieNode{lastUsed: time.Now()}
if len(tokens) > 0 {
child := &trieNode{
tokens: slices.Clone(tokens),
endOffset: len(tokens),
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{child}
}
return root
}
func TestFindBestMatchMultipleBranches(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
branch1 := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
branch2 := &trieNode{
tokens: []int32{4, 5, 6},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{branch1, branch2}
// Match branch 1.
path, matched := findBestMatch(root, []int32{1, 2, 3, 7})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if len(path) != 2 || path[1] != branch1 {
t.Fatal("expected to match branch1")
}
// Match branch 2.
path, matched = findBestMatch(root, []int32{4, 5, 6, 8})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if len(path) != 2 || path[1] != branch2 {
t.Fatal("expected to match branch2")
}
// Match neither.
_, matched = findBestMatch(root, []int32{7, 8, 9})
if matched != 0 {
t.Fatalf("expected 0 matched, got %d", matched)
}
}
func TestFindBestMatchPrefersFullEdge(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
shared := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{shared}
longer := &trieNode{
tokens: []int32{10, 11, 12, 13, 14},
endOffset: 8,
parent: shared,
lastUsed: time.Now(),
}
shorter := &trieNode{
tokens: []int32{10, 11, 12},
endOffset: 6,
parent: shared,
lastUsed: time.Now(),
}
// Put longer first so naive first-match would pick it.
shared.children = []*trieNode{longer, shorter}
input := []int32{1, 2, 3, 10, 11, 12, 99, 100}
path, matched := findBestMatch(root, input)
if matched != 6 {
t.Fatalf("expected 6 matched, got %d", matched)
}
if len(path) != 3 {
t.Fatalf("expected 3 nodes in path, got %d", len(path))
}
if path[2] != shorter {
t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)")
}
}
func TestFindBestMatchPrefersLongerPartial(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
child1 := &trieNode{
tokens: []int32{1, 2, 3, 4, 5},
endOffset: 5,
parent: root,
lastUsed: time.Now(),
}
child2 := &trieNode{
tokens: []int32{1, 2, 9},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{child2, child1}
input := []int32{1, 2, 3, 7, 8}
path, matched := findBestMatch(root, input)
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if path[1] != child1 {
t.Fatal("expected findBestMatch to pick child1 (longer partial match)")
}
}
func TestSplitNodeWithSnapshots(t *testing.T) {
root := newTestTrie([]int32{1, 2, 3, 4, 5})
child := root.children[0]
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
child.user = true
caches := []cache.Cache{rc}
newParent := splitNode(child, 3, caches, nil)
if !newParent.hasSnapshots() {
t.Fatal("newParent should have snapshots after split")
}
if newParent.user {
t.Fatal("newParent should not be a user snapshot after splitNode")
}
if !child.hasSnapshots() {
t.Fatal("child should have snapshots after split")
}
if !child.user {
t.Fatal("child should remain a user snapshot")
}
}
func TestFindSplitAppendSequence(t *testing.T) {
root := newTestTrie([]int32{1, 2, 3, 4, 5})
path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
lastNode := path[len(path)-1]
matchedInEdge := matched - lastNode.startOffset()
split := splitNode(lastNode, matchedInEdge, nil, nil)
split.appendTokens(root, []int32{6, 7}, 5)
if len(root.children) != 1 {
t.Fatalf("root should have 1 child, got %d", len(root.children))
}
shared := root.children[0]
if !slices.Equal(shared.tokens, []int32{1, 2, 3}) {
t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens)
}
if len(shared.children) != 2 {
t.Fatalf("shared should have 2 children, got %d", len(shared.children))
}
_, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5})
if m1 != 5 {
t.Fatalf("original branch: expected 5 matched, got %d", m1)
}
_, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if m2 != 5 {
t.Fatalf("new branch: expected 5 matched, got %d", m2)
}
_, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9})
if m3 != 3 {
t.Fatalf("unrelated input: expected 3 matched, got %d", m3)
}
}
func TestRepeatedBranching(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5)
_, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if matchedB != 3 {
t.Fatalf("B: expected 3 matched, got %d", matchedB)
}
nodeA := root.children[0]
split1 := splitNode(nodeA, 3, nil, nil)
split1.appendTokens(root, []int32{6, 7}, 5)
_, matchedC := findBestMatch(root, []int32{1, 2, 8, 9})
if matchedC != 2 {
t.Fatalf("C: expected 2 matched, got %d", matchedC)
}
split2 := splitNode(split1, 2, nil, nil)
split2.appendTokens(root, []int32{8, 9}, 4)
_, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5})
if mA != 5 {
t.Fatalf("A: expected 5 matched, got %d", mA)
}
_, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if mB != 5 {
t.Fatalf("B: expected 5 matched, got %d", mB)
}
_, mC := findBestMatch(root, []int32{1, 2, 8, 9})
if mC != 4 {
t.Fatalf("C: expected 4 matched, got %d", mC)
}
checkTrieInvariants(t, root)
}
func TestMergeWithChild(t *testing.T) {
t.Run("Basic", func(t *testing.T) {
// root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]}
now := time.Now()
root := &trieNode{lastUsed: now}
a := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: now,
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}},
}
b := &trieNode{
tokens: []int32{4, 5},
endOffset: 5,
parent: a,
lastUsed: now,
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}},
}
c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now}
d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now}
root.children = []*trieNode{a}
a.children = []*trieNode{b}
b.children = []*trieNode{c, d}
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
mergeWithChild(a, []cache.Cache{mc}, nil)
// Tokens concatenated.
if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) {
t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens)
}
if a.endOffset != 5 {
t.Fatalf("merged endOffset = %d, want 5", a.endOffset)
}
// Grandchildren reparented.
if len(a.children) != 2 {
t.Fatalf("merged children count = %d, want 2", len(a.children))
}
if c.parent != a || d.parent != a {
t.Fatal("grandchildren should be reparented to merged node")
}
// B detached.
if b.parent != nil || b.children != nil || b.snapshots != nil {
t.Fatal("child B should be fully detached after merge")
}
// Merged snapshot should cover [0,5).
if !a.hasSnapshots() {
t.Fatal("merged node should have snapshots")
}
ms := a.snapshots[0].(*fakeSnapshot)
if ms.from != 0 || ms.to != 5 {
t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
}
checkTrieInvariants(t, root)
})
t.Run("UserFlag", func(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
parent := &trieNode{
tokens: []int32{1, 2}, endOffset: 2, parent: root,
lastUsed: time.Now(), user: false,
}
child := &trieNode{
tokens: []int32{3, 4}, endOffset: 4, parent: parent,
lastUsed: time.Now(), user: true,
}
root.children = []*trieNode{parent}
parent.children = []*trieNode{child}
mergeWithChild(parent, nil, nil)
if !parent.user {
t.Fatal("merged node should inherit user=true from child")
}
})
t.Run("LastUsed", func(t *testing.T) {
now := time.Now()
root := &trieNode{lastUsed: now}
parent := &trieNode{
tokens: []int32{1}, endOffset: 1, parent: root,
lastUsed: now.Add(-1 * time.Hour),
}
child := &trieNode{
tokens: []int32{2}, endOffset: 2, parent: parent,
lastUsed: now.Add(1 * time.Hour),
}
root.children = []*trieNode{parent}
parent.children = []*trieNode{child}
mergeWithChild(parent, nil, nil)
if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) {
t.Fatal("merged node should pick the more recent lastUsed")
}
})
t.Run("PanicOnMultipleChildren", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on node with 2 children")
}
}()
root := &trieNode{lastUsed: time.Now()}
node := &trieNode{
tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(),
children: []*trieNode{
{tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()},
{tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()},
},
}
root.children = []*trieNode{node}
mergeWithChild(node, nil, nil)
})
}
func TestSplitMergeRoundTrip(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
leaf := &trieNode{
tokens: []int32{1, 2, 3, 4, 5},
endOffset: 5,
parent: root,
lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}},
}
root.children = []*trieNode{leaf}
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
caches := []cache.Cache{mc}
// Split at 3: [1,2,3] -> [4,5]
newParent := splitNode(leaf, 3, caches, nil)
if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) {
t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens)
}
if !slices.Equal(leaf.tokens, []int32{4, 5}) {
t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens)
}
checkTrieInvariants(t, root)
// Merge back: should restore [1,2,3,4,5]
mergeWithChild(newParent, caches, nil)
if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) {
t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens)
}
if newParent.endOffset != 5 {
t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset)
}
if len(newParent.children) != 0 {
t.Fatalf("after merge: children count = %d, want 0", len(newParent.children))
}
// Merged snapshot should cover [0,5).
if !newParent.hasSnapshots() {
t.Fatal("after merge: should have snapshots")
}
ms := newParent.snapshots[0].(*fakeSnapshot)
if ms.from != 0 || ms.to != 5 {
t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
}
checkTrieInvariants(t, root)
}
func TestRemoveNode(t *testing.T) {
t.Run("Leaf", func(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
shared := &trieNode{
tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(),
}
leafA := &trieNode{
tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
}
leafB := &trieNode{
tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
}
root.children = []*trieNode{shared}
shared.children = []*trieNode{leafA, leafB}
removeNode(leafA, nil)
if len(shared.children) != 1 {
t.Fatalf("parent should have 1 child, got %d", len(shared.children))
}
if shared.children[0] != leafB {
t.Fatal("remaining child should be leafB")
}
if leafA.parent != nil {
t.Fatal("removed node parent should be nil")
}
if leafA.snapshots != nil {
t.Fatal("removed node snapshots should be nil")
}
checkTrieInvariants(t, root)
})
t.Run("PanicOnRoot", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when removing root")
}
}()
removeNode(&trieNode{}, nil)
})
t.Run("PanicOnNonLeaf", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when removing non-leaf")
}
}()
parent := &trieNode{parent: &trieNode{}}
parent.children = []*trieNode{{}}
removeNode(parent, nil)
})
}

View File

@@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct {
Prompt string `json:"prompt"`
Images []llm.ImageData `json:"images,omitempty"`
Options *completionOpts `json:"options,omitempty"`
}
@@ -155,6 +156,7 @@ func (c *Client) Close() error {
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := completionRequest{
Prompt: req.Prompt,
Images: req.Images,
}
if req.Options != nil {
creq.Options = &completionOpts{

View File

@@ -18,6 +18,10 @@ func Version() string {
}
func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 {
return
}
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)

View File

@@ -304,6 +304,18 @@ func Exp(a *Array) *Array {
return out
}
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)

View File

@@ -4,10 +4,14 @@ package mlx
import "C"
import (
"cmp"
"math"
"unsafe"
)
// End is a sentinel value meaning "to the end of the dimension",
// equivalent to an omitted stop in Python (e.g. a[i:]).
const End = math.MaxInt32
type slice struct {
args []int
}
@@ -16,6 +20,16 @@ func Slice(args ...int) slice {
return slice{args: args}
}
func resolve(val, dim int) C.int {
if val == End {
return C.int(dim)
}
if val < 0 {
return C.int(dim + val)
}
return C.int(val)
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
@@ -28,26 +42,28 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
}
for i, s := range slices {
dim := dims[i]
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dims[i])
args[1][i] = C.int(dim)
args[2][i] = C.int(1)
case 1:
// slice[i]
args[0][i] = C.int(s.args[0])
args[1][i] = C.int(s.args[0] + 1)
start := resolve(s.args[0], dim)
args[0][i] = start
args[1][i] = start + 1
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = C.int(s.args[0])
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")

View File

@@ -0,0 +1,32 @@
package base
import (
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// ImageInput is a single image attached to a prompt.
type ImageInput struct {
ID int
Data []byte
}
// PromptTokenization contains tokenized prompt IDs plus optional request-scoped
// model metadata needed during forward.
type PromptTokenization struct {
Tokens []int32
State any
}
// MultimodalPromptTokenizerWithState is an optional model interface used by
// mlxrunner to expand tagged multimodal prompts into token IDs, returning
// request-scoped state to be attached to the forward pass.
type MultimodalPromptTokenizerWithState interface {
TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error)
}
// ForwardWithStateModel is an optional model interface for request-scoped
// forward metadata that should not be stored in shared caches.
type ForwardWithStateModel interface {
ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array
}

View File

@@ -12,12 +12,42 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
if len(request.Images) > 0 {
// The shared trie cache keys only on token IDs today, so multimodal
// prompts must not reuse snapshots across distinct image inputs.
r.cache.clear()
}
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
images := make([]base.ImageInput, len(request.Images))
for i := range request.Images {
images[i] = base.ImageInput{
ID: request.Images[i].ID,
Data: request.Images[i].Data,
}
}
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
if err != nil {
return nil, nil, err
}
if out == nil {
return nil, nil, errors.New("empty multimodal tokenization result")
}
return out.Tokens, out.State, nil
}
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
}
func (r *Runner) TextGenerationPipeline(request Request) error {
if r.Model == nil {
return errors.New("model not loaded")
@@ -50,12 +80,15 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.log()
r.cache.dumpTree()
}
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := r.Tokenizer.Encode(request.Prompt, true)
inputs, promptState, err := r.tokenizeRequest(request)
if err != nil {
return err
}
if len(inputs) == 0 {
return errors.New("empty prompt")
}
@@ -83,10 +116,17 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
tokens := session.remaining
prefillChunk := prefillChunkSize()
modelForward := func(tokens *mlx.Array) *mlx.Array {
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
return withState.ForwardWithState(tokens, caches, promptState)
}
return r.Model.Forward(tokens, caches)
}
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
state = appendCacheState(state, c)
state = append(state, c.State()...)
}
if len(state) == 0 {
return
@@ -102,16 +142,37 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
n := min(prefillChunk, total-processed-1)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
// If there's a pending intermediate snapshot, split the batch
// so we can capture it at the exact offset. The cache offset
// after this batch will be: baseOffset + processed + n.
if session.snapshotOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot
}
}
modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
mlx.Sweep()
materializeCaches()
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
// Create snapshot at branch point for future diverging requests.
if session.snapshotOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
if baseOffset+processed >= session.snapshotOffset {
session.snapshot(false)
}
}
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
fwd := modelForward(token.ExpandDims(0))
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
@@ -29,7 +30,8 @@ type Request struct {
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Prompt string `json:"prompt"`
Images []llm.ImageData `json:"images,omitempty"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`

View File

@@ -169,7 +169,7 @@ func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
return logprobs
}
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0))
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}

View File

@@ -0,0 +1,354 @@
package qwen3_5
import (
"fmt"
"math"
"regexp"
"strconv"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`)
type promptVisionSpan struct {
Start int32
End int32
Main *mlx.Array
Grid *VisionGrid
}
type promptVisionState struct {
Spans []promptVisionSpan
PositionCache []int32
}
func promptStartPosFromCaches(caches []cache.Cache) int32 {
offset := -1
for _, c := range caches {
if c == nil {
continue
}
off := c.Offset()
if offset < 0 || off < offset {
offset = off
}
}
if offset < 0 {
return 0
}
return int32(offset)
}
func promptVisionStateFromState(state any) *promptVisionState {
typed, _ := state.(*promptVisionState)
return typed
}
func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) {
chunkEnd := chunkStart + chunkLen
overlapStart := max(chunkStart, spanStart)
overlapEnd := min(chunkEnd, spanEnd)
if overlapStart >= overlapEnd {
return 0, 0, 0, 0, false
}
chunkLo := overlapStart - chunkStart
chunkHi := overlapEnd - chunkStart
spanLo := overlapStart - spanStart
spanHi := overlapEnd - spanStart
return chunkLo, chunkHi, spanLo, spanHi, true
}
func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array {
if m == nil || h == nil || state == nil || len(state.Spans) == 0 {
return h
}
dims := h.Dims()
if len(dims) != 3 {
return h
}
L := int32(dims[1])
for _, span := range state.Spans {
chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End)
if !ok || span.Main == nil || !span.Main.Valid() {
continue
}
repl := span.Main.Slice(
mlx.Slice(),
mlx.Slice(int(spanLo), int(spanHi)),
mlx.Slice(),
)
repl = repl.AsType(h.DType())
h = h.SliceUpdate(
repl,
mlx.Slice(),
mlx.Slice(int(chunkLo), int(chunkHi)),
mlx.Slice(),
)
}
return h
}
func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) {
for i := range images {
if images[i].ID == id {
return images[i], true
}
}
return base.ImageInput{}, false
}
func mapPromptPosition(state *promptVisionState, id int32) int32 {
if state == nil {
return id
}
if id < int32(len(state.PositionCache)) {
return state.PositionCache[id]
}
if len(state.PositionCache) > 0 {
return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1
}
return id
}
func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 {
if fallback <= 0 {
fallback = 1
}
if grid == nil {
return fallback
}
if merge <= 0 {
merge = 1
}
return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge))
}
func normalizeMRoPESections(sections []int32) [4]int32 {
var out [4]int32
for i := range min(4, len(sections)) {
if sections[i] > 0 {
out[i] = sections[i]
}
}
return out
}
func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int {
if interleaved {
if pair%3 == 1 && pair < 1+3*sections[1] {
return 1
}
if pair%3 == 2 && pair < 2+3*sections[2] {
return 2
}
if pair%3 == 0 && pair < 3*sections[0] {
return 0
}
return 3
}
secW := sections[0] + sections[1]
secE := secW + sections[2]
switch {
case pair < sections[0]:
return 0
case pair < secW:
return 1
case pair < secE:
return 2
default:
return 3
}
}
func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 {
var positions [4][]int32
for i := range positions {
positions[i] = make([]int32, chunkLen)
}
// positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections.
for i := range chunkLen {
p := mapPromptPosition(state, startPos+i)
positions[0][i] = p
positions[1][i] = p
positions[2][i] = p
}
merge := int32(1)
if m != nil && m.Config != nil && m.Config.Vision != nil {
merge = m.Config.Vision.SpatialMergeSize
}
for _, span := range state.Spans {
if span.Grid == nil {
continue
}
chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End)
if !ok {
continue
}
w := max(int32(1), span.Grid.Width/merge)
for i := chunkLo; i < chunkHi; i++ {
rel := spanLo + (i - chunkLo)
positions[1][i] += rel / w
positions[2][i] += rel % w
}
}
return positions
}
func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) {
if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 {
return nil, nil
}
ropeDim := m.Config.RopeDim
if ropeDim%2 != 0 {
ropeDim--
}
if ropeDim <= 0 {
return nil, nil
}
half := ropeDim / 2
positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen)
sections := normalizeMRoPESections(m.Config.MRoPESections)
theta := m.Config.RopeTheta
if theta <= 0 {
theta = 100000.0
}
freqs := make([]float64, half)
for j := range half {
freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim))
}
angles := make([]float32, chunkLen*ropeDim)
for i := range chunkLen {
base := i * ropeDim
for j := range half {
component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved)
angle := float32(float64(positions[component][i]) * freqs[j])
angles[base+j] = angle
angles[base+half+j] = angle
}
}
arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim))
cos := mlx.Cos(arr)
sin := mlx.Sin(arr)
if dtype != 0 {
cos = cos.AsType(dtype)
sin = sin.AsType(dtype)
}
return cos, sin
}
func (m *Model) tokenizePromptWithResolvedImages(
prompt string,
images []base.ImageInput,
resolve func([]byte) (*VisionEmbeddings, error),
) ([]int32, *promptVisionState, error) {
if m == nil || m.tok == nil {
return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized")
}
if m.Vision == nil || m.ImageProcessor == nil || resolve == nil {
return m.tok.Encode(prompt, true), nil, nil
}
parts := imageTagRE.Split(prompt, -1)
matches := imageTagRE.FindAllStringSubmatch(prompt, -1)
resolved := make(map[int]*VisionEmbeddings, len(images))
var out []int32
state := &promptVisionState{}
var p int32
appendToken := func(tok, pos int32) {
out = append(out, tok)
state.PositionCache = append(state.PositionCache, pos)
}
for i, part := range parts {
for _, tok := range m.tok.Encode(part, i == 0) {
appendToken(tok, p)
p++
}
if i >= len(matches) {
continue
}
imageID, err := strconv.Atoi(matches[i][1])
if err != nil {
return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err)
}
img, ok := findImageByID(images, imageID)
if !ok {
return nil, nil, fmt.Errorf("invalid image index: %d", imageID)
}
embeds := resolved[imageID]
if embeds == nil {
embeds, err = resolve(img.Data)
if err != nil {
return nil, nil, err
}
resolved[imageID] = embeds
}
if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 {
return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings")
}
tokensPerImage := int32(embeds.Main.Dim(1))
if tokensPerImage <= 0 {
return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage)
}
appendToken(m.VisionStartToken, p)
p++
basePos := p
spanStart := int32(len(out))
for range tokensPerImage {
appendToken(m.ImageTokenID, basePos)
}
spanEnd := int32(len(out))
merge := int32(1)
if m.Config != nil && m.Config.Vision != nil {
merge = m.Config.Vision.SpatialMergeSize
}
gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage)
p += gridSpan
appendToken(m.VisionEndToken, p)
p++
state.Spans = append(state.Spans, promptVisionSpan{
Start: spanStart,
End: spanEnd,
Main: embeds.Main,
Grid: embeds.Grid,
})
}
return out, state, nil
}
func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) {
tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage)
if err != nil {
return nil, err
}
return &base.PromptTokenization{Tokens: tokens, State: state}, nil
}

View File

@@ -2,6 +2,7 @@
package qwen3_5
import (
"cmp"
"encoding/json"
"fmt"
"math"
@@ -22,16 +23,26 @@ func init() {
base.Register("Qwen3NextForConditionalGeneration", NewModel)
}
var (
_ base.MultimodalPromptTokenizerWithState = (*Model)(nil)
_ base.ForwardWithStateModel = (*Model)(nil)
)
// RopeParameters carries optional rope metadata embedded under rope_parameters.
type RopeParameters struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
MRoPEInterleaved bool `json:"mrope_interleaved"`
MRoPESection []int32 `json:"mrope_section"`
DimensionSections []int32 `json:"dimension_sections"`
}
// Config holds Qwen 3.5 text config (top-level or nested text_config).
type Config struct {
// TextConfig holds the Qwen 3.5 text-model architecture fields.
// In VLM configs these live under the "text_config" key; in text-only
// configs they appear at the top level.
type TextConfig struct {
ModelType string `json:"model_type"`
HiddenSize int32 `json:"hidden_size"`
IntermediateSize int32 `json:"intermediate_size"`
@@ -67,6 +78,19 @@ type Config struct {
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
RopeScaling map[string]any `json:"rope_scaling"`
RopeParameters *RopeParameters `json:"rope_parameters"`
MRoPESections []int32 `json:"mrope_sections"`
MRoPEInterleaved bool `json:"mrope_interleaved"`
}
// Config is the full model config. It embeds TextConfig for the text-model
// fields and adds top-level-only fields (vision, token IDs, quantization).
type Config struct {
TextConfig
Vision *VisionConfig `json:"vision_config"`
ImageTokenID int32 `json:"image_token_id"`
VisionStartToken int32 `json:"vision_start_token_id"`
VisionEndToken int32 `json:"vision_end_token_id"`
// Quantization metadata.
QuantGroupSize int `json:"-"`
@@ -90,6 +114,9 @@ type Model struct {
*Config
weightPrefix string
Vision *VisionModel
ImageProcessor *VisionImageProcessor
}
// Layer is a transformer decoder layer.
@@ -190,17 +217,24 @@ func parseConfig(configData []byte) (Config, error) {
var cfg Config
activeRaw := rawTop
// First pass: unmarshal the full config to pick up top-level fields
// (vision_config, image_token_id, etc.) and text fields for text-only models.
if err := json.Unmarshal(configData, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config: %w", err)
}
// Second pass: if text_config exists, unmarshal it into TextConfig so
// text-model fields from text_config take priority over any top-level
// duplicates. Top-level-only fields (Vision, token IDs) are unaffected
// because they live on Config, not TextConfig.
if textRaw, ok := rawTop["text_config"]; ok {
if err := json.Unmarshal(textRaw, &cfg); err != nil {
if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil {
return Config{}, fmt.Errorf("parse text_config: %w", err)
}
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
}
} else {
if err := json.Unmarshal(configData, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config: %w", err)
}
}
if cfg.HiddenSize <= 0 {
@@ -225,12 +259,8 @@ func parseConfig(configData []byte) (Config, error) {
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.LinearConvKernelDim <= 0 {
cfg.LinearConvKernelDim = 4
}
cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6)
cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4)
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
@@ -246,14 +276,21 @@ func parseConfig(configData []byte) (Config, error) {
if cfg.RopeParameters.PartialRotaryFactor > 0 {
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
}
if len(cfg.MRoPESections) == 0 {
switch {
case len(cfg.RopeParameters.MRoPESection) > 0:
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...)
case len(cfg.RopeParameters.DimensionSections) > 0:
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...)
}
}
cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 100000.0
if len(cfg.MRoPESections) > 4 {
cfg.MRoPESections = cfg.MRoPESections[:4]
}
if cfg.PartialRotaryFactor == 0 {
cfg.PartialRotaryFactor = 0.25
}
if cfg.PartialRotaryFactor < 0 {
cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0)
if cfg.PartialRotaryFactor <= 0 {
cfg.PartialRotaryFactor = 0.25
}
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
@@ -281,24 +318,23 @@ func parseConfig(configData []byte) (Config, error) {
}
if cfg.NumExperts > 0 {
if cfg.NumExpertsPerTok <= 0 {
cfg.NumExpertsPerTok = 1
}
if cfg.MoeIntermediateSize <= 0 {
cfg.MoeIntermediateSize = cfg.IntermediateSize
}
if cfg.SharedExpertIntermediateSize <= 0 {
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
}
cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1))
cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize)
cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize)
if _, ok := activeRaw["norm_topk_prob"]; !ok {
cfg.NormTopKProb = true
}
if cfg.DecoderSparseStep <= 0 {
cfg.DecoderSparseStep = 1
}
cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1))
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if cfg.Vision != nil {
cfg.Vision.applyDefaults()
}
cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655))
cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652))
cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653))
return cfg, nil
}
@@ -364,6 +400,11 @@ func NewModel(root *model.Root) (base.Model, error) {
if err != nil {
return nil, err
}
if cfg.Vision != nil {
if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil {
cfg.Vision.applyPreprocessorConfig(preprocessorData)
}
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
@@ -1060,6 +1101,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.Layers[i] = layer
}
if cfg.Vision != nil && cfg.Vision.Depth > 0 {
vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix)
if err != nil {
return err
}
m.Vision = vision
m.ImageProcessor = processor
}
return nil
}
@@ -1117,7 +1167,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
return q, k, v, z, b, a
}
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func rotateHalf(x *mlx.Array) *mlx.Array {
shape := x.Dims()
last := int32(shape[len(shape)-1])
half := last / 2
if half <= 0 {
return x
}
x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half})
x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1)
}
func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array {
if x == nil || cos == nil || sin == nil || ropeDim <= 0 {
return x
}
shape := x.Dims()
if len(shape) != 4 {
return x
}
last := int32(shape[len(shape)-1])
if ropeDim > last {
ropeDim = last
}
if ropeDim%2 != 0 {
ropeDim--
}
if ropeDim <= 0 {
return x
}
rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim})
rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin))
if ropeDim == last {
return rot
}
tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
return mlx.Concatenate([]*mlx.Array{rot, tail}, -1)
}
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
qg := a.QProj.Forward(x)
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
@@ -1140,8 +1234,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
if ropeCos != nil && ropeSin != nil {
q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim)
k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim)
} else {
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil {
k, v = c.Update(k, v)
@@ -1323,13 +1422,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
var r *mlx.Array
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
if l.IsLinear {
r = l.Linear.Forward(normed, c, B, L, cfg)
} else {
r = l.FullAttn.Forward(normed, c, B, L, cfg)
r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin)
}
h := mlx.Add(x, r)
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
@@ -1337,16 +1436,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
return m.ForwardWithState(tokens, caches, nil)
}
func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
startPos := promptStartPosFromCaches(caches)
promptState := promptVisionStateFromState(state)
h := m.EmbedTokens.Forward(tokens)
h = m.applyPromptVisionEmbeddings(h, startPos, promptState)
var ropeCos, ropeSin *mlx.Array
if len(m.MRoPESections) > 0 {
ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType())
}
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin)
}
out := m.Norm.Forward(h, m.RMSNormEps)
return out

View File

@@ -1,10 +1,14 @@
package qwen3_5
import (
"fmt"
"slices"
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/tokenizer"
)
func skipIfNoMLX(t *testing.T) {
@@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) {
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
cfg := &Config{TextConfig: TextConfig{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
}}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
@@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) {
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
Config: &Config{TextConfig: TextConfig{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
}},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
@@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) {
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{
cfg := &Config{TextConfig: TextConfig{
HiddenSize: 4,
IntermediateSize: 8,
NumHiddenLayers: 2,
@@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
LinearValueHeadDim: 2,
LinearConvKernelDim: 4,
FullAttentionInterval: 2,
}
}}
m := &Model{
Config: cfg,
@@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
t.Fatalf("k norm dtype = %v, want %v", got, f32)
}
}
func TestParseConfigVisionFields(t *testing.T) {
data := []byte(`{
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 4,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"rope_parameters": {
"rope_theta": 10000000
}
},
"vision_config": {
"depth": 2,
"hidden_size": 256,
"num_heads": 8,
"in_channels": 3,
"patch_size": 14,
"spatial_merge_size": 2,
"layer_norm_epsilon": 0.000001,
"temporal_patch_size": 2,
"num_position_embeddings": 2304
},
"image_token_id": 111,
"vision_start_token_id": 112,
"vision_end_token_id": 113
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.Vision == nil {
t.Fatal("vision config should be parsed")
}
if cfg.Vision.Depth != 2 {
t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth)
}
if cfg.Vision.GridPerSide != 48 {
t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide)
}
if cfg.Vision.RopeTheta != 10000 {
t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta)
}
if cfg.RopeTheta != 10000000 {
t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 {
t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken)
}
}
func TestParseConfigMRoPEFromRopeParameters(t *testing.T) {
data := []byte(`{
"text_config": {
"hidden_size": 2048,
"intermediate_size": 8192,
"num_hidden_layers": 4,
"num_attention_heads": 16,
"num_key_value_heads": 2,
"head_dim": 256,
"linear_num_value_heads": 32,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"rope_parameters": {
"rope_theta": 10000000,
"partial_rotary_factor": 0.25,
"mrope_interleaved": true,
"mrope_section": [11, 11, 10]
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if !cfg.MRoPEInterleaved {
t.Fatal("mrope_interleaved should be parsed from rope_parameters")
}
if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) {
t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
}
func TestParseConfigVisionTokenDefaults(t *testing.T) {
data := []byte(`{
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 2,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.ImageTokenID != 151655 {
t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID)
}
if cfg.VisionStartToken != 151652 {
t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken)
}
if cfg.VisionEndToken != 151653 {
t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken)
}
}
func TestResolveVisionPrefix(t *testing.T) {
tests := []struct {
name string
tensors map[string]*mlx.Array
want string
}{
{
name: "legacy visual prefix",
tensors: map[string]*mlx.Array{
"model.visual.patch_embed.proj.weight": mlx.New("patch"),
},
want: "model.visual",
},
{
name: "imported vision tower prefix",
tensors: map[string]*mlx.Array{
"vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"),
},
want: "vision_tower",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want {
t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want)
}
})
}
}
func TestVisionPreprocessorOverridesDefaults(t *testing.T) {
v := &VisionConfig{}
v.applyDefaults()
v.applyPreprocessorConfig([]byte(`{
"patch_size": 16,
"temporal_patch_size": 3,
"merge_size": 4,
"size": {
"shortest_edge": 1024,
"longest_edge": 8192
},
"image_mean": [0.1, 0.2, 0.3],
"image_std": [0.9, 0.8, 0.7]
}`))
if v.PatchSize != 16 {
t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize)
}
if v.TemporalPatchSize != 3 {
t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize)
}
if v.SpatialMergeSize != 4 {
t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize)
}
if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 {
t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge)
}
if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 {
t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd)
}
}
func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) {
v := &VisionConfig{}
v.applyDefaults()
v.applyPreprocessorConfig([]byte(`{
"size": {
"shortest_edge": 65536,
"longest_edge": 16777216
},
"patch_size": 16,
"temporal_patch_size": 2,
"merge_size": 2,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}`))
p := newVisionImageProcessor(v)
if p == nil {
t.Fatal("newVisionImageProcessor returned nil")
}
if p.shortestEdge != 65536 || p.longestEdge != 16777216 {
t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge)
}
}
func testTokenizer(t *testing.T) *tokenizer.Tokenizer {
t.Helper()
tok, err := tokenizer.LoadFromBytes([]byte(`{
"model": {
"type": "BPE",
"vocab": {"a": 0},
"merges": []
}
}`))
if err != nil {
t.Fatalf("failed to load test tokenizer: %v", err)
}
return tok
}
func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) {
skipIfNoMLX(t)
m := &Model{
tok: testTokenizer(t),
Config: &Config{
ImageTokenID: 101,
VisionStartToken: 102,
VisionEndToken: 103,
Vision: &VisionConfig{SpatialMergeSize: 2},
},
Vision: &VisionModel{},
ImageProcessor: &VisionImageProcessor{},
}
main := mlx.FromValues([]float32{
10, 11,
20, 21,
}, 1, 2, 2)
resolveCalls := 0
got, state, err := m.tokenizePromptWithResolvedImages(
"a[img-7][img-7]a",
[]base.ImageInput{{ID: 7, Data: []byte("img7")}},
func(data []byte) (*VisionEmbeddings, error) {
if string(data) != "img7" {
return nil, fmt.Errorf("unexpected data: %q", string(data))
}
resolveCalls++
return &VisionEmbeddings{
Main: main,
Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1},
}, nil
},
)
if err != nil {
t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err)
}
if resolveCalls != 1 {
t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls)
}
want := []int32{
0,
102, 101, 101, 103,
102, 101, 101, 103,
0,
}
if !slices.Equal(got, want) {
t.Fatalf("expanded tokens mismatch: got %v want %v", got, want)
}
if state == nil {
t.Fatal("expected prompt vision state")
}
if len(state.Spans) != 2 {
t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans))
}
if state.Spans[0].Start != 2 || state.Spans[0].End != 4 {
t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End)
}
if state.Spans[1].Start != 6 || state.Spans[1].End != 8 {
t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End)
}
wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7}
if !slices.Equal(state.PositionCache, wantPos) {
t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos)
}
}
func TestBuildPromptMRoPEPositions(t *testing.T) {
m := &Model{
Config: &Config{
Vision: &VisionConfig{SpatialMergeSize: 2},
},
}
state := &promptVisionState{
PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6},
Spans: []promptVisionSpan{
{
Start: 2,
End: 8,
Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1},
},
},
}
pos := m.buildPromptMRoPEPositions(state, 0, 10)
if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) {
t.Fatalf("time positions mismatch: got %v want %v", got, want)
}
if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) {
t.Fatalf("height positions mismatch: got %v want %v", got, want)
}
if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) {
t.Fatalf("width positions mismatch: got %v want %v", got, want)
}
}
func TestMapPromptPositionContinuesAfterCache(t *testing.T) {
state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}}
if got := mapPromptPosition(state, 3); got != 2 {
t.Fatalf("mapPromptPosition(3) = %d, want 2", got)
}
if got := mapPromptPosition(state, 5); got != 4 {
t.Fatalf("mapPromptPosition(5) = %d, want 4", got)
}
}
func TestApplyPromptVisionEmbeddings(t *testing.T) {
skipIfNoMLX(t)
m := &Model{}
state := &promptVisionState{
Spans: []promptVisionSpan{
{
Start: 1,
End: 3,
Main: mlx.FromValues([]float32{
10, 11,
20, 21,
}, 1, 2, 2),
},
},
}
h := mlx.FromValues([]float32{
0, 1,
2, 3,
4, 5,
6, 7,
}, 1, 4, 2)
got := m.applyPromptVisionEmbeddings(h, 0, state)
mlx.Eval(got)
want := []float32{
0, 1,
10, 11,
20, 21,
6, 7,
}
if !slices.Equal(got.Floats(), want) {
t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want)
}
}

854
x/models/qwen3_5/vision.go Normal file
View File

@@ -0,0 +1,854 @@
package qwen3_5
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"math"
"github.com/ollama/ollama/model/imageproc"
"github.com/ollama/ollama/x/mlxrunner/mlx"
mlxmodel "github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/models/nn"
)
var errNoVisionModel = errors.New("qwen3_5: no vision model")
// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config.
type VisionConfig struct {
Depth int32 `json:"depth"`
HiddenSize int32 `json:"hidden_size"`
NumHeads int32 `json:"num_heads"`
InChannels int32 `json:"in_channels"`
PatchSize int32 `json:"patch_size"`
SpatialMergeSize int32 `json:"spatial_merge_size"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
TemporalPatchSize int32 `json:"temporal_patch_size"`
NumPositionEmbeddings int32 `json:"num_position_embeddings"`
Size struct {
ShortestEdge int32 `json:"shortest_edge"`
LongestEdge int32 `json:"longest_edge"`
} `json:"size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
GridPerSide int32 `json:"-"`
}
func (v *VisionConfig) applyDefaults() {
if v == nil {
return
}
if v.HiddenSize <= 0 {
v.HiddenSize = 1280
}
if v.NumHeads <= 0 {
v.NumHeads = 16
}
if v.InChannels <= 0 {
v.InChannels = 3
}
if v.PatchSize <= 0 {
v.PatchSize = 14
}
if v.SpatialMergeSize <= 0 {
v.SpatialMergeSize = 2
}
if v.LayerNormEpsilon == 0 {
v.LayerNormEpsilon = 1e-6
}
if v.RopeTheta == 0 {
v.RopeTheta = 10000
}
if v.TemporalPatchSize <= 0 {
v.TemporalPatchSize = 2
}
if v.NumPositionEmbeddings <= 0 {
v.NumPositionEmbeddings = 2304
}
if len(v.ImageMean) < 3 {
v.ImageMean = []float32{0.5, 0.5, 0.5}
}
if len(v.ImageStd) < 3 {
v.ImageStd = []float32{0.5, 0.5, 0.5}
}
if v.Size.ShortestEdge <= 0 {
v.Size.ShortestEdge = 64 << 10
}
if v.Size.LongestEdge <= 0 {
v.Size.LongestEdge = 2 << 20
}
grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings)))
if grid <= 0 {
grid = 48
}
v.GridPerSide = grid
}
func (v *VisionConfig) applyPreprocessorConfig(data []byte) {
if v == nil || len(data) == 0 {
return
}
var pre struct {
Size struct {
ShortestEdge int32 `json:"shortest_edge"`
LongestEdge int32 `json:"longest_edge"`
} `json:"size"`
PatchSize int32 `json:"patch_size"`
TemporalPatchSize int32 `json:"temporal_patch_size"`
MergeSize int32 `json:"merge_size"`
ImageMean []float32 `json:"image_mean"`
ImageStd []float32 `json:"image_std"`
}
if err := json.Unmarshal(data, &pre); err != nil {
return
}
if pre.PatchSize > 0 {
v.PatchSize = pre.PatchSize
}
if pre.TemporalPatchSize > 0 {
v.TemporalPatchSize = pre.TemporalPatchSize
}
if pre.MergeSize > 0 {
v.SpatialMergeSize = pre.MergeSize
}
if pre.Size.ShortestEdge > 0 {
v.Size.ShortestEdge = pre.Size.ShortestEdge
}
if pre.Size.LongestEdge > 0 {
v.Size.LongestEdge = pre.Size.LongestEdge
}
if len(pre.ImageMean) >= 3 {
v.ImageMean = pre.ImageMean
}
if len(pre.ImageStd) >= 3 {
v.ImageStd = pre.ImageStd
}
v.applyDefaults()
}
// VisionGrid tracks patch-grid dimensions for an image.
type VisionGrid struct {
Height int32
Width int32
Temporal int32
}
// VisionImageProcessor reproduces qwen3vl image preprocessing.
type VisionImageProcessor struct {
numChannels int32
patchSize int32
temporalPatchSize int32
mergeSize int32
shortestEdge int32
longestEdge int32
factor int32
imageMean [3]float32
imageStd [3]float32
}
func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor {
if cfg == nil {
return nil
}
return &VisionImageProcessor{
numChannels: cfg.InChannels,
patchSize: cfg.PatchSize,
temporalPatchSize: cfg.TemporalPatchSize,
mergeSize: cfg.SpatialMergeSize,
shortestEdge: cfg.Size.ShortestEdge,
longestEdge: cfg.Size.LongestEdge,
factor: cfg.PatchSize * cfg.SpatialMergeSize,
imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]},
imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]},
}
}
func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) {
factor := int(p.factor)
if factor <= 0 {
return 0, 0, fmt.Errorf("invalid factor: %d", factor)
}
if height < factor || width < factor {
return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor)
}
if min(height, width) == 0 {
return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height)
}
if max(height, width)/min(height, width) > 200 {
return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height)
}
roundEven := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := roundEven(float64(height)/float64(factor)) * factor
wBar := roundEven(float64(width)/float64(factor)) * factor
if hBar*wBar > int(p.longestEdge) {
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < int(p.shortestEdge) {
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar, nil
}
func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) {
if p == nil {
return nil, nil, errNoVisionModel
}
img = imageproc.Composite(img)
origW := img.Bounds().Dx()
origH := img.Bounds().Dy()
resizedH, resizedW, err := p.smartResize(origH, origW)
if err != nil {
return nil, nil, err
}
resized := imageproc.Resize(
img,
image.Point{X: resizedW, Y: resizedH},
imageproc.ResizeBilinear,
)
pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true)
grid := &VisionGrid{
Height: int32(resizedH / int(p.patchSize)),
Width: int32(resizedW / int(p.patchSize)),
Temporal: 1,
}
patches := p.createPatches(pixels, resizedH, resizedW, grid)
patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize)
numPatches := int(grid.Height * grid.Width)
pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0)
return pixelValues, grid, nil
}
func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 {
channels := int(p.numChannels)
patchSize := int(p.patchSize)
mergeSize := int(p.mergeSize)
temporalPatchSize := int(p.temporalPatchSize)
// Temporal is always 1 for static images; only spatial patches are created.
numPatches := int(grid.Height * grid.Width)
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
for h := 0; h < int(grid.Height); h += mergeSize {
for w := 0; w < int(grid.Width); w += mergeSize {
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
for c := range channels {
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
for py := range patchSize {
for px := range patchSize {
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
srcIdx := c*height*width + y*width + x
dstIdx := channelOffset + py*patchSize + px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
if temporalPatchSize > 1 {
for c := range channels {
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
frameSize := patchSize * patchSize
for tp := 1; tp < temporalPatchSize; tp++ {
cur := channelOffset + tp*frameSize
copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
return result
}
// VisionAttention runs one self-attention block inside the vision encoder.
type VisionAttention struct {
QKV nn.LinearLayer
Query nn.LinearLayer
Key nn.LinearLayer
Value nn.LinearLayer
Output nn.LinearLayer
}
func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array {
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin))
}
func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
shape := x.Dims()
if len(shape) != 3 {
return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape)
}
B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
headDim := cfg.HiddenSize / cfg.NumHeads
if headDim <= 0 {
return nil, fmt.Errorf("invalid vision head dim: %d", headDim)
}
var q, k, v *mlx.Array
if a.QKV != nil {
qkv := a.QKV.Forward(x)
qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim)
q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2)
k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2)
v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2)
} else {
if a.Query == nil || a.Key == nil || a.Value == nil {
return nil, errors.New("vision attention is missing q/k/v projections")
}
q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim)
k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim)
v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim)
}
q = applyVisionRoPE(q, cos, sin)
k = applyVisionRoPE(k, cos, sin)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false)
attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden)
if a.Output == nil {
return nil, errors.New("vision attention is missing output projection")
}
return a.Output.Forward(attn), nil
}
// VisionMLP is the vision feed-forward block.
type VisionMLP struct {
FC1 nn.LinearLayer
FC2 nn.LinearLayer
}
func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) {
if m.FC1 == nil || m.FC2 == nil {
return nil, errors.New("vision mlp is missing fc1/fc2")
}
return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil
}
// VisionEncoderLayer is one transformer block in the vision encoder.
type VisionEncoderLayer struct {
Norm1 *nn.LayerNorm
Attn *VisionAttention
Norm2 *nn.LayerNorm
MLP *VisionMLP
}
func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil {
return nil, errors.New("vision layer is incomplete")
}
r := x
a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg)
if err != nil {
return nil, err
}
x = mlx.Add(r, a)
r = x
m, err := l.MLP.Forward(l.Norm2.Forward(x))
if err != nil {
return nil, err
}
return mlx.Add(r, m), nil
}
// VisionPatchMerger projects merged spatial groups into language embedding space.
type VisionPatchMerger struct {
Norm *nn.LayerNorm
FC1 nn.LinearLayer
FC2 nn.LinearLayer
}
func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) {
shape := x.Dims()
if len(shape) != 3 {
return nil, fmt.Errorf("expected [B,L,D], got %v", shape)
}
if merge <= 0 {
merge = 1
}
B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2])
group := merge * merge
if group <= 0 || L%group != 0 {
return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge)
}
x = mlx.Reshape(x, B, L/group, group, D)
x = mlx.Reshape(x, B, L/group, group*D)
return x, nil
}
func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil {
return nil, errors.New("vision patch merger is incomplete")
}
x = m.Norm.Forward(x)
var err error
x, err = groupMergedTokens(x, cfg.SpatialMergeSize)
if err != nil {
return nil, err
}
x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x)))
return x, nil
}
// VisionModel contains the full Qwen vision tower.
type VisionModel struct {
PatchProjection nn.LinearLayer
PositionEmbed *nn.Embedding
Layers []*VisionEncoderLayer
PatchMerger *VisionPatchMerger
cfg *VisionConfig
}
func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 {
if merge <= 0 {
merge = 1
}
// Temporal is always 1 for static images; only spatial coordinates are generated.
coords := make([][2]int32, 0, grid.Height*grid.Width)
for h := int32(0); h < grid.Height; h += merge {
for w := int32(0); w < grid.Width; w += merge {
for mh := range merge {
for mw := range merge {
coords = append(coords, [2]int32{h + mh, w + mw})
}
}
}
}
return coords
}
func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
if m.PositionEmbed == nil {
return x, nil
}
shape := x.Dims()
if len(shape) != 3 {
return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape)
}
B, D := int32(shape[0]), int32(shape[2])
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
L := int32(len(coords))
if L != int32(shape[1]) {
return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L)
}
stepH := float32(0)
if grid.Height > 1 {
stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1)
}
stepW := float32(0)
if grid.Width > 1 {
stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1)
}
indices := make([]int32, 0, L*4)
weights := make([]float32, 0, L*4)
for _, c := range coords {
y := float32(c[0]) * stepH
x0 := float32(c[1]) * stepW
fy := int32(y)
fx := int32(x0)
cy := min(fy+1, m.cfg.GridPerSide-1)
cx := min(fx+1, m.cfg.GridPerSide-1)
indices = append(indices,
fy*m.cfg.GridPerSide+fx,
fy*m.cfg.GridPerSide+cx,
cy*m.cfg.GridPerSide+fx,
cy*m.cfg.GridPerSide+cx,
)
dy := y - float32(fy)
dx := x0 - float32(fx)
weights = append(weights,
(1-dy)*(1-dx),
(1-dy)*dx,
dy*(1-dx),
dy*dx,
)
}
idxArr := mlx.FromValues(indices, int(L), 4)
wArr := mlx.FromValues(weights, int(L), 4, 1)
pos := m.PositionEmbed.Forward(idxArr)
wArr = wArr.AsType(pos.DType())
pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false)
if D != int32(pos.Dim(1)) {
return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1))
}
pos = mlx.ExpandDims(pos, 0)
if B > 1 {
pos = mlx.Tile(pos, []int32{B, 1, 1})
}
return mlx.Add(x, pos), nil
}
func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) {
headDim := m.cfg.HiddenSize / m.cfg.NumHeads
if headDim <= 0 {
return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim)
}
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
L := int32(len(coords))
half := headDim / 2
quarter := half / 2
if quarter <= 0 {
return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim)
}
angles := make([]float32, L*headDim)
for i, c := range coords {
base := int32(i) * headDim
for j := range quarter {
freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half))
angles[base+j] = float32(float64(c[0]) * freq)
angles[base+quarter+j] = float32(float64(c[1]) * freq)
}
for j := range half {
angles[base+half+j] = angles[base+j]
}
}
arr := mlx.FromValues(angles, int(L), int(headDim))
cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2)
sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2)
return cos, sin, nil
}
func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
if m == nil || pixelValues == nil || grid == nil {
return nil, errNoVisionModel
}
if m.PatchProjection == nil || m.PatchMerger == nil {
return nil, errors.New("vision model is missing required projections")
}
x := m.PatchProjection.Forward(pixelValues)
var err error
x, err = m.addPositionEmbedding(x, grid)
if err != nil {
return nil, err
}
cos, sin, err := m.rotaryEmbeddings(grid)
if err != nil {
return nil, err
}
for i, layer := range m.Layers {
x, err = layer.Forward(x, cos, sin, m.cfg)
if err != nil {
return nil, fmt.Errorf("vision layer %d: %w", i, err)
}
}
main, err := m.PatchMerger.Forward(x, m.cfg)
if err != nil {
return nil, fmt.Errorf("vision patch merger: %w", err)
}
return main, nil
}
type VisionEmbeddings struct {
Main *mlx.Array
Grid *VisionGrid
}
func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) {
if m == nil || m.Vision == nil || m.ImageProcessor == nil {
return nil, errNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelValues, grid, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
main, err := m.Vision.Forward(pixelValues, grid)
if err != nil {
return nil, err
}
return &VisionEmbeddings{Main: main, Grid: grid}, nil
}
func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string {
candidates := []string{
"vision_tower",
weightPrefix + "vision_tower",
"model.visual",
"visual",
weightPrefix + "model.visual",
weightPrefix + "visual",
}
hasTensor := func(prefix string) bool {
for _, suffix := range []string{
".patch_embed.proj.weight",
".patch_embed.weight",
".pos_embed.weight",
".blocks.0.attn.qkv.weight",
".merger.linear_fc1.weight",
".merger.mlp.0.weight",
} {
if tensors[prefix+suffix] != nil {
return true
}
}
return false
}
for _, prefix := range candidates {
if hasTensor(prefix) {
return prefix
}
}
return ""
}
func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer {
for _, p := range paths {
if l := linears.Make(p); l != nil {
return l
}
}
return nil
}
func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm {
for _, base := range bases {
if w := tensors[base+".weight"]; w != nil {
return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps}
}
if w := tensors[base]; w != nil {
return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps}
}
}
return nil
}
func loadVisionPatchMerger(
tensors map[string]*mlx.Array,
linears mlxmodel.LinearFactory,
eps float32,
bases ...string,
) *VisionPatchMerger {
for _, base := range bases {
norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q")
fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0")
fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2")
if norm != nil && fc1 != nil && fc2 != nil {
return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2}
}
}
return nil
}
func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) {
if w == nil || !w.Valid() {
return nil, errors.New("missing patch embedding weight")
}
if w.NumDims() < 2 {
return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims())
}
if w.NumDims() == 2 {
return w, nil
}
out := int32(w.Dim(0))
in := int32(w.Size() / w.Dim(0))
return mlx.Reshape(w, out, in), nil
}
func loadVisionComponents(
tensors map[string]*mlx.Array,
linears mlxmodel.LinearFactory,
cfg *Config,
weightPrefix string,
) (*VisionModel, *VisionImageProcessor, error) {
if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 {
return nil, nil, nil
}
cfg.Vision.applyDefaults()
visionPrefix := resolveVisionPrefix(tensors, weightPrefix)
if visionPrefix == "" {
return nil, nil, errors.New("vision enabled in config but vision tensors were not found")
}
patchW, _ := tensorAny(
tensors,
visionPrefix+".patch_embed.proj.weight",
visionPrefix+".patch_embed.weight",
)
if patchW == nil {
return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix)
}
patchW, err := flattenPatchEmbeddingWeight(patchW)
if err != nil {
return nil, nil, err
}
patchB, _ := tensorAny(
tensors,
visionPrefix+".patch_embed.proj.bias",
visionPrefix+".patch_embed.bias",
)
patchProj := nn.NewLinear(patchW, patchB)
if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize {
return nil, nil, fmt.Errorf(
"vision patch embedding input dim mismatch: got %d expected %d",
got,
cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize,
)
}
posW, _ := tensorAny(
tensors,
visionPrefix+".pos_embed.weight",
visionPrefix+".position_embedding.weight",
)
if posW == nil {
return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix)
}
cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0))
cfg.Vision.applyDefaults()
vm := &VisionModel{
PatchProjection: patchProj,
PositionEmbed: nn.NewEmbedding(posW),
Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth),
cfg: cfg.Vision,
}
for i := range cfg.Vision.Depth {
layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i)
layer := &VisionEncoderLayer{
Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"),
Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"),
Attn: &VisionAttention{
QKV: firstLinear(
linears,
layerPrefix+".attn.qkv",
layerPrefix+".attn_qkv",
),
Query: firstLinear(
linears,
layerPrefix+".attn.q_proj",
layerPrefix+".attn_q",
),
Key: firstLinear(
linears,
layerPrefix+".attn.k_proj",
layerPrefix+".attn_k",
),
Value: firstLinear(
linears,
layerPrefix+".attn.v_proj",
layerPrefix+".attn_v",
),
Output: firstLinear(
linears,
layerPrefix+".attn.proj",
layerPrefix+".attn_out",
layerPrefix+".attn.o_proj",
),
},
MLP: &VisionMLP{
FC1: firstLinear(
linears,
layerPrefix+".mlp.fc1",
layerPrefix+".mlp.linear_fc1",
),
FC2: firstLinear(
linears,
layerPrefix+".mlp.fc2",
layerPrefix+".mlp.linear_fc2",
),
},
}
if layer.Norm1 == nil || layer.Norm2 == nil {
return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i)
}
if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) {
return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i)
}
if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil {
return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i)
}
vm.Layers[i] = layer
}
vm.PatchMerger = loadVisionPatchMerger(
tensors,
linears,
cfg.Vision.LayerNormEpsilon,
visionPrefix+".merger",
)
if vm.PatchMerger == nil {
return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix)
}
return vm, newVisionImageProcessor(cfg.Vision), nil
}