mirror of
https://github.com/ollama/ollama.git
synced 2026-04-19 22:54:32 +02:00
Compare commits
16 Commits
v0.18.2-rc
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
578c32e42e | ||
|
|
a10d2625ca | ||
|
|
b960d769ad | ||
|
|
455a6099d1 | ||
|
|
7e6e8377eb | ||
|
|
126d8db7f3 | ||
|
|
3f3a24b418 | ||
|
|
96e36c0d90 | ||
|
|
6f8ddbb26b | ||
|
|
b5e7888414 | ||
|
|
eab4d22269 | ||
|
|
5759c2d2d2 | ||
|
|
42b1c2642b | ||
|
|
727d69ddf3 | ||
|
|
f622b0c5fc | ||
|
|
5d0000634c |
@@ -155,7 +155,7 @@ func (s *Server) ollamaProxy() http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
target := envconfig.Host()
|
||||
target := envconfig.ConnectableHost()
|
||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||
|
||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
@@ -2071,7 +2071,7 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "explicit :cloud model without local stub returns not found by default",
|
||||
model: "minimax-m2.5:cloud",
|
||||
model: "minimax-m2.7:cloud",
|
||||
showStatus: http.StatusNotFound,
|
||||
whoamiStatus: http.StatusOK,
|
||||
whoamiResp: api.UserResponse{Name: "testuser"},
|
||||
|
||||
@@ -59,6 +59,7 @@ func (c *Claude) Run(model string, args []string) error {
|
||||
"ANTHROPIC_BASE_URL="+envconfig.Host().String(),
|
||||
"ANTHROPIC_API_KEY=",
|
||||
"ANTHROPIC_AUTH_TOKEN=ollama",
|
||||
"CLAUDE_CODE_ATTRIBUTION_HEADER=0",
|
||||
)
|
||||
|
||||
env = append(env, c.modelEnvVars(model)...)
|
||||
|
||||
@@ -310,7 +310,7 @@ func names(items []ModelItem) []string {
|
||||
func TestBuildModelList_NoExistingModels(t *testing.T) {
|
||||
items, _, _, _ := buildModelList(nil, nil, "")
|
||||
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5"}
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5"}
|
||||
if diff := cmp.Diff(want, names(items)); diff != "" {
|
||||
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -338,7 +338,7 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
||||
got := names(items)
|
||||
|
||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
||||
want := []string{"glm-4.7-flash", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "llama3.2", "qwen2.5"}
|
||||
want := []string{"glm-4.7-flash", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -354,7 +354,7 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
|
||||
got := names(items)
|
||||
|
||||
// All recs pinned at top (cloud before local in mixed case), then non-recs
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -392,7 +392,7 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
|
||||
if !strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3.5:cloud":
|
||||
case "minimax-m2.7:cloud", "kimi-k2.5:cloud", "qwen3.5:cloud":
|
||||
if strings.HasSuffix(item.Description, "(not downloaded)") {
|
||||
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
|
||||
}
|
||||
@@ -412,7 +412,7 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
|
||||
// glm-4.7-flash and glm-5:cloud are installed so they sort normally;
|
||||
// kimi-k2.5:cloud, qwen3.5:cloud, and qwen3.5 are not installed so they go to the bottom
|
||||
// All recs: cloud first in mixed case, then local, in rec order within each
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5"}
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -430,7 +430,7 @@ func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *tes
|
||||
// kimi-k2.5:cloud is installed so it sorts normally;
|
||||
// the rest of the recommendations are not installed so they go to the bottom
|
||||
// All recs pinned at top (cloud first in mixed case), then non-recs
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.5:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
|
||||
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5:cloud", "minimax-m2.7:cloud", "glm-4.7-flash", "qwen3.5", "llama3.2"}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -583,7 +583,7 @@ func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
|
||||
lastRecIdx := -1
|
||||
firstNonRecIdx := len(got)
|
||||
for i, name := range got {
|
||||
isRec := name == "glm-4.7-flash" || name == "qwen3.5" || name == "minimax-m2.5:cloud" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud" || name == "qwen3.5:cloud"
|
||||
isRec := name == "glm-4.7-flash" || name == "qwen3.5" || name == "minimax-m2.7:cloud" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud" || name == "qwen3.5:cloud"
|
||||
if isRec && i > lastRecIdx {
|
||||
lastRecIdx = i
|
||||
}
|
||||
|
||||
@@ -413,9 +413,6 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
||||
return "", err
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||
if err := config.SetLastModel(current); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
@@ -428,9 +425,6 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
||||
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(current); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
}
|
||||
@@ -439,8 +433,10 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := config.SetLastModel(model); err != nil {
|
||||
return "", err
|
||||
if model != current {
|
||||
if err := config.SetLastModel(model); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
@@ -475,8 +471,10 @@ func (c *launcherClient) launchSingleIntegration(ctx context.Context, name strin
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
if target != current {
|
||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||
return fmt.Errorf("failed to save: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return launchAfterConfiguration(name, runner, target, req)
|
||||
|
||||
@@ -951,7 +951,7 @@ func TestLaunchIntegration_OpenclawInstallsBeforeConfigSideEffects(t *testing.T)
|
||||
if err == nil {
|
||||
t.Fatal("expected launch to fail before configuration when OpenClaw is missing")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "npm was not found") {
|
||||
if !strings.Contains(err.Error(), "required dependencies are missing") {
|
||||
t.Fatalf("expected install prerequisite error, got %v", err)
|
||||
}
|
||||
if selectorCalled {
|
||||
|
||||
@@ -24,7 +24,7 @@ var recommendedModels = []ModelItem{
|
||||
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
|
||||
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
|
||||
{Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
|
||||
{Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
|
||||
{Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
|
||||
{Name: "qwen3.5", Description: "Reasoning, coding, and visual understanding locally", Recommended: true},
|
||||
}
|
||||
@@ -43,7 +43,7 @@ type cloudModelLimit struct {
|
||||
// cloudModelLimits maps cloud model base names to their token limits.
|
||||
// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
|
||||
var cloudModelLimits = map[string]cloudModelLimit{
|
||||
"minimax-m2.5": {Context: 204_800, Output: 128_000},
|
||||
"minimax-m2.7": {Context: 204_800, Output: 128_000},
|
||||
"cogito-2.1:671b": {Context: 163_840, Output: 65_536},
|
||||
"deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
|
||||
"deepseek-v3.2": {Context: 163_840, Output: 65_536},
|
||||
@@ -92,6 +92,10 @@ func OpenBrowser(url string) {
|
||||
case "darwin":
|
||||
_ = exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
// Skip on headless systems where no display server is available
|
||||
if os.Getenv("DISPLAY") == "" && os.Getenv("WAYLAND_DISPLAY") == "" {
|
||||
return
|
||||
}
|
||||
_ = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
_ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
|
||||
@@ -429,13 +429,17 @@ func ensureOpenclawInstalled() (string, error) {
|
||||
return "clawdbot", nil
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("npm"); err != nil {
|
||||
return "", fmt.Errorf("openclaw is not installed and npm was not found\n\n" +
|
||||
"Install Node.js first:\n" +
|
||||
" https://nodejs.org/\n\n" +
|
||||
"Then rerun:\n" +
|
||||
" ollama launch\n" +
|
||||
"and select OpenClaw")
|
||||
_, npmErr := exec.LookPath("npm")
|
||||
_, gitErr := exec.LookPath("git")
|
||||
if npmErr != nil || gitErr != nil {
|
||||
var missing []string
|
||||
if npmErr != nil {
|
||||
missing = append(missing, "npm (Node.js): https://nodejs.org/")
|
||||
}
|
||||
if gitErr != nil {
|
||||
missing = append(missing, "git: https://git-scm.com/")
|
||||
}
|
||||
return "", fmt.Errorf("openclaw is not installed and required dependencies are missing\n\nInstall the following first:\n %s", strings.Join(missing, "\n "))
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("OpenClaw is not installed. Install with npm?")
|
||||
@@ -605,6 +609,8 @@ func clearSessionModelOverride(primary string) {
|
||||
if override, _ := sess["modelOverride"].(string); override != "" && override != primary {
|
||||
delete(sess, "modelOverride")
|
||||
delete(sess, "providerOverride")
|
||||
}
|
||||
if model, _ := sess["model"].(string); model != "" && model != primary {
|
||||
sess["model"] = primary
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -1376,7 +1376,7 @@ func TestOpenclawModelConfig(t *testing.T) {
|
||||
// report it as a remote/cloud model
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" {
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.5"}`)
|
||||
fmt.Fprintf(w, `{"capabilities":[],"model_info":{},"remote_model":"minimax-m2.7"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
@@ -1386,7 +1386,7 @@ func TestOpenclawModelConfig(t *testing.T) {
|
||||
u, _ := url.Parse(srv.URL)
|
||||
client := api.NewClient(u, srv.Client())
|
||||
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.5:cloud")
|
||||
cfg, isCloud := openclawModelConfig(context.Background(), client, "minimax-m2.7:cloud")
|
||||
|
||||
if !isCloud {
|
||||
t.Error("expected isCloud = true for cloud model")
|
||||
@@ -1768,3 +1768,124 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearSessionModelOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
sessionsDir := filepath.Join(tmpDir, ".openclaw", "agents", "main", "sessions")
|
||||
sessionsPath := filepath.Join(sessionsDir, "sessions.json")
|
||||
|
||||
writeSessionsFile := func(t *testing.T, sessions map[string]map[string]any) {
|
||||
t.Helper()
|
||||
if err := os.MkdirAll(sessionsDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := json.Marshal(sessions)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(sessionsPath, data, 0o600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
readSessionsFile := func(t *testing.T) map[string]map[string]any {
|
||||
t.Helper()
|
||||
data, err := os.ReadFile(sessionsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("reading sessions file: %v", err)
|
||||
}
|
||||
var sessions map[string]map[string]any
|
||||
if err := json.Unmarshal(data, &sessions); err != nil {
|
||||
t.Fatalf("parsing sessions file: %v", err)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
t.Run("clears modelOverride and updates model", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"model": "ollama/old-model", "modelOverride": "old-model", "providerOverride": "ollama"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
sess := sessions["sess1"]
|
||||
if _, ok := sess["modelOverride"]; ok {
|
||||
t.Error("modelOverride should have been deleted")
|
||||
}
|
||||
if _, ok := sess["providerOverride"]; ok {
|
||||
t.Error("providerOverride should have been deleted")
|
||||
}
|
||||
if sess["model"] != "new-model" {
|
||||
t.Errorf("model = %q, want %q", sess["model"], "new-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updates model field in sessions without modelOverride", func(t *testing.T) {
|
||||
// This is the bug case: session has model pointing to old primary,
|
||||
// but no explicit modelOverride. After changing primary, the session
|
||||
// model field must also be updated.
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"model": "ollama/old-model"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
if sessions["sess1"]["model"] != "new-model" {
|
||||
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "new-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not update session already using primary", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"model": "current-model"},
|
||||
})
|
||||
clearSessionModelOverride("current-model")
|
||||
sessions := readSessionsFile(t)
|
||||
if sessions["sess1"]["model"] != "current-model" {
|
||||
t.Errorf("model = %q, want %q", sessions["sess1"]["model"], "current-model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("does not update session with empty model field", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"sess1": {"other": "data"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
if _, ok := sessions["sess1"]["model"]; ok {
|
||||
t.Error("model field should not have been added to session with no model")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles multiple sessions mixed", func(t *testing.T) {
|
||||
writeSessionsFile(t, map[string]map[string]any{
|
||||
"with-override": {"model": "old", "modelOverride": "old", "providerOverride": "ollama"},
|
||||
"without-override": {"model": "old"},
|
||||
"already-current": {"model": "new-model"},
|
||||
"no-model": {"other": "data"},
|
||||
})
|
||||
clearSessionModelOverride("new-model")
|
||||
sessions := readSessionsFile(t)
|
||||
|
||||
if sessions["with-override"]["model"] != "new-model" {
|
||||
t.Errorf("with-override model = %q, want %q", sessions["with-override"]["model"], "new-model")
|
||||
}
|
||||
if _, ok := sessions["with-override"]["modelOverride"]; ok {
|
||||
t.Error("with-override: modelOverride should be deleted")
|
||||
}
|
||||
if sessions["without-override"]["model"] != "new-model" {
|
||||
t.Errorf("without-override model = %q, want %q", sessions["without-override"]["model"], "new-model")
|
||||
}
|
||||
if sessions["already-current"]["model"] != "new-model" {
|
||||
t.Errorf("already-current model = %q, want %q", sessions["already-current"]["model"], "new-model")
|
||||
}
|
||||
if _, ok := sessions["no-model"]["model"]; ok {
|
||||
t.Error("no-model: model should not have been added")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when sessions file missing", func(t *testing.T) {
|
||||
os.RemoveAll(sessionsDir)
|
||||
clearSessionModelOverride("new-model") // should not panic or error
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,11 +97,8 @@ func renderSignIn(modelName, signInURL string, spinner, width int) string {
|
||||
|
||||
fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
|
||||
|
||||
// Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
|
||||
// Padding is outside the hyperlink so spaces don't get underlined.
|
||||
link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
|
||||
s.WriteString("Navigate to:\n")
|
||||
s.WriteString(urlWrap.Render(link))
|
||||
s.WriteString(urlWrap.Render(urlColor.Render(signInURL)))
|
||||
s.WriteString("\n\n")
|
||||
|
||||
s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
|
||||
|
||||
@@ -25,22 +25,6 @@ func TestRenderSignIn_ContainsURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
|
||||
url := "https://ollama.com/connect?key=abc123"
|
||||
got := renderSignIn("test:cloud", url, 0, 120)
|
||||
|
||||
// Should contain OSC 8 open sequence with the URL
|
||||
osc8Open := "\033]8;;" + url + "\033\\"
|
||||
if !strings.Contains(got, osc8Open) {
|
||||
t.Error("should contain OSC 8 open sequence with URL")
|
||||
}
|
||||
|
||||
// Should contain OSC 8 close sequence
|
||||
osc8Close := "\033]8;;\033\\"
|
||||
if !strings.Contains(got, osc8Close) {
|
||||
t.Error("should contain OSC 8 close sequence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderSignIn_ContainsSpinner(t *testing.T) {
|
||||
got := renderSignIn("test:cloud", "https://example.com", 0, 80)
|
||||
|
||||
@@ -41,13 +41,27 @@ ollama launch claude --model kimi-k2.5:cloud
|
||||
|
||||
- `kimi-k2.5:cloud`
|
||||
- `glm-5:cloud`
|
||||
- `minimax-m2.5:cloud`
|
||||
- `minimax-m2.7:cloud`
|
||||
- `qwen3.5:cloud`
|
||||
- `glm-4.7-flash`
|
||||
- `qwen3.5`
|
||||
|
||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Non-interactive (headless) mode
|
||||
|
||||
Run Claude Code without interaction for use in Docker, CI/CD, or scripts:
|
||||
|
||||
```shell
|
||||
ollama launch claude --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||
```
|
||||
|
||||
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Claude Code.
|
||||
|
||||
## Web search
|
||||
|
||||
Claude Code can search the web through Ollama's web search API. See the [web search documentation](/capabilities/web-search) for setup and usage.
|
||||
|
||||
## Scheduled Tasks with `/loop`
|
||||
|
||||
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
|
||||
|
||||
@@ -15,13 +15,29 @@ Ollama handles everything automatically:
|
||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||
3. **Model** — Pick a model from the selector (local or cloud)
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and installs the web search and fetch plugin
|
||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||
|
||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||
|
||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||
|
||||
## Web search and fetch
|
||||
|
||||
OpenClaw ships with a web search and fetch plugin that gives local or cloud models the ability to search the web and extract readable page content.
|
||||
|
||||
```bash
|
||||
ollama launch openclaw
|
||||
```
|
||||
|
||||
Web search and fetch is enabled automatically when launching OpenClaw through Ollama. To install the plugin directly:
|
||||
|
||||
```bash
|
||||
openclaw plugins install @ollama/openclaw-web-search
|
||||
```
|
||||
|
||||
<Note>Web search for local models requires `ollama signin`.</Note>
|
||||
|
||||
## Configure without launching
|
||||
|
||||
To change the model without starting the gateway and TUI:
|
||||
@@ -43,7 +59,7 @@ If the gateway is already running, it restarts automatically to pick up the new
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `minimax-m2.5:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||
- `glm-5:cloud` — Reasoning and code generation
|
||||
|
||||
**Local models:**
|
||||
@@ -52,6 +68,16 @@ If the gateway is already running, it restarts automatically to pick up the new
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Non-interactive (headless) mode
|
||||
|
||||
Run OpenClaw without interaction for use in Docker, CI/CD, or scripts:
|
||||
|
||||
```bash
|
||||
ollama launch openclaw --model kimi-k2.5:cloud --yes
|
||||
```
|
||||
|
||||
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified.
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
```bash
|
||||
|
||||
@@ -59,6 +59,29 @@ func Host() *url.URL {
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectableHost returns Host() with unspecified bind addresses (0.0.0.0, ::)
|
||||
// replaced by the corresponding loopback address (127.0.0.1, ::1).
|
||||
// Unspecified addresses are valid for binding a server socket but not for
|
||||
// connecting as a client, which fails on Windows.
|
||||
func ConnectableHost() *url.URL {
|
||||
u := Host()
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
return u
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil && ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
host = "127.0.0.1"
|
||||
} else {
|
||||
host = "::1"
|
||||
}
|
||||
u.Host = net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||
func AllowedOrigins() (origins []string) {
|
||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||
|
||||
@@ -52,6 +52,37 @@ func TestHost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectableHost(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
value string
|
||||
expect string
|
||||
}{
|
||||
"empty": {"", "http://127.0.0.1:11434"},
|
||||
"localhost": {"127.0.0.1", "http://127.0.0.1:11434"},
|
||||
"localhost and port": {"127.0.0.1:1234", "http://127.0.0.1:1234"},
|
||||
"ipv4 unspecified": {"0.0.0.0", "http://127.0.0.1:11434"},
|
||||
"ipv4 unspecified + port": {"0.0.0.0:1234", "http://127.0.0.1:1234"},
|
||||
"ipv6 unspecified": {"[::]", "http://[::1]:11434"},
|
||||
"ipv6 unspecified + port": {"[::]:1234", "http://[::1]:1234"},
|
||||
"ipv6 localhost": {"[::1]", "http://[::1]:11434"},
|
||||
"ipv6 localhost + port": {"[::1]:1234", "http://[::1]:1234"},
|
||||
"specific address": {"192.168.1.5", "http://192.168.1.5:11434"},
|
||||
"specific address + port": {"192.168.1.5:8080", "http://192.168.1.5:8080"},
|
||||
"hostname": {"example.com", "http://example.com:11434"},
|
||||
"hostname and port": {"example.com:1234", "http://example.com:1234"},
|
||||
"https unspecified + port": {"https://0.0.0.0:4321", "https://127.0.0.1:4321"},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", tt.value)
|
||||
if host := ConnectableHost(); host.String() != tt.expect {
|
||||
t.Errorf("%s: expected %s, got %s", name, tt.expect, host.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrigins(t *testing.T) {
|
||||
cases := []struct {
|
||||
value string
|
||||
|
||||
@@ -345,44 +345,163 @@ func escapeGLM46Content(s string) string {
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
|
||||
// GLM models sometimes omit the closing tag, producing XML like:
|
||||
// repairPhase represents the expected next tag in the repair cycle.
|
||||
type repairPhase int
|
||||
|
||||
const (
|
||||
phaseArgKeyOpen repairPhase = iota // expecting <arg_key>
|
||||
phaseArgKeyClose // expecting </arg_key>
|
||||
phaseArgValOpen // expecting <arg_value>
|
||||
phaseArgValClose // expecting </arg_value>
|
||||
phaseCount // number of phases
|
||||
)
|
||||
|
||||
// repairGLM46XML reconstructs well-formed XML from GLM model output that may
|
||||
// have missing or mismatched tags. The expected structure is:
|
||||
//
|
||||
// <arg_value>value</tool_call>
|
||||
// func_name
|
||||
// <arg_key>key</arg_key>
|
||||
// <arg_value>value</arg_value>
|
||||
// ...
|
||||
//
|
||||
// instead of:
|
||||
//
|
||||
// <arg_value>value</arg_value></tool_call>
|
||||
func repairUnclosedArgValues(s string) string {
|
||||
// GLM models frequently omit opening or closing tags. This function follows
|
||||
// the expected tag cycle, scanning forward for each expected tag in sequence.
|
||||
// When a tag is missing, it inserts the tag and consumes any text in between.
|
||||
func repairGLM46XML(s string) string {
|
||||
// tagCycle is the repeating sequence of tags after the function name.
|
||||
tagCycle := [phaseCount]string{"<arg_key>", "</arg_key>", "<arg_value>", "</arg_value>"}
|
||||
|
||||
// findNextTag returns the index and identity of the earliest known tag in s.
|
||||
findNextTag := func(s string) (int, string) {
|
||||
bestIdx := -1
|
||||
bestTag := ""
|
||||
for _, tag := range tagCycle {
|
||||
if idx := strings.Index(s, tag); idx != -1 && (bestIdx == -1 || idx < bestIdx) {
|
||||
bestIdx = idx
|
||||
bestTag = tag
|
||||
}
|
||||
}
|
||||
return bestIdx, bestTag
|
||||
}
|
||||
|
||||
// tagIndex returns the phase corresponding to the given tag.
|
||||
tagIndex := func(tag string) repairPhase {
|
||||
for i, t := range tagCycle {
|
||||
if t == tag {
|
||||
return repairPhase(i)
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for {
|
||||
openIdx := strings.Index(s, "<arg_value>")
|
||||
if openIdx == -1 {
|
||||
|
||||
idx, firstTag := findNextTag(s)
|
||||
if idx == -1 {
|
||||
return s
|
||||
}
|
||||
prefix := s[:idx]
|
||||
s = s[idx:]
|
||||
|
||||
// If the first tag is not <arg_key>, the text before it may contain both
|
||||
// the function name and key content (e.g. "weather city</arg_key>").
|
||||
// Function names cannot contain space, so split at the first space.
|
||||
phase := phaseArgKeyOpen
|
||||
if firstTag != "<arg_key>" {
|
||||
if spIdx := strings.IndexFunc(prefix, unicode.IsSpace); spIdx != -1 {
|
||||
result.WriteString(prefix[:spIdx])
|
||||
keyContent := strings.TrimLeftFunc(prefix[spIdx:], unicode.IsSpace)
|
||||
result.WriteString("<arg_key>")
|
||||
result.WriteString(keyContent)
|
||||
phase = phaseArgKeyClose
|
||||
} else {
|
||||
result.WriteString(prefix)
|
||||
}
|
||||
} else {
|
||||
result.WriteString(prefix)
|
||||
}
|
||||
|
||||
// Walk through the expected tag cycle. At each step, look for the
|
||||
// expected tag. If a different tag appears first, emit the missing
|
||||
// tags to catch up, then continue.
|
||||
for len(s) > 0 {
|
||||
idx, found := findNextTag(s)
|
||||
expected := tagCycle[phase]
|
||||
isOpen := phase%2 == 0 // even phases are opening tags
|
||||
|
||||
if idx == -1 {
|
||||
// No more tags — emit remaining text with fixups
|
||||
if isOpen {
|
||||
// Expecting an opening tag but nothing left — we're done
|
||||
break
|
||||
}
|
||||
// Expecting a closing tag — emit text then close
|
||||
result.WriteString(s)
|
||||
result.WriteString(expected)
|
||||
phase = (phase + 1) % phaseCount
|
||||
break
|
||||
}
|
||||
afterOpen := openIdx + len("<arg_value>")
|
||||
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
|
||||
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
|
||||
// Check if properly closed before the next <arg_key> (or no next key)
|
||||
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
|
||||
end := afterOpen + closeIdx + len("</arg_value>")
|
||||
result.WriteString(s[:end])
|
||||
s = s[end:]
|
||||
|
||||
if found == expected {
|
||||
// Found the expected tag — emit any text before it, then the tag
|
||||
result.WriteString(s[:idx])
|
||||
result.WriteString(expected)
|
||||
s = s[idx+len(expected):]
|
||||
phase = (phase + 1) % phaseCount
|
||||
continue
|
||||
}
|
||||
// Unclosed — insert </arg_value> before the next <arg_key> or at end
|
||||
if nextKeyIdx != -1 {
|
||||
insertAt := afterOpen + nextKeyIdx
|
||||
result.WriteString(s[:insertAt])
|
||||
result.WriteString("</arg_value>")
|
||||
s = s[insertAt:]
|
||||
} else {
|
||||
result.WriteString(s)
|
||||
result.WriteString("</arg_value>")
|
||||
break
|
||||
|
||||
// Found a different tag. Insert missing tags to catch up.
|
||||
foundIdx := tagIndex(found)
|
||||
|
||||
if isOpen && idx > 0 {
|
||||
// Text before the found tag while expecting an opening tag —
|
||||
// the opening tag was omitted. Emit it before the text.
|
||||
result.WriteString(expected)
|
||||
// Advance to the next phase (text content) and then look
|
||||
// for the closing tag — but the found tag might be that
|
||||
// closing tag or something further ahead. Emit text up to
|
||||
// the found tag and insert any missing tags between.
|
||||
result.WriteString(s[:idx])
|
||||
phase = (phase + 1) % phaseCount // now expecting closing
|
||||
s = s[idx:]
|
||||
// Fall through to re-evaluate with the closing tag expected
|
||||
continue
|
||||
}
|
||||
|
||||
// Emit missing tags to advance from current phase to the found tag's phase
|
||||
for phase != foundIdx {
|
||||
tag := tagCycle[phase]
|
||||
if phase%2 == 0 {
|
||||
result.WriteString(tag)
|
||||
} else {
|
||||
// Closing tag — emit any text before the found tag first,
|
||||
// but only if we're one step before the found tag
|
||||
if (phase+1)%phaseCount == foundIdx && idx > 0 {
|
||||
result.WriteString(s[:idx])
|
||||
s = s[idx:]
|
||||
idx = 0
|
||||
}
|
||||
result.WriteString(tag)
|
||||
}
|
||||
phase = (phase + 1) % phaseCount
|
||||
}
|
||||
// Now phase == foundIdx, re-process without advancing s
|
||||
}
|
||||
|
||||
// If we stopped mid-pair (after an opening tag), close it
|
||||
switch phase {
|
||||
case phaseArgKeyClose: // after <arg_key>, expecting text/</arg_key>
|
||||
result.WriteString("</arg_key>")
|
||||
result.WriteString("<arg_value>")
|
||||
result.WriteString("</arg_value>")
|
||||
case phaseArgValOpen: // after </arg_key>, expecting <arg_value>
|
||||
result.WriteString("<arg_value>")
|
||||
result.WriteString("</arg_value>")
|
||||
case phaseArgValClose: // after <arg_value>, expecting text/</arg_value>
|
||||
result.WriteString("</arg_value>")
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
@@ -398,7 +517,7 @@ func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCa
|
||||
var parsed GLMToolCallXML
|
||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||
parsed = GLMToolCallXML{}
|
||||
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
|
||||
repaired := "<tool_call>" + repairGLM46XML(escaped) + "</tool_call>"
|
||||
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||
}
|
||||
|
||||
@@ -887,6 +887,28 @@ line3</arg_value>`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unopened arg_value after arg_key",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: "get-weather\n<arg_key>city</arg_key>\nNew York</arg_value>\n<arg_key>unit</arg_key>\ncelsius</arg_value>",
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "New York", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed unopened and valid arg_values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: "get-weather\n<arg_key>city</arg_key>\n<arg_value>Paris</arg_value>\n<arg_key>unit</arg_key>\ncelsius</arg_value>",
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-weather",
|
||||
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
@@ -902,7 +924,7 @@ line3</arg_value>`,
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairUnclosedArgValues(t *testing.T) {
|
||||
func TestRepairGLM46XML(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
@@ -910,33 +932,63 @@ func TestRepairUnclosedArgValues(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "already valid",
|
||||
input: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
input: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "unclosed at end",
|
||||
input: `<arg_key>k</arg_key><arg_value>v`,
|
||||
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
name: "missing </arg_value> at end",
|
||||
input: `func<arg_key>k</arg_key><arg_value>v`,
|
||||
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "unclosed before next arg_key",
|
||||
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
name: "missing </arg_value> before next arg_key",
|
||||
input: `func<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
want: `func<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "no arg_value tags",
|
||||
name: "no tags at all",
|
||||
input: `just plain text`,
|
||||
want: `just plain text`,
|
||||
},
|
||||
{
|
||||
name: "multiple unclosed",
|
||||
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2`,
|
||||
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
name: "missing <arg_value> open tag",
|
||||
input: `func<arg_key>k</arg_key>v</arg_value>`,
|
||||
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "missing </arg_key> close tag",
|
||||
input: `func<arg_key>k<arg_value>v</arg_value>`,
|
||||
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "missing <arg_key> open tag",
|
||||
input: `func k</arg_key><arg_value>v</arg_value>`,
|
||||
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "all closing tags missing",
|
||||
input: `func<arg_key>k<arg_value>v`,
|
||||
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "all opening tags missing",
|
||||
input: "func k</arg_key>v</arg_value>",
|
||||
want: "func<arg_key>k</arg_key><arg_value>v</arg_value>",
|
||||
},
|
||||
{
|
||||
name: "multiple pairs with mixed missing tags",
|
||||
input: `func<arg_key>a</arg_key>1</arg_value><arg_key>b<arg_value>2</arg_value>`,
|
||||
want: `func<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||
},
|
||||
{
|
||||
name: "newlines preserved",
|
||||
input: "func\n<arg_key>city</arg_key>\nNew York</arg_value>",
|
||||
want: "func\n<arg_key>city</arg_key><arg_value>\nNew York</arg_value>",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := repairUnclosedArgValues(tc.input)
|
||||
got := repairGLM46XML(tc.input)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
|
||||
@@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
||||
spinnerKey = "create"
|
||||
capabilities = []string{"completion"}
|
||||
|
||||
// Check if model supports thinking based on architecture
|
||||
if supportsThinking(opts.ModelDir) {
|
||||
configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json"))
|
||||
mcfg := parseModelConfig(configData)
|
||||
|
||||
if mcfg.supportsThinking() {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
if mcfg.supportsVision() {
|
||||
capabilities = append(capabilities, "vision")
|
||||
}
|
||||
|
||||
// Set parser and renderer name based on architecture
|
||||
parserName = getParserName(opts.ModelDir)
|
||||
rendererName = getRendererName(opts.ModelDir)
|
||||
parserName = mcfg.parserName()
|
||||
rendererName = mcfg.rendererName()
|
||||
} else {
|
||||
modelType = "image generation model"
|
||||
spinnerKey = "imagegen"
|
||||
@@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
||||
// This reads the config.json from the model directory and checks the architectures field.
|
||||
func supportsThinking(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// modelConfig holds the fields from config.json needed during model creation.
|
||||
type visionConfig struct {
|
||||
Depth int32 `json:"depth"`
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
type modelConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
VisionConfig *visionConfig `json:"vision_config"`
|
||||
ImageTokenID *int32 `json:"image_token_id"`
|
||||
VisionStartTokenID *int32 `json:"vision_start_token_id"`
|
||||
VisionEndTokenID *int32 `json:"vision_end_token_id"`
|
||||
}
|
||||
|
||||
// Check architectures that support thinking
|
||||
thinkingArchitectures := []string{
|
||||
"glm4moe", // GLM-4 MoE models
|
||||
"deepseek", // DeepSeek models
|
||||
"qwen3", // Qwen3 models
|
||||
}
|
||||
func parseModelConfig(data []byte) modelConfig {
|
||||
var cfg modelConfig
|
||||
_ = json.Unmarshal(data, &cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Check the architecture list
|
||||
for _, arch := range cfg.Architectures {
|
||||
// archOrTypeContains returns true if any architecture or the model_type
|
||||
// contains one of the given substrings (case-insensitive).
|
||||
func (c *modelConfig) archOrTypeContains(substrs ...string) bool {
|
||||
for _, arch := range c.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(archLower, thinkArch) {
|
||||
for _, s := range substrs {
|
||||
if strings.Contains(archLower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
for _, thinkArch := range thinkingArchitectures {
|
||||
if strings.Contains(typeLower, thinkArch) {
|
||||
if c.ModelType != "" {
|
||||
typeLower := strings.ToLower(c.ModelType)
|
||||
for _, s := range substrs {
|
||||
if strings.Contains(typeLower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
||||
func getParserName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
func (c *modelConfig) supportsThinking() bool {
|
||||
return c.archOrTypeContains("glm4moe", "deepseek", "qwen3")
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
func (c *modelConfig) supportsVision() bool {
|
||||
return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil
|
||||
}
|
||||
|
||||
// Check architectures for known parsers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
func (c *modelConfig) parserName() string {
|
||||
switch {
|
||||
case c.archOrTypeContains("glm4", "glm-4"):
|
||||
return "glm-4.7"
|
||||
case c.archOrTypeContains("deepseek"):
|
||||
return "deepseek3"
|
||||
case c.archOrTypeContains("qwen3"):
|
||||
return "qwen3"
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRendererName returns the renderer name for a model based on its architecture.
|
||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
||||
func getRendererName(modelDir string) string {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
func (c *modelConfig) rendererName() string {
|
||||
switch {
|
||||
case c.archOrTypeContains("glm4", "glm-4"):
|
||||
return "glm-4.7"
|
||||
case c.archOrTypeContains("deepseek"):
|
||||
return "deepseek3"
|
||||
case c.archOrTypeContains("qwen3"):
|
||||
return "qwen3-coder"
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3-coder"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
|
||||
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsVision(t *testing.T) {
|
||||
t.Run("vision_config present", func(t *testing.T) {
|
||||
cfg := parseModelConfig([]byte(`{
|
||||
"vision_config": {"depth": 2},
|
||||
"image_token_id": 151655
|
||||
}`))
|
||||
if !cfg.supportsVision() {
|
||||
t.Fatal("supportsVision() = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("token ids alone imply vision", func(t *testing.T) {
|
||||
cfg := parseModelConfig([]byte(`{
|
||||
"vision_start_token_id": 10,
|
||||
"vision_end_token_id": 11
|
||||
}`))
|
||||
if !cfg.supportsVision() {
|
||||
t.Fatal("supportsVision() = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("plain text model", func(t *testing.T) {
|
||||
cfg := parseModelConfig([]byte(`{
|
||||
"architectures": ["Qwen3_5ForCausalLM"]
|
||||
}`))
|
||||
if cfg.supportsVision() {
|
||||
t.Fatal("supportsVision() = true, want false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,26 @@
|
||||
// cache.go manages a shared KV cache across conversations using a compressed
|
||||
// prefix trie. Each trie node stores a token sequence (edge) and optional
|
||||
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
|
||||
//
|
||||
// Key properties:
|
||||
// - Only one path through the trie is "active" (backed by live MLX arrays)
|
||||
// at a time. Switching paths pages out the frontier node and pages in the
|
||||
// new path.
|
||||
// - Snapshots are only captured at the frontier (end) of the active path.
|
||||
// Intermediate node snapshots come from split prefill.
|
||||
// - All cache layers must stay at the same token offset.
|
||||
// - Sibling edges must not share a common token prefix (compressed trie
|
||||
// invariant).
|
||||
// - begin() always re-evaluates at least one token so the pipeline can seed
|
||||
// generation, even on a full prefix match.
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
@@ -10,10 +28,13 @@ import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
|
||||
|
||||
type kvCache struct {
|
||||
// For now we only support a single entry, so this is just one sequence
|
||||
tokens []int32
|
||||
caches []cache.Cache
|
||||
root *trieNode // root of the prefix trie
|
||||
activePath []*trieNode // current root→leaf path with live MLX arrays
|
||||
caches []cache.Cache
|
||||
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
||||
}
|
||||
|
||||
// cacheSession manages caches for a single pipeline run.
|
||||
@@ -26,176 +47,555 @@ type cacheSession struct {
|
||||
|
||||
caches []cache.Cache
|
||||
remaining []int32
|
||||
|
||||
// snapshotOffset, if > 0, is a trie node boundary where we need to
|
||||
// capture a snapshot during prefill. This enables future requests
|
||||
// branching at this point to restore non-rewindable caches (e.g.
|
||||
// RecurrentCache) instead of re-evaluating from scratch.
|
||||
snapshotOffset int
|
||||
}
|
||||
|
||||
func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array {
|
||||
if c == nil {
|
||||
return dst
|
||||
func (c *kvCache) ensureCaches(m base.Model) {
|
||||
if len(c.caches) != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
keys, values := c.State()
|
||||
if keys != nil && keys.Valid() {
|
||||
dst = append(dst, keys)
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
return
|
||||
}
|
||||
if values != nil && values.Valid() {
|
||||
dst = append(dst, values)
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (c *kvCache) free() {
|
||||
for i, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
func (c *kvCache) ensureRoot() {
|
||||
if c.root == nil {
|
||||
c.root = &trieNode{
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
kv.Free()
|
||||
c.caches[i] = nil
|
||||
}
|
||||
c.caches = nil
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *kvCache) cachesCanTrim() bool {
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.CanTrim() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *kvCache) trimToPrefix(prefix int) {
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil || !kv.CanTrim() {
|
||||
continue
|
||||
}
|
||||
if trim := kv.Offset() - prefix; trim > 0 {
|
||||
kv.Trim(trim)
|
||||
}
|
||||
}
|
||||
if prefix < len(c.tokens) {
|
||||
c.tokens = c.tokens[:prefix]
|
||||
c.activePath = []*trieNode{c.root}
|
||||
}
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
ensureCaches := func() {
|
||||
if len(c.caches) != 0 {
|
||||
return
|
||||
}
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
c.caches = cacheFactory.NewCaches()
|
||||
return
|
||||
}
|
||||
c.caches = make([]cache.Cache, m.NumLayers())
|
||||
for i := range c.caches {
|
||||
c.caches[i] = cache.NewKVCache()
|
||||
c.ensureCaches(m)
|
||||
c.ensureRoot()
|
||||
|
||||
matchPath, matched := findBestMatch(c.root, inputs)
|
||||
originalMatched := matched
|
||||
|
||||
// Always keep at least one token to re-evaluate so the
|
||||
// pipeline can seed token generation from it.
|
||||
if matched == len(inputs) && matched > 0 {
|
||||
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
||||
}
|
||||
|
||||
// Check for partial match within a node's edge — truncate path
|
||||
// to the parent boundary. snapshot() will split the node and
|
||||
// create the branch point during prefill when caches are ready.
|
||||
partialMatch := false
|
||||
if len(matchPath) > 1 {
|
||||
lastNode := matchPath[len(matchPath)-1]
|
||||
matchedInEdge := matched - lastNode.startOffset()
|
||||
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
||||
matchPath = matchPath[:len(matchPath)-1]
|
||||
partialMatch = true
|
||||
}
|
||||
}
|
||||
ensureCaches()
|
||||
|
||||
remaining := c.findRemaining(inputs)
|
||||
ensureCaches()
|
||||
// Switch to the matched path, paging in/out as needed.
|
||||
c.switchToPath(matchPath)
|
||||
|
||||
// switchToPath aligns caches to a common offset
|
||||
prefix := c.minCacheOffset()
|
||||
remaining := inputs[prefix:]
|
||||
|
||||
// Schedule a snapshot at the branch point during prefill so future
|
||||
// requests diverging here can restore instead of re-evaluating.
|
||||
var snapshotAt int
|
||||
if partialMatch || (prefix == 0 && matched > 0) {
|
||||
snapshotAt = matched
|
||||
}
|
||||
|
||||
args := []any{"total", len(inputs), "matched", originalMatched}
|
||||
args = append(args, "cached", prefix, "left", len(remaining))
|
||||
if snapshotAt > 0 {
|
||||
args = append(args, "pending_snapshot", snapshotAt)
|
||||
}
|
||||
if prefix == 0 {
|
||||
slog.Info("cache miss", args...)
|
||||
} else {
|
||||
slog.Info("cache hit", args...)
|
||||
}
|
||||
|
||||
return &cacheSession{
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
snapshotOffset: snapshotAt,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
}
|
||||
}
|
||||
|
||||
// switchToPath transitions from the current active path to a new path,
|
||||
// paging out diverging segments and paging in the new path.
|
||||
func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||
defer c.enforceEvictionPolicy()
|
||||
|
||||
// Find common ancestor index.
|
||||
commonLen := 0
|
||||
for commonLen < len(c.activePath) && commonLen < len(newPath) {
|
||||
if c.activePath[commonLen] != newPath[commonLen] {
|
||||
break
|
||||
}
|
||||
commonLen++
|
||||
}
|
||||
|
||||
ancestorOffset := 0
|
||||
if commonLen > 0 {
|
||||
ancestorOffset = c.activePath[commonLen-1].endOffset
|
||||
}
|
||||
|
||||
var pageOutCount, pageInCount int
|
||||
|
||||
// Page out the leaf of the old path. Only the leaf's live cache
|
||||
// state is correct — intermediate nodes already have snapshots
|
||||
// captured during their creation (splitNode + prefill). Snapshotting
|
||||
// non-leaf nodes here would produce wrong results for non-rewindable
|
||||
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
||||
// the intermediate boundary.
|
||||
if leaf := len(c.activePath) - 1; leaf >= commonLen {
|
||||
node := c.activePath[leaf]
|
||||
if !node.hasAllSnapshots() {
|
||||
fromOffset := node.startOffset()
|
||||
snaps := make([]cache.Snapshot, len(c.caches))
|
||||
for j, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
snaps[j] = kv.Snapshot(fromOffset)
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
pageOutCount++
|
||||
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
|
||||
}
|
||||
}
|
||||
|
||||
// Rewind each cache to the ancestor offset or free it. Freed
|
||||
// caches (e.g. RecurrentCache that can't rewind) will be restored
|
||||
// from snapshots during page-in.
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(nil, ancestorOffset) {
|
||||
kv.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Page in — walk the full new path, restoring from snapshots.
|
||||
// Freed caches naturally pick up the first available snapshot.
|
||||
// Caches already past a node skip it via offset check.
|
||||
for _, node := range newPath {
|
||||
if len(node.snapshots) == 0 {
|
||||
continue
|
||||
}
|
||||
for j, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||
continue
|
||||
}
|
||||
if kv.Offset() >= node.endOffset {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(node.snapshots[j], node.endOffset) {
|
||||
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
|
||||
c.freeAll()
|
||||
c.activePath = []*trieNode{c.root}
|
||||
return
|
||||
}
|
||||
}
|
||||
if node.endOffset > ancestorOffset {
|
||||
pageInCount++
|
||||
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
|
||||
}
|
||||
}
|
||||
|
||||
// Align all caches to the minimum offset.
|
||||
c.activePath = newPath
|
||||
minOff := c.minCacheOffset()
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil && kv.Offset() != minOff {
|
||||
if !kv.Restore(nil, minOff) {
|
||||
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
||||
c.freeAll()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := len(c.activePath) - 1; i >= 0; i-- {
|
||||
if c.activePath[i].endOffset <= minOff {
|
||||
c.activePath = c.activePath[:i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pageOutCount > 0 || pageInCount > 0 {
|
||||
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
|
||||
}
|
||||
}
|
||||
|
||||
// snapshot creates a snapshot at the current cache position. During prefill,
|
||||
// it is called at branch points (user=false) to create restore points for
|
||||
// future diverging requests and with user=true to mark an explicit reusable
|
||||
// restore point.
|
||||
func (s *cacheSession) snapshot(user bool) {
|
||||
c := s.cache
|
||||
cacheOffset := c.minCacheOffset()
|
||||
if cacheOffset <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear pending intermediate snapshot if we've reached or passed it.
|
||||
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
|
||||
s.snapshotOffset = 0
|
||||
}
|
||||
|
||||
// The last node in activePath is the frontier where caches are advancing.
|
||||
// cacheOffset is always >= its endOffset: begin() restores caches to this
|
||||
// boundary and prefill advances monotonically forward.
|
||||
frontier := c.activePath[len(c.activePath)-1]
|
||||
|
||||
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
|
||||
if frontier.endOffset == cacheOffset {
|
||||
if user {
|
||||
frontier.user = true
|
||||
}
|
||||
if !frontier.hasAllSnapshots() {
|
||||
s.attachSnapshots(frontier, cacheOffset)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if frontier.endOffset > cacheOffset {
|
||||
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
|
||||
return
|
||||
}
|
||||
|
||||
// Advance the trie to cacheOffset — find or create a node there.
|
||||
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
|
||||
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
|
||||
|
||||
// Attach fresh snapshots from the live caches. Always use fresh
|
||||
// snapshots even if the node already has some (e.g. from splitNode's
|
||||
// Cache.Split which may be incomplete for non-splittable caches
|
||||
// like RecurrentCache).
|
||||
if user {
|
||||
frontier.user = true
|
||||
}
|
||||
s.attachSnapshots(frontier, cacheOffset)
|
||||
}
|
||||
|
||||
// advancePath advances the active path from the current frontier by matching
|
||||
// tokens against existing trie children, splitting partial matches, and
|
||||
// appending any remaining tokens as new nodes. Returns the new frontier.
|
||||
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
|
||||
// Check if existing children already cover some or all of tokens.
|
||||
// tokens may span multiple trie nodes when extending a previous run's
|
||||
// leaf and this snapshot now overlaps that same range.
|
||||
matchPath, matched := findBestMatch(frontier, tokens)
|
||||
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
|
||||
remaining := tokens[matched:]
|
||||
|
||||
// Check for a partial match within the last node's edge — if so, split it.
|
||||
if len(matchPath) > 1 {
|
||||
lastNode := matchPath[len(matchPath)-1]
|
||||
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
|
||||
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
||||
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Append traversed nodes (excluding frontier) to the active path.
|
||||
c.activePath = append(c.activePath, matchPath[1:]...)
|
||||
dest := matchPath[len(matchPath)-1]
|
||||
|
||||
if len(remaining) > 0 {
|
||||
// Drop non-user snapshots so appendTokens can extend in-place
|
||||
// rather than creating a new child node.
|
||||
if len(dest.children) == 0 && !dest.user {
|
||||
dest.setSnapshots(nil, &c.pagedOutBytes)
|
||||
}
|
||||
newDest := dest.appendTokens(c.root, remaining, endOffset)
|
||||
if newDest != dest {
|
||||
c.activePath = append(c.activePath, newDest)
|
||||
}
|
||||
dest = newDest
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
|
||||
// The node must be on the active path (and thus protected from eviction;
|
||||
// lastUsed is updated in close()). All non-nil caches must be at the same
|
||||
// offset (cacheOffset); a mismatch indicates a bug in the caller.
|
||||
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||
c := s.cache
|
||||
|
||||
if c.activePath[len(c.activePath)-1] != node {
|
||||
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
|
||||
return
|
||||
}
|
||||
|
||||
snaps := make([]cache.Snapshot, len(c.caches))
|
||||
for i, kv := range c.caches {
|
||||
if kv != nil {
|
||||
if kv.Offset() != cacheOffset {
|
||||
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
|
||||
}
|
||||
snaps[i] = kv.Snapshot(node.startOffset())
|
||||
}
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
slog.Debug("created snapshot", "offset", cacheOffset)
|
||||
c.enforceEvictionPolicy()
|
||||
}
|
||||
|
||||
// clear releases live caches and drops the trie so future requests cannot
|
||||
// reuse prompt state keyed only by token IDs.
|
||||
func (c *kvCache) clear() {
|
||||
c.freeAll()
|
||||
walkNodes(c.root, func(n *trieNode) bool {
|
||||
for _, s := range n.snapshots {
|
||||
if s != nil {
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
n.snapshots = nil
|
||||
return true
|
||||
})
|
||||
c.root = nil
|
||||
c.activePath = nil
|
||||
}
|
||||
|
||||
// freeAll releases all cache layers.
|
||||
func (c *kvCache) freeAll() {
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil {
|
||||
kv.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kvCache) minCacheOffset() int {
|
||||
offset := 0
|
||||
found := false
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); !found || off < offset {
|
||||
offset = off
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return offset
|
||||
}
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
if len(s.caches) == 0 {
|
||||
offset := s.cache.minCacheOffset()
|
||||
if offset <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
offset := -1
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, kv := range s.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
// Mixed cache types (e.g. recurrent + KV) can transiently report different
|
||||
// offsets, so use the minimum as the safe reusable token prefix.
|
||||
if off := kv.Offset(); offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
arrays = appendCacheState(arrays, kv)
|
||||
}
|
||||
if offset <= 0 {
|
||||
return
|
||||
arrays = append(arrays, kv.State()...)
|
||||
}
|
||||
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data.
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
stored := append(s.inputs, s.outputs...)
|
||||
if offset > len(stored) {
|
||||
offset = len(stored)
|
||||
}
|
||||
s.cache.tokens = stored[:offset]
|
||||
}
|
||||
// Advance the trie frontier with any newly generated tokens.
|
||||
c := s.cache
|
||||
if len(c.activePath) > 0 {
|
||||
frontier := c.activePath[len(c.activePath)-1]
|
||||
stored := append(s.inputs, s.outputs...)
|
||||
|
||||
// findRemaining finds the longest common prefix between tokens and the cached
|
||||
// sequence, trims stale cache entries, and returns the remaining tokens.
|
||||
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
||||
prefix := 0
|
||||
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
||||
prefix++
|
||||
}
|
||||
|
||||
// Always keep at least one token to re-evaluate so the
|
||||
// pipeline can seed token generation from it.
|
||||
if prefix == len(tokens) && prefix > 0 {
|
||||
prefix--
|
||||
}
|
||||
|
||||
if prefix < len(c.tokens) {
|
||||
if c.cachesCanTrim() {
|
||||
c.trimToPrefix(prefix)
|
||||
} else {
|
||||
c.free()
|
||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
||||
return tokens
|
||||
if offset > frontier.endOffset {
|
||||
newTokens := stored[frontier.endOffset:offset]
|
||||
c.advancePath(frontier, newTokens, offset)
|
||||
}
|
||||
now := time.Now()
|
||||
for _, node := range c.activePath {
|
||||
node.lastUsed = now
|
||||
}
|
||||
}
|
||||
|
||||
if prefix == 0 {
|
||||
slog.Info("Cache miss", "left", len(tokens))
|
||||
} else {
|
||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
||||
}
|
||||
return tokens[prefix:]
|
||||
}
|
||||
|
||||
func (c *kvCache) log() {
|
||||
if len(c.caches) == 0 {
|
||||
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
|
||||
func (c *kvCache) enforceEvictionPolicy() {
|
||||
if c.pagedOutBytes <= maxPagedOutBytes {
|
||||
return
|
||||
}
|
||||
offset := -1
|
||||
var totalBytes int
|
||||
|
||||
activeSet := make(map[*trieNode]bool, len(c.activePath))
|
||||
for _, n := range c.activePath {
|
||||
activeSet[n] = true
|
||||
}
|
||||
|
||||
for c.pagedOutBytes > maxPagedOutBytes {
|
||||
var best *trieNode
|
||||
walkNodes(c.root, func(n *trieNode) bool {
|
||||
if n == c.root || activeSet[n] || !n.hasSnapshots() {
|
||||
return true
|
||||
}
|
||||
// Evict: oldest, then deepest, then largest.
|
||||
if best == nil || cmp.Or(
|
||||
n.lastUsed.Compare(best.lastUsed),
|
||||
cmp.Compare(best.endOffset, n.endOffset),
|
||||
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
|
||||
) < 0 {
|
||||
best = n
|
||||
}
|
||||
return true
|
||||
})
|
||||
if best == nil {
|
||||
break
|
||||
}
|
||||
c.evictNode(best)
|
||||
}
|
||||
}
|
||||
|
||||
// evictNode evicts a single node from the trie, freeing its snapshot memory.
|
||||
func (c *kvCache) evictNode(node *trieNode) {
|
||||
if len(node.children) == 0 {
|
||||
// Leaf: remove entirely.
|
||||
parent := node.parent
|
||||
kind := "evicting leaf"
|
||||
if node.user {
|
||||
kind = "evicting user snapshot"
|
||||
}
|
||||
slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
removeNode(node, &c.pagedOutBytes)
|
||||
|
||||
// If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge.
|
||||
if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root {
|
||||
logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset))
|
||||
mergeWithChild(parent, c.caches, &c.pagedOutBytes)
|
||||
}
|
||||
} else if len(node.children) == 1 {
|
||||
// Interior snapshot node with one child: merge with child.
|
||||
slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
mergeWithChild(node, c.caches, &c.pagedOutBytes)
|
||||
} else {
|
||||
// Multi-child branch point: drop snapshots but keep the node.
|
||||
slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
node.setSnapshots(nil, &c.pagedOutBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kvCache) dumpTree() {
|
||||
// Summary stats
|
||||
var cacheBytes int
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
for _, a := range appendCacheState(nil, kv) {
|
||||
totalBytes += a.NumBytes()
|
||||
for _, a := range kv.State() {
|
||||
if a != nil {
|
||||
cacheBytes += a.NumBytes()
|
||||
}
|
||||
}
|
||||
}
|
||||
if offset < 0 {
|
||||
return
|
||||
|
||||
// Build active path set for marking.
|
||||
active := make(map[*trieNode]bool, len(c.activePath))
|
||||
for _, n := range c.activePath {
|
||||
active[n] = true
|
||||
}
|
||||
|
||||
var nodeCount, snapshotCount int
|
||||
var pagedBytes int64
|
||||
var lines []string
|
||||
var dump func(n *trieNode, prefix string, isLast bool)
|
||||
dump = func(n *trieNode, prefix string, isLast bool) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
nodeCount++
|
||||
|
||||
// Build connector
|
||||
var connector string
|
||||
if n.parent == nil {
|
||||
connector = ""
|
||||
} else if isLast {
|
||||
connector = prefix + "`-- "
|
||||
} else {
|
||||
connector = prefix + "|-- "
|
||||
}
|
||||
|
||||
// Node label
|
||||
nodeBytes := n.snapshotBytes()
|
||||
pagedBytes += nodeBytes
|
||||
|
||||
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
|
||||
if nodeBytes > 0 {
|
||||
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
||||
}
|
||||
var flags []string
|
||||
if n.user {
|
||||
flags = append(flags, "user")
|
||||
}
|
||||
if n.hasAllSnapshots() {
|
||||
snapshotCount++
|
||||
flags = append(flags, "snap")
|
||||
}
|
||||
if active[n] {
|
||||
flags = append(flags, "active")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
label += " (" + flags[0]
|
||||
for _, f := range flags[1:] {
|
||||
label += ", " + f
|
||||
}
|
||||
label += ")"
|
||||
}
|
||||
lines = append(lines, connector+label)
|
||||
|
||||
// Recurse children
|
||||
childPrefix := prefix
|
||||
if n.parent != nil {
|
||||
if isLast {
|
||||
childPrefix += " "
|
||||
} else {
|
||||
childPrefix += "| "
|
||||
}
|
||||
}
|
||||
for i, child := range n.children {
|
||||
dump(child, childPrefix, i == len(n.children)-1)
|
||||
}
|
||||
}
|
||||
dump(c.root, "", true)
|
||||
|
||||
offset := c.minCacheOffset()
|
||||
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
|
||||
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
|
||||
for i, l := range lines {
|
||||
if i == 0 {
|
||||
logutil.Trace("cache trie: " + l)
|
||||
} else {
|
||||
logutil.Trace(" " + l)
|
||||
}
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
||||
}
|
||||
|
||||
300
x/mlxrunner/cache/cache.go
vendored
300
x/mlxrunner/cache/cache.go
vendored
@@ -8,13 +8,34 @@ import (
|
||||
type Cache interface {
|
||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||
State() (keys, values *mlx.Array)
|
||||
CanTrim() bool
|
||||
Trim(int) int
|
||||
Clone() Cache
|
||||
State() []*mlx.Array
|
||||
Free()
|
||||
Offset() int
|
||||
Len() int
|
||||
|
||||
// Snapshot copies cache state from fromOffset to current offset into
|
||||
// pinned VRAM arrays. The active cache is unchanged.
|
||||
Snapshot(fromOffset int) Snapshot
|
||||
|
||||
// Restore brings the cache to target. If snapshot is nil, rewinds
|
||||
// using the cache's own live state.
|
||||
Restore(snapshot Snapshot, target int) bool
|
||||
|
||||
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
||||
// Takes ownership of both inputs.
|
||||
Merge(parent, child Snapshot) Snapshot
|
||||
|
||||
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
|
||||
// Takes ownership of the input. Cache types that cannot split
|
||||
// (e.g. recurrent) return (nil, snapshot).
|
||||
Split(snapshot Snapshot, at int) (parent, child Snapshot)
|
||||
}
|
||||
|
||||
// Snapshot is paged-out cache state that can be restored later.
|
||||
type Snapshot interface {
|
||||
// Size returns the byte size of the paged-out data (in VRAM).
|
||||
Size() int
|
||||
// Close unpins the snapshot's arrays so they can be freed by Sweep.
|
||||
Close()
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
@@ -59,40 +80,148 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
}
|
||||
}
|
||||
|
||||
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
|
||||
type kvSnapshot struct {
|
||||
keys, values *mlx.Array
|
||||
fromOffset, toOffset int
|
||||
}
|
||||
|
||||
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
|
||||
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
|
||||
|
||||
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
||||
if c.keys == nil || c.offset <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
from := max(0, fromOffset)
|
||||
to := c.offset
|
||||
|
||||
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||
kCopy := mlx.Copy(kSlice)
|
||||
vCopy := mlx.Copy(vSlice)
|
||||
mlx.Pin(kCopy, vCopy)
|
||||
mlx.AsyncEval(kCopy, vCopy)
|
||||
|
||||
return &kvSnapshot{
|
||||
keys: kCopy,
|
||||
values: vCopy,
|
||||
fromOffset: from,
|
||||
toOffset: to,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
// Rewind using live state — just clamp offset.
|
||||
target = max(0, min(target, c.offset))
|
||||
c.offset = target
|
||||
return true
|
||||
}
|
||||
|
||||
snap := snapshot.(*kvSnapshot)
|
||||
|
||||
// Check that the cache has data up to the snapshot's starting point.
|
||||
if c.offset < snap.fromOffset {
|
||||
return false
|
||||
}
|
||||
|
||||
// Rewind to snapshot start, then feed snapshot data through Update.
|
||||
c.offset = snap.fromOffset
|
||||
c.Update(snap.keys, snap.values)
|
||||
|
||||
// Clamp to target if needed (target may be less than full snapshot).
|
||||
if target < c.offset {
|
||||
c.offset = target
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
|
||||
if parent == nil || child == nil {
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
if child != nil {
|
||||
child.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
p := parent.(*kvSnapshot)
|
||||
ch := child.(*kvSnapshot)
|
||||
|
||||
mk := p.keys.Concatenate(2, ch.keys)
|
||||
mv := p.values.Concatenate(2, ch.values)
|
||||
mlx.Pin(mk, mv)
|
||||
mlx.AsyncEval(mk, mv)
|
||||
|
||||
p.Close()
|
||||
ch.Close()
|
||||
|
||||
return &kvSnapshot{
|
||||
keys: mk,
|
||||
values: mv,
|
||||
fromOffset: p.fromOffset,
|
||||
toOffset: ch.toOffset,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
if snapshot == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *KVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *KVCache) Clone() Cache {
|
||||
clone := &KVCache{
|
||||
keys: c.keys.Clone(),
|
||||
values: c.values.Clone(),
|
||||
offset: c.offset,
|
||||
step: c.step,
|
||||
snap := snapshot.(*kvSnapshot)
|
||||
splitIdx := at - snap.fromOffset
|
||||
seqLen := snap.toOffset - snap.fromOffset
|
||||
if splitIdx <= 0 {
|
||||
return nil, snapshot
|
||||
}
|
||||
mlx.Pin(clone.keys, clone.values)
|
||||
return clone
|
||||
if splitIdx >= seqLen {
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
|
||||
pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
|
||||
ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
|
||||
cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
|
||||
mlx.Pin(pk, pv, ck, cv)
|
||||
mlx.AsyncEval(pk, pv, ck, cv)
|
||||
|
||||
snap.Close()
|
||||
|
||||
p := &kvSnapshot{
|
||||
keys: pk,
|
||||
values: pv,
|
||||
fromOffset: snap.fromOffset,
|
||||
toOffset: at,
|
||||
}
|
||||
ch := &kvSnapshot{
|
||||
keys: ck,
|
||||
values: cv,
|
||||
fromOffset: at,
|
||||
toOffset: snap.toOffset,
|
||||
}
|
||||
return p, ch
|
||||
}
|
||||
|
||||
func (c *KVCache) Free() {
|
||||
mlx.Unpin(c.keys, c.values)
|
||||
c.keys, c.values = nil, nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
@@ -184,29 +313,104 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
||||
|
||||
func (c *RotatingKVCache) Trim(n int) int {
|
||||
n = min(c.offset, n)
|
||||
c.offset -= n
|
||||
c.idx -= n
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Clone() Cache {
|
||||
return &RotatingKVCache{
|
||||
maxSize: c.maxSize,
|
||||
idx: c.idx,
|
||||
KVCache: c.KVCache.Clone().(*KVCache),
|
||||
return []*mlx.Array{
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
|
||||
type rotatingSnapshot struct {
|
||||
kvSnapshot // embedded KV data
|
||||
idx int // buffer write position at snapshot time
|
||||
}
|
||||
|
||||
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
|
||||
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
|
||||
|
||||
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
||||
if c.keys == nil || c.offset <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
k := state[0].Clone()
|
||||
v := state[1].Clone()
|
||||
mlx.Pin(k, v)
|
||||
|
||||
return &rotatingSnapshot{
|
||||
kvSnapshot: kvSnapshot{
|
||||
keys: k,
|
||||
values: v,
|
||||
fromOffset: fromOffset,
|
||||
toOffset: c.offset,
|
||||
},
|
||||
idx: c.idx,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
// Live rewind is only safe when the buffer hasn't filled yet
|
||||
// (offset <= maxSize). Once the window has shifted, rewinding
|
||||
// leaves fewer than maxSize trailing tokens to attend to —
|
||||
// a snapshot is required to restore the full window.
|
||||
if c.offset > c.maxSize {
|
||||
return false
|
||||
}
|
||||
target = max(0, min(target, c.offset))
|
||||
c.offset = target
|
||||
c.idx = target
|
||||
return true
|
||||
}
|
||||
|
||||
snap := snapshot.(*rotatingSnapshot)
|
||||
|
||||
// Reject if clamping would leave an incomplete window.
|
||||
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
||||
return false
|
||||
}
|
||||
|
||||
// Restore from snapshot: rebuild buffer state.
|
||||
// Free existing state first.
|
||||
if c.keys != nil {
|
||||
mlx.Unpin(c.keys, c.values)
|
||||
}
|
||||
c.keys = snap.keys.Clone()
|
||||
c.values = snap.values.Clone()
|
||||
mlx.Pin(c.keys, c.values)
|
||||
c.offset = snap.toOffset
|
||||
c.idx = snap.idx
|
||||
|
||||
// Clamp to target if needed.
|
||||
if target < c.offset {
|
||||
target = max(0, target)
|
||||
c.offset = target
|
||||
c.idx = target
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
|
||||
// For rotating caches, the child snapshot supersedes the parent
|
||||
// since it contains the full window state.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
// Rotating cache snapshots contain the full window state.
|
||||
// Cannot cleanly split a ring buffer at an arbitrary point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Free() {
|
||||
c.KVCache.Free()
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
271
x/mlxrunner/cache/cache_test.go
vendored
Normal file
271
x/mlxrunner/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,271 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
// Snapshot [5, 10).
|
||||
snap := c.Snapshot(5)
|
||||
|
||||
// Free the cache completely — offset is now 0.
|
||||
c.Free()
|
||||
|
||||
// Restore should fail because cache doesn't have data up to fromOffset=5.
|
||||
if c.Restore(snap, 10) {
|
||||
t.Fatal("expected Restore to fail with no base data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
|
||||
// is preserved through a snapshot→free→restore cycle.
|
||||
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
if snap == nil {
|
||||
t.Fatal("Snapshot returned nil")
|
||||
}
|
||||
|
||||
// Free and restore to a fresh cache.
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
}
|
||||
|
||||
// Verify State() returns arrays with correct sequence dimension.
|
||||
state := c2.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
// keys shape: [B, H, seqLen, Dk]
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
if state[1].Dim(2) != 10 {
|
||||
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
|
||||
// that can be sequentially restored to rebuild the original cache state.
|
||||
func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
parent, child := c.Split(snap, 5)
|
||||
if parent == nil || child == nil {
|
||||
t.Fatal("Split returned nil")
|
||||
}
|
||||
|
||||
// Restore parent → offset=5, seq dim=5.
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(parent, 5) {
|
||||
t.Fatal("Restore(parent) failed")
|
||||
}
|
||||
if c2.Offset() != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
|
||||
}
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 5 {
|
||||
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
|
||||
}
|
||||
|
||||
// Restore child on top → offset=10, seq dim=10.
|
||||
if !c2.Restore(child, 10) {
|
||||
t.Fatal("Restore(child) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", c2.Offset())
|
||||
}
|
||||
state = c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
|
||||
// produces a snapshot equivalent to the original.
|
||||
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
parent, child := c.Split(snap, 6)
|
||||
merged := c.Merge(parent, child)
|
||||
if merged == nil {
|
||||
t.Fatal("Merge returned nil")
|
||||
}
|
||||
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(merged, 10) {
|
||||
t.Fatal("Restore(merged) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
}
|
||||
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
if state[1].Dim(2) != 10 {
|
||||
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
// Offset 3 is outside the window.
|
||||
if c.Restore(nil, 3) {
|
||||
t.Fatal("Restore(nil, 3) should fail when outside window")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
|
||||
// from a snapshot, the rotating cache has the correct window of data.
|
||||
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
if snap == nil {
|
||||
t.Fatal("Snapshot returned nil")
|
||||
}
|
||||
|
||||
// Feed 5 more tokens.
|
||||
for range 5 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
// Restore to offset 10.
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
|
||||
seqDim := state[0].Dim(2)
|
||||
if seqDim != 4 {
|
||||
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
|
||||
// snapshot correctly preserves the write position (idx), so subsequent
|
||||
// single-token updates land in the right buffer slot.
|
||||
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Fill the window: 6 tokens into a size-4 window.
|
||||
// After this, idx has wrapped and the buffer has rotated.
|
||||
for range 6 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset = %d, want 6", c.Offset())
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
|
||||
// Mutate the cache further so live state diverges from snapshot.
|
||||
for range 3 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
}
|
||||
|
||||
// Restore to snapshot state.
|
||||
if !c.Restore(snap, 6) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", c.Offset())
|
||||
}
|
||||
|
||||
// Feed one more token. If idx was restored correctly, this should
|
||||
// produce a valid window of size 4 at offset 7.
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(k, v)
|
||||
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
|
||||
}
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
seqDim := state[0].Dim(2)
|
||||
if seqDim != 4 {
|
||||
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||
}
|
||||
}
|
||||
88
x/mlxrunner/cache/recurrent.go
vendored
88
x/mlxrunner/cache/recurrent.go
vendored
@@ -56,16 +56,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
||||
return detached
|
||||
}
|
||||
|
||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
||||
if a == nil || !a.Valid() {
|
||||
return nil
|
||||
}
|
||||
snap := mlx.Copy(a)
|
||||
mlx.Eval(snap)
|
||||
mlx.Pin(snap)
|
||||
return snap
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
@@ -123,30 +113,69 @@ func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array
|
||||
return keys, values
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
||||
return c.convState, c.deltaState
|
||||
func (c *RecurrentCache) State() []*mlx.Array {
|
||||
return []*mlx.Array{c.convState, c.deltaState}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
||||
|
||||
func (c *RecurrentCache) Trim(n int) int {
|
||||
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
||||
_ = n
|
||||
return 0
|
||||
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
|
||||
// does not depend on any parent state.
|
||||
type recurrentSnapshot struct {
|
||||
convState, deltaState *mlx.Array
|
||||
offset int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Clone() Cache {
|
||||
clone := &RecurrentCache{
|
||||
offset: c.offset,
|
||||
convTail: c.convTail,
|
||||
convDim: c.convDim,
|
||||
numVHeads: c.numVHeads,
|
||||
headVDim: c.headVDim,
|
||||
headKDim: c.headKDim,
|
||||
convState: snapshotPinned(c.convState),
|
||||
deltaState: snapshotPinned(c.deltaState),
|
||||
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
|
||||
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
|
||||
|
||||
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
|
||||
// Recurrent state is not position-sliceable — always snapshot the full state.
|
||||
if c.convState == nil && c.deltaState == nil {
|
||||
return nil
|
||||
}
|
||||
return clone
|
||||
|
||||
snap := &recurrentSnapshot{offset: c.offset}
|
||||
snap.convState = c.convState.Clone()
|
||||
snap.deltaState = c.deltaState.Clone()
|
||||
mlx.Pin(snap.convState, snap.deltaState)
|
||||
|
||||
return snap
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
// Recurrent state is cumulative and can't rewind. Only succeed
|
||||
// if we're already at the target (no-op).
|
||||
return target == c.offset
|
||||
}
|
||||
|
||||
snap := snapshot.(*recurrentSnapshot)
|
||||
|
||||
// Recurrent state encodes all tokens up to snap.offset. Restoring
|
||||
// to a target before that would leave stale state from tokens
|
||||
// [target, snap.offset) baked in. Only allow restoring forward.
|
||||
if target < snap.offset {
|
||||
return false
|
||||
}
|
||||
|
||||
c.convState = c.setStateRaw(c.convState, snap.convState)
|
||||
c.deltaState = c.setStateRaw(c.deltaState, snap.deltaState)
|
||||
c.offset = snap.offset
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
|
||||
// Recurrent snapshots are self-contained — child supersedes parent.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
// Recurrent state is cumulative and not position-sliceable.
|
||||
// Cannot recover intermediate state at the split point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
@@ -156,4 +185,3 @@ func (c *RecurrentCache) Free() {
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
func (c *RecurrentCache) Len() int { return c.offset }
|
||||
|
||||
44
x/mlxrunner/cache/recurrent_test.go
vendored
Normal file
44
x/mlxrunner/cache/recurrent_test.go
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
|
||||
// allows restoring forward (target >= snapshot offset), never backward.
|
||||
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||
_ = c.ConvState(1, mlx.DTypeFloat16)
|
||||
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
||||
c.Advance(10)
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
|
||||
c.Advance(5) // now at 15
|
||||
|
||||
// Restore backward should fail.
|
||||
if c.Restore(snap, 5) {
|
||||
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
|
||||
}
|
||||
|
||||
// Restore to exact snap offset should succeed.
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore(snap, 10) should succeed")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
}
|
||||
|
||||
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
|
||||
snap2 := c.Snapshot(0)
|
||||
if !c.Restore(snap2, 15) {
|
||||
t.Fatal("Restore(snap, 15) should succeed")
|
||||
}
|
||||
// Recurrent state is at snap.offset (10), not target (15).
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
|
||||
}
|
||||
}
|
||||
859
x/mlxrunner/cache_test.go
Normal file
859
x/mlxrunner/cache_test.go
Normal file
@@ -0,0 +1,859 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// snapshotTracker records every fakeSnapshot created and every Close() call
|
||||
// so tests can detect leaked (created but never closed) or double-closed snapshots.
|
||||
type snapshotTracker struct {
|
||||
all []*fakeSnapshot
|
||||
}
|
||||
|
||||
func (tr *snapshotTracker) track(s *fakeSnapshot) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.tracker = tr
|
||||
tr.all = append(tr.all, s)
|
||||
}
|
||||
|
||||
// Fake caches that store actual token sequences so tests can verify the right
|
||||
// data was restored, not just the right offset.
|
||||
|
||||
// fakeSnapshot stores a copy of the token sub-sequence it covers.
|
||||
type fakeSnapshot struct {
|
||||
tokens []int32
|
||||
from, to int
|
||||
byteSize int // configurable for eviction tests
|
||||
|
||||
tracker *snapshotTracker
|
||||
closeCount int
|
||||
}
|
||||
|
||||
func (s *fakeSnapshot) Size() int { return s.byteSize }
|
||||
func (s *fakeSnapshot) Close() {
|
||||
s.closeCount++
|
||||
}
|
||||
|
||||
// fakeRewindableCache tracks the full token sequence and supports
|
||||
// arbitrary rewind via Restore(nil, target).
|
||||
type fakeRewindableCache struct {
|
||||
tokens []int32
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
|
||||
|
||||
func (c *fakeRewindableCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
if fromOffset >= len(c.tokens) {
|
||||
return nil
|
||||
}
|
||||
from := fromOffset
|
||||
if from < 0 {
|
||||
from = 0
|
||||
}
|
||||
s := &fakeSnapshot{
|
||||
tokens: slices.Clone(c.tokens[from:]),
|
||||
from: from,
|
||||
to: len(c.tokens),
|
||||
}
|
||||
c.tracker.track(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
// Rewind live state.
|
||||
if target < 0 {
|
||||
target = 0
|
||||
}
|
||||
if target > len(c.tokens) {
|
||||
target = len(c.tokens)
|
||||
}
|
||||
c.tokens = c.tokens[:target]
|
||||
return true
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if len(c.tokens) < s.from {
|
||||
return false // don't have base data up to snapshot start
|
||||
}
|
||||
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
||||
if target < len(c.tokens) {
|
||||
c.tokens = c.tokens[:target]
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||
if parent == nil || child == nil {
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
if child != nil {
|
||||
child.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
p := parent.(*fakeSnapshot)
|
||||
ch := child.(*fakeSnapshot)
|
||||
merged := make([]int32, len(p.tokens)+len(ch.tokens))
|
||||
copy(merged, p.tokens)
|
||||
copy(merged[len(p.tokens):], ch.tokens)
|
||||
s := &fakeSnapshot{
|
||||
tokens: merged,
|
||||
from: p.from,
|
||||
to: ch.to,
|
||||
byteSize: p.byteSize + ch.byteSize,
|
||||
}
|
||||
c.tracker.track(s)
|
||||
p.Close()
|
||||
ch.Close()
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||
if snapshot == nil {
|
||||
return nil, nil
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
relAt := at - s.from
|
||||
if relAt <= 0 {
|
||||
return nil, snapshot
|
||||
}
|
||||
if relAt >= len(s.tokens) {
|
||||
return snapshot, nil
|
||||
}
|
||||
p := &fakeSnapshot{
|
||||
tokens: slices.Clone(s.tokens[:relAt]),
|
||||
from: s.from,
|
||||
to: at,
|
||||
byteSize: s.byteSize,
|
||||
}
|
||||
ch := &fakeSnapshot{
|
||||
tokens: slices.Clone(s.tokens[relAt:]),
|
||||
from: at,
|
||||
to: s.to,
|
||||
byteSize: s.byteSize,
|
||||
}
|
||||
c.tracker.track(p)
|
||||
c.tracker.track(ch)
|
||||
s.Close()
|
||||
return p, ch
|
||||
}
|
||||
|
||||
// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full
|
||||
// token sequence but only the trailing maxSize tokens are "live" in the window.
|
||||
// Once the window fills, live rewind is impossible without a snapshot.
|
||||
type fakeSlidingWindowCache struct {
|
||||
tokens []int32
|
||||
maxSize int
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
|
||||
|
||||
func (c *fakeSlidingWindowCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
// Snapshot captures the full window state (like RotatingKVCache.Snapshot).
|
||||
s := &fakeSnapshot{
|
||||
tokens: slices.Clone(c.tokens),
|
||||
from: 0,
|
||||
to: len(c.tokens),
|
||||
}
|
||||
c.tracker.track(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
if target == len(c.tokens) {
|
||||
return true
|
||||
}
|
||||
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
||||
if len(c.tokens) > c.maxSize {
|
||||
return false
|
||||
}
|
||||
c.tokens = c.tokens[:target]
|
||||
return true
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
c.tokens = slices.Clone(s.tokens)
|
||||
if target < len(c.tokens) {
|
||||
c.tokens = c.tokens[:target]
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||
// Child supersedes parent for sliding window (full window state).
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||
// Can't split a ring buffer at an arbitrary point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
// fakeRecurrentCache models RecurrentCache semantics: stores tokens
|
||||
// but cannot rewind without a snapshot.
|
||||
type fakeRecurrentCache struct {
|
||||
tokens []int32
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
|
||||
|
||||
func (c *fakeRecurrentCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
// Recurrent state is cumulative; snapshot captures the full state.
|
||||
if len(c.tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
s := &fakeSnapshot{
|
||||
tokens: slices.Clone(c.tokens),
|
||||
from: 0,
|
||||
to: len(c.tokens),
|
||||
}
|
||||
c.tracker.track(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
return target == len(c.tokens) // can only no-op
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if target < s.to {
|
||||
return false // can't go backward
|
||||
}
|
||||
c.tokens = slices.Clone(s.tokens)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||
// Child supersedes parent for cumulative state.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||
return nil, snapshot // can't split cumulative state
|
||||
}
|
||||
|
||||
type feedableCache interface {
|
||||
cache.Cache
|
||||
feed(tokens []int32)
|
||||
}
|
||||
|
||||
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
||||
type testEnv struct {
|
||||
kvc *kvCache
|
||||
caches []cache.Cache // typed references for assertions
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
// newTransformerEnv creates a test environment with a single rewindable cache
|
||||
// (pure transformer model).
|
||||
func newTransformerEnv() *testEnv {
|
||||
tracker := &snapshotTracker{}
|
||||
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tracker,
|
||||
}
|
||||
}
|
||||
|
||||
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
||||
// one sliding window cache (Mistral-style architecture).
|
||||
func newSlidingWindowEnv() *testEnv {
|
||||
tr := &snapshotTracker{}
|
||||
rc := &fakeRewindableCache{tracker: tr}
|
||||
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
|
||||
caches := []cache.Cache{rc, sw}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
}
|
||||
}
|
||||
|
||||
// newRecurrentEnv creates a test environment with one rewindable cache and one
|
||||
// non-rewindable cache (Jamba-style architecture).
|
||||
func newRecurrentEnv() *testEnv {
|
||||
tr := &snapshotTracker{}
|
||||
rc := &fakeRewindableCache{tracker: tr}
|
||||
nrc := &fakeRecurrentCache{tracker: tr}
|
||||
caches := []cache.Cache{rc, nrc}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
}
|
||||
}
|
||||
|
||||
// assertAllTokens checks that every cache in the environment contains exactly
|
||||
// the expected token sequence.
|
||||
func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) {
|
||||
t.Helper()
|
||||
for i, c := range e.caches {
|
||||
assertTokens(t, label, c, expected)
|
||||
// Verify all caches report the same offset.
|
||||
if i > 0 && c.Offset() != e.caches[0].Offset() {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
||||
label, i, c.Offset(), e.caches[0].Offset())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// simulateRequest mirrors the production pipeline lifecycle:
|
||||
// begin -> prefill with snapshot(false) at branch points -> generate -> close
|
||||
|
||||
type requestResult struct {
|
||||
remaining []int32
|
||||
snapshotOffset int
|
||||
}
|
||||
|
||||
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
||||
// a user snapshot (snapshot(true)) is created at that offset during prefill.
|
||||
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
||||
t.Helper()
|
||||
|
||||
userSnapAt := 0
|
||||
if len(userSnapshotAt) > 0 {
|
||||
userSnapAt = userSnapshotAt[0]
|
||||
}
|
||||
|
||||
session := kvc.begin(nil, inputs)
|
||||
result := requestResult{
|
||||
remaining: slices.Clone(session.remaining),
|
||||
snapshotOffset: session.snapshotOffset,
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "after begin")
|
||||
|
||||
baseOffset := kvc.minCacheOffset()
|
||||
remaining := inputs[baseOffset:]
|
||||
|
||||
// Collect snapshot points (offset -> user flag) in ascending order.
|
||||
type snapPoint struct {
|
||||
offset int
|
||||
user bool
|
||||
}
|
||||
var points []snapPoint
|
||||
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
|
||||
points = append(points, snapPoint{session.snapshotOffset, false})
|
||||
}
|
||||
if userSnapAt > 0 && userSnapAt > baseOffset {
|
||||
points = append(points, snapPoint{userSnapAt, true})
|
||||
}
|
||||
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
|
||||
|
||||
// Prefill: feed tokens, pausing at each snapshot point.
|
||||
for _, sp := range points {
|
||||
count := sp.offset - baseOffset
|
||||
if count > len(remaining) {
|
||||
break
|
||||
}
|
||||
if count > 0 {
|
||||
feedAll(kvc.caches, remaining[:count])
|
||||
remaining = remaining[count:]
|
||||
baseOffset = sp.offset
|
||||
}
|
||||
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
||||
session.snapshot(sp.user)
|
||||
}
|
||||
|
||||
// Feed rest of input tokens.
|
||||
if len(remaining) > 0 {
|
||||
feedAll(kvc.caches, remaining)
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "after prefill")
|
||||
|
||||
// Generate tokens.
|
||||
if len(generated) > 0 {
|
||||
session.outputs = generated
|
||||
feedAll(kvc.caches, generated)
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "before close")
|
||||
session.close()
|
||||
return result
|
||||
}
|
||||
|
||||
func feedAll(caches []cache.Cache, tokens []int32) {
|
||||
for _, c := range caches {
|
||||
if fc, ok := c.(feedableCache); ok {
|
||||
fc.feed(tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assertCacheOffsetAlignment verifies all caches report the same offset.
|
||||
func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
||||
t.Helper()
|
||||
if len(kvc.caches) < 2 {
|
||||
return
|
||||
}
|
||||
expected := kvc.caches[0].Offset()
|
||||
for i := 1; i < len(kvc.caches); i++ {
|
||||
if got := kvc.caches[i].Offset(); got != expected {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assertTokens checks that a feedable cache contains the expected token sequence.
|
||||
// For sliding window caches, only the trailing maxSize tokens are checked.
|
||||
func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) {
|
||||
t.Helper()
|
||||
switch fc := c.(type) {
|
||||
case *fakeRewindableCache:
|
||||
if !slices.Equal(fc.tokens, expected) {
|
||||
t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
||||
}
|
||||
case *fakeSlidingWindowCache:
|
||||
// Sliding window stores full history but only trailing maxSize are live.
|
||||
// Verify the full token sequence matches (the window semantics are
|
||||
// enforced by Snapshot/Restore, not by the token log).
|
||||
if !slices.Equal(fc.tokens, expected) {
|
||||
t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected)
|
||||
}
|
||||
case *fakeRecurrentCache:
|
||||
if !slices.Equal(fc.tokens, expected) {
|
||||
t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("%s: unknown cache type %T", label, c)
|
||||
}
|
||||
}
|
||||
|
||||
// checkTrieInvariants walks the trie and checks structural invariants.
|
||||
func checkTrieInvariants(t *testing.T, root *trieNode) {
|
||||
t.Helper()
|
||||
walkNodes(root, func(n *trieNode) bool {
|
||||
if n.parent != nil {
|
||||
if n.startOffset() != n.parent.endOffset {
|
||||
t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d",
|
||||
n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset)
|
||||
}
|
||||
}
|
||||
if len(n.tokens) != n.endOffset-n.startOffset() {
|
||||
t.Errorf("node [%d,%d): token count %d != offset span %d",
|
||||
n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset())
|
||||
}
|
||||
for _, c := range n.children {
|
||||
if c.parent != n {
|
||||
t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset)
|
||||
}
|
||||
}
|
||||
// No two siblings should start with the same token.
|
||||
seen := make(map[int32]bool)
|
||||
for _, c := range n.children {
|
||||
if len(c.tokens) > 0 {
|
||||
first := c.tokens[0]
|
||||
if seen[first] {
|
||||
t.Errorf("node [%d,%d): duplicate sibling first token %d",
|
||||
n.startOffset(), n.endOffset, first)
|
||||
}
|
||||
seen[first] = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// checkSnapshotLeaks verifies that every tracked snapshot is either still live
|
||||
// in the trie (closeCount == 0) or has been closed exactly once. It reports
|
||||
// leaked snapshots (not in trie, never closed) and double-closes.
|
||||
func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) {
|
||||
t.Helper()
|
||||
if tracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect all live snapshots still referenced by trie nodes.
|
||||
live := make(map[*fakeSnapshot]bool)
|
||||
walkNodes(root, func(n *trieNode) bool {
|
||||
for _, s := range n.snapshots {
|
||||
if s != nil {
|
||||
if fs, ok := s.(*fakeSnapshot); ok {
|
||||
live[fs] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
for i, s := range tracker.all {
|
||||
if live[s] {
|
||||
if s.closeCount != 0 {
|
||||
t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)",
|
||||
i, s.from, s.to, s.closeCount)
|
||||
}
|
||||
} else {
|
||||
if s.closeCount == 0 {
|
||||
t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie",
|
||||
i, s.from, s.to)
|
||||
} else if s.closeCount > 1 {
|
||||
t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times",
|
||||
i, s.from, s.to, s.closeCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forEachEnv runs fn as subtests for three realistic model configurations:
|
||||
// pure transformer, transformer + sliding window (Mistral-style), and
|
||||
// transformer + recurrent (Jamba-style). Leak checking runs automatically
|
||||
// at the end of each subtest.
|
||||
func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) {
|
||||
t.Helper()
|
||||
run := func(t *testing.T, env *testEnv) {
|
||||
t.Cleanup(func() {
|
||||
checkSnapshotLeaks(t, env.tracker, env.kvc.root)
|
||||
})
|
||||
fn(t, env)
|
||||
}
|
||||
t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) })
|
||||
t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) })
|
||||
t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) })
|
||||
}
|
||||
|
||||
// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle:
|
||||
// two conversations share a prefix and diverge, creating a branch point.
|
||||
// A third conversation extends the first. Verifies trie structure, cache
|
||||
// hit lengths, and that semantic caches contain the correct token sequences.
|
||||
func TestBranchCreationAndReuse(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss.
|
||||
resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21})
|
||||
if len(resA.remaining) != 8 {
|
||||
t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||
|
||||
// Verify trie was populated by close().
|
||||
_, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||
if mA != 10 {
|
||||
t.Fatalf("A findable: expected 10 matched, got %d", mA)
|
||||
}
|
||||
|
||||
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
||||
// Partial match in A's edge triggers snapshotOffset.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||
if resB.snapshotOffset != 5 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||
}
|
||||
// Cache was rewound to 0 (partial match truncates path to root),
|
||||
// so all tokens were re-evaluated.
|
||||
if len(resB.remaining) != 8 {
|
||||
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||
|
||||
// Both A and B should be findable in the trie.
|
||||
_, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||
if mA2 < 5 {
|
||||
t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2)
|
||||
}
|
||||
_, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||
if mB < 5 {
|
||||
t.Fatalf("B findable: expected >= 5 matched, got %d", mB)
|
||||
}
|
||||
|
||||
// Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix.
|
||||
// Should get a cache hit for the shared prefix.
|
||||
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil)
|
||||
if len(resC.remaining) >= 10 {
|
||||
t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact
|
||||
// same prompt is requested twice, the cache does not overclaim cached work.
|
||||
// The last token must be re-evaluated to seed generation.
|
||||
func TestExactMatchSeedBehavior(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: first time.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
|
||||
// Request B: identical prompt. Holdback means matched=4, partial in
|
||||
// the 5-token edge, so path truncates to root and all tokens are
|
||||
// re-evaluated. snapshotOffset should be set at the holdback point.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
||||
if len(resB.remaining) != 5 {
|
||||
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
|
||||
}
|
||||
if resB.snapshotOffset != 4 {
|
||||
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestConversationResumption tests the most common pattern: user sends a message,
|
||||
// gets a response, then sends a follow-up. The follow-up should reuse the cached
|
||||
// prefix (system prompt + first turn + assistant response).
|
||||
func TestConversationResumption(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Turn 1: system prompt + user message, assistant generates response.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12})
|
||||
env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12})
|
||||
|
||||
// Turn 2: full history + new user message. Should get a cache hit on
|
||||
// the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21].
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30})
|
||||
if len(resB.remaining) > 5 {
|
||||
t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30})
|
||||
|
||||
// Turn 3: even longer history.
|
||||
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil)
|
||||
if len(resC.remaining) > 5 {
|
||||
t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestEvictionPreservesActiveConversations creates multiple conversations sharing
|
||||
// a system prompt, triggers eviction via large snapshot sizes, and verifies the
|
||||
// active path and shared prefix survive while memory stays bounded.
|
||||
func TestEvictionPreservesActiveConversations(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
systemPrompt := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Create 5 conversations with unique suffixes.
|
||||
for i := range 5 {
|
||||
suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)}
|
||||
inputs := append(slices.Clone(systemPrompt), suffix...)
|
||||
simulateRequest(t, kvc, inputs, []int32{int32(200 + i)})
|
||||
}
|
||||
|
||||
// Inflate snapshot sizes to trigger eviction.
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if !n.hasSnapshots() {
|
||||
return true
|
||||
}
|
||||
snaps := make([]cache.Snapshot, len(n.snapshots))
|
||||
for i, s := range n.snapshots {
|
||||
if s != nil {
|
||||
snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot
|
||||
}
|
||||
}
|
||||
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
||||
return true
|
||||
})
|
||||
|
||||
// Run eviction.
|
||||
kvc.enforceEvictionPolicy()
|
||||
|
||||
// Memory should be within limits.
|
||||
if kvc.pagedOutBytes > maxPagedOutBytes {
|
||||
t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes)
|
||||
}
|
||||
|
||||
// Active path should be untouched.
|
||||
if len(kvc.activePath) < 2 {
|
||||
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
|
||||
}
|
||||
|
||||
// System prompt prefix should still be findable (evicting a
|
||||
// multi-child branch point only drops snapshots, not the node).
|
||||
_, matched := findBestMatch(kvc.root, systemPrompt)
|
||||
if matched < len(systemPrompt) {
|
||||
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots
|
||||
// (snapshot(true)) resist structural changes that would destroy them:
|
||||
// - A user node forces new tokens into a child instead of extending in-place
|
||||
// - The snapshot remains restorable after other branches are added
|
||||
func TestUserSnapshotPreservesRestorePoint(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: user snapshot at offset 5, then generate.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5)
|
||||
|
||||
assertUserNodeExists(t, kvc, "after A")
|
||||
|
||||
// Request B: extends A's prefix. The user node at offset 5 should
|
||||
// force tokens into a child rather than extending in-place.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil)
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21})
|
||||
assertUserNodeExists(t, kvc, "after B")
|
||||
|
||||
// Request C: diverge from the user node.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40})
|
||||
|
||||
// Request D: switch back to A's branch — user snapshot still restorable.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil)
|
||||
env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted,
|
||||
// a user-marked parent node is not auto-merged with its remaining single child.
|
||||
func TestUserSnapshotResistsAutoMerge(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: user snapshot at offset 3, then continue to offset 5.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3)
|
||||
|
||||
// Request B: diverges at the user node, creating a second child.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20})
|
||||
|
||||
userNode := findUserNode(t, kvc)
|
||||
if len(userNode.children) != 2 {
|
||||
t.Fatalf("user node children = %d, want 2", len(userNode.children))
|
||||
}
|
||||
|
||||
// Inflate snapshot sizes and evict. The non-active branch should be
|
||||
// evicted, leaving the user node with one child.
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if !n.hasSnapshots() {
|
||||
return true
|
||||
}
|
||||
snaps := make([]cache.Snapshot, len(n.snapshots))
|
||||
for i, s := range n.snapshots {
|
||||
if s != nil {
|
||||
snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024}
|
||||
}
|
||||
}
|
||||
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
||||
return true
|
||||
})
|
||||
kvc.enforceEvictionPolicy()
|
||||
|
||||
// The user node should still exist (not auto-merged) even with one child.
|
||||
assertUserNodeExists(t, kvc, "after eviction")
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
func findUserNode(t *testing.T, kvc *kvCache) *trieNode {
|
||||
t.Helper()
|
||||
var found *trieNode
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if n.user {
|
||||
found = n
|
||||
}
|
||||
return true
|
||||
})
|
||||
if found == nil {
|
||||
t.Fatal("no user-marked node found")
|
||||
}
|
||||
return found
|
||||
}
|
||||
|
||||
func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) {
|
||||
t.Helper()
|
||||
var exists bool
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if n.user {
|
||||
exists = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !exists {
|
||||
t.Fatalf("%s: no user-marked node found", label)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBranchSwitchRestoresCorrectState exercises switching back to an older
|
||||
// branch after working on a different one, verifying that the restored cache
|
||||
// state contains the correct token sequence for both rewindable and
|
||||
// non-rewindable caches.
|
||||
func TestBranchSwitchRestoresCorrectState(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: [1,2,3,4,5] + generate [10,11]
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11})
|
||||
|
||||
// Request B: [1,2,3,6,7] — diverges at token 4
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13})
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13})
|
||||
|
||||
// Request C: switch back to A's branch [1,2,3,4,5,10,11,20]
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil)
|
||||
env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
296
x/mlxrunner/cache_trie.go
Normal file
296
x/mlxrunner/cache_trie.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
// trieNode represents a node in the compressed prefix trie for KV cache branching.
|
||||
// Each node stores a compressed edge (multiple tokens) and optional paged-out
|
||||
// snapshot data per cache layer.
|
||||
type trieNode struct {
|
||||
tokens []int32 // compressed edge — multiple tokens per node
|
||||
endOffset int // cumulative tokens from root to end of this node
|
||||
parent *trieNode
|
||||
children []*trieNode
|
||||
lastUsed time.Time // for LRU eviction
|
||||
snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out)
|
||||
user bool // true = explicit restore point (resist auto-merge)
|
||||
}
|
||||
|
||||
// startOffset returns the cumulative token offset at the start of this node's edge.
|
||||
func (n *trieNode) startOffset() int {
|
||||
return n.endOffset - len(n.tokens)
|
||||
}
|
||||
|
||||
// snapshotBytes returns the total bytes of paged-out snapshots on this node.
|
||||
func (n *trieNode) snapshotBytes() int64 {
|
||||
var total int64
|
||||
for _, s := range n.snapshots {
|
||||
if s != nil {
|
||||
total += int64(s.Size())
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// setSnapshots replaces this node's snapshots with snaps and closes the old ones.
|
||||
// If counter is non-nil, the net byte delta is applied to it.
|
||||
func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) {
|
||||
old := n.swapSnapshots(snaps, counter)
|
||||
for _, s := range old {
|
||||
if s != nil {
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swapSnapshots is like setSnapshots but returns the previous snapshots
|
||||
// without closing them. Use this when the old snapshots will be consumed
|
||||
// (e.g. by Split/Merge).
|
||||
func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot {
|
||||
old := n.snapshots
|
||||
if counter != nil {
|
||||
*counter -= n.snapshotBytes()
|
||||
}
|
||||
n.snapshots = snaps
|
||||
if counter != nil {
|
||||
*counter += n.snapshotBytes()
|
||||
}
|
||||
return old
|
||||
}
|
||||
|
||||
// hasSnapshots returns true if any layer has snapshot data.
|
||||
func (n *trieNode) hasSnapshots() bool {
|
||||
return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil })
|
||||
}
|
||||
|
||||
// hasAllSnapshots returns true if every layer has snapshot data.
|
||||
func (n *trieNode) hasAllSnapshots() bool {
|
||||
return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil)
|
||||
}
|
||||
|
||||
// findBestMatch walks the trie matching input tokens, returning the path of
|
||||
// nodes traversed and the total number of tokens matched.
|
||||
func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) {
|
||||
if root == nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
path = []*trieNode{root}
|
||||
pos := 0
|
||||
|
||||
node := root
|
||||
for pos < len(tokens) {
|
||||
// When multiple children share the same first token (e.g. after
|
||||
// a split), prefer the child whose full edge matches over one
|
||||
// that only partially matches. This is just being defensive - it
|
||||
// shouldn't actually happen.
|
||||
var best *trieNode
|
||||
bestMatched := 0
|
||||
bestFull := false
|
||||
for _, child := range node.children {
|
||||
edge := child.tokens
|
||||
if len(edge) == 0 {
|
||||
continue
|
||||
}
|
||||
if edge[0] != tokens[pos] {
|
||||
continue
|
||||
}
|
||||
// Count matching tokens in this child's edge.
|
||||
j := 0
|
||||
for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] {
|
||||
j++
|
||||
}
|
||||
full := j == len(edge)
|
||||
// Prefer full edge matches; among same type, prefer longer.
|
||||
if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) {
|
||||
best = child
|
||||
bestMatched = j
|
||||
bestFull = full
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
break
|
||||
}
|
||||
|
||||
pos += bestMatched
|
||||
path = append(path, best)
|
||||
|
||||
if !bestFull {
|
||||
// Partial match within this edge
|
||||
break
|
||||
}
|
||||
node = best
|
||||
}
|
||||
|
||||
return path, pos
|
||||
}
|
||||
|
||||
// appendTokens either creates a new child node or extends the leaf in place,
|
||||
// returning the node that now holds the tokens.
|
||||
func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode {
|
||||
if n == root || len(n.children) > 0 || n.hasSnapshots() {
|
||||
child := &trieNode{
|
||||
tokens: make([]int32, len(tokens)),
|
||||
endOffset: endOffset,
|
||||
parent: n,
|
||||
lastUsed: n.lastUsed,
|
||||
}
|
||||
copy(child.tokens, tokens)
|
||||
n.children = append(n.children, child)
|
||||
return child
|
||||
}
|
||||
n.tokens = append(n.tokens, tokens...)
|
||||
n.endOffset = endOffset
|
||||
return n
|
||||
}
|
||||
|
||||
// removeNode removes a leaf node from the trie.
|
||||
func removeNode(node *trieNode, counter *int64) {
|
||||
if node.parent == nil {
|
||||
panic("removeNode called on root")
|
||||
}
|
||||
if len(node.children) != 0 {
|
||||
panic("removeNode called on non-leaf node")
|
||||
}
|
||||
p := node.parent
|
||||
for i, child := range p.children {
|
||||
if child == node {
|
||||
p.children = append(p.children[:i], p.children[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
node.parent = nil
|
||||
node.setSnapshots(nil, counter)
|
||||
}
|
||||
|
||||
// splitNode splits a node at the given token offset within its edge,
|
||||
// creating a new parent node. Returns the new parent.
|
||||
// `at` is relative to the node's edge (0-based index into node.tokens).
|
||||
// If caches are provided, snapshots are split between parent and child
|
||||
// using Cache.Split; otherwise snapshots are invalidated.
|
||||
func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode {
|
||||
if at <= 0 || at >= len(node.tokens) {
|
||||
panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens)))
|
||||
}
|
||||
|
||||
// Create new parent with the prefix of the edge.
|
||||
newParent := &trieNode{
|
||||
tokens: make([]int32, at),
|
||||
endOffset: node.startOffset() + at,
|
||||
parent: node.parent,
|
||||
children: []*trieNode{node},
|
||||
lastUsed: node.lastUsed,
|
||||
}
|
||||
copy(newParent.tokens, node.tokens[:at])
|
||||
|
||||
// Update the original node to have only the suffix.
|
||||
node.tokens = node.tokens[at:]
|
||||
// endOffset stays the same for the original node.
|
||||
|
||||
// Split snapshots between parent and child using Cache.Split.
|
||||
// Split consumes the old snapshots, so we remove them first (adjusting
|
||||
// the counter), then assign the split halves (adjusting it back).
|
||||
if node.hasSnapshots() {
|
||||
oldSnaps := node.swapSnapshots(nil, counter)
|
||||
parentSnaps := make([]cache.Snapshot, len(oldSnaps))
|
||||
childSnaps := make([]cache.Snapshot, len(oldSnaps))
|
||||
for i, snap := range oldSnaps {
|
||||
if snap != nil {
|
||||
parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset)
|
||||
}
|
||||
}
|
||||
newParent.setSnapshots(parentSnaps, counter)
|
||||
node.setSnapshots(childSnaps, counter)
|
||||
}
|
||||
|
||||
// Reparent: replace node with newParent in the old parent's children.
|
||||
if node.parent != nil {
|
||||
for i, child := range node.parent.children {
|
||||
if child == node {
|
||||
node.parent.children[i] = newParent
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
node.parent = newParent
|
||||
|
||||
return newParent
|
||||
}
|
||||
|
||||
// mergeWithChild merges a node with its single child: concatenates tokens,
|
||||
// merges snapshot data via Cache.Merge, and removes the child.
|
||||
func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) {
|
||||
if len(node.children) != 1 {
|
||||
panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children)))
|
||||
}
|
||||
|
||||
child := node.children[0]
|
||||
|
||||
// Concatenate tokens.
|
||||
node.tokens = append(node.tokens, child.tokens...)
|
||||
node.endOffset = child.endOffset
|
||||
|
||||
// Merge snapshots per layer. Merge consumes the old snapshots, so we
|
||||
// remove them first (adjusting the counter), then assign the merged
|
||||
// result (adjusting it back).
|
||||
if len(node.snapshots) > 0 || len(child.snapshots) > 0 {
|
||||
nodeSnaps := node.swapSnapshots(nil, counter)
|
||||
childSnaps := child.swapSnapshots(nil, counter)
|
||||
merged := make([]cache.Snapshot, len(caches))
|
||||
for i := range caches {
|
||||
var ps, cs cache.Snapshot
|
||||
if nodeSnaps != nil {
|
||||
ps = nodeSnaps[i]
|
||||
}
|
||||
if childSnaps != nil {
|
||||
cs = childSnaps[i]
|
||||
}
|
||||
|
||||
merged[i] = caches[i].Merge(ps, cs)
|
||||
}
|
||||
node.setSnapshots(merged, counter)
|
||||
}
|
||||
|
||||
// Adopt grandchildren.
|
||||
node.children = child.children
|
||||
for _, gc := range node.children {
|
||||
gc.parent = node
|
||||
}
|
||||
|
||||
// Inherit user flag from child if child was a user-created snapshot node.
|
||||
node.user = child.user
|
||||
|
||||
// Update lastUsed to the more recent of the two.
|
||||
if child.lastUsed.After(node.lastUsed) {
|
||||
node.lastUsed = child.lastUsed
|
||||
}
|
||||
|
||||
child.parent = nil
|
||||
child.children = nil
|
||||
}
|
||||
|
||||
// walkNodes calls fn for every node in the trie (depth-first).
|
||||
// If fn returns false, the walk stops.
|
||||
func walkNodes(root *trieNode, fn func(*trieNode) bool) {
|
||||
if root == nil {
|
||||
return
|
||||
}
|
||||
var walk func(*trieNode) bool
|
||||
walk = func(n *trieNode) bool {
|
||||
if !fn(n) {
|
||||
return false
|
||||
}
|
||||
for _, child := range n.children {
|
||||
if !walk(child) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
walk(root)
|
||||
}
|
||||
455
x/mlxrunner/cache_trie_test.go
Normal file
455
x/mlxrunner/cache_trie_test.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
func newTestTrie(tokens []int32) *trieNode {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
if len(tokens) > 0 {
|
||||
child := &trieNode{
|
||||
tokens: slices.Clone(tokens),
|
||||
endOffset: len(tokens),
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{child}
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func TestFindBestMatchMultipleBranches(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
branch1 := &trieNode{
|
||||
tokens: []int32{1, 2, 3},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
branch2 := &trieNode{
|
||||
tokens: []int32{4, 5, 6},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{branch1, branch2}
|
||||
|
||||
// Match branch 1.
|
||||
path, matched := findBestMatch(root, []int32{1, 2, 3, 7})
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
if len(path) != 2 || path[1] != branch1 {
|
||||
t.Fatal("expected to match branch1")
|
||||
}
|
||||
|
||||
// Match branch 2.
|
||||
path, matched = findBestMatch(root, []int32{4, 5, 6, 8})
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
if len(path) != 2 || path[1] != branch2 {
|
||||
t.Fatal("expected to match branch2")
|
||||
}
|
||||
|
||||
// Match neither.
|
||||
_, matched = findBestMatch(root, []int32{7, 8, 9})
|
||||
if matched != 0 {
|
||||
t.Fatalf("expected 0 matched, got %d", matched)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindBestMatchPrefersFullEdge(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
shared := &trieNode{
|
||||
tokens: []int32{1, 2, 3},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{shared}
|
||||
|
||||
longer := &trieNode{
|
||||
tokens: []int32{10, 11, 12, 13, 14},
|
||||
endOffset: 8,
|
||||
parent: shared,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
shorter := &trieNode{
|
||||
tokens: []int32{10, 11, 12},
|
||||
endOffset: 6,
|
||||
parent: shared,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
// Put longer first so naive first-match would pick it.
|
||||
shared.children = []*trieNode{longer, shorter}
|
||||
|
||||
input := []int32{1, 2, 3, 10, 11, 12, 99, 100}
|
||||
path, matched := findBestMatch(root, input)
|
||||
|
||||
if matched != 6 {
|
||||
t.Fatalf("expected 6 matched, got %d", matched)
|
||||
}
|
||||
if len(path) != 3 {
|
||||
t.Fatalf("expected 3 nodes in path, got %d", len(path))
|
||||
}
|
||||
if path[2] != shorter {
|
||||
t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindBestMatchPrefersLongerPartial(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
child1 := &trieNode{
|
||||
tokens: []int32{1, 2, 3, 4, 5},
|
||||
endOffset: 5,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
child2 := &trieNode{
|
||||
tokens: []int32{1, 2, 9},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{child2, child1}
|
||||
|
||||
input := []int32{1, 2, 3, 7, 8}
|
||||
path, matched := findBestMatch(root, input)
|
||||
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
if path[1] != child1 {
|
||||
t.Fatal("expected findBestMatch to pick child1 (longer partial match)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitNodeWithSnapshots(t *testing.T) {
|
||||
root := newTestTrie([]int32{1, 2, 3, 4, 5})
|
||||
child := root.children[0]
|
||||
|
||||
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
|
||||
child.user = true
|
||||
|
||||
caches := []cache.Cache{rc}
|
||||
|
||||
newParent := splitNode(child, 3, caches, nil)
|
||||
|
||||
if !newParent.hasSnapshots() {
|
||||
t.Fatal("newParent should have snapshots after split")
|
||||
}
|
||||
if newParent.user {
|
||||
t.Fatal("newParent should not be a user snapshot after splitNode")
|
||||
}
|
||||
if !child.hasSnapshots() {
|
||||
t.Fatal("child should have snapshots after split")
|
||||
}
|
||||
if !child.user {
|
||||
t.Fatal("child should remain a user snapshot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSplitAppendSequence(t *testing.T) {
|
||||
root := newTestTrie([]int32{1, 2, 3, 4, 5})
|
||||
|
||||
path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
|
||||
lastNode := path[len(path)-1]
|
||||
matchedInEdge := matched - lastNode.startOffset()
|
||||
split := splitNode(lastNode, matchedInEdge, nil, nil)
|
||||
|
||||
split.appendTokens(root, []int32{6, 7}, 5)
|
||||
|
||||
if len(root.children) != 1 {
|
||||
t.Fatalf("root should have 1 child, got %d", len(root.children))
|
||||
}
|
||||
shared := root.children[0]
|
||||
if !slices.Equal(shared.tokens, []int32{1, 2, 3}) {
|
||||
t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens)
|
||||
}
|
||||
if len(shared.children) != 2 {
|
||||
t.Fatalf("shared should have 2 children, got %d", len(shared.children))
|
||||
}
|
||||
|
||||
_, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5})
|
||||
if m1 != 5 {
|
||||
t.Fatalf("original branch: expected 5 matched, got %d", m1)
|
||||
}
|
||||
_, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if m2 != 5 {
|
||||
t.Fatalf("new branch: expected 5 matched, got %d", m2)
|
||||
}
|
||||
_, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9})
|
||||
if m3 != 3 {
|
||||
t.Fatalf("unrelated input: expected 3 matched, got %d", m3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepeatedBranching(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5)
|
||||
|
||||
_, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if matchedB != 3 {
|
||||
t.Fatalf("B: expected 3 matched, got %d", matchedB)
|
||||
}
|
||||
nodeA := root.children[0]
|
||||
split1 := splitNode(nodeA, 3, nil, nil)
|
||||
split1.appendTokens(root, []int32{6, 7}, 5)
|
||||
|
||||
_, matchedC := findBestMatch(root, []int32{1, 2, 8, 9})
|
||||
if matchedC != 2 {
|
||||
t.Fatalf("C: expected 2 matched, got %d", matchedC)
|
||||
}
|
||||
split2 := splitNode(split1, 2, nil, nil)
|
||||
split2.appendTokens(root, []int32{8, 9}, 4)
|
||||
|
||||
_, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5})
|
||||
if mA != 5 {
|
||||
t.Fatalf("A: expected 5 matched, got %d", mA)
|
||||
}
|
||||
_, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if mB != 5 {
|
||||
t.Fatalf("B: expected 5 matched, got %d", mB)
|
||||
}
|
||||
_, mC := findBestMatch(root, []int32{1, 2, 8, 9})
|
||||
if mC != 4 {
|
||||
t.Fatalf("C: expected 4 matched, got %d", mC)
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
}
|
||||
|
||||
func TestMergeWithChild(t *testing.T) {
|
||||
t.Run("Basic", func(t *testing.T) {
|
||||
// root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]}
|
||||
now := time.Now()
|
||||
root := &trieNode{lastUsed: now}
|
||||
a := &trieNode{
|
||||
tokens: []int32{1, 2, 3},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: now,
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}},
|
||||
}
|
||||
b := &trieNode{
|
||||
tokens: []int32{4, 5},
|
||||
endOffset: 5,
|
||||
parent: a,
|
||||
lastUsed: now,
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}},
|
||||
}
|
||||
c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now}
|
||||
d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now}
|
||||
root.children = []*trieNode{a}
|
||||
a.children = []*trieNode{b}
|
||||
b.children = []*trieNode{c, d}
|
||||
|
||||
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
mergeWithChild(a, []cache.Cache{mc}, nil)
|
||||
|
||||
// Tokens concatenated.
|
||||
if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) {
|
||||
t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens)
|
||||
}
|
||||
if a.endOffset != 5 {
|
||||
t.Fatalf("merged endOffset = %d, want 5", a.endOffset)
|
||||
}
|
||||
// Grandchildren reparented.
|
||||
if len(a.children) != 2 {
|
||||
t.Fatalf("merged children count = %d, want 2", len(a.children))
|
||||
}
|
||||
if c.parent != a || d.parent != a {
|
||||
t.Fatal("grandchildren should be reparented to merged node")
|
||||
}
|
||||
// B detached.
|
||||
if b.parent != nil || b.children != nil || b.snapshots != nil {
|
||||
t.Fatal("child B should be fully detached after merge")
|
||||
}
|
||||
// Merged snapshot should cover [0,5).
|
||||
if !a.hasSnapshots() {
|
||||
t.Fatal("merged node should have snapshots")
|
||||
}
|
||||
ms := a.snapshots[0].(*fakeSnapshot)
|
||||
if ms.from != 0 || ms.to != 5 {
|
||||
t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
})
|
||||
|
||||
t.Run("UserFlag", func(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
parent := &trieNode{
|
||||
tokens: []int32{1, 2}, endOffset: 2, parent: root,
|
||||
lastUsed: time.Now(), user: false,
|
||||
}
|
||||
child := &trieNode{
|
||||
tokens: []int32{3, 4}, endOffset: 4, parent: parent,
|
||||
lastUsed: time.Now(), user: true,
|
||||
}
|
||||
root.children = []*trieNode{parent}
|
||||
parent.children = []*trieNode{child}
|
||||
|
||||
mergeWithChild(parent, nil, nil)
|
||||
|
||||
if !parent.user {
|
||||
t.Fatal("merged node should inherit user=true from child")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LastUsed", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
root := &trieNode{lastUsed: now}
|
||||
parent := &trieNode{
|
||||
tokens: []int32{1}, endOffset: 1, parent: root,
|
||||
lastUsed: now.Add(-1 * time.Hour),
|
||||
}
|
||||
child := &trieNode{
|
||||
tokens: []int32{2}, endOffset: 2, parent: parent,
|
||||
lastUsed: now.Add(1 * time.Hour),
|
||||
}
|
||||
root.children = []*trieNode{parent}
|
||||
parent.children = []*trieNode{child}
|
||||
|
||||
mergeWithChild(parent, nil, nil)
|
||||
|
||||
if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) {
|
||||
t.Fatal("merged node should pick the more recent lastUsed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PanicOnMultipleChildren", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on node with 2 children")
|
||||
}
|
||||
}()
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
node := &trieNode{
|
||||
tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(),
|
||||
children: []*trieNode{
|
||||
{tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()},
|
||||
{tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()},
|
||||
},
|
||||
}
|
||||
root.children = []*trieNode{node}
|
||||
mergeWithChild(node, nil, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSplitMergeRoundTrip(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
leaf := &trieNode{
|
||||
tokens: []int32{1, 2, 3, 4, 5},
|
||||
endOffset: 5,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}},
|
||||
}
|
||||
root.children = []*trieNode{leaf}
|
||||
|
||||
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
caches := []cache.Cache{mc}
|
||||
|
||||
// Split at 3: [1,2,3] -> [4,5]
|
||||
newParent := splitNode(leaf, 3, caches, nil)
|
||||
if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) {
|
||||
t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens)
|
||||
}
|
||||
if !slices.Equal(leaf.tokens, []int32{4, 5}) {
|
||||
t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens)
|
||||
}
|
||||
checkTrieInvariants(t, root)
|
||||
|
||||
// Merge back: should restore [1,2,3,4,5]
|
||||
mergeWithChild(newParent, caches, nil)
|
||||
if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) {
|
||||
t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens)
|
||||
}
|
||||
if newParent.endOffset != 5 {
|
||||
t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset)
|
||||
}
|
||||
if len(newParent.children) != 0 {
|
||||
t.Fatalf("after merge: children count = %d, want 0", len(newParent.children))
|
||||
}
|
||||
// Merged snapshot should cover [0,5).
|
||||
if !newParent.hasSnapshots() {
|
||||
t.Fatal("after merge: should have snapshots")
|
||||
}
|
||||
ms := newParent.snapshots[0].(*fakeSnapshot)
|
||||
if ms.from != 0 || ms.to != 5 {
|
||||
t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
}
|
||||
|
||||
func TestRemoveNode(t *testing.T) {
|
||||
t.Run("Leaf", func(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
shared := &trieNode{
|
||||
tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(),
|
||||
}
|
||||
leafA := &trieNode{
|
||||
tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(),
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
|
||||
}
|
||||
leafB := &trieNode{
|
||||
tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(),
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
|
||||
}
|
||||
root.children = []*trieNode{shared}
|
||||
shared.children = []*trieNode{leafA, leafB}
|
||||
|
||||
removeNode(leafA, nil)
|
||||
|
||||
if len(shared.children) != 1 {
|
||||
t.Fatalf("parent should have 1 child, got %d", len(shared.children))
|
||||
}
|
||||
if shared.children[0] != leafB {
|
||||
t.Fatal("remaining child should be leafB")
|
||||
}
|
||||
if leafA.parent != nil {
|
||||
t.Fatal("removed node parent should be nil")
|
||||
}
|
||||
if leafA.snapshots != nil {
|
||||
t.Fatal("removed node snapshots should be nil")
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
})
|
||||
|
||||
t.Run("PanicOnRoot", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic when removing root")
|
||||
}
|
||||
}()
|
||||
removeNode(&trieNode{}, nil)
|
||||
})
|
||||
|
||||
t.Run("PanicOnNonLeaf", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic when removing non-leaf")
|
||||
}
|
||||
}()
|
||||
parent := &trieNode{parent: &trieNode{}}
|
||||
parent.children = []*trieNode{{}}
|
||||
removeNode(parent, nil)
|
||||
})
|
||||
}
|
||||
@@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Images []llm.ImageData `json:"images,omitempty"`
|
||||
Options *completionOpts `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
@@ -155,6 +156,7 @@ func (c *Client) Close() error {
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Images: req.Images,
|
||||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = &completionOpts{
|
||||
|
||||
@@ -18,6 +18,10 @@ func Version() string {
|
||||
}
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
if len(outputs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
|
||||
@@ -304,6 +304,18 @@ func Exp(a *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
|
||||
@@ -4,10 +4,14 @@ package mlx
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// End is a sentinel value meaning "to the end of the dimension",
|
||||
// equivalent to an omitted stop in Python (e.g. a[i:]).
|
||||
const End = math.MaxInt32
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
@@ -16,6 +20,16 @@ func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func resolve(val, dim int) C.int {
|
||||
if val == End {
|
||||
return C.int(dim)
|
||||
}
|
||||
if val < 0 {
|
||||
return C.int(dim + val)
|
||||
}
|
||||
return C.int(val)
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
@@ -28,26 +42,28 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
dim := dims[i]
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dims[i])
|
||||
args[1][i] = C.int(dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = C.int(s.args[0] + 1)
|
||||
start := resolve(s.args[0], dim)
|
||||
args[0][i] = start
|
||||
args[1][i] = start + 1
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = C.int(s.args[0])
|
||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
|
||||
32
x/mlxrunner/model/base/multimodal.go
Normal file
32
x/mlxrunner/model/base/multimodal.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// ImageInput is a single image attached to a prompt.
|
||||
type ImageInput struct {
|
||||
ID int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// PromptTokenization contains tokenized prompt IDs plus optional request-scoped
|
||||
// model metadata needed during forward.
|
||||
type PromptTokenization struct {
|
||||
Tokens []int32
|
||||
State any
|
||||
}
|
||||
|
||||
// MultimodalPromptTokenizerWithState is an optional model interface used by
|
||||
// mlxrunner to expand tagged multimodal prompts into token IDs, returning
|
||||
// request-scoped state to be attached to the forward pass.
|
||||
type MultimodalPromptTokenizerWithState interface {
|
||||
TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error)
|
||||
}
|
||||
|
||||
// ForwardWithStateModel is an optional model interface for request-scoped
|
||||
// forward metadata that should not be stored in shared caches.
|
||||
type ForwardWithStateModel interface {
|
||||
ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array
|
||||
}
|
||||
@@ -12,12 +12,42 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
func prefillChunkSize() int {
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
|
||||
if len(request.Images) > 0 {
|
||||
// The shared trie cache keys only on token IDs today, so multimodal
|
||||
// prompts must not reuse snapshots across distinct image inputs.
|
||||
r.cache.clear()
|
||||
}
|
||||
|
||||
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
|
||||
images := make([]base.ImageInput, len(request.Images))
|
||||
for i := range request.Images {
|
||||
images[i] = base.ImageInput{
|
||||
ID: request.Images[i].ID,
|
||||
Data: request.Images[i].Data,
|
||||
}
|
||||
}
|
||||
|
||||
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if out == nil {
|
||||
return nil, nil, errors.New("empty multimodal tokenization result")
|
||||
}
|
||||
return out.Tokens, out.State, nil
|
||||
}
|
||||
|
||||
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
|
||||
}
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
@@ -50,12 +80,15 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
r.cache.log()
|
||||
r.cache.dumpTree()
|
||||
}
|
||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||
}()
|
||||
|
||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||
inputs, promptState, err := r.tokenizeRequest(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
@@ -83,10 +116,17 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
tokens := session.remaining
|
||||
prefillChunk := prefillChunkSize()
|
||||
|
||||
modelForward := func(tokens *mlx.Array) *mlx.Array {
|
||||
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
|
||||
return withState.ForwardWithState(tokens, caches, promptState)
|
||||
}
|
||||
return r.Model.Forward(tokens, caches)
|
||||
}
|
||||
|
||||
materializeCaches := func() {
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, c := range caches {
|
||||
state = appendCacheState(state, c)
|
||||
state = append(state, c.State()...)
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
@@ -102,16 +142,37 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
||||
|
||||
// If there's a pending intermediate snapshot, split the batch
|
||||
// so we can capture it at the exact offset. The cache offset
|
||||
// after this batch will be: baseOffset + processed + n.
|
||||
if session.snapshotOffset > 0 {
|
||||
baseOffset := len(session.inputs) - len(tokens)
|
||||
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
|
||||
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
||||
n = tokensUntilSnapshot
|
||||
}
|
||||
}
|
||||
|
||||
modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
|
||||
mlx.Sweep()
|
||||
materializeCaches()
|
||||
processed += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
|
||||
// Create snapshot at branch point for future diverging requests.
|
||||
if session.snapshotOffset > 0 {
|
||||
baseOffset := len(session.inputs) - len(tokens)
|
||||
if baseOffset+processed >= session.snapshotOffset {
|
||||
session.snapshot(false)
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
fwd := modelForward(token.ExpandDims(0))
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
@@ -29,7 +30,8 @@ type Request struct {
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Prompt string `json:"prompt"`
|
||||
Images []llm.ImageData `json:"images,omitempty"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
|
||||
@@ -169,7 +169,7 @@ func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||
return logprobs
|
||||
}
|
||||
|
||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0))
|
||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
|
||||
354
x/models/qwen3_5/multimodal.go
Normal file
354
x/models/qwen3_5/multimodal.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
|
||||
type promptVisionSpan struct {
|
||||
Start int32
|
||||
End int32
|
||||
|
||||
Main *mlx.Array
|
||||
Grid *VisionGrid
|
||||
}
|
||||
|
||||
type promptVisionState struct {
|
||||
Spans []promptVisionSpan
|
||||
PositionCache []int32
|
||||
}
|
||||
|
||||
func promptStartPosFromCaches(caches []cache.Cache) int32 {
|
||||
offset := -1
|
||||
for _, c := range caches {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
off := c.Offset()
|
||||
if offset < 0 || off < offset {
|
||||
offset = off
|
||||
}
|
||||
}
|
||||
if offset < 0 {
|
||||
return 0
|
||||
}
|
||||
return int32(offset)
|
||||
}
|
||||
|
||||
func promptVisionStateFromState(state any) *promptVisionState {
|
||||
typed, _ := state.(*promptVisionState)
|
||||
return typed
|
||||
}
|
||||
|
||||
func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) {
|
||||
chunkEnd := chunkStart + chunkLen
|
||||
overlapStart := max(chunkStart, spanStart)
|
||||
overlapEnd := min(chunkEnd, spanEnd)
|
||||
if overlapStart >= overlapEnd {
|
||||
return 0, 0, 0, 0, false
|
||||
}
|
||||
|
||||
chunkLo := overlapStart - chunkStart
|
||||
chunkHi := overlapEnd - chunkStart
|
||||
spanLo := overlapStart - spanStart
|
||||
spanHi := overlapEnd - spanStart
|
||||
return chunkLo, chunkHi, spanLo, spanHi, true
|
||||
}
|
||||
|
||||
func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array {
|
||||
if m == nil || h == nil || state == nil || len(state.Spans) == 0 {
|
||||
return h
|
||||
}
|
||||
|
||||
dims := h.Dims()
|
||||
if len(dims) != 3 {
|
||||
return h
|
||||
}
|
||||
|
||||
L := int32(dims[1])
|
||||
for _, span := range state.Spans {
|
||||
chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End)
|
||||
if !ok || span.Main == nil || !span.Main.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
repl := span.Main.Slice(
|
||||
mlx.Slice(),
|
||||
mlx.Slice(int(spanLo), int(spanHi)),
|
||||
mlx.Slice(),
|
||||
)
|
||||
repl = repl.AsType(h.DType())
|
||||
h = h.SliceUpdate(
|
||||
repl,
|
||||
mlx.Slice(),
|
||||
mlx.Slice(int(chunkLo), int(chunkHi)),
|
||||
mlx.Slice(),
|
||||
)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) {
|
||||
for i := range images {
|
||||
if images[i].ID == id {
|
||||
return images[i], true
|
||||
}
|
||||
}
|
||||
return base.ImageInput{}, false
|
||||
}
|
||||
|
||||
func mapPromptPosition(state *promptVisionState, id int32) int32 {
|
||||
if state == nil {
|
||||
return id
|
||||
}
|
||||
if id < int32(len(state.PositionCache)) {
|
||||
return state.PositionCache[id]
|
||||
}
|
||||
if len(state.PositionCache) > 0 {
|
||||
return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 {
|
||||
if fallback <= 0 {
|
||||
fallback = 1
|
||||
}
|
||||
if grid == nil {
|
||||
return fallback
|
||||
}
|
||||
if merge <= 0 {
|
||||
merge = 1
|
||||
}
|
||||
return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge))
|
||||
}
|
||||
|
||||
func normalizeMRoPESections(sections []int32) [4]int32 {
|
||||
var out [4]int32
|
||||
for i := range min(4, len(sections)) {
|
||||
if sections[i] > 0 {
|
||||
out[i] = sections[i]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int {
|
||||
if interleaved {
|
||||
if pair%3 == 1 && pair < 1+3*sections[1] {
|
||||
return 1
|
||||
}
|
||||
if pair%3 == 2 && pair < 2+3*sections[2] {
|
||||
return 2
|
||||
}
|
||||
if pair%3 == 0 && pair < 3*sections[0] {
|
||||
return 0
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
secW := sections[0] + sections[1]
|
||||
secE := secW + sections[2]
|
||||
switch {
|
||||
case pair < sections[0]:
|
||||
return 0
|
||||
case pair < secW:
|
||||
return 1
|
||||
case pair < secE:
|
||||
return 2
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 {
|
||||
var positions [4][]int32
|
||||
for i := range positions {
|
||||
positions[i] = make([]int32, chunkLen)
|
||||
}
|
||||
|
||||
// positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections.
|
||||
for i := range chunkLen {
|
||||
p := mapPromptPosition(state, startPos+i)
|
||||
positions[0][i] = p
|
||||
positions[1][i] = p
|
||||
positions[2][i] = p
|
||||
}
|
||||
|
||||
merge := int32(1)
|
||||
if m != nil && m.Config != nil && m.Config.Vision != nil {
|
||||
merge = m.Config.Vision.SpatialMergeSize
|
||||
}
|
||||
for _, span := range state.Spans {
|
||||
if span.Grid == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
w := max(int32(1), span.Grid.Width/merge)
|
||||
for i := chunkLo; i < chunkHi; i++ {
|
||||
rel := spanLo + (i - chunkLo)
|
||||
positions[1][i] += rel / w
|
||||
positions[2][i] += rel % w
|
||||
}
|
||||
}
|
||||
|
||||
return positions
|
||||
}
|
||||
|
||||
func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) {
|
||||
if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ropeDim := m.Config.RopeDim
|
||||
if ropeDim%2 != 0 {
|
||||
ropeDim--
|
||||
}
|
||||
if ropeDim <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
half := ropeDim / 2
|
||||
positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen)
|
||||
sections := normalizeMRoPESections(m.Config.MRoPESections)
|
||||
theta := m.Config.RopeTheta
|
||||
if theta <= 0 {
|
||||
theta = 100000.0
|
||||
}
|
||||
|
||||
freqs := make([]float64, half)
|
||||
for j := range half {
|
||||
freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim))
|
||||
}
|
||||
|
||||
angles := make([]float32, chunkLen*ropeDim)
|
||||
for i := range chunkLen {
|
||||
base := i * ropeDim
|
||||
for j := range half {
|
||||
component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved)
|
||||
angle := float32(float64(positions[component][i]) * freqs[j])
|
||||
angles[base+j] = angle
|
||||
angles[base+half+j] = angle
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim))
|
||||
cos := mlx.Cos(arr)
|
||||
sin := mlx.Sin(arr)
|
||||
if dtype != 0 {
|
||||
cos = cos.AsType(dtype)
|
||||
sin = sin.AsType(dtype)
|
||||
}
|
||||
return cos, sin
|
||||
}
|
||||
|
||||
func (m *Model) tokenizePromptWithResolvedImages(
|
||||
prompt string,
|
||||
images []base.ImageInput,
|
||||
resolve func([]byte) (*VisionEmbeddings, error),
|
||||
) ([]int32, *promptVisionState, error) {
|
||||
if m == nil || m.tok == nil {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized")
|
||||
}
|
||||
|
||||
if m.Vision == nil || m.ImageProcessor == nil || resolve == nil {
|
||||
return m.tok.Encode(prompt, true), nil, nil
|
||||
}
|
||||
|
||||
parts := imageTagRE.Split(prompt, -1)
|
||||
matches := imageTagRE.FindAllStringSubmatch(prompt, -1)
|
||||
|
||||
resolved := make(map[int]*VisionEmbeddings, len(images))
|
||||
var out []int32
|
||||
state := &promptVisionState{}
|
||||
var p int32
|
||||
appendToken := func(tok, pos int32) {
|
||||
out = append(out, tok)
|
||||
state.PositionCache = append(state.PositionCache, pos)
|
||||
}
|
||||
for i, part := range parts {
|
||||
for _, tok := range m.tok.Encode(part, i == 0) {
|
||||
appendToken(tok, p)
|
||||
p++
|
||||
}
|
||||
|
||||
if i >= len(matches) {
|
||||
continue
|
||||
}
|
||||
|
||||
imageID, err := strconv.Atoi(matches[i][1])
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err)
|
||||
}
|
||||
|
||||
img, ok := findImageByID(images, imageID)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", imageID)
|
||||
}
|
||||
|
||||
embeds := resolved[imageID]
|
||||
if embeds == nil {
|
||||
embeds, err = resolve(img.Data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
resolved[imageID] = embeds
|
||||
}
|
||||
if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings")
|
||||
}
|
||||
|
||||
tokensPerImage := int32(embeds.Main.Dim(1))
|
||||
if tokensPerImage <= 0 {
|
||||
return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage)
|
||||
}
|
||||
|
||||
appendToken(m.VisionStartToken, p)
|
||||
p++
|
||||
basePos := p
|
||||
spanStart := int32(len(out))
|
||||
for range tokensPerImage {
|
||||
appendToken(m.ImageTokenID, basePos)
|
||||
}
|
||||
spanEnd := int32(len(out))
|
||||
merge := int32(1)
|
||||
if m.Config != nil && m.Config.Vision != nil {
|
||||
merge = m.Config.Vision.SpatialMergeSize
|
||||
}
|
||||
gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage)
|
||||
p += gridSpan
|
||||
appendToken(m.VisionEndToken, p)
|
||||
p++
|
||||
|
||||
state.Spans = append(state.Spans, promptVisionSpan{
|
||||
Start: spanStart,
|
||||
End: spanEnd,
|
||||
Main: embeds.Main,
|
||||
Grid: embeds.Grid,
|
||||
})
|
||||
}
|
||||
|
||||
return out, state, nil
|
||||
}
|
||||
|
||||
func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) {
|
||||
tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &base.PromptTokenization{Tokens: tokens, State: state}, nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
@@ -22,16 +23,26 @@ func init() {
|
||||
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
||||
}
|
||||
|
||||
var (
|
||||
_ base.MultimodalPromptTokenizerWithState = (*Model)(nil)
|
||||
_ base.ForwardWithStateModel = (*Model)(nil)
|
||||
)
|
||||
|
||||
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
||||
type RopeParameters struct {
|
||||
Type string `json:"type"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||
MRoPESection []int32 `json:"mrope_section"`
|
||||
DimensionSections []int32 `json:"dimension_sections"`
|
||||
}
|
||||
|
||||
// Config holds Qwen 3.5 text config (top-level or nested text_config).
|
||||
type Config struct {
|
||||
// TextConfig holds the Qwen 3.5 text-model architecture fields.
|
||||
// In VLM configs these live under the "text_config" key; in text-only
|
||||
// configs they appear at the top level.
|
||||
type TextConfig struct {
|
||||
ModelType string `json:"model_type"`
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
@@ -67,6 +78,19 @@ type Config struct {
|
||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||
RopeScaling map[string]any `json:"rope_scaling"`
|
||||
RopeParameters *RopeParameters `json:"rope_parameters"`
|
||||
MRoPESections []int32 `json:"mrope_sections"`
|
||||
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||
}
|
||||
|
||||
// Config is the full model config. It embeds TextConfig for the text-model
|
||||
// fields and adds top-level-only fields (vision, token IDs, quantization).
|
||||
type Config struct {
|
||||
TextConfig
|
||||
|
||||
Vision *VisionConfig `json:"vision_config"`
|
||||
ImageTokenID int32 `json:"image_token_id"`
|
||||
VisionStartToken int32 `json:"vision_start_token_id"`
|
||||
VisionEndToken int32 `json:"vision_end_token_id"`
|
||||
|
||||
// Quantization metadata.
|
||||
QuantGroupSize int `json:"-"`
|
||||
@@ -90,6 +114,9 @@ type Model struct {
|
||||
*Config
|
||||
|
||||
weightPrefix string
|
||||
|
||||
Vision *VisionModel
|
||||
ImageProcessor *VisionImageProcessor
|
||||
}
|
||||
|
||||
// Layer is a transformer decoder layer.
|
||||
@@ -190,17 +217,24 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
|
||||
var cfg Config
|
||||
activeRaw := rawTop
|
||||
|
||||
// First pass: unmarshal the full config to pick up top-level fields
|
||||
// (vision_config, image_token_id, etc.) and text fields for text-only models.
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
// Second pass: if text_config exists, unmarshal it into TextConfig so
|
||||
// text-model fields from text_config take priority over any top-level
|
||||
// duplicates. Top-level-only fields (Vision, token IDs) are unaffected
|
||||
// because they live on Config, not TextConfig.
|
||||
if textRaw, ok := rawTop["text_config"]; ok {
|
||||
if err := json.Unmarshal(textRaw, &cfg); err != nil {
|
||||
if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil {
|
||||
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
||||
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.HiddenSize <= 0 {
|
||||
@@ -225,12 +259,8 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||
}
|
||||
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.LinearConvKernelDim <= 0 {
|
||||
cfg.LinearConvKernelDim = 4
|
||||
}
|
||||
cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6)
|
||||
cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4)
|
||||
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
||||
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
||||
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
||||
@@ -246,14 +276,21 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
||||
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
||||
}
|
||||
if len(cfg.MRoPESections) == 0 {
|
||||
switch {
|
||||
case len(cfg.RopeParameters.MRoPESection) > 0:
|
||||
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...)
|
||||
case len(cfg.RopeParameters.DimensionSections) > 0:
|
||||
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...)
|
||||
}
|
||||
}
|
||||
cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved)
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 100000.0
|
||||
if len(cfg.MRoPESections) > 4 {
|
||||
cfg.MRoPESections = cfg.MRoPESections[:4]
|
||||
}
|
||||
if cfg.PartialRotaryFactor == 0 {
|
||||
cfg.PartialRotaryFactor = 0.25
|
||||
}
|
||||
if cfg.PartialRotaryFactor < 0 {
|
||||
cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0)
|
||||
if cfg.PartialRotaryFactor <= 0 {
|
||||
cfg.PartialRotaryFactor = 0.25
|
||||
}
|
||||
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
||||
@@ -281,24 +318,23 @@ func parseConfig(configData []byte) (Config, error) {
|
||||
}
|
||||
|
||||
if cfg.NumExperts > 0 {
|
||||
if cfg.NumExpertsPerTok <= 0 {
|
||||
cfg.NumExpertsPerTok = 1
|
||||
}
|
||||
if cfg.MoeIntermediateSize <= 0 {
|
||||
cfg.MoeIntermediateSize = cfg.IntermediateSize
|
||||
}
|
||||
if cfg.SharedExpertIntermediateSize <= 0 {
|
||||
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
|
||||
}
|
||||
cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1))
|
||||
cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize)
|
||||
cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize)
|
||||
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
||||
cfg.NormTopKProb = true
|
||||
}
|
||||
if cfg.DecoderSparseStep <= 0 {
|
||||
cfg.DecoderSparseStep = 1
|
||||
}
|
||||
cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1))
|
||||
}
|
||||
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
if cfg.Vision != nil {
|
||||
cfg.Vision.applyDefaults()
|
||||
}
|
||||
cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655))
|
||||
cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652))
|
||||
cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653))
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -364,6 +400,11 @@ func NewModel(root *model.Root) (base.Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.Vision != nil {
|
||||
if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil {
|
||||
cfg.Vision.applyPreprocessorConfig(preprocessorData)
|
||||
}
|
||||
}
|
||||
|
||||
if qt := root.QuantType(); qt != "" {
|
||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||
@@ -1060,6 +1101,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||
m.Layers[i] = layer
|
||||
}
|
||||
|
||||
if cfg.Vision != nil && cfg.Vision.Depth > 0 {
|
||||
vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vision = vision
|
||||
m.ImageProcessor = processor
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1117,7 +1167,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
|
||||
return q, k, v, z, b, a
|
||||
}
|
||||
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func rotateHalf(x *mlx.Array) *mlx.Array {
|
||||
shape := x.Dims()
|
||||
last := int32(shape[len(shape)-1])
|
||||
half := last / 2
|
||||
if half <= 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half})
|
||||
x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||
return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1)
|
||||
}
|
||||
|
||||
func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array {
|
||||
if x == nil || cos == nil || sin == nil || ropeDim <= 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
shape := x.Dims()
|
||||
if len(shape) != 4 {
|
||||
return x
|
||||
}
|
||||
|
||||
last := int32(shape[len(shape)-1])
|
||||
if ropeDim > last {
|
||||
ropeDim = last
|
||||
}
|
||||
if ropeDim%2 != 0 {
|
||||
ropeDim--
|
||||
}
|
||||
if ropeDim <= 0 {
|
||||
return x
|
||||
}
|
||||
|
||||
rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim})
|
||||
rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin))
|
||||
if ropeDim == last {
|
||||
return rot
|
||||
}
|
||||
|
||||
tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||
return mlx.Concatenate([]*mlx.Array{rot, tail}, -1)
|
||||
}
|
||||
|
||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||
qg := a.QProj.Forward(x)
|
||||
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
||||
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
||||
@@ -1140,8 +1234,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
||||
if c != nil {
|
||||
offset = c.Offset()
|
||||
}
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
if ropeCos != nil && ropeSin != nil {
|
||||
q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim)
|
||||
k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim)
|
||||
} else {
|
||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
k, v = c.Update(k, v)
|
||||
@@ -1323,13 +1422,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||
var r *mlx.Array
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
if l.IsLinear {
|
||||
r = l.Linear.Forward(normed, c, B, L, cfg)
|
||||
} else {
|
||||
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
||||
r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin)
|
||||
}
|
||||
h := mlx.Add(x, r)
|
||||
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||
@@ -1337,16 +1436,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
|
||||
}
|
||||
|
||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
return m.ForwardWithState(tokens, caches, nil)
|
||||
}
|
||||
|
||||
func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array {
|
||||
dims := tokens.Dims()
|
||||
B, L := int32(dims[0]), int32(dims[1])
|
||||
|
||||
startPos := promptStartPosFromCaches(caches)
|
||||
promptState := promptVisionStateFromState(state)
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = m.applyPromptVisionEmbeddings(h, startPos, promptState)
|
||||
var ropeCos, ropeSin *mlx.Array
|
||||
if len(m.MRoPESections) > 0 {
|
||||
ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType())
|
||||
}
|
||||
for i, layer := range m.Layers {
|
||||
var c cache.Cache
|
||||
if caches != nil && i < len(caches) {
|
||||
c = caches[i]
|
||||
}
|
||||
h = layer.Forward(h, c, B, L, m.Config)
|
||||
h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin)
|
||||
}
|
||||
out := m.Norm.Forward(h, m.RMSNormEps)
|
||||
return out
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
@@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLayerSelectionHelpers(t *testing.T) {
|
||||
cfg := &Config{
|
||||
cfg := &Config{TextConfig: TextConfig{
|
||||
NumHiddenLayers: 6,
|
||||
FullAttentionInterval: 3,
|
||||
NumExperts: 8,
|
||||
DecoderSparseStep: 2,
|
||||
MLPOnlyLayers: []int32{1},
|
||||
}
|
||||
}}
|
||||
|
||||
if !layerIsLinear(cfg, 0) {
|
||||
t.Fatalf("layer 0 should be linear")
|
||||
@@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) {
|
||||
|
||||
func TestNewCachesLayout(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
Config: &Config{TextConfig: TextConfig{
|
||||
LinearConvKernelDim: 4,
|
||||
LinearNumKeyHeads: 2,
|
||||
LinearKeyHeadDim: 8,
|
||||
LinearNumValueHeads: 4,
|
||||
LinearValueHeadDim: 16,
|
||||
},
|
||||
}},
|
||||
Layers: []*Layer{
|
||||
{IsLinear: true},
|
||||
{IsLinear: false},
|
||||
@@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) {
|
||||
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cfg := &Config{
|
||||
cfg := &Config{TextConfig: TextConfig{
|
||||
HiddenSize: 4,
|
||||
IntermediateSize: 8,
|
||||
NumHiddenLayers: 2,
|
||||
@@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
LinearValueHeadDim: 2,
|
||||
LinearConvKernelDim: 4,
|
||||
FullAttentionInterval: 2,
|
||||
}
|
||||
}}
|
||||
|
||||
m := &Model{
|
||||
Config: cfg,
|
||||
@@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigVisionFields(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000000
|
||||
}
|
||||
},
|
||||
"vision_config": {
|
||||
"depth": 2,
|
||||
"hidden_size": 256,
|
||||
"num_heads": 8,
|
||||
"in_channels": 3,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"layer_norm_epsilon": 0.000001,
|
||||
"temporal_patch_size": 2,
|
||||
"num_position_embeddings": 2304
|
||||
},
|
||||
"image_token_id": 111,
|
||||
"vision_start_token_id": 112,
|
||||
"vision_end_token_id": 113
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Vision == nil {
|
||||
t.Fatal("vision config should be parsed")
|
||||
}
|
||||
if cfg.Vision.Depth != 2 {
|
||||
t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth)
|
||||
}
|
||||
if cfg.Vision.GridPerSide != 48 {
|
||||
t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide)
|
||||
}
|
||||
if cfg.Vision.RopeTheta != 10000 {
|
||||
t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta)
|
||||
}
|
||||
if cfg.RopeTheta != 10000000 {
|
||||
t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta)
|
||||
}
|
||||
if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 {
|
||||
t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigMRoPEFromRopeParameters(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"text_config": {
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 8192,
|
||||
"num_hidden_layers": 4,
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"linear_num_value_heads": 32,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4,
|
||||
"rope_parameters": {
|
||||
"rope_theta": 10000000,
|
||||
"partial_rotary_factor": 0.25,
|
||||
"mrope_interleaved": true,
|
||||
"mrope_section": [11, 11, 10]
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.MRoPEInterleaved {
|
||||
t.Fatal("mrope_interleaved should be parsed from rope_parameters")
|
||||
}
|
||||
if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) {
|
||||
t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections)
|
||||
}
|
||||
if cfg.RopeDim != 64 {
|
||||
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfigVisionTokenDefaults(t *testing.T) {
|
||||
data := []byte(`{
|
||||
"text_config": {
|
||||
"hidden_size": 4096,
|
||||
"intermediate_size": 14336,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"linear_num_value_heads": 64,
|
||||
"linear_num_key_heads": 16,
|
||||
"linear_key_head_dim": 128,
|
||||
"linear_value_head_dim": 128,
|
||||
"linear_conv_kernel_dim": 4
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.ImageTokenID != 151655 {
|
||||
t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID)
|
||||
}
|
||||
if cfg.VisionStartToken != 151652 {
|
||||
t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken)
|
||||
}
|
||||
if cfg.VisionEndToken != 151653 {
|
||||
t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveVisionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tensors map[string]*mlx.Array
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "legacy visual prefix",
|
||||
tensors: map[string]*mlx.Array{
|
||||
"model.visual.patch_embed.proj.weight": mlx.New("patch"),
|
||||
},
|
||||
want: "model.visual",
|
||||
},
|
||||
{
|
||||
name: "imported vision tower prefix",
|
||||
tensors: map[string]*mlx.Array{
|
||||
"vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"),
|
||||
},
|
||||
want: "vision_tower",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want {
|
||||
t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisionPreprocessorOverridesDefaults(t *testing.T) {
|
||||
v := &VisionConfig{}
|
||||
v.applyDefaults()
|
||||
|
||||
v.applyPreprocessorConfig([]byte(`{
|
||||
"patch_size": 16,
|
||||
"temporal_patch_size": 3,
|
||||
"merge_size": 4,
|
||||
"size": {
|
||||
"shortest_edge": 1024,
|
||||
"longest_edge": 8192
|
||||
},
|
||||
"image_mean": [0.1, 0.2, 0.3],
|
||||
"image_std": [0.9, 0.8, 0.7]
|
||||
}`))
|
||||
|
||||
if v.PatchSize != 16 {
|
||||
t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize)
|
||||
}
|
||||
if v.TemporalPatchSize != 3 {
|
||||
t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize)
|
||||
}
|
||||
if v.SpatialMergeSize != 4 {
|
||||
t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize)
|
||||
}
|
||||
if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 {
|
||||
t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge)
|
||||
}
|
||||
if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 {
|
||||
t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) {
|
||||
v := &VisionConfig{}
|
||||
v.applyDefaults()
|
||||
|
||||
v.applyPreprocessorConfig([]byte(`{
|
||||
"size": {
|
||||
"shortest_edge": 65536,
|
||||
"longest_edge": 16777216
|
||||
},
|
||||
"patch_size": 16,
|
||||
"temporal_patch_size": 2,
|
||||
"merge_size": 2,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
}`))
|
||||
|
||||
p := newVisionImageProcessor(v)
|
||||
if p == nil {
|
||||
t.Fatal("newVisionImageProcessor returned nil")
|
||||
}
|
||||
|
||||
if p.shortestEdge != 65536 || p.longestEdge != 16777216 {
|
||||
t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge)
|
||||
}
|
||||
}
|
||||
|
||||
func testTokenizer(t *testing.T) *tokenizer.Tokenizer {
|
||||
t.Helper()
|
||||
|
||||
tok, err := tokenizer.LoadFromBytes([]byte(`{
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {"a": 0},
|
||||
"merges": []
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load test tokenizer: %v", err)
|
||||
}
|
||||
|
||||
return tok
|
||||
}
|
||||
|
||||
func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
m := &Model{
|
||||
tok: testTokenizer(t),
|
||||
Config: &Config{
|
||||
ImageTokenID: 101,
|
||||
VisionStartToken: 102,
|
||||
VisionEndToken: 103,
|
||||
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||
},
|
||||
Vision: &VisionModel{},
|
||||
ImageProcessor: &VisionImageProcessor{},
|
||||
}
|
||||
|
||||
main := mlx.FromValues([]float32{
|
||||
10, 11,
|
||||
20, 21,
|
||||
}, 1, 2, 2)
|
||||
|
||||
resolveCalls := 0
|
||||
got, state, err := m.tokenizePromptWithResolvedImages(
|
||||
"a[img-7][img-7]a",
|
||||
[]base.ImageInput{{ID: 7, Data: []byte("img7")}},
|
||||
func(data []byte) (*VisionEmbeddings, error) {
|
||||
if string(data) != "img7" {
|
||||
return nil, fmt.Errorf("unexpected data: %q", string(data))
|
||||
}
|
||||
resolveCalls++
|
||||
return &VisionEmbeddings{
|
||||
Main: main,
|
||||
Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err)
|
||||
}
|
||||
if resolveCalls != 1 {
|
||||
t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls)
|
||||
}
|
||||
|
||||
want := []int32{
|
||||
0,
|
||||
102, 101, 101, 103,
|
||||
102, 101, 101, 103,
|
||||
0,
|
||||
}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("expanded tokens mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
if state == nil {
|
||||
t.Fatal("expected prompt vision state")
|
||||
}
|
||||
if len(state.Spans) != 2 {
|
||||
t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans))
|
||||
}
|
||||
if state.Spans[0].Start != 2 || state.Spans[0].End != 4 {
|
||||
t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End)
|
||||
}
|
||||
if state.Spans[1].Start != 6 || state.Spans[1].End != 8 {
|
||||
t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End)
|
||||
}
|
||||
wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7}
|
||||
if !slices.Equal(state.PositionCache, wantPos) {
|
||||
t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptMRoPEPositions(t *testing.T) {
|
||||
m := &Model{
|
||||
Config: &Config{
|
||||
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||
},
|
||||
}
|
||||
state := &promptVisionState{
|
||||
PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6},
|
||||
Spans: []promptVisionSpan{
|
||||
{
|
||||
Start: 2,
|
||||
End: 8,
|
||||
Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
pos := m.buildPromptMRoPEPositions(state, 0, 10)
|
||||
if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) {
|
||||
t.Fatalf("time positions mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) {
|
||||
t.Fatalf("height positions mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) {
|
||||
t.Fatalf("width positions mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapPromptPositionContinuesAfterCache(t *testing.T) {
|
||||
state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}}
|
||||
|
||||
if got := mapPromptPosition(state, 3); got != 2 {
|
||||
t.Fatalf("mapPromptPosition(3) = %d, want 2", got)
|
||||
}
|
||||
if got := mapPromptPosition(state, 5); got != 4 {
|
||||
t.Fatalf("mapPromptPosition(5) = %d, want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPromptVisionEmbeddings(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
m := &Model{}
|
||||
state := &promptVisionState{
|
||||
Spans: []promptVisionSpan{
|
||||
{
|
||||
Start: 1,
|
||||
End: 3,
|
||||
Main: mlx.FromValues([]float32{
|
||||
10, 11,
|
||||
20, 21,
|
||||
}, 1, 2, 2),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
h := mlx.FromValues([]float32{
|
||||
0, 1,
|
||||
2, 3,
|
||||
4, 5,
|
||||
6, 7,
|
||||
}, 1, 4, 2)
|
||||
|
||||
got := m.applyPromptVisionEmbeddings(h, 0, state)
|
||||
mlx.Eval(got)
|
||||
|
||||
want := []float32{
|
||||
0, 1,
|
||||
10, 11,
|
||||
20, 21,
|
||||
6, 7,
|
||||
}
|
||||
if !slices.Equal(got.Floats(), want) {
|
||||
t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want)
|
||||
}
|
||||
}
|
||||
|
||||
854
x/models/qwen3_5/vision.go
Normal file
854
x/models/qwen3_5/vision.go
Normal file
@@ -0,0 +1,854 @@
|
||||
package qwen3_5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
mlxmodel "github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
var errNoVisionModel = errors.New("qwen3_5: no vision model")
|
||||
|
||||
// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config.
|
||||
type VisionConfig struct {
|
||||
Depth int32 `json:"depth"`
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHeads int32 `json:"num_heads"`
|
||||
InChannels int32 `json:"in_channels"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
SpatialMergeSize int32 `json:"spatial_merge_size"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||
NumPositionEmbeddings int32 `json:"num_position_embeddings"`
|
||||
|
||||
Size struct {
|
||||
ShortestEdge int32 `json:"shortest_edge"`
|
||||
LongestEdge int32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
|
||||
GridPerSide int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (v *VisionConfig) applyDefaults() {
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
if v.HiddenSize <= 0 {
|
||||
v.HiddenSize = 1280
|
||||
}
|
||||
if v.NumHeads <= 0 {
|
||||
v.NumHeads = 16
|
||||
}
|
||||
if v.InChannels <= 0 {
|
||||
v.InChannels = 3
|
||||
}
|
||||
if v.PatchSize <= 0 {
|
||||
v.PatchSize = 14
|
||||
}
|
||||
if v.SpatialMergeSize <= 0 {
|
||||
v.SpatialMergeSize = 2
|
||||
}
|
||||
if v.LayerNormEpsilon == 0 {
|
||||
v.LayerNormEpsilon = 1e-6
|
||||
}
|
||||
if v.RopeTheta == 0 {
|
||||
v.RopeTheta = 10000
|
||||
}
|
||||
if v.TemporalPatchSize <= 0 {
|
||||
v.TemporalPatchSize = 2
|
||||
}
|
||||
if v.NumPositionEmbeddings <= 0 {
|
||||
v.NumPositionEmbeddings = 2304
|
||||
}
|
||||
if len(v.ImageMean) < 3 {
|
||||
v.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||
}
|
||||
if len(v.ImageStd) < 3 {
|
||||
v.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||
}
|
||||
if v.Size.ShortestEdge <= 0 {
|
||||
v.Size.ShortestEdge = 64 << 10
|
||||
}
|
||||
if v.Size.LongestEdge <= 0 {
|
||||
v.Size.LongestEdge = 2 << 20
|
||||
}
|
||||
|
||||
grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings)))
|
||||
if grid <= 0 {
|
||||
grid = 48
|
||||
}
|
||||
v.GridPerSide = grid
|
||||
}
|
||||
|
||||
func (v *VisionConfig) applyPreprocessorConfig(data []byte) {
|
||||
if v == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var pre struct {
|
||||
Size struct {
|
||||
ShortestEdge int32 `json:"shortest_edge"`
|
||||
LongestEdge int32 `json:"longest_edge"`
|
||||
} `json:"size"`
|
||||
PatchSize int32 `json:"patch_size"`
|
||||
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||
MergeSize int32 `json:"merge_size"`
|
||||
ImageMean []float32 `json:"image_mean"`
|
||||
ImageStd []float32 `json:"image_std"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &pre); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if pre.PatchSize > 0 {
|
||||
v.PatchSize = pre.PatchSize
|
||||
}
|
||||
if pre.TemporalPatchSize > 0 {
|
||||
v.TemporalPatchSize = pre.TemporalPatchSize
|
||||
}
|
||||
if pre.MergeSize > 0 {
|
||||
v.SpatialMergeSize = pre.MergeSize
|
||||
}
|
||||
if pre.Size.ShortestEdge > 0 {
|
||||
v.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||
}
|
||||
if pre.Size.LongestEdge > 0 {
|
||||
v.Size.LongestEdge = pre.Size.LongestEdge
|
||||
}
|
||||
if len(pre.ImageMean) >= 3 {
|
||||
v.ImageMean = pre.ImageMean
|
||||
}
|
||||
if len(pre.ImageStd) >= 3 {
|
||||
v.ImageStd = pre.ImageStd
|
||||
}
|
||||
v.applyDefaults()
|
||||
}
|
||||
|
||||
// VisionGrid tracks patch-grid dimensions for an image.
|
||||
type VisionGrid struct {
|
||||
Height int32
|
||||
Width int32
|
||||
Temporal int32
|
||||
}
|
||||
|
||||
// VisionImageProcessor reproduces qwen3vl image preprocessing.
|
||||
type VisionImageProcessor struct {
|
||||
numChannels int32
|
||||
patchSize int32
|
||||
temporalPatchSize int32
|
||||
mergeSize int32
|
||||
shortestEdge int32
|
||||
longestEdge int32
|
||||
factor int32
|
||||
imageMean [3]float32
|
||||
imageStd [3]float32
|
||||
}
|
||||
|
||||
func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &VisionImageProcessor{
|
||||
numChannels: cfg.InChannels,
|
||||
patchSize: cfg.PatchSize,
|
||||
temporalPatchSize: cfg.TemporalPatchSize,
|
||||
mergeSize: cfg.SpatialMergeSize,
|
||||
shortestEdge: cfg.Size.ShortestEdge,
|
||||
longestEdge: cfg.Size.LongestEdge,
|
||||
factor: cfg.PatchSize * cfg.SpatialMergeSize,
|
||||
imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]},
|
||||
imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) {
|
||||
factor := int(p.factor)
|
||||
if factor <= 0 {
|
||||
return 0, 0, fmt.Errorf("invalid factor: %d", factor)
|
||||
}
|
||||
|
||||
if height < factor || width < factor {
|
||||
return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor)
|
||||
}
|
||||
if min(height, width) == 0 {
|
||||
return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height)
|
||||
}
|
||||
if max(height, width)/min(height, width) > 200 {
|
||||
return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height)
|
||||
}
|
||||
|
||||
roundEven := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := roundEven(float64(height)/float64(factor)) * factor
|
||||
wBar := roundEven(float64(width)/float64(factor)) * factor
|
||||
|
||||
if hBar*wBar > int(p.longestEdge) {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if hBar*wBar < int(p.shortestEdge) {
|
||||
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar, nil
|
||||
}
|
||||
|
||||
func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) {
|
||||
if p == nil {
|
||||
return nil, nil, errNoVisionModel
|
||||
}
|
||||
|
||||
img = imageproc.Composite(img)
|
||||
origW := img.Bounds().Dx()
|
||||
origH := img.Bounds().Dy()
|
||||
|
||||
resizedH, resizedW, err := p.smartResize(origH, origW)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resized := imageproc.Resize(
|
||||
img,
|
||||
image.Point{X: resizedW, Y: resizedH},
|
||||
imageproc.ResizeBilinear,
|
||||
)
|
||||
pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true)
|
||||
|
||||
grid := &VisionGrid{
|
||||
Height: int32(resizedH / int(p.patchSize)),
|
||||
Width: int32(resizedW / int(p.patchSize)),
|
||||
Temporal: 1,
|
||||
}
|
||||
|
||||
patches := p.createPatches(pixels, resizedH, resizedW, grid)
|
||||
|
||||
patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize)
|
||||
numPatches := int(grid.Height * grid.Width)
|
||||
pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0)
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
|
||||
func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 {
|
||||
channels := int(p.numChannels)
|
||||
patchSize := int(p.patchSize)
|
||||
mergeSize := int(p.mergeSize)
|
||||
temporalPatchSize := int(p.temporalPatchSize)
|
||||
|
||||
// Temporal is always 1 for static images; only spatial patches are created.
|
||||
numPatches := int(grid.Height * grid.Width)
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
|
||||
patchIndex := 0
|
||||
for h := 0; h < int(grid.Height); h += mergeSize {
|
||||
for w := 0; w < int(grid.Width); w += mergeSize {
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
srcIdx := c*height*width + y*width + x
|
||||
dstIdx := channelOffset + py*patchSize + px
|
||||
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if temporalPatchSize > 1 {
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||
frameSize := patchSize * patchSize
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
cur := channelOffset + tp*frameSize
|
||||
copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// VisionAttention runs one self-attention block inside the vision encoder.
|
||||
type VisionAttention struct {
|
||||
QKV nn.LinearLayer
|
||||
Query nn.LinearLayer
|
||||
Key nn.LinearLayer
|
||||
Value nn.LinearLayer
|
||||
Output nn.LinearLayer
|
||||
}
|
||||
|
||||
func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array {
|
||||
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin))
|
||||
}
|
||||
|
||||
func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||
shape := x.Dims()
|
||||
if len(shape) != 3 {
|
||||
return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape)
|
||||
}
|
||||
B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||
headDim := cfg.HiddenSize / cfg.NumHeads
|
||||
if headDim <= 0 {
|
||||
return nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||
}
|
||||
|
||||
var q, k, v *mlx.Array
|
||||
if a.QKV != nil {
|
||||
qkv := a.QKV.Forward(x)
|
||||
qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim)
|
||||
q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2)
|
||||
k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2)
|
||||
v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2)
|
||||
} else {
|
||||
if a.Query == nil || a.Key == nil || a.Value == nil {
|
||||
return nil, errors.New("vision attention is missing q/k/v projections")
|
||||
}
|
||||
q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||
k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||
v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||
}
|
||||
|
||||
q = applyVisionRoPE(q, cos, sin)
|
||||
k = applyVisionRoPE(k, cos, sin)
|
||||
|
||||
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||
|
||||
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||
attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false)
|
||||
attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden)
|
||||
if a.Output == nil {
|
||||
return nil, errors.New("vision attention is missing output projection")
|
||||
}
|
||||
return a.Output.Forward(attn), nil
|
||||
}
|
||||
|
||||
// VisionMLP is the vision feed-forward block.
|
||||
type VisionMLP struct {
|
||||
FC1 nn.LinearLayer
|
||||
FC2 nn.LinearLayer
|
||||
}
|
||||
|
||||
func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) {
|
||||
if m.FC1 == nil || m.FC2 == nil {
|
||||
return nil, errors.New("vision mlp is missing fc1/fc2")
|
||||
}
|
||||
return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil
|
||||
}
|
||||
|
||||
// VisionEncoderLayer is one transformer block in the vision encoder.
|
||||
type VisionEncoderLayer struct {
|
||||
Norm1 *nn.LayerNorm
|
||||
Attn *VisionAttention
|
||||
Norm2 *nn.LayerNorm
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||
if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil {
|
||||
return nil, errors.New("vision layer is incomplete")
|
||||
}
|
||||
|
||||
r := x
|
||||
a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x = mlx.Add(r, a)
|
||||
|
||||
r = x
|
||||
m, err := l.MLP.Forward(l.Norm2.Forward(x))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mlx.Add(r, m), nil
|
||||
}
|
||||
|
||||
// VisionPatchMerger projects merged spatial groups into language embedding space.
|
||||
type VisionPatchMerger struct {
|
||||
Norm *nn.LayerNorm
|
||||
FC1 nn.LinearLayer
|
||||
FC2 nn.LinearLayer
|
||||
}
|
||||
|
||||
func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) {
|
||||
shape := x.Dims()
|
||||
if len(shape) != 3 {
|
||||
return nil, fmt.Errorf("expected [B,L,D], got %v", shape)
|
||||
}
|
||||
if merge <= 0 {
|
||||
merge = 1
|
||||
}
|
||||
B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||
group := merge * merge
|
||||
if group <= 0 || L%group != 0 {
|
||||
return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge)
|
||||
}
|
||||
|
||||
x = mlx.Reshape(x, B, L/group, group, D)
|
||||
x = mlx.Reshape(x, B, L/group, group*D)
|
||||
return x, nil
|
||||
}
|
||||
|
||||
func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||
if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil {
|
||||
return nil, errors.New("vision patch merger is incomplete")
|
||||
}
|
||||
|
||||
x = m.Norm.Forward(x)
|
||||
|
||||
var err error
|
||||
x, err = groupMergedTokens(x, cfg.SpatialMergeSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x)))
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// VisionModel contains the full Qwen vision tower.
|
||||
type VisionModel struct {
|
||||
PatchProjection nn.LinearLayer
|
||||
PositionEmbed *nn.Embedding
|
||||
Layers []*VisionEncoderLayer
|
||||
PatchMerger *VisionPatchMerger
|
||||
|
||||
cfg *VisionConfig
|
||||
}
|
||||
|
||||
func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 {
|
||||
if merge <= 0 {
|
||||
merge = 1
|
||||
}
|
||||
// Temporal is always 1 for static images; only spatial coordinates are generated.
|
||||
coords := make([][2]int32, 0, grid.Height*grid.Width)
|
||||
for h := int32(0); h < grid.Height; h += merge {
|
||||
for w := int32(0); w < grid.Width; w += merge {
|
||||
for mh := range merge {
|
||||
for mw := range merge {
|
||||
coords = append(coords, [2]int32{h + mh, w + mw})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return coords
|
||||
}
|
||||
|
||||
func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||
if m.PositionEmbed == nil {
|
||||
return x, nil
|
||||
}
|
||||
shape := x.Dims()
|
||||
if len(shape) != 3 {
|
||||
return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape)
|
||||
}
|
||||
B, D := int32(shape[0]), int32(shape[2])
|
||||
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||
L := int32(len(coords))
|
||||
if L != int32(shape[1]) {
|
||||
return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L)
|
||||
}
|
||||
|
||||
stepH := float32(0)
|
||||
if grid.Height > 1 {
|
||||
stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1)
|
||||
}
|
||||
stepW := float32(0)
|
||||
if grid.Width > 1 {
|
||||
stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1)
|
||||
}
|
||||
|
||||
indices := make([]int32, 0, L*4)
|
||||
weights := make([]float32, 0, L*4)
|
||||
for _, c := range coords {
|
||||
y := float32(c[0]) * stepH
|
||||
x0 := float32(c[1]) * stepW
|
||||
|
||||
fy := int32(y)
|
||||
fx := int32(x0)
|
||||
cy := min(fy+1, m.cfg.GridPerSide-1)
|
||||
cx := min(fx+1, m.cfg.GridPerSide-1)
|
||||
|
||||
indices = append(indices,
|
||||
fy*m.cfg.GridPerSide+fx,
|
||||
fy*m.cfg.GridPerSide+cx,
|
||||
cy*m.cfg.GridPerSide+fx,
|
||||
cy*m.cfg.GridPerSide+cx,
|
||||
)
|
||||
|
||||
dy := y - float32(fy)
|
||||
dx := x0 - float32(fx)
|
||||
weights = append(weights,
|
||||
(1-dy)*(1-dx),
|
||||
(1-dy)*dx,
|
||||
dy*(1-dx),
|
||||
dy*dx,
|
||||
)
|
||||
}
|
||||
|
||||
idxArr := mlx.FromValues(indices, int(L), 4)
|
||||
wArr := mlx.FromValues(weights, int(L), 4, 1)
|
||||
|
||||
pos := m.PositionEmbed.Forward(idxArr)
|
||||
wArr = wArr.AsType(pos.DType())
|
||||
pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false)
|
||||
if D != int32(pos.Dim(1)) {
|
||||
return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1))
|
||||
}
|
||||
|
||||
pos = mlx.ExpandDims(pos, 0)
|
||||
if B > 1 {
|
||||
pos = mlx.Tile(pos, []int32{B, 1, 1})
|
||||
}
|
||||
|
||||
return mlx.Add(x, pos), nil
|
||||
}
|
||||
|
||||
func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) {
|
||||
headDim := m.cfg.HiddenSize / m.cfg.NumHeads
|
||||
if headDim <= 0 {
|
||||
return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||
}
|
||||
|
||||
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||
L := int32(len(coords))
|
||||
half := headDim / 2
|
||||
quarter := half / 2
|
||||
if quarter <= 0 {
|
||||
return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim)
|
||||
}
|
||||
|
||||
angles := make([]float32, L*headDim)
|
||||
for i, c := range coords {
|
||||
base := int32(i) * headDim
|
||||
for j := range quarter {
|
||||
freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half))
|
||||
angles[base+j] = float32(float64(c[0]) * freq)
|
||||
angles[base+quarter+j] = float32(float64(c[1]) * freq)
|
||||
}
|
||||
for j := range half {
|
||||
angles[base+half+j] = angles[base+j]
|
||||
}
|
||||
}
|
||||
|
||||
arr := mlx.FromValues(angles, int(L), int(headDim))
|
||||
cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2)
|
||||
sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2)
|
||||
return cos, sin, nil
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||
if m == nil || pixelValues == nil || grid == nil {
|
||||
return nil, errNoVisionModel
|
||||
}
|
||||
if m.PatchProjection == nil || m.PatchMerger == nil {
|
||||
return nil, errors.New("vision model is missing required projections")
|
||||
}
|
||||
|
||||
x := m.PatchProjection.Forward(pixelValues)
|
||||
var err error
|
||||
x, err = m.addPositionEmbedding(x, grid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cos, sin, err := m.rotaryEmbeddings(grid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
x, err = layer.Forward(x, cos, sin, m.cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vision layer %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
main, err := m.PatchMerger.Forward(x, m.cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("vision patch merger: %w", err)
|
||||
}
|
||||
return main, nil
|
||||
}
|
||||
|
||||
type VisionEmbeddings struct {
|
||||
Main *mlx.Array
|
||||
Grid *VisionGrid
|
||||
}
|
||||
|
||||
func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) {
|
||||
if m == nil || m.Vision == nil || m.ImageProcessor == nil {
|
||||
return nil, errNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
main, err := m.Vision.Forward(pixelValues, grid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &VisionEmbeddings{Main: main, Grid: grid}, nil
|
||||
}
|
||||
|
||||
func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string {
|
||||
candidates := []string{
|
||||
"vision_tower",
|
||||
weightPrefix + "vision_tower",
|
||||
"model.visual",
|
||||
"visual",
|
||||
weightPrefix + "model.visual",
|
||||
weightPrefix + "visual",
|
||||
}
|
||||
|
||||
hasTensor := func(prefix string) bool {
|
||||
for _, suffix := range []string{
|
||||
".patch_embed.proj.weight",
|
||||
".patch_embed.weight",
|
||||
".pos_embed.weight",
|
||||
".blocks.0.attn.qkv.weight",
|
||||
".merger.linear_fc1.weight",
|
||||
".merger.mlp.0.weight",
|
||||
} {
|
||||
if tensors[prefix+suffix] != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range candidates {
|
||||
if hasTensor(prefix) {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer {
|
||||
for _, p := range paths {
|
||||
if l := linears.Make(p); l != nil {
|
||||
return l
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm {
|
||||
for _, base := range bases {
|
||||
if w := tensors[base+".weight"]; w != nil {
|
||||
return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps}
|
||||
}
|
||||
if w := tensors[base]; w != nil {
|
||||
return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadVisionPatchMerger(
|
||||
tensors map[string]*mlx.Array,
|
||||
linears mlxmodel.LinearFactory,
|
||||
eps float32,
|
||||
bases ...string,
|
||||
) *VisionPatchMerger {
|
||||
for _, base := range bases {
|
||||
norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q")
|
||||
fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0")
|
||||
fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2")
|
||||
if norm != nil && fc1 != nil && fc2 != nil {
|
||||
return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) {
|
||||
if w == nil || !w.Valid() {
|
||||
return nil, errors.New("missing patch embedding weight")
|
||||
}
|
||||
if w.NumDims() < 2 {
|
||||
return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims())
|
||||
}
|
||||
if w.NumDims() == 2 {
|
||||
return w, nil
|
||||
}
|
||||
|
||||
out := int32(w.Dim(0))
|
||||
in := int32(w.Size() / w.Dim(0))
|
||||
return mlx.Reshape(w, out, in), nil
|
||||
}
|
||||
|
||||
func loadVisionComponents(
|
||||
tensors map[string]*mlx.Array,
|
||||
linears mlxmodel.LinearFactory,
|
||||
cfg *Config,
|
||||
weightPrefix string,
|
||||
) (*VisionModel, *VisionImageProcessor, error) {
|
||||
if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
cfg.Vision.applyDefaults()
|
||||
|
||||
visionPrefix := resolveVisionPrefix(tensors, weightPrefix)
|
||||
if visionPrefix == "" {
|
||||
return nil, nil, errors.New("vision enabled in config but vision tensors were not found")
|
||||
}
|
||||
|
||||
patchW, _ := tensorAny(
|
||||
tensors,
|
||||
visionPrefix+".patch_embed.proj.weight",
|
||||
visionPrefix+".patch_embed.weight",
|
||||
)
|
||||
if patchW == nil {
|
||||
return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix)
|
||||
}
|
||||
patchW, err := flattenPatchEmbeddingWeight(patchW)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
patchB, _ := tensorAny(
|
||||
tensors,
|
||||
visionPrefix+".patch_embed.proj.bias",
|
||||
visionPrefix+".patch_embed.bias",
|
||||
)
|
||||
|
||||
patchProj := nn.NewLinear(patchW, patchB)
|
||||
if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize {
|
||||
return nil, nil, fmt.Errorf(
|
||||
"vision patch embedding input dim mismatch: got %d expected %d",
|
||||
got,
|
||||
cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize,
|
||||
)
|
||||
}
|
||||
|
||||
posW, _ := tensorAny(
|
||||
tensors,
|
||||
visionPrefix+".pos_embed.weight",
|
||||
visionPrefix+".position_embedding.weight",
|
||||
)
|
||||
if posW == nil {
|
||||
return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix)
|
||||
}
|
||||
cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0))
|
||||
cfg.Vision.applyDefaults()
|
||||
|
||||
vm := &VisionModel{
|
||||
PatchProjection: patchProj,
|
||||
PositionEmbed: nn.NewEmbedding(posW),
|
||||
Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth),
|
||||
cfg: cfg.Vision,
|
||||
}
|
||||
|
||||
for i := range cfg.Vision.Depth {
|
||||
layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i)
|
||||
layer := &VisionEncoderLayer{
|
||||
Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"),
|
||||
Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"),
|
||||
Attn: &VisionAttention{
|
||||
QKV: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.qkv",
|
||||
layerPrefix+".attn_qkv",
|
||||
),
|
||||
Query: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.q_proj",
|
||||
layerPrefix+".attn_q",
|
||||
),
|
||||
Key: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.k_proj",
|
||||
layerPrefix+".attn_k",
|
||||
),
|
||||
Value: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.v_proj",
|
||||
layerPrefix+".attn_v",
|
||||
),
|
||||
Output: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".attn.proj",
|
||||
layerPrefix+".attn_out",
|
||||
layerPrefix+".attn.o_proj",
|
||||
),
|
||||
},
|
||||
MLP: &VisionMLP{
|
||||
FC1: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".mlp.fc1",
|
||||
layerPrefix+".mlp.linear_fc1",
|
||||
),
|
||||
FC2: firstLinear(
|
||||
linears,
|
||||
layerPrefix+".mlp.fc2",
|
||||
layerPrefix+".mlp.linear_fc2",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
if layer.Norm1 == nil || layer.Norm2 == nil {
|
||||
return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i)
|
||||
}
|
||||
if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) {
|
||||
return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i)
|
||||
}
|
||||
if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil {
|
||||
return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i)
|
||||
}
|
||||
|
||||
vm.Layers[i] = layer
|
||||
}
|
||||
|
||||
vm.PatchMerger = loadVisionPatchMerger(
|
||||
tensors,
|
||||
linears,
|
||||
cfg.Vision.LayerNormEpsilon,
|
||||
visionPrefix+".merger",
|
||||
)
|
||||
if vm.PatchMerger == nil {
|
||||
return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix)
|
||||
}
|
||||
|
||||
return vm, newVisionImageProcessor(cfg.Vision), nil
|
||||
}
|
||||
Reference in New Issue
Block a user