mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 08:15:42 +02:00
Compare commits
11 Commits
v0.18.2
...
pdevine/qw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
578c32e42e | ||
|
|
a10d2625ca | ||
|
|
b960d769ad | ||
|
|
455a6099d1 | ||
|
|
7e6e8377eb | ||
|
|
126d8db7f3 | ||
|
|
3f3a24b418 | ||
|
|
96e36c0d90 | ||
|
|
6f8ddbb26b | ||
|
|
b5e7888414 | ||
|
|
eab4d22269 |
@@ -155,7 +155,7 @@ func (s *Server) ollamaProxy() http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target := envconfig.Host()
|
target := envconfig.ConnectableHost()
|
||||||
s.log().Info("configuring ollama proxy", "target", target.String())
|
s.log().Info("configuring ollama proxy", "target", target.String())
|
||||||
|
|
||||||
newProxy := httputil.NewSingleHostReverseProxy(target)
|
newProxy := httputil.NewSingleHostReverseProxy(target)
|
||||||
|
|||||||
@@ -413,9 +413,6 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
fmt.Fprintf(os.Stderr, "Headless mode: auto-selected last used model %q\n", current)
|
||||||
if err := config.SetLastModel(current); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return current, nil
|
return current, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,9 +425,6 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
|||||||
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
if err := c.ensureModelsReady(ctx, []string{current}); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if err := config.SetLastModel(current); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return current, nil
|
return current, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -439,9 +433,11 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
if model != current {
|
||||||
if err := config.SetLastModel(model); err != nil {
|
if err := config.SetLastModel(model); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return model, nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,9 +471,11 @@ func (c *launcherClient) launchSingleIntegration(ctx context.Context, name strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if target != current {
|
||||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return launchAfterConfiguration(name, runner, target, req)
|
return launchAfterConfiguration(name, runner, target, req)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,6 +48,20 @@ ollama launch claude --model kimi-k2.5:cloud
|
|||||||
|
|
||||||
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run Claude Code without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch claude --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Claude Code.
|
||||||
|
|
||||||
|
## Web search
|
||||||
|
|
||||||
|
Claude Code can search the web through Ollama's web search API. See the [web search documentation](/capabilities/web-search) for setup and usage.
|
||||||
|
|
||||||
## Scheduled Tasks with `/loop`
|
## Scheduled Tasks with `/loop`
|
||||||
|
|
||||||
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
|
The `/loop` command runs a prompt or slash command on a recurring schedule inside Claude Code. This is useful for automating repetitive tasks like checking PRs, running research, or setting reminders.
|
||||||
|
|||||||
@@ -15,13 +15,29 @@ Ollama handles everything automatically:
|
|||||||
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
|
||||||
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
2. **Security** — On the first launch, a security notice explains the risks of tool access
|
||||||
3. **Model** — Pick a model from the selector (local or cloud)
|
3. **Model** — Pick a model from the selector (local or cloud)
|
||||||
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, and sets your model as the primary
|
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and installs the web search and fetch plugin
|
||||||
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
5. **Gateway** — Starts in the background and opens the OpenClaw TUI
|
||||||
|
|
||||||
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
|
||||||
|
|
||||||
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
<Note>Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.</Note>
|
||||||
|
|
||||||
|
## Web search and fetch
|
||||||
|
|
||||||
|
OpenClaw ships with a web search and fetch plugin that gives local or cloud models the ability to search the web and extract readable page content.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama launch openclaw
|
||||||
|
```
|
||||||
|
|
||||||
|
Web search and fetch is enabled automatically when launching OpenClaw through Ollama. To install the plugin directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
openclaw plugins install @ollama/openclaw-web-search
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>Web search for local models requires `ollama signin`.</Note>
|
||||||
|
|
||||||
## Configure without launching
|
## Configure without launching
|
||||||
|
|
||||||
To change the model without starting the gateway and TUI:
|
To change the model without starting the gateway and TUI:
|
||||||
@@ -52,6 +68,16 @@ If the gateway is already running, it restarts automatically to pick up the new
|
|||||||
|
|
||||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run OpenClaw without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama launch openclaw --model kimi-k2.5:cloud --yes
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified.
|
||||||
|
|
||||||
## Connect messaging apps
|
## Connect messaging apps
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -59,6 +59,29 @@ func Host() *url.URL {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnectableHost returns Host() with unspecified bind addresses (0.0.0.0, ::)
|
||||||
|
// replaced by the corresponding loopback address (127.0.0.1, ::1).
|
||||||
|
// Unspecified addresses are valid for binding a server socket but not for
|
||||||
|
// connecting as a client, which fails on Windows.
|
||||||
|
func ConnectableHost() *url.URL {
|
||||||
|
u := Host()
|
||||||
|
host, port, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip := net.ParseIP(host); ip != nil && ip.IsUnspecified() {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
host = "127.0.0.1"
|
||||||
|
} else {
|
||||||
|
host = "::1"
|
||||||
|
}
|
||||||
|
u.Host = net.JoinHostPort(host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||||
func AllowedOrigins() (origins []string) {
|
func AllowedOrigins() (origins []string) {
|
||||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||||
|
|||||||
@@ -52,6 +52,37 @@ func TestHost(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectableHost(t *testing.T) {
|
||||||
|
cases := map[string]struct {
|
||||||
|
value string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
"empty": {"", "http://127.0.0.1:11434"},
|
||||||
|
"localhost": {"127.0.0.1", "http://127.0.0.1:11434"},
|
||||||
|
"localhost and port": {"127.0.0.1:1234", "http://127.0.0.1:1234"},
|
||||||
|
"ipv4 unspecified": {"0.0.0.0", "http://127.0.0.1:11434"},
|
||||||
|
"ipv4 unspecified + port": {"0.0.0.0:1234", "http://127.0.0.1:1234"},
|
||||||
|
"ipv6 unspecified": {"[::]", "http://[::1]:11434"},
|
||||||
|
"ipv6 unspecified + port": {"[::]:1234", "http://[::1]:1234"},
|
||||||
|
"ipv6 localhost": {"[::1]", "http://[::1]:11434"},
|
||||||
|
"ipv6 localhost + port": {"[::1]:1234", "http://[::1]:1234"},
|
||||||
|
"specific address": {"192.168.1.5", "http://192.168.1.5:11434"},
|
||||||
|
"specific address + port": {"192.168.1.5:8080", "http://192.168.1.5:8080"},
|
||||||
|
"hostname": {"example.com", "http://example.com:11434"},
|
||||||
|
"hostname and port": {"example.com:1234", "http://example.com:1234"},
|
||||||
|
"https unspecified + port": {"https://0.0.0.0:4321", "https://127.0.0.1:4321"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tt := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", tt.value)
|
||||||
|
if host := ConnectableHost(); host.String() != tt.expect {
|
||||||
|
t.Errorf("%s: expected %s, got %s", name, tt.expect, host.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOrigins(t *testing.T) {
|
func TestOrigins(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
value string
|
value string
|
||||||
|
|||||||
@@ -345,44 +345,163 @@ func escapeGLM46Content(s string) string {
|
|||||||
return result.String()
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
|
// repairPhase represents the expected next tag in the repair cycle.
|
||||||
// GLM models sometimes omit the closing tag, producing XML like:
|
type repairPhase int
|
||||||
|
|
||||||
|
const (
|
||||||
|
phaseArgKeyOpen repairPhase = iota // expecting <arg_key>
|
||||||
|
phaseArgKeyClose // expecting </arg_key>
|
||||||
|
phaseArgValOpen // expecting <arg_value>
|
||||||
|
phaseArgValClose // expecting </arg_value>
|
||||||
|
phaseCount // number of phases
|
||||||
|
)
|
||||||
|
|
||||||
|
// repairGLM46XML reconstructs well-formed XML from GLM model output that may
|
||||||
|
// have missing or mismatched tags. The expected structure is:
|
||||||
//
|
//
|
||||||
// <arg_value>value</tool_call>
|
// func_name
|
||||||
|
// <arg_key>key</arg_key>
|
||||||
|
// <arg_value>value</arg_value>
|
||||||
|
// ...
|
||||||
//
|
//
|
||||||
// instead of:
|
// GLM models frequently omit opening or closing tags. This function follows
|
||||||
//
|
// the expected tag cycle, scanning forward for each expected tag in sequence.
|
||||||
// <arg_value>value</arg_value></tool_call>
|
// When a tag is missing, it inserts the tag and consumes any text in between.
|
||||||
func repairUnclosedArgValues(s string) string {
|
func repairGLM46XML(s string) string {
|
||||||
|
// tagCycle is the repeating sequence of tags after the function name.
|
||||||
|
tagCycle := [phaseCount]string{"<arg_key>", "</arg_key>", "<arg_value>", "</arg_value>"}
|
||||||
|
|
||||||
|
// findNextTag returns the index and identity of the earliest known tag in s.
|
||||||
|
findNextTag := func(s string) (int, string) {
|
||||||
|
bestIdx := -1
|
||||||
|
bestTag := ""
|
||||||
|
for _, tag := range tagCycle {
|
||||||
|
if idx := strings.Index(s, tag); idx != -1 && (bestIdx == -1 || idx < bestIdx) {
|
||||||
|
bestIdx = idx
|
||||||
|
bestTag = tag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestIdx, bestTag
|
||||||
|
}
|
||||||
|
|
||||||
|
// tagIndex returns the phase corresponding to the given tag.
|
||||||
|
tagIndex := func(tag string) repairPhase {
|
||||||
|
for i, t := range tagCycle {
|
||||||
|
if t == tag {
|
||||||
|
return repairPhase(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
var result strings.Builder
|
var result strings.Builder
|
||||||
for {
|
|
||||||
openIdx := strings.Index(s, "<arg_value>")
|
idx, firstTag := findNextTag(s)
|
||||||
if openIdx == -1 {
|
if idx == -1 {
|
||||||
result.WriteString(s)
|
return s
|
||||||
|
}
|
||||||
|
prefix := s[:idx]
|
||||||
|
s = s[idx:]
|
||||||
|
|
||||||
|
// If the first tag is not <arg_key>, the text before it may contain both
|
||||||
|
// the function name and key content (e.g. "weather city</arg_key>").
|
||||||
|
// Function names cannot contain space, so split at the first space.
|
||||||
|
phase := phaseArgKeyOpen
|
||||||
|
if firstTag != "<arg_key>" {
|
||||||
|
if spIdx := strings.IndexFunc(prefix, unicode.IsSpace); spIdx != -1 {
|
||||||
|
result.WriteString(prefix[:spIdx])
|
||||||
|
keyContent := strings.TrimLeftFunc(prefix[spIdx:], unicode.IsSpace)
|
||||||
|
result.WriteString("<arg_key>")
|
||||||
|
result.WriteString(keyContent)
|
||||||
|
phase = phaseArgKeyClose
|
||||||
|
} else {
|
||||||
|
result.WriteString(prefix)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.WriteString(prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk through the expected tag cycle. At each step, look for the
|
||||||
|
// expected tag. If a different tag appears first, emit the missing
|
||||||
|
// tags to catch up, then continue.
|
||||||
|
for len(s) > 0 {
|
||||||
|
idx, found := findNextTag(s)
|
||||||
|
expected := tagCycle[phase]
|
||||||
|
isOpen := phase%2 == 0 // even phases are opening tags
|
||||||
|
|
||||||
|
if idx == -1 {
|
||||||
|
// No more tags — emit remaining text with fixups
|
||||||
|
if isOpen {
|
||||||
|
// Expecting an opening tag but nothing left — we're done
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
afterOpen := openIdx + len("<arg_value>")
|
// Expecting a closing tag — emit text then close
|
||||||
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
|
result.WriteString(s)
|
||||||
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
|
result.WriteString(expected)
|
||||||
// Check if properly closed before the next <arg_key> (or no next key)
|
phase = (phase + 1) % phaseCount
|
||||||
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
|
break
|
||||||
end := afterOpen + closeIdx + len("</arg_value>")
|
}
|
||||||
result.WriteString(s[:end])
|
|
||||||
s = s[end:]
|
if found == expected {
|
||||||
|
// Found the expected tag — emit any text before it, then the tag
|
||||||
|
result.WriteString(s[:idx])
|
||||||
|
result.WriteString(expected)
|
||||||
|
s = s[idx+len(expected):]
|
||||||
|
phase = (phase + 1) % phaseCount
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Unclosed — insert </arg_value> before the next <arg_key> or at end
|
|
||||||
if nextKeyIdx != -1 {
|
// Found a different tag. Insert missing tags to catch up.
|
||||||
insertAt := afterOpen + nextKeyIdx
|
foundIdx := tagIndex(found)
|
||||||
result.WriteString(s[:insertAt])
|
|
||||||
result.WriteString("</arg_value>")
|
if isOpen && idx > 0 {
|
||||||
s = s[insertAt:]
|
// Text before the found tag while expecting an opening tag —
|
||||||
|
// the opening tag was omitted. Emit it before the text.
|
||||||
|
result.WriteString(expected)
|
||||||
|
// Advance to the next phase (text content) and then look
|
||||||
|
// for the closing tag — but the found tag might be that
|
||||||
|
// closing tag or something further ahead. Emit text up to
|
||||||
|
// the found tag and insert any missing tags between.
|
||||||
|
result.WriteString(s[:idx])
|
||||||
|
phase = (phase + 1) % phaseCount // now expecting closing
|
||||||
|
s = s[idx:]
|
||||||
|
// Fall through to re-evaluate with the closing tag expected
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit missing tags to advance from current phase to the found tag's phase
|
||||||
|
for phase != foundIdx {
|
||||||
|
tag := tagCycle[phase]
|
||||||
|
if phase%2 == 0 {
|
||||||
|
result.WriteString(tag)
|
||||||
} else {
|
} else {
|
||||||
result.WriteString(s)
|
// Closing tag — emit any text before the found tag first,
|
||||||
|
// but only if we're one step before the found tag
|
||||||
|
if (phase+1)%phaseCount == foundIdx && idx > 0 {
|
||||||
|
result.WriteString(s[:idx])
|
||||||
|
s = s[idx:]
|
||||||
|
idx = 0
|
||||||
|
}
|
||||||
|
result.WriteString(tag)
|
||||||
|
}
|
||||||
|
phase = (phase + 1) % phaseCount
|
||||||
|
}
|
||||||
|
// Now phase == foundIdx, re-process without advancing s
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we stopped mid-pair (after an opening tag), close it
|
||||||
|
switch phase {
|
||||||
|
case phaseArgKeyClose: // after <arg_key>, expecting text/</arg_key>
|
||||||
|
result.WriteString("</arg_key>")
|
||||||
|
result.WriteString("<arg_value>")
|
||||||
|
result.WriteString("</arg_value>")
|
||||||
|
case phaseArgValOpen: // after </arg_key>, expecting <arg_value>
|
||||||
|
result.WriteString("<arg_value>")
|
||||||
|
result.WriteString("</arg_value>")
|
||||||
|
case phaseArgValClose: // after <arg_value>, expecting text/</arg_value>
|
||||||
result.WriteString("</arg_value>")
|
result.WriteString("</arg_value>")
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.String()
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -398,7 +517,7 @@ func parseGLM46ToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCa
|
|||||||
var parsed GLMToolCallXML
|
var parsed GLMToolCallXML
|
||||||
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
|
||||||
parsed = GLMToolCallXML{}
|
parsed = GLMToolCallXML{}
|
||||||
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
|
repaired := "<tool_call>" + repairGLM46XML(escaped) + "</tool_call>"
|
||||||
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
|
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
|
||||||
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -887,6 +887,28 @@ line3</arg_value>`,
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "unopened arg_value after arg_key",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: "get-weather\n<arg_key>city</arg_key>\nNew York</arg_value>\n<arg_key>unit</arg_key>\ncelsius</arg_value>",
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get-weather",
|
||||||
|
Arguments: args(`{"city": "New York", "unit": "celsius"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed unopened and valid arg_values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: "get-weather\n<arg_key>city</arg_key>\n<arg_value>Paris</arg_value>\n<arg_key>unit</arg_key>\ncelsius</arg_value>",
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get-weather",
|
||||||
|
Arguments: args(`{"city": "Paris", "unit": "celsius"}`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range cases {
|
for i, tc := range cases {
|
||||||
@@ -902,7 +924,7 @@ line3</arg_value>`,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRepairUnclosedArgValues(t *testing.T) {
|
func TestRepairGLM46XML(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
input string
|
input string
|
||||||
@@ -910,33 +932,63 @@ func TestRepairUnclosedArgValues(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "already valid",
|
name: "already valid",
|
||||||
input: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
input: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unclosed at end",
|
name: "missing </arg_value> at end",
|
||||||
input: `<arg_key>k</arg_key><arg_value>v`,
|
input: `func<arg_key>k</arg_key><arg_value>v`,
|
||||||
want: `<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unclosed before next arg_key",
|
name: "missing </arg_value> before next arg_key",
|
||||||
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
|
input: `func<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||||
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
want: `func<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no arg_value tags",
|
name: "no tags at all",
|
||||||
input: `just plain text`,
|
input: `just plain text`,
|
||||||
want: `just plain text`,
|
want: `just plain text`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple unclosed",
|
name: "missing <arg_value> open tag",
|
||||||
input: `<arg_key>a</arg_key><arg_value>1<arg_key>b</arg_key><arg_value>2`,
|
input: `func<arg_key>k</arg_key>v</arg_value>`,
|
||||||
want: `<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing </arg_key> close tag",
|
||||||
|
input: `func<arg_key>k<arg_value>v</arg_value>`,
|
||||||
|
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing <arg_key> open tag",
|
||||||
|
input: `func k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all closing tags missing",
|
||||||
|
input: `func<arg_key>k<arg_value>v`,
|
||||||
|
want: `func<arg_key>k</arg_key><arg_value>v</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all opening tags missing",
|
||||||
|
input: "func k</arg_key>v</arg_value>",
|
||||||
|
want: "func<arg_key>k</arg_key><arg_value>v</arg_value>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple pairs with mixed missing tags",
|
||||||
|
input: `func<arg_key>a</arg_key>1</arg_value><arg_key>b<arg_value>2</arg_value>`,
|
||||||
|
want: `func<arg_key>a</arg_key><arg_value>1</arg_value><arg_key>b</arg_key><arg_value>2</arg_value>`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "newlines preserved",
|
||||||
|
input: "func\n<arg_key>city</arg_key>\nNew York</arg_value>",
|
||||||
|
want: "func\n<arg_key>city</arg_key><arg_value>\nNew York</arg_value>",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got := repairUnclosedArgValues(tc.input)
|
got := repairGLM46XML(tc.input)
|
||||||
if got != tc.want {
|
if got != tc.want {
|
||||||
t.Errorf("got %q, want %q", got, tc.want)
|
t.Errorf("got %q, want %q", got, tc.want)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,14 +134,18 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
|
|||||||
spinnerKey = "create"
|
spinnerKey = "create"
|
||||||
capabilities = []string{"completion"}
|
capabilities = []string{"completion"}
|
||||||
|
|
||||||
// Check if model supports thinking based on architecture
|
configData, _ := os.ReadFile(filepath.Join(opts.ModelDir, "config.json"))
|
||||||
if supportsThinking(opts.ModelDir) {
|
mcfg := parseModelConfig(configData)
|
||||||
|
|
||||||
|
if mcfg.supportsThinking() {
|
||||||
capabilities = append(capabilities, "thinking")
|
capabilities = append(capabilities, "thinking")
|
||||||
}
|
}
|
||||||
|
if mcfg.supportsVision() {
|
||||||
|
capabilities = append(capabilities, "vision")
|
||||||
|
}
|
||||||
|
|
||||||
// Set parser and renderer name based on architecture
|
parserName = mcfg.parserName()
|
||||||
parserName = getParserName(opts.ModelDir)
|
rendererName = mcfg.rendererName()
|
||||||
rendererName = getRendererName(opts.ModelDir)
|
|
||||||
} else {
|
} else {
|
||||||
modelType = "image generation model"
|
modelType = "image generation model"
|
||||||
spinnerKey = "imagegen"
|
spinnerKey = "imagegen"
|
||||||
@@ -438,145 +442,76 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
|||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// supportsThinking checks if the model supports thinking mode based on its architecture.
|
// modelConfig holds the fields from config.json needed during model creation.
|
||||||
// This reads the config.json from the model directory and checks the architectures field.
|
type visionConfig struct {
|
||||||
func supportsThinking(modelDir string) bool {
|
Depth int32 `json:"depth"`
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg struct {
|
type modelConfig struct {
|
||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
ModelType string `json:"model_type"`
|
ModelType string `json:"model_type"`
|
||||||
}
|
VisionConfig *visionConfig `json:"vision_config"`
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
ImageTokenID *int32 `json:"image_token_id"`
|
||||||
return false
|
VisionStartTokenID *int32 `json:"vision_start_token_id"`
|
||||||
|
VisionEndTokenID *int32 `json:"vision_end_token_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check architectures that support thinking
|
func parseModelConfig(data []byte) modelConfig {
|
||||||
thinkingArchitectures := []string{
|
var cfg modelConfig
|
||||||
"glm4moe", // GLM-4 MoE models
|
_ = json.Unmarshal(data, &cfg)
|
||||||
"deepseek", // DeepSeek models
|
return cfg
|
||||||
"qwen3", // Qwen3 models
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the architecture list
|
// archOrTypeContains returns true if any architecture or the model_type
|
||||||
for _, arch := range cfg.Architectures {
|
// contains one of the given substrings (case-insensitive).
|
||||||
|
func (c *modelConfig) archOrTypeContains(substrs ...string) bool {
|
||||||
|
for _, arch := range c.Architectures {
|
||||||
archLower := strings.ToLower(arch)
|
archLower := strings.ToLower(arch)
|
||||||
for _, thinkArch := range thinkingArchitectures {
|
for _, s := range substrs {
|
||||||
if strings.Contains(archLower, thinkArch) {
|
if strings.Contains(archLower, s) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.ModelType != "" {
|
||||||
// Also check model_type
|
typeLower := strings.ToLower(c.ModelType)
|
||||||
if cfg.ModelType != "" {
|
for _, s := range substrs {
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
if strings.Contains(typeLower, s) {
|
||||||
for _, thinkArch := range thinkingArchitectures {
|
|
||||||
if strings.Contains(typeLower, thinkArch) {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// getParserName returns the parser name for a model based on its architecture.
|
func (c *modelConfig) supportsThinking() bool {
|
||||||
// This reads the config.json from the model directory and determines the appropriate parser.
|
return c.archOrTypeContains("glm4moe", "deepseek", "qwen3")
|
||||||
func getParserName(modelDir string) string {
|
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg struct {
|
func (c *modelConfig) supportsVision() bool {
|
||||||
Architectures []string `json:"architectures"`
|
return c.VisionConfig != nil || c.ImageTokenID != nil || c.VisionStartTokenID != nil || c.VisionEndTokenID != nil
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check architectures for known parsers
|
func (c *modelConfig) parserName() string {
|
||||||
for _, arch := range cfg.Architectures {
|
switch {
|
||||||
archLower := strings.ToLower(arch)
|
case c.archOrTypeContains("glm4", "glm-4"):
|
||||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
return "glm-4.7"
|
||||||
}
|
case c.archOrTypeContains("deepseek"):
|
||||||
if strings.Contains(archLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
case c.archOrTypeContains("qwen3"):
|
||||||
if strings.Contains(archLower, "qwen3") {
|
|
||||||
return "qwen3"
|
return "qwen3"
|
||||||
}
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also check model_type
|
func (c *modelConfig) rendererName() string {
|
||||||
if cfg.ModelType != "" {
|
switch {
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
case c.archOrTypeContains("glm4", "glm-4"):
|
||||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
return "glm-4.7"
|
||||||
}
|
case c.archOrTypeContains("deepseek"):
|
||||||
if strings.Contains(typeLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
case c.archOrTypeContains("qwen3"):
|
||||||
if strings.Contains(typeLower, "qwen3") {
|
|
||||||
return "qwen3"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRendererName returns the renderer name for a model based on its architecture.
|
|
||||||
// This reads the config.json from the model directory and determines the appropriate renderer.
|
|
||||||
func getRendererName(modelDir string) string {
|
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg struct {
|
|
||||||
Architectures []string `json:"architectures"`
|
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check architectures for known renderers
|
|
||||||
for _, arch := range cfg.Architectures {
|
|
||||||
archLower := strings.ToLower(arch)
|
|
||||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
|
||||||
}
|
|
||||||
if strings.Contains(archLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
|
||||||
}
|
|
||||||
if strings.Contains(archLower, "qwen3") {
|
|
||||||
return "qwen3-coder"
|
return "qwen3-coder"
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Also check model_type
|
|
||||||
if cfg.ModelType != "" {
|
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
|
||||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
|
||||||
return "glm-4.7"
|
|
||||||
}
|
|
||||||
if strings.Contains(typeLower, "deepseek") {
|
|
||||||
return "deepseek3"
|
|
||||||
}
|
|
||||||
if strings.Contains(typeLower, "qwen3") {
|
|
||||||
return "qwen3-coder"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -339,3 +339,34 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
|
|||||||
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSupportsVision(t *testing.T) {
|
||||||
|
t.Run("vision_config present", func(t *testing.T) {
|
||||||
|
cfg := parseModelConfig([]byte(`{
|
||||||
|
"vision_config": {"depth": 2},
|
||||||
|
"image_token_id": 151655
|
||||||
|
}`))
|
||||||
|
if !cfg.supportsVision() {
|
||||||
|
t.Fatal("supportsVision() = false, want true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("token ids alone imply vision", func(t *testing.T) {
|
||||||
|
cfg := parseModelConfig([]byte(`{
|
||||||
|
"vision_start_token_id": 10,
|
||||||
|
"vision_end_token_id": 11
|
||||||
|
}`))
|
||||||
|
if !cfg.supportsVision() {
|
||||||
|
t.Fatal("supportsVision() = false, want true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plain text model", func(t *testing.T) {
|
||||||
|
cfg := parseModelConfig([]byte(`{
|
||||||
|
"architectures": ["Qwen3_5ForCausalLM"]
|
||||||
|
}`))
|
||||||
|
if cfg.supportsVision() {
|
||||||
|
t.Fatal("supportsVision() = true, want false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,26 @@
|
|||||||
|
// cache.go manages a shared KV cache across conversations using a compressed
|
||||||
|
// prefix trie. Each trie node stores a token sequence (edge) and optional
|
||||||
|
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
|
||||||
|
//
|
||||||
|
// Key properties:
|
||||||
|
// - Only one path through the trie is "active" (backed by live MLX arrays)
|
||||||
|
// at a time. Switching paths pages out the frontier node and pages in the
|
||||||
|
// new path.
|
||||||
|
// - Snapshots are only captured at the frontier (end) of the active path.
|
||||||
|
// Intermediate node snapshots come from split prefill.
|
||||||
|
// - All cache layers must stay at the same token offset.
|
||||||
|
// - Sibling edges must not share a common token prefix (compressed trie
|
||||||
|
// invariant).
|
||||||
|
// - begin() always re-evaluates at least one token so the pipeline can seed
|
||||||
|
// generation, even on a full prefix match.
|
||||||
|
|
||||||
package mlxrunner
|
package mlxrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
@@ -10,10 +28,13 @@ import (
|
|||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
|
||||||
|
|
||||||
type kvCache struct {
|
type kvCache struct {
|
||||||
// For now we only support a single entry, so this is just one sequence
|
root *trieNode // root of the prefix trie
|
||||||
tokens []int32
|
activePath []*trieNode // current root→leaf path with live MLX arrays
|
||||||
caches []cache.Cache
|
caches []cache.Cache
|
||||||
|
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheSession manages caches for a single pipeline run.
|
// cacheSession manages caches for a single pipeline run.
|
||||||
@@ -26,66 +47,15 @@ type cacheSession struct {
|
|||||||
|
|
||||||
caches []cache.Cache
|
caches []cache.Cache
|
||||||
remaining []int32
|
remaining []int32
|
||||||
|
|
||||||
|
// snapshotOffset, if > 0, is a trie node boundary where we need to
|
||||||
|
// capture a snapshot during prefill. This enables future requests
|
||||||
|
// branching at this point to restore non-rewindable caches (e.g.
|
||||||
|
// RecurrentCache) instead of re-evaluating from scratch.
|
||||||
|
snapshotOffset int
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendCacheState(dst []*mlx.Array, c cache.Cache) []*mlx.Array {
|
func (c *kvCache) ensureCaches(m base.Model) {
|
||||||
if c == nil {
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
keys, values := c.State()
|
|
||||||
if keys != nil && keys.Valid() {
|
|
||||||
dst = append(dst, keys)
|
|
||||||
}
|
|
||||||
if values != nil && values.Valid() {
|
|
||||||
dst = append(dst, values)
|
|
||||||
}
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *kvCache) free() {
|
|
||||||
for i, kv := range c.caches {
|
|
||||||
if kv == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
kv.Free()
|
|
||||||
c.caches[i] = nil
|
|
||||||
}
|
|
||||||
c.caches = nil
|
|
||||||
c.tokens = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *kvCache) cachesCanTrim() bool {
|
|
||||||
for _, kv := range c.caches {
|
|
||||||
if kv == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !kv.CanTrim() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *kvCache) trimToPrefix(prefix int) {
|
|
||||||
for _, kv := range c.caches {
|
|
||||||
if kv == nil || !kv.CanTrim() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if trim := kv.Offset() - prefix; trim > 0 {
|
|
||||||
kv.Trim(trim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if prefix < len(c.tokens) {
|
|
||||||
c.tokens = c.tokens[:prefix]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// begin prepares caches for a new request. It finds the nearest
|
|
||||||
// matching cache or creates new caches if none match.
|
|
||||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|
||||||
ensureCaches := func() {
|
|
||||||
if len(c.caches) != 0 {
|
if len(c.caches) != 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -98,104 +68,534 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
|||||||
c.caches[i] = cache.NewKVCache()
|
c.caches[i] = cache.NewKVCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ensureCaches()
|
|
||||||
|
|
||||||
remaining := c.findRemaining(inputs)
|
func (c *kvCache) ensureRoot() {
|
||||||
ensureCaches()
|
if c.root == nil {
|
||||||
|
c.root = &trieNode{
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
c.activePath = []*trieNode{c.root}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// begin prepares caches for a new request. It finds the nearest
|
||||||
|
// matching cache or creates new caches if none match.
|
||||||
|
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||||
|
c.ensureCaches(m)
|
||||||
|
c.ensureRoot()
|
||||||
|
|
||||||
|
matchPath, matched := findBestMatch(c.root, inputs)
|
||||||
|
originalMatched := matched
|
||||||
|
|
||||||
|
// Always keep at least one token to re-evaluate so the
|
||||||
|
// pipeline can seed token generation from it.
|
||||||
|
if matched == len(inputs) && matched > 0 {
|
||||||
|
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for partial match within a node's edge — truncate path
|
||||||
|
// to the parent boundary. snapshot() will split the node and
|
||||||
|
// create the branch point during prefill when caches are ready.
|
||||||
|
partialMatch := false
|
||||||
|
if len(matchPath) > 1 {
|
||||||
|
lastNode := matchPath[len(matchPath)-1]
|
||||||
|
matchedInEdge := matched - lastNode.startOffset()
|
||||||
|
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
||||||
|
matchPath = matchPath[:len(matchPath)-1]
|
||||||
|
partialMatch = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch to the matched path, paging in/out as needed.
|
||||||
|
c.switchToPath(matchPath)
|
||||||
|
|
||||||
|
// switchToPath aligns caches to a common offset
|
||||||
|
prefix := c.minCacheOffset()
|
||||||
|
remaining := inputs[prefix:]
|
||||||
|
|
||||||
|
// Schedule a snapshot at the branch point during prefill so future
|
||||||
|
// requests diverging here can restore instead of re-evaluating.
|
||||||
|
var snapshotAt int
|
||||||
|
if partialMatch || (prefix == 0 && matched > 0) {
|
||||||
|
snapshotAt = matched
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []any{"total", len(inputs), "matched", originalMatched}
|
||||||
|
args = append(args, "cached", prefix, "left", len(remaining))
|
||||||
|
if snapshotAt > 0 {
|
||||||
|
args = append(args, "pending_snapshot", snapshotAt)
|
||||||
|
}
|
||||||
|
if prefix == 0 {
|
||||||
|
slog.Info("cache miss", args...)
|
||||||
|
} else {
|
||||||
|
slog.Info("cache hit", args...)
|
||||||
|
}
|
||||||
|
|
||||||
return &cacheSession{
|
return &cacheSession{
|
||||||
cache: c,
|
cache: c,
|
||||||
inputs: inputs,
|
inputs: inputs,
|
||||||
|
snapshotOffset: snapshotAt,
|
||||||
caches: c.caches,
|
caches: c.caches,
|
||||||
remaining: remaining,
|
remaining: remaining,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// close saves the token state if the forward pass ran.
|
// switchToPath transitions from the current active path to a new path,
|
||||||
func (s *cacheSession) close() {
|
// paging out diverging segments and paging in the new path.
|
||||||
if len(s.caches) == 0 {
|
func (c *kvCache) switchToPath(newPath []*trieNode) {
|
||||||
|
defer c.enforceEvictionPolicy()
|
||||||
|
|
||||||
|
// Find common ancestor index.
|
||||||
|
commonLen := 0
|
||||||
|
for commonLen < len(c.activePath) && commonLen < len(newPath) {
|
||||||
|
if c.activePath[commonLen] != newPath[commonLen] {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
commonLen++
|
||||||
|
}
|
||||||
|
|
||||||
|
ancestorOffset := 0
|
||||||
|
if commonLen > 0 {
|
||||||
|
ancestorOffset = c.activePath[commonLen-1].endOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
var pageOutCount, pageInCount int
|
||||||
|
|
||||||
|
// Page out the leaf of the old path. Only the leaf's live cache
|
||||||
|
// state is correct — intermediate nodes already have snapshots
|
||||||
|
// captured during their creation (splitNode + prefill). Snapshotting
|
||||||
|
// non-leaf nodes here would produce wrong results for non-rewindable
|
||||||
|
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
||||||
|
// the intermediate boundary.
|
||||||
|
if leaf := len(c.activePath) - 1; leaf >= commonLen {
|
||||||
|
node := c.activePath[leaf]
|
||||||
|
if !node.hasAllSnapshots() {
|
||||||
|
fromOffset := node.startOffset()
|
||||||
|
snaps := make([]cache.Snapshot, len(c.caches))
|
||||||
|
for j, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snaps[j] = kv.Snapshot(fromOffset)
|
||||||
|
}
|
||||||
|
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||||
|
pageOutCount++
|
||||||
|
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewind each cache to the ancestor offset or free it. Freed
|
||||||
|
// caches (e.g. RecurrentCache that can't rewind) will be restored
|
||||||
|
// from snapshots during page-in.
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !kv.Restore(nil, ancestorOffset) {
|
||||||
|
kv.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Page in — walk the full new path, restoring from snapshots.
|
||||||
|
// Freed caches naturally pick up the first available snapshot.
|
||||||
|
// Caches already past a node skip it via offset check.
|
||||||
|
for _, node := range newPath {
|
||||||
|
if len(node.snapshots) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for j, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if kv.Offset() >= node.endOffset {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !kv.Restore(node.snapshots[j], node.endOffset) {
|
||||||
|
slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset())
|
||||||
|
c.freeAll()
|
||||||
|
c.activePath = []*trieNode{c.root}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if node.endOffset > ancestorOffset {
|
||||||
|
pageInCount++
|
||||||
|
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Align all caches to the minimum offset.
|
||||||
|
c.activePath = newPath
|
||||||
|
minOff := c.minCacheOffset()
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv != nil && kv.Offset() != minOff {
|
||||||
|
if !kv.Restore(nil, minOff) {
|
||||||
|
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
||||||
|
c.freeAll()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := len(c.activePath) - 1; i >= 0; i-- {
|
||||||
|
if c.activePath[i].endOffset <= minOff {
|
||||||
|
c.activePath = c.activePath[:i+1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pageOutCount > 0 || pageInCount > 0 {
|
||||||
|
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// snapshot creates a snapshot at the current cache position. During prefill,
|
||||||
|
// it is called at branch points (user=false) to create restore points for
|
||||||
|
// future diverging requests and with user=true to mark an explicit reusable
|
||||||
|
// restore point.
|
||||||
|
func (s *cacheSession) snapshot(user bool) {
|
||||||
|
c := s.cache
|
||||||
|
cacheOffset := c.minCacheOffset()
|
||||||
|
if cacheOffset <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear pending intermediate snapshot if we've reached or passed it.
|
||||||
|
if s.snapshotOffset > 0 && cacheOffset >= s.snapshotOffset {
|
||||||
|
s.snapshotOffset = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// The last node in activePath is the frontier where caches are advancing.
|
||||||
|
// cacheOffset is always >= its endOffset: begin() restores caches to this
|
||||||
|
// boundary and prefill advances monotonically forward.
|
||||||
|
frontier := c.activePath[len(c.activePath)-1]
|
||||||
|
|
||||||
|
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
|
||||||
|
if frontier.endOffset == cacheOffset {
|
||||||
|
if user {
|
||||||
|
frontier.user = true
|
||||||
|
}
|
||||||
|
if !frontier.hasAllSnapshots() {
|
||||||
|
s.attachSnapshots(frontier, cacheOffset)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if frontier.endOffset > cacheOffset {
|
||||||
|
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance the trie to cacheOffset — find or create a node there.
|
||||||
|
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
|
||||||
|
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
|
||||||
|
|
||||||
|
// Attach fresh snapshots from the live caches. Always use fresh
|
||||||
|
// snapshots even if the node already has some (e.g. from splitNode's
|
||||||
|
// Cache.Split which may be incomplete for non-splittable caches
|
||||||
|
// like RecurrentCache).
|
||||||
|
if user {
|
||||||
|
frontier.user = true
|
||||||
|
}
|
||||||
|
s.attachSnapshots(frontier, cacheOffset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// advancePath advances the active path from the current frontier by matching
|
||||||
|
// tokens against existing trie children, splitting partial matches, and
|
||||||
|
// appending any remaining tokens as new nodes. Returns the new frontier.
|
||||||
|
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
|
||||||
|
// Check if existing children already cover some or all of tokens.
|
||||||
|
// tokens may span multiple trie nodes when extending a previous run's
|
||||||
|
// leaf and this snapshot now overlaps that same range.
|
||||||
|
matchPath, matched := findBestMatch(frontier, tokens)
|
||||||
|
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
|
||||||
|
remaining := tokens[matched:]
|
||||||
|
|
||||||
|
// Check for a partial match within the last node's edge — if so, split it.
|
||||||
|
if len(matchPath) > 1 {
|
||||||
|
lastNode := matchPath[len(matchPath)-1]
|
||||||
|
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
|
||||||
|
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
||||||
|
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append traversed nodes (excluding frontier) to the active path.
|
||||||
|
c.activePath = append(c.activePath, matchPath[1:]...)
|
||||||
|
dest := matchPath[len(matchPath)-1]
|
||||||
|
|
||||||
|
if len(remaining) > 0 {
|
||||||
|
// Drop non-user snapshots so appendTokens can extend in-place
|
||||||
|
// rather than creating a new child node.
|
||||||
|
if len(dest.children) == 0 && !dest.user {
|
||||||
|
dest.setSnapshots(nil, &c.pagedOutBytes)
|
||||||
|
}
|
||||||
|
newDest := dest.appendTokens(c.root, remaining, endOffset)
|
||||||
|
if newDest != dest {
|
||||||
|
c.activePath = append(c.activePath, newDest)
|
||||||
|
}
|
||||||
|
dest = newDest
|
||||||
|
}
|
||||||
|
return dest
|
||||||
|
}
|
||||||
|
|
||||||
|
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
|
||||||
|
// The node must be on the active path (and thus protected from eviction;
|
||||||
|
// lastUsed is updated in close()). All non-nil caches must be at the same
|
||||||
|
// offset (cacheOffset); a mismatch indicates a bug in the caller.
|
||||||
|
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||||
|
c := s.cache
|
||||||
|
|
||||||
|
if c.activePath[len(c.activePath)-1] != node {
|
||||||
|
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
snaps := make([]cache.Snapshot, len(c.caches))
|
||||||
|
for i, kv := range c.caches {
|
||||||
|
if kv != nil {
|
||||||
|
if kv.Offset() != cacheOffset {
|
||||||
|
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
|
||||||
|
}
|
||||||
|
snaps[i] = kv.Snapshot(node.startOffset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||||
|
slog.Debug("created snapshot", "offset", cacheOffset)
|
||||||
|
c.enforceEvictionPolicy()
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear releases live caches and drops the trie so future requests cannot
|
||||||
|
// reuse prompt state keyed only by token IDs.
|
||||||
|
func (c *kvCache) clear() {
|
||||||
|
c.freeAll()
|
||||||
|
walkNodes(c.root, func(n *trieNode) bool {
|
||||||
|
for _, s := range n.snapshots {
|
||||||
|
if s != nil {
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n.snapshots = nil
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
c.root = nil
|
||||||
|
c.activePath = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// freeAll releases all cache layers.
|
||||||
|
func (c *kvCache) freeAll() {
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv != nil {
|
||||||
|
kv.Free()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) minCacheOffset() int {
|
||||||
|
offset := 0
|
||||||
|
found := false
|
||||||
|
for _, kv := range c.caches {
|
||||||
|
if kv == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if off := kv.Offset(); !found || off < offset {
|
||||||
|
offset = off
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// close saves the token state if the forward pass ran.
|
||||||
|
func (s *cacheSession) close() {
|
||||||
|
offset := s.cache.minCacheOffset()
|
||||||
|
if offset <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
offset := -1
|
|
||||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||||
for _, kv := range s.caches {
|
for _, kv := range s.caches {
|
||||||
if kv == nil {
|
if kv == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Mixed cache types (e.g. recurrent + KV) can transiently report different
|
arrays = append(arrays, kv.State()...)
|
||||||
// offsets, so use the minimum as the safe reusable token prefix.
|
|
||||||
if off := kv.Offset(); offset < 0 || off < offset {
|
|
||||||
offset = off
|
|
||||||
}
|
|
||||||
arrays = appendCacheState(arrays, kv)
|
|
||||||
}
|
|
||||||
if offset <= 0 {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that if we have run the forward pass and set the metadata
|
// Ensure that if we have run the forward pass and set the metadata
|
||||||
// that we also actually have the data.
|
// that we also actually have the data.
|
||||||
mlx.AsyncEval(arrays...)
|
mlx.AsyncEval(arrays...)
|
||||||
|
|
||||||
|
// Advance the trie frontier with any newly generated tokens.
|
||||||
|
c := s.cache
|
||||||
|
if len(c.activePath) > 0 {
|
||||||
|
frontier := c.activePath[len(c.activePath)-1]
|
||||||
stored := append(s.inputs, s.outputs...)
|
stored := append(s.inputs, s.outputs...)
|
||||||
if offset > len(stored) {
|
|
||||||
offset = len(stored)
|
|
||||||
}
|
|
||||||
s.cache.tokens = stored[:offset]
|
|
||||||
}
|
|
||||||
|
|
||||||
// findRemaining finds the longest common prefix between tokens and the cached
|
if offset > frontier.endOffset {
|
||||||
// sequence, trims stale cache entries, and returns the remaining tokens.
|
newTokens := stored[frontier.endOffset:offset]
|
||||||
func (c *kvCache) findRemaining(tokens []int32) []int32 {
|
c.advancePath(frontier, newTokens, offset)
|
||||||
prefix := 0
|
|
||||||
for prefix < len(tokens) && prefix < len(c.tokens) && tokens[prefix] == c.tokens[prefix] {
|
|
||||||
prefix++
|
|
||||||
}
|
}
|
||||||
|
now := time.Now()
|
||||||
// Always keep at least one token to re-evaluate so the
|
for _, node := range c.activePath {
|
||||||
// pipeline can seed token generation from it.
|
node.lastUsed = now
|
||||||
if prefix == len(tokens) && prefix > 0 {
|
|
||||||
prefix--
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix < len(c.tokens) {
|
|
||||||
if c.cachesCanTrim() {
|
|
||||||
c.trimToPrefix(prefix)
|
|
||||||
} else {
|
|
||||||
c.free()
|
|
||||||
slog.Info("Cache miss", "left", len(tokens), "matched", prefix, "reason", "non_trimmable_divergence")
|
|
||||||
return tokens
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if prefix == 0 {
|
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
|
||||||
slog.Info("Cache miss", "left", len(tokens))
|
func (c *kvCache) enforceEvictionPolicy() {
|
||||||
} else {
|
if c.pagedOutBytes <= maxPagedOutBytes {
|
||||||
slog.Info("Cache hit", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]))
|
|
||||||
}
|
|
||||||
return tokens[prefix:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *kvCache) log() {
|
|
||||||
if len(c.caches) == 0 {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
offset := -1
|
|
||||||
var totalBytes int
|
activeSet := make(map[*trieNode]bool, len(c.activePath))
|
||||||
|
for _, n := range c.activePath {
|
||||||
|
activeSet[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for c.pagedOutBytes > maxPagedOutBytes {
|
||||||
|
var best *trieNode
|
||||||
|
walkNodes(c.root, func(n *trieNode) bool {
|
||||||
|
if n == c.root || activeSet[n] || !n.hasSnapshots() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Evict: oldest, then deepest, then largest.
|
||||||
|
if best == nil || cmp.Or(
|
||||||
|
n.lastUsed.Compare(best.lastUsed),
|
||||||
|
cmp.Compare(best.endOffset, n.endOffset),
|
||||||
|
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
|
||||||
|
) < 0 {
|
||||||
|
best = n
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if best == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
c.evictNode(best)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictNode evicts a single node from the trie, freeing its snapshot memory.
|
||||||
|
func (c *kvCache) evictNode(node *trieNode) {
|
||||||
|
if len(node.children) == 0 {
|
||||||
|
// Leaf: remove entirely.
|
||||||
|
parent := node.parent
|
||||||
|
kind := "evicting leaf"
|
||||||
|
if node.user {
|
||||||
|
kind = "evicting user snapshot"
|
||||||
|
}
|
||||||
|
slog.Debug(kind, "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||||
|
removeNode(node, &c.pagedOutBytes)
|
||||||
|
|
||||||
|
// If parent is a regular (non-user-snapshot) node with one remaining child, auto-merge.
|
||||||
|
if parent != nil && !parent.user && len(parent.children) == 1 && parent != c.root {
|
||||||
|
logutil.Trace(fmt.Sprintf("auto-merging parent at offset %d with single child", parent.endOffset))
|
||||||
|
mergeWithChild(parent, c.caches, &c.pagedOutBytes)
|
||||||
|
}
|
||||||
|
} else if len(node.children) == 1 {
|
||||||
|
// Interior snapshot node with one child: merge with child.
|
||||||
|
slog.Debug("evicting snapshot node", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||||
|
mergeWithChild(node, c.caches, &c.pagedOutBytes)
|
||||||
|
} else {
|
||||||
|
// Multi-child branch point: drop snapshots but keep the node.
|
||||||
|
slog.Debug("evicting branch snapshot", "offset", node.endOffset, "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||||
|
node.setSnapshots(nil, &c.pagedOutBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kvCache) dumpTree() {
|
||||||
|
// Summary stats
|
||||||
|
var cacheBytes int
|
||||||
for _, kv := range c.caches {
|
for _, kv := range c.caches {
|
||||||
if kv == nil {
|
if kv == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if off := kv.Offset(); offset < 0 || off < offset {
|
for _, a := range kv.State() {
|
||||||
offset = off
|
if a != nil {
|
||||||
}
|
cacheBytes += a.NumBytes()
|
||||||
for _, a := range appendCacheState(nil, kv) {
|
|
||||||
totalBytes += a.NumBytes()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if offset < 0 {
|
}
|
||||||
|
|
||||||
|
// Build active path set for marking.
|
||||||
|
active := make(map[*trieNode]bool, len(c.activePath))
|
||||||
|
for _, n := range c.activePath {
|
||||||
|
active[n] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var nodeCount, snapshotCount int
|
||||||
|
var pagedBytes int64
|
||||||
|
var lines []string
|
||||||
|
var dump func(n *trieNode, prefix string, isLast bool)
|
||||||
|
dump = func(n *trieNode, prefix string, isLast bool) {
|
||||||
|
if n == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logutil.Trace(fmt.Sprintf("kv cache tokens: %d, size: %s", offset, mlx.PrettyBytes(totalBytes)))
|
nodeCount++
|
||||||
|
|
||||||
|
// Build connector
|
||||||
|
var connector string
|
||||||
|
if n.parent == nil {
|
||||||
|
connector = ""
|
||||||
|
} else if isLast {
|
||||||
|
connector = prefix + "`-- "
|
||||||
|
} else {
|
||||||
|
connector = prefix + "|-- "
|
||||||
|
}
|
||||||
|
|
||||||
|
// Node label
|
||||||
|
nodeBytes := n.snapshotBytes()
|
||||||
|
pagedBytes += nodeBytes
|
||||||
|
|
||||||
|
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
|
||||||
|
if nodeBytes > 0 {
|
||||||
|
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
||||||
|
}
|
||||||
|
var flags []string
|
||||||
|
if n.user {
|
||||||
|
flags = append(flags, "user")
|
||||||
|
}
|
||||||
|
if n.hasAllSnapshots() {
|
||||||
|
snapshotCount++
|
||||||
|
flags = append(flags, "snap")
|
||||||
|
}
|
||||||
|
if active[n] {
|
||||||
|
flags = append(flags, "active")
|
||||||
|
}
|
||||||
|
if len(flags) > 0 {
|
||||||
|
label += " (" + flags[0]
|
||||||
|
for _, f := range flags[1:] {
|
||||||
|
label += ", " + f
|
||||||
|
}
|
||||||
|
label += ")"
|
||||||
|
}
|
||||||
|
lines = append(lines, connector+label)
|
||||||
|
|
||||||
|
// Recurse children
|
||||||
|
childPrefix := prefix
|
||||||
|
if n.parent != nil {
|
||||||
|
if isLast {
|
||||||
|
childPrefix += " "
|
||||||
|
} else {
|
||||||
|
childPrefix += "| "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i, child := range n.children {
|
||||||
|
dump(child, childPrefix, i == len(n.children)-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dump(c.root, "", true)
|
||||||
|
|
||||||
|
offset := c.minCacheOffset()
|
||||||
|
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
|
||||||
|
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
|
||||||
|
for i, l := range lines {
|
||||||
|
if i == 0 {
|
||||||
|
logutil.Trace("cache trie: " + l)
|
||||||
|
} else {
|
||||||
|
logutil.Trace(" " + l)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
282
x/mlxrunner/cache/cache.go
vendored
282
x/mlxrunner/cache/cache.go
vendored
@@ -8,13 +8,34 @@ import (
|
|||||||
type Cache interface {
|
type Cache interface {
|
||||||
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
|
||||||
// State returns the cache-owned state roots that should be kept/evaluated.
|
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||||
State() (keys, values *mlx.Array)
|
State() []*mlx.Array
|
||||||
CanTrim() bool
|
|
||||||
Trim(int) int
|
|
||||||
Clone() Cache
|
|
||||||
Free()
|
Free()
|
||||||
Offset() int
|
Offset() int
|
||||||
Len() int
|
|
||||||
|
// Snapshot copies cache state from fromOffset to current offset into
|
||||||
|
// pinned VRAM arrays. The active cache is unchanged.
|
||||||
|
Snapshot(fromOffset int) Snapshot
|
||||||
|
|
||||||
|
// Restore brings the cache to target. If snapshot is nil, rewinds
|
||||||
|
// using the cache's own live state.
|
||||||
|
Restore(snapshot Snapshot, target int) bool
|
||||||
|
|
||||||
|
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
||||||
|
// Takes ownership of both inputs.
|
||||||
|
Merge(parent, child Snapshot) Snapshot
|
||||||
|
|
||||||
|
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
|
||||||
|
// Takes ownership of the input. Cache types that cannot split
|
||||||
|
// (e.g. recurrent) return (nil, snapshot).
|
||||||
|
Split(snapshot Snapshot, at int) (parent, child Snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot is paged-out cache state that can be restored later.
|
||||||
|
type Snapshot interface {
|
||||||
|
// Size returns the byte size of the paged-out data (in VRAM).
|
||||||
|
Size() int
|
||||||
|
// Close unpins the snapshot's arrays so they can be freed by Sweep.
|
||||||
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
type KVCache struct {
|
type KVCache struct {
|
||||||
@@ -59,40 +80,148 @@ func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
|
func (c *KVCache) State() []*mlx.Array {
|
||||||
if c.keys == nil || c.values == nil {
|
if c.keys == nil || c.values == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []*mlx.Array{
|
||||||
|
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||||
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
|
||||||
|
type kvSnapshot struct {
|
||||||
|
keys, values *mlx.Array
|
||||||
|
fromOffset, toOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
|
||||||
|
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
|
||||||
|
|
||||||
|
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
||||||
|
if c.keys == nil || c.offset <= fromOffset {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
from := max(0, fromOffset)
|
||||||
|
to := c.offset
|
||||||
|
|
||||||
|
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||||
|
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||||
|
kCopy := mlx.Copy(kSlice)
|
||||||
|
vCopy := mlx.Copy(vSlice)
|
||||||
|
mlx.Pin(kCopy, vCopy)
|
||||||
|
mlx.AsyncEval(kCopy, vCopy)
|
||||||
|
|
||||||
|
return &kvSnapshot{
|
||||||
|
keys: kCopy,
|
||||||
|
values: vCopy,
|
||||||
|
fromOffset: from,
|
||||||
|
toOffset: to,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if snapshot == nil {
|
||||||
|
// Rewind using live state — just clamp offset.
|
||||||
|
target = max(0, min(target, c.offset))
|
||||||
|
c.offset = target
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := snapshot.(*kvSnapshot)
|
||||||
|
|
||||||
|
// Check that the cache has data up to the snapshot's starting point.
|
||||||
|
if c.offset < snap.fromOffset {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewind to snapshot start, then feed snapshot data through Update.
|
||||||
|
c.offset = snap.fromOffset
|
||||||
|
c.Update(snap.keys, snap.values)
|
||||||
|
|
||||||
|
// Clamp to target if needed (target may be less than full snapshot).
|
||||||
|
if target < c.offset {
|
||||||
|
c.offset = target
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
|
||||||
|
if parent == nil || child == nil {
|
||||||
|
if parent != nil {
|
||||||
|
parent.Close()
|
||||||
|
}
|
||||||
|
if child != nil {
|
||||||
|
child.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p := parent.(*kvSnapshot)
|
||||||
|
ch := child.(*kvSnapshot)
|
||||||
|
|
||||||
|
mk := p.keys.Concatenate(2, ch.keys)
|
||||||
|
mv := p.values.Concatenate(2, ch.values)
|
||||||
|
mlx.Pin(mk, mv)
|
||||||
|
mlx.AsyncEval(mk, mv)
|
||||||
|
|
||||||
|
p.Close()
|
||||||
|
ch.Close()
|
||||||
|
|
||||||
|
return &kvSnapshot{
|
||||||
|
keys: mk,
|
||||||
|
values: mv,
|
||||||
|
fromOffset: p.fromOffset,
|
||||||
|
toOffset: ch.toOffset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||||
|
if snapshot == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
snap := snapshot.(*kvSnapshot)
|
||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
splitIdx := at - snap.fromOffset
|
||||||
|
seqLen := snap.toOffset - snap.fromOffset
|
||||||
|
if splitIdx <= 0 {
|
||||||
|
return nil, snapshot
|
||||||
|
}
|
||||||
|
if splitIdx >= seqLen {
|
||||||
|
return snapshot, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) CanTrim() bool { return true }
|
pk := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
|
||||||
|
pv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()))
|
||||||
|
ck := mlx.Copy(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
|
||||||
|
cv := mlx.Copy(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()))
|
||||||
|
mlx.Pin(pk, pv, ck, cv)
|
||||||
|
mlx.AsyncEval(pk, pv, ck, cv)
|
||||||
|
|
||||||
func (c *KVCache) Trim(n int) int {
|
snap.Close()
|
||||||
n = min(c.offset, n)
|
|
||||||
c.offset -= n
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *KVCache) Clone() Cache {
|
p := &kvSnapshot{
|
||||||
clone := &KVCache{
|
keys: pk,
|
||||||
keys: c.keys.Clone(),
|
values: pv,
|
||||||
values: c.values.Clone(),
|
fromOffset: snap.fromOffset,
|
||||||
offset: c.offset,
|
toOffset: at,
|
||||||
step: c.step,
|
|
||||||
}
|
}
|
||||||
mlx.Pin(clone.keys, clone.values)
|
ch := &kvSnapshot{
|
||||||
return clone
|
keys: ck,
|
||||||
|
values: cv,
|
||||||
|
fromOffset: at,
|
||||||
|
toOffset: snap.toOffset,
|
||||||
|
}
|
||||||
|
return p, ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) Free() {
|
func (c *KVCache) Free() {
|
||||||
mlx.Unpin(c.keys, c.values)
|
mlx.Unpin(c.keys, c.values)
|
||||||
c.keys, c.values = nil, nil
|
c.keys, c.values = nil, nil
|
||||||
|
c.offset = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KVCache) Offset() int { return c.offset }
|
func (c *KVCache) Offset() int { return c.offset }
|
||||||
func (c *KVCache) Len() int { return c.offset }
|
|
||||||
|
|
||||||
// RotatingKVCache implements sliding window attention with bounded memory
|
// RotatingKVCache implements sliding window attention with bounded memory
|
||||||
type RotatingKVCache struct {
|
type RotatingKVCache struct {
|
||||||
@@ -184,29 +313,104 @@ func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Arra
|
|||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
|
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||||
if c.keys == nil || c.values == nil {
|
if c.keys == nil || c.values == nil {
|
||||||
return nil, nil
|
return nil
|
||||||
|
}
|
||||||
|
return []*mlx.Array{
|
||||||
|
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||||
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||||
}
|
}
|
||||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
|
||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) CanTrim() bool { return true }
|
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
|
||||||
|
type rotatingSnapshot struct {
|
||||||
func (c *RotatingKVCache) Trim(n int) int {
|
kvSnapshot // embedded KV data
|
||||||
n = min(c.offset, n)
|
idx int // buffer write position at snapshot time
|
||||||
c.offset -= n
|
|
||||||
c.idx -= n
|
|
||||||
return n
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) Clone() Cache {
|
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
|
||||||
return &RotatingKVCache{
|
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
|
||||||
maxSize: c.maxSize,
|
|
||||||
|
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
||||||
|
if c.keys == nil || c.offset <= fromOffset {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state := c.State()
|
||||||
|
k := state[0].Clone()
|
||||||
|
v := state[1].Clone()
|
||||||
|
mlx.Pin(k, v)
|
||||||
|
|
||||||
|
return &rotatingSnapshot{
|
||||||
|
kvSnapshot: kvSnapshot{
|
||||||
|
keys: k,
|
||||||
|
values: v,
|
||||||
|
fromOffset: fromOffset,
|
||||||
|
toOffset: c.offset,
|
||||||
|
},
|
||||||
idx: c.idx,
|
idx: c.idx,
|
||||||
KVCache: c.KVCache.Clone().(*KVCache),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if snapshot == nil {
|
||||||
|
// Live rewind is only safe when the buffer hasn't filled yet
|
||||||
|
// (offset <= maxSize). Once the window has shifted, rewinding
|
||||||
|
// leaves fewer than maxSize trailing tokens to attend to —
|
||||||
|
// a snapshot is required to restore the full window.
|
||||||
|
if c.offset > c.maxSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
target = max(0, min(target, c.offset))
|
||||||
|
c.offset = target
|
||||||
|
c.idx = target
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := snapshot.(*rotatingSnapshot)
|
||||||
|
|
||||||
|
// Reject if clamping would leave an incomplete window.
|
||||||
|
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore from snapshot: rebuild buffer state.
|
||||||
|
// Free existing state first.
|
||||||
|
if c.keys != nil {
|
||||||
|
mlx.Unpin(c.keys, c.values)
|
||||||
|
}
|
||||||
|
c.keys = snap.keys.Clone()
|
||||||
|
c.values = snap.values.Clone()
|
||||||
|
mlx.Pin(c.keys, c.values)
|
||||||
|
c.offset = snap.toOffset
|
||||||
|
c.idx = snap.idx
|
||||||
|
|
||||||
|
// Clamp to target if needed.
|
||||||
|
if target < c.offset {
|
||||||
|
target = max(0, target)
|
||||||
|
c.offset = target
|
||||||
|
c.idx = target
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
|
||||||
|
// For rotating caches, the child snapshot supersedes the parent
|
||||||
|
// since it contains the full window state.
|
||||||
|
if parent != nil {
|
||||||
|
parent.Close()
|
||||||
|
}
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||||
|
// Rotating cache snapshots contain the full window state.
|
||||||
|
// Cannot cleanly split a ring buffer at an arbitrary point.
|
||||||
|
return nil, snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RotatingKVCache) Free() {
|
||||||
|
c.KVCache.Free()
|
||||||
|
c.idx = 0
|
||||||
|
}
|
||||||
|
|||||||
271
x/mlxrunner/cache/cache_test.go
vendored
Normal file
271
x/mlxrunner/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func skipIfNoMLX(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if err := mlx.CheckInit(); err != nil {
|
||||||
|
t.Skipf("MLX not available: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewKVCache()
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot [5, 10).
|
||||||
|
snap := c.Snapshot(5)
|
||||||
|
|
||||||
|
// Free the cache completely — offset is now 0.
|
||||||
|
c.Free()
|
||||||
|
|
||||||
|
// Restore should fail because cache doesn't have data up to fromOffset=5.
|
||||||
|
if c.Restore(snap, 10) {
|
||||||
|
t.Fatal("expected Restore to fail with no base data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
|
||||||
|
// is preserved through a snapshot→free→restore cycle.
|
||||||
|
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewKVCache()
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := c.Snapshot(0)
|
||||||
|
if snap == nil {
|
||||||
|
t.Fatal("Snapshot returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free and restore to a fresh cache.
|
||||||
|
c2 := NewKVCache()
|
||||||
|
if !c2.Restore(snap, 10) {
|
||||||
|
t.Fatal("Restore failed")
|
||||||
|
}
|
||||||
|
if c2.Offset() != 10 {
|
||||||
|
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify State() returns arrays with correct sequence dimension.
|
||||||
|
state := c2.State()
|
||||||
|
if len(state) != 2 {
|
||||||
|
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||||
|
}
|
||||||
|
// keys shape: [B, H, seqLen, Dk]
|
||||||
|
if state[0].Dim(2) != 10 {
|
||||||
|
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||||
|
}
|
||||||
|
if state[1].Dim(2) != 10 {
|
||||||
|
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
|
||||||
|
// that can be sequentially restored to rebuild the original cache state.
|
||||||
|
func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewKVCache()
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := c.Snapshot(0)
|
||||||
|
parent, child := c.Split(snap, 5)
|
||||||
|
if parent == nil || child == nil {
|
||||||
|
t.Fatal("Split returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore parent → offset=5, seq dim=5.
|
||||||
|
c2 := NewKVCache()
|
||||||
|
if !c2.Restore(parent, 5) {
|
||||||
|
t.Fatal("Restore(parent) failed")
|
||||||
|
}
|
||||||
|
if c2.Offset() != 5 {
|
||||||
|
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
|
||||||
|
}
|
||||||
|
state := c2.State()
|
||||||
|
if state[0].Dim(2) != 5 {
|
||||||
|
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore child on top → offset=10, seq dim=10.
|
||||||
|
if !c2.Restore(child, 10) {
|
||||||
|
t.Fatal("Restore(child) failed")
|
||||||
|
}
|
||||||
|
if c2.Offset() != 10 {
|
||||||
|
t.Fatalf("offset after child = %d, want 10", c2.Offset())
|
||||||
|
}
|
||||||
|
state = c2.State()
|
||||||
|
if state[0].Dim(2) != 10 {
|
||||||
|
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
|
||||||
|
// produces a snapshot equivalent to the original.
|
||||||
|
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewKVCache()
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := c.Snapshot(0)
|
||||||
|
parent, child := c.Split(snap, 6)
|
||||||
|
merged := c.Merge(parent, child)
|
||||||
|
if merged == nil {
|
||||||
|
t.Fatal("Merge returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
c2 := NewKVCache()
|
||||||
|
if !c2.Restore(merged, 10) {
|
||||||
|
t.Fatal("Restore(merged) failed")
|
||||||
|
}
|
||||||
|
if c2.Offset() != 10 {
|
||||||
|
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
state := c2.State()
|
||||||
|
if state[0].Dim(2) != 10 {
|
||||||
|
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||||
|
}
|
||||||
|
if state[1].Dim(2) != 10 {
|
||||||
|
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewRotatingKVCache(4)
|
||||||
|
|
||||||
|
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
|
||||||
|
for range 10 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset 3 is outside the window.
|
||||||
|
if c.Restore(nil, 3) {
|
||||||
|
t.Fatal("Restore(nil, 3) should fail when outside window")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
|
||||||
|
// from a snapshot, the rotating cache has the correct window of data.
|
||||||
|
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewRotatingKVCache(4)
|
||||||
|
|
||||||
|
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
|
||||||
|
for range 10 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := c.Snapshot(0)
|
||||||
|
if snap == nil {
|
||||||
|
t.Fatal("Snapshot returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feed 5 more tokens.
|
||||||
|
for range 5 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore to offset 10.
|
||||||
|
if !c.Restore(snap, 10) {
|
||||||
|
t.Fatal("Restore failed")
|
||||||
|
}
|
||||||
|
if c.Offset() != 10 {
|
||||||
|
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
state := c.State()
|
||||||
|
if len(state) != 2 {
|
||||||
|
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||||
|
}
|
||||||
|
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
|
||||||
|
seqDim := state[0].Dim(2)
|
||||||
|
if seqDim != 4 {
|
||||||
|
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
|
||||||
|
// snapshot correctly preserves the write position (idx), so subsequent
|
||||||
|
// single-token updates land in the right buffer slot.
|
||||||
|
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewRotatingKVCache(4)
|
||||||
|
|
||||||
|
// Fill the window: 6 tokens into a size-4 window.
|
||||||
|
// After this, idx has wrapped and the buffer has rotated.
|
||||||
|
for range 6 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
if c.Offset() != 6 {
|
||||||
|
t.Fatalf("offset = %d, want 6", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := c.Snapshot(0)
|
||||||
|
|
||||||
|
// Mutate the cache further so live state diverges from snapshot.
|
||||||
|
for range 3 {
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore to snapshot state.
|
||||||
|
if !c.Restore(snap, 6) {
|
||||||
|
t.Fatal("Restore failed")
|
||||||
|
}
|
||||||
|
if c.Offset() != 6 {
|
||||||
|
t.Fatalf("offset after restore = %d, want 6", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feed one more token. If idx was restored correctly, this should
|
||||||
|
// produce a valid window of size 4 at offset 7.
|
||||||
|
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||||
|
c.Update(k, v)
|
||||||
|
|
||||||
|
if c.Offset() != 7 {
|
||||||
|
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
|
||||||
|
}
|
||||||
|
state := c.State()
|
||||||
|
if len(state) != 2 {
|
||||||
|
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||||
|
}
|
||||||
|
seqDim := state[0].Dim(2)
|
||||||
|
if seqDim != 4 {
|
||||||
|
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
88
x/mlxrunner/cache/recurrent.go
vendored
88
x/mlxrunner/cache/recurrent.go
vendored
@@ -56,16 +56,6 @@ func (c *RecurrentCache) setStateDetached(old, v *mlx.Array, ensureContiguous bo
|
|||||||
return detached
|
return detached
|
||||||
}
|
}
|
||||||
|
|
||||||
func snapshotPinned(a *mlx.Array) *mlx.Array {
|
|
||||||
if a == nil || !a.Valid() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
snap := mlx.Copy(a)
|
|
||||||
mlx.Eval(snap)
|
|
||||||
mlx.Pin(snap)
|
|
||||||
return snap
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||||
return &RecurrentCache{
|
return &RecurrentCache{
|
||||||
convTail: int(convTail),
|
convTail: int(convTail),
|
||||||
@@ -123,30 +113,69 @@ func (c *RecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array
|
|||||||
return keys, values
|
return keys, values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RecurrentCache) State() (*mlx.Array, *mlx.Array) {
|
func (c *RecurrentCache) State() []*mlx.Array {
|
||||||
return c.convState, c.deltaState
|
return []*mlx.Array{c.convState, c.deltaState}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RecurrentCache) CanTrim() bool { return false }
|
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
|
||||||
|
// does not depend on any parent state.
|
||||||
func (c *RecurrentCache) Trim(n int) int {
|
type recurrentSnapshot struct {
|
||||||
// Recurrent state is not directly trimmable. Divergent prefixes must drop the cache.
|
convState, deltaState *mlx.Array
|
||||||
_ = n
|
offset int
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RecurrentCache) Clone() Cache {
|
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
|
||||||
clone := &RecurrentCache{
|
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
|
||||||
offset: c.offset,
|
|
||||||
convTail: c.convTail,
|
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
|
||||||
convDim: c.convDim,
|
// Recurrent state is not position-sliceable — always snapshot the full state.
|
||||||
numVHeads: c.numVHeads,
|
if c.convState == nil && c.deltaState == nil {
|
||||||
headVDim: c.headVDim,
|
return nil
|
||||||
headKDim: c.headKDim,
|
|
||||||
convState: snapshotPinned(c.convState),
|
|
||||||
deltaState: snapshotPinned(c.deltaState),
|
|
||||||
}
|
}
|
||||||
return clone
|
|
||||||
|
snap := &recurrentSnapshot{offset: c.offset}
|
||||||
|
snap.convState = c.convState.Clone()
|
||||||
|
snap.deltaState = c.deltaState.Clone()
|
||||||
|
mlx.Pin(snap.convState, snap.deltaState)
|
||||||
|
|
||||||
|
return snap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
||||||
|
if snapshot == nil {
|
||||||
|
// Recurrent state is cumulative and can't rewind. Only succeed
|
||||||
|
// if we're already at the target (no-op).
|
||||||
|
return target == c.offset
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := snapshot.(*recurrentSnapshot)
|
||||||
|
|
||||||
|
// Recurrent state encodes all tokens up to snap.offset. Restoring
|
||||||
|
// to a target before that would leave stale state from tokens
|
||||||
|
// [target, snap.offset) baked in. Only allow restoring forward.
|
||||||
|
if target < snap.offset {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.convState = c.setStateRaw(c.convState, snap.convState)
|
||||||
|
c.deltaState = c.setStateRaw(c.deltaState, snap.deltaState)
|
||||||
|
c.offset = snap.offset
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
|
||||||
|
// Recurrent snapshots are self-contained — child supersedes parent.
|
||||||
|
if parent != nil {
|
||||||
|
parent.Close()
|
||||||
|
}
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||||
|
// Recurrent state is cumulative and not position-sliceable.
|
||||||
|
// Cannot recover intermediate state at the split point.
|
||||||
|
return nil, snapshot
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RecurrentCache) Free() {
|
func (c *RecurrentCache) Free() {
|
||||||
@@ -156,4 +185,3 @@ func (c *RecurrentCache) Free() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||||
func (c *RecurrentCache) Len() int { return c.offset }
|
|
||||||
|
|||||||
44
x/mlxrunner/cache/recurrent_test.go
vendored
Normal file
44
x/mlxrunner/cache/recurrent_test.go
vendored
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only
|
||||||
|
// allows restoring forward (target >= snapshot offset), never backward.
|
||||||
|
func TestRecurrentCacheRestoreDirectionality(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||||
|
_ = c.ConvState(1, mlx.DTypeFloat16)
|
||||||
|
_ = c.DeltaState(1, mlx.DTypeFloat16)
|
||||||
|
c.Advance(10)
|
||||||
|
|
||||||
|
snap := c.Snapshot(0)
|
||||||
|
|
||||||
|
c.Advance(5) // now at 15
|
||||||
|
|
||||||
|
// Restore backward should fail.
|
||||||
|
if c.Restore(snap, 5) {
|
||||||
|
t.Fatal("Restore(snap, 5) should fail — target < snap.offset")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore to exact snap offset should succeed.
|
||||||
|
if !c.Restore(snap, 10) {
|
||||||
|
t.Fatal("Restore(snap, 10) should succeed")
|
||||||
|
}
|
||||||
|
if c.Offset() != 10 {
|
||||||
|
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore forward (target > snap offset) should succeed, offset = snap.offset.
|
||||||
|
snap2 := c.Snapshot(0)
|
||||||
|
if !c.Restore(snap2, 15) {
|
||||||
|
t.Fatal("Restore(snap, 15) should succeed")
|
||||||
|
}
|
||||||
|
// Recurrent state is at snap.offset (10), not target (15).
|
||||||
|
if c.Offset() != 10 {
|
||||||
|
t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
859
x/mlxrunner/cache_test.go
Normal file
859
x/mlxrunner/cache_test.go
Normal file
@@ -0,0 +1,859 @@
|
|||||||
|
package mlxrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// snapshotTracker records every fakeSnapshot created and every Close() call
|
||||||
|
// so tests can detect leaked (created but never closed) or double-closed snapshots.
|
||||||
|
type snapshotTracker struct {
|
||||||
|
all []*fakeSnapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *snapshotTracker) track(s *fakeSnapshot) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.tracker = tr
|
||||||
|
tr.all = append(tr.all, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fake caches that store actual token sequences so tests can verify the right
|
||||||
|
// data was restored, not just the right offset.
|
||||||
|
|
||||||
|
// fakeSnapshot stores a copy of the token sub-sequence it covers.
|
||||||
|
type fakeSnapshot struct {
|
||||||
|
tokens []int32
|
||||||
|
from, to int
|
||||||
|
byteSize int // configurable for eviction tests
|
||||||
|
|
||||||
|
tracker *snapshotTracker
|
||||||
|
closeCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeSnapshot) Size() int { return s.byteSize }
|
||||||
|
func (s *fakeSnapshot) Close() {
|
||||||
|
s.closeCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeRewindableCache tracks the full token sequence and supports
|
||||||
|
// arbitrary rewind via Restore(nil, target).
|
||||||
|
type fakeRewindableCache struct {
|
||||||
|
tokens []int32
|
||||||
|
tracker *snapshotTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) feed(tokens []int32) {
|
||||||
|
c.tokens = append(c.tokens, tokens...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||||
|
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) Free() {
|
||||||
|
c.tokens = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||||
|
if fromOffset >= len(c.tokens) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
from := fromOffset
|
||||||
|
if from < 0 {
|
||||||
|
from = 0
|
||||||
|
}
|
||||||
|
s := &fakeSnapshot{
|
||||||
|
tokens: slices.Clone(c.tokens[from:]),
|
||||||
|
from: from,
|
||||||
|
to: len(c.tokens),
|
||||||
|
}
|
||||||
|
c.tracker.track(s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
|
if snapshot == nil {
|
||||||
|
// Rewind live state.
|
||||||
|
if target < 0 {
|
||||||
|
target = 0
|
||||||
|
}
|
||||||
|
if target > len(c.tokens) {
|
||||||
|
target = len(c.tokens)
|
||||||
|
}
|
||||||
|
c.tokens = c.tokens[:target]
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s := snapshot.(*fakeSnapshot)
|
||||||
|
if len(c.tokens) < s.from {
|
||||||
|
return false // don't have base data up to snapshot start
|
||||||
|
}
|
||||||
|
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
||||||
|
if target < len(c.tokens) {
|
||||||
|
c.tokens = c.tokens[:target]
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||||
|
if parent == nil || child == nil {
|
||||||
|
if parent != nil {
|
||||||
|
parent.Close()
|
||||||
|
}
|
||||||
|
if child != nil {
|
||||||
|
child.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p := parent.(*fakeSnapshot)
|
||||||
|
ch := child.(*fakeSnapshot)
|
||||||
|
merged := make([]int32, len(p.tokens)+len(ch.tokens))
|
||||||
|
copy(merged, p.tokens)
|
||||||
|
copy(merged[len(p.tokens):], ch.tokens)
|
||||||
|
s := &fakeSnapshot{
|
||||||
|
tokens: merged,
|
||||||
|
from: p.from,
|
||||||
|
to: ch.to,
|
||||||
|
byteSize: p.byteSize + ch.byteSize,
|
||||||
|
}
|
||||||
|
c.tracker.track(s)
|
||||||
|
p.Close()
|
||||||
|
ch.Close()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||||
|
if snapshot == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
s := snapshot.(*fakeSnapshot)
|
||||||
|
relAt := at - s.from
|
||||||
|
if relAt <= 0 {
|
||||||
|
return nil, snapshot
|
||||||
|
}
|
||||||
|
if relAt >= len(s.tokens) {
|
||||||
|
return snapshot, nil
|
||||||
|
}
|
||||||
|
p := &fakeSnapshot{
|
||||||
|
tokens: slices.Clone(s.tokens[:relAt]),
|
||||||
|
from: s.from,
|
||||||
|
to: at,
|
||||||
|
byteSize: s.byteSize,
|
||||||
|
}
|
||||||
|
ch := &fakeSnapshot{
|
||||||
|
tokens: slices.Clone(s.tokens[relAt:]),
|
||||||
|
from: at,
|
||||||
|
to: s.to,
|
||||||
|
byteSize: s.byteSize,
|
||||||
|
}
|
||||||
|
c.tracker.track(p)
|
||||||
|
c.tracker.track(ch)
|
||||||
|
s.Close()
|
||||||
|
return p, ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full
|
||||||
|
// token sequence but only the trailing maxSize tokens are "live" in the window.
|
||||||
|
// Once the window fills, live rewind is impossible without a snapshot.
|
||||||
|
type fakeSlidingWindowCache struct {
|
||||||
|
tokens []int32
|
||||||
|
maxSize int
|
||||||
|
tracker *snapshotTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
||||||
|
c.tokens = append(c.tokens, tokens...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||||
|
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) Free() {
|
||||||
|
c.tokens = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||||
|
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Snapshot captures the full window state (like RotatingKVCache.Snapshot).
|
||||||
|
s := &fakeSnapshot{
|
||||||
|
tokens: slices.Clone(c.tokens),
|
||||||
|
from: 0,
|
||||||
|
to: len(c.tokens),
|
||||||
|
}
|
||||||
|
c.tracker.track(s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
|
if snapshot == nil {
|
||||||
|
if target == len(c.tokens) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
||||||
|
if len(c.tokens) > c.maxSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.tokens = c.tokens[:target]
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
s := snapshot.(*fakeSnapshot)
|
||||||
|
c.tokens = slices.Clone(s.tokens)
|
||||||
|
if target < len(c.tokens) {
|
||||||
|
c.tokens = c.tokens[:target]
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||||
|
// Child supersedes parent for sliding window (full window state).
|
||||||
|
if parent != nil {
|
||||||
|
parent.Close()
|
||||||
|
}
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||||
|
// Can't split a ring buffer at an arbitrary point.
|
||||||
|
return nil, snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeRecurrentCache models RecurrentCache semantics: stores tokens
|
||||||
|
// but cannot rewind without a snapshot.
|
||||||
|
type fakeRecurrentCache struct {
|
||||||
|
tokens []int32
|
||||||
|
tracker *snapshotTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) feed(tokens []int32) {
|
||||||
|
c.tokens = append(c.tokens, tokens...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||||
|
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) Free() {
|
||||||
|
c.tokens = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||||
|
// Recurrent state is cumulative; snapshot captures the full state.
|
||||||
|
if len(c.tokens) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := &fakeSnapshot{
|
||||||
|
tokens: slices.Clone(c.tokens),
|
||||||
|
from: 0,
|
||||||
|
to: len(c.tokens),
|
||||||
|
}
|
||||||
|
c.tracker.track(s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||||
|
if snapshot == nil {
|
||||||
|
return target == len(c.tokens) // can only no-op
|
||||||
|
}
|
||||||
|
s := snapshot.(*fakeSnapshot)
|
||||||
|
if target < s.to {
|
||||||
|
return false // can't go backward
|
||||||
|
}
|
||||||
|
c.tokens = slices.Clone(s.tokens)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||||
|
// Child supersedes parent for cumulative state.
|
||||||
|
if parent != nil {
|
||||||
|
parent.Close()
|
||||||
|
}
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||||
|
return nil, snapshot // can't split cumulative state
|
||||||
|
}
|
||||||
|
|
||||||
|
type feedableCache interface {
|
||||||
|
cache.Cache
|
||||||
|
feed(tokens []int32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
||||||
|
type testEnv struct {
|
||||||
|
kvc *kvCache
|
||||||
|
caches []cache.Cache // typed references for assertions
|
||||||
|
tracker *snapshotTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTransformerEnv creates a test environment with a single rewindable cache
|
||||||
|
// (pure transformer model).
|
||||||
|
func newTransformerEnv() *testEnv {
|
||||||
|
tracker := &snapshotTracker{}
|
||||||
|
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||||
|
return &testEnv{
|
||||||
|
kvc: &kvCache{caches: caches},
|
||||||
|
caches: caches,
|
||||||
|
tracker: tracker,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
||||||
|
// one sliding window cache (Mistral-style architecture).
|
||||||
|
func newSlidingWindowEnv() *testEnv {
|
||||||
|
tr := &snapshotTracker{}
|
||||||
|
rc := &fakeRewindableCache{tracker: tr}
|
||||||
|
sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr}
|
||||||
|
caches := []cache.Cache{rc, sw}
|
||||||
|
return &testEnv{
|
||||||
|
kvc: &kvCache{caches: caches},
|
||||||
|
caches: caches,
|
||||||
|
tracker: tr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRecurrentEnv creates a test environment with one rewindable cache and one
|
||||||
|
// non-rewindable cache (Jamba-style architecture).
|
||||||
|
func newRecurrentEnv() *testEnv {
|
||||||
|
tr := &snapshotTracker{}
|
||||||
|
rc := &fakeRewindableCache{tracker: tr}
|
||||||
|
nrc := &fakeRecurrentCache{tracker: tr}
|
||||||
|
caches := []cache.Cache{rc, nrc}
|
||||||
|
return &testEnv{
|
||||||
|
kvc: &kvCache{caches: caches},
|
||||||
|
caches: caches,
|
||||||
|
tracker: tr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertAllTokens checks that every cache in the environment contains exactly
|
||||||
|
// the expected token sequence.
|
||||||
|
func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) {
|
||||||
|
t.Helper()
|
||||||
|
for i, c := range e.caches {
|
||||||
|
assertTokens(t, label, c, expected)
|
||||||
|
// Verify all caches report the same offset.
|
||||||
|
if i > 0 && c.Offset() != e.caches[0].Offset() {
|
||||||
|
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
||||||
|
label, i, c.Offset(), e.caches[0].Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// simulateRequest mirrors the production pipeline lifecycle:
|
||||||
|
// begin -> prefill with snapshot(false) at branch points -> generate -> close
|
||||||
|
|
||||||
|
type requestResult struct {
|
||||||
|
remaining []int32
|
||||||
|
snapshotOffset int
|
||||||
|
}
|
||||||
|
|
||||||
|
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
||||||
|
// a user snapshot (snapshot(true)) is created at that offset during prefill.
|
||||||
|
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
userSnapAt := 0
|
||||||
|
if len(userSnapshotAt) > 0 {
|
||||||
|
userSnapAt = userSnapshotAt[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
session := kvc.begin(nil, inputs)
|
||||||
|
result := requestResult{
|
||||||
|
remaining: slices.Clone(session.remaining),
|
||||||
|
snapshotOffset: session.snapshotOffset,
|
||||||
|
}
|
||||||
|
|
||||||
|
assertCacheOffsetAlignment(t, kvc, "after begin")
|
||||||
|
|
||||||
|
baseOffset := kvc.minCacheOffset()
|
||||||
|
remaining := inputs[baseOffset:]
|
||||||
|
|
||||||
|
// Collect snapshot points (offset -> user flag) in ascending order.
|
||||||
|
type snapPoint struct {
|
||||||
|
offset int
|
||||||
|
user bool
|
||||||
|
}
|
||||||
|
var points []snapPoint
|
||||||
|
if session.snapshotOffset > 0 && session.snapshotOffset > baseOffset {
|
||||||
|
points = append(points, snapPoint{session.snapshotOffset, false})
|
||||||
|
}
|
||||||
|
if userSnapAt > 0 && userSnapAt > baseOffset {
|
||||||
|
points = append(points, snapPoint{userSnapAt, true})
|
||||||
|
}
|
||||||
|
slices.SortFunc(points, func(a, b snapPoint) int { return a.offset - b.offset })
|
||||||
|
|
||||||
|
// Prefill: feed tokens, pausing at each snapshot point.
|
||||||
|
for _, sp := range points {
|
||||||
|
count := sp.offset - baseOffset
|
||||||
|
if count > len(remaining) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
feedAll(kvc.caches, remaining[:count])
|
||||||
|
remaining = remaining[count:]
|
||||||
|
baseOffset = sp.offset
|
||||||
|
}
|
||||||
|
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
||||||
|
session.snapshot(sp.user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Feed rest of input tokens.
|
||||||
|
if len(remaining) > 0 {
|
||||||
|
feedAll(kvc.caches, remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertCacheOffsetAlignment(t, kvc, "after prefill")
|
||||||
|
|
||||||
|
// Generate tokens.
|
||||||
|
if len(generated) > 0 {
|
||||||
|
session.outputs = generated
|
||||||
|
feedAll(kvc.caches, generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertCacheOffsetAlignment(t, kvc, "before close")
|
||||||
|
session.close()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedAll(caches []cache.Cache, tokens []int32) {
|
||||||
|
for _, c := range caches {
|
||||||
|
if fc, ok := c.(feedableCache); ok {
|
||||||
|
fc.feed(tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertCacheOffsetAlignment verifies all caches report the same offset.
|
||||||
|
func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
||||||
|
t.Helper()
|
||||||
|
if len(kvc.caches) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected := kvc.caches[0].Offset()
|
||||||
|
for i := 1; i < len(kvc.caches); i++ {
|
||||||
|
if got := kvc.caches[i].Offset(); got != expected {
|
||||||
|
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertTokens checks that a feedable cache contains the expected token sequence.
|
||||||
|
// For sliding window caches, only the trailing maxSize tokens are checked.
|
||||||
|
func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) {
|
||||||
|
t.Helper()
|
||||||
|
switch fc := c.(type) {
|
||||||
|
case *fakeRewindableCache:
|
||||||
|
if !slices.Equal(fc.tokens, expected) {
|
||||||
|
t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
||||||
|
}
|
||||||
|
case *fakeSlidingWindowCache:
|
||||||
|
// Sliding window stores full history but only trailing maxSize are live.
|
||||||
|
// Verify the full token sequence matches (the window semantics are
|
||||||
|
// enforced by Snapshot/Restore, not by the token log).
|
||||||
|
if !slices.Equal(fc.tokens, expected) {
|
||||||
|
t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected)
|
||||||
|
}
|
||||||
|
case *fakeRecurrentCache:
|
||||||
|
if !slices.Equal(fc.tokens, expected) {
|
||||||
|
t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("%s: unknown cache type %T", label, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkTrieInvariants walks the trie and checks structural invariants.
|
||||||
|
func checkTrieInvariants(t *testing.T, root *trieNode) {
|
||||||
|
t.Helper()
|
||||||
|
walkNodes(root, func(n *trieNode) bool {
|
||||||
|
if n.parent != nil {
|
||||||
|
if n.startOffset() != n.parent.endOffset {
|
||||||
|
t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d",
|
||||||
|
n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(n.tokens) != n.endOffset-n.startOffset() {
|
||||||
|
t.Errorf("node [%d,%d): token count %d != offset span %d",
|
||||||
|
n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset())
|
||||||
|
}
|
||||||
|
for _, c := range n.children {
|
||||||
|
if c.parent != n {
|
||||||
|
t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// No two siblings should start with the same token.
|
||||||
|
seen := make(map[int32]bool)
|
||||||
|
for _, c := range n.children {
|
||||||
|
if len(c.tokens) > 0 {
|
||||||
|
first := c.tokens[0]
|
||||||
|
if seen[first] {
|
||||||
|
t.Errorf("node [%d,%d): duplicate sibling first token %d",
|
||||||
|
n.startOffset(), n.endOffset, first)
|
||||||
|
}
|
||||||
|
seen[first] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkSnapshotLeaks verifies that every tracked snapshot is either still live
|
||||||
|
// in the trie (closeCount == 0) or has been closed exactly once. It reports
|
||||||
|
// leaked snapshots (not in trie, never closed) and double-closes.
|
||||||
|
func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) {
|
||||||
|
t.Helper()
|
||||||
|
if tracker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all live snapshots still referenced by trie nodes.
|
||||||
|
live := make(map[*fakeSnapshot]bool)
|
||||||
|
walkNodes(root, func(n *trieNode) bool {
|
||||||
|
for _, s := range n.snapshots {
|
||||||
|
if s != nil {
|
||||||
|
if fs, ok := s.(*fakeSnapshot); ok {
|
||||||
|
live[fs] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, s := range tracker.all {
|
||||||
|
if live[s] {
|
||||||
|
if s.closeCount != 0 {
|
||||||
|
t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)",
|
||||||
|
i, s.from, s.to, s.closeCount)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if s.closeCount == 0 {
|
||||||
|
t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie",
|
||||||
|
i, s.from, s.to)
|
||||||
|
} else if s.closeCount > 1 {
|
||||||
|
t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times",
|
||||||
|
i, s.from, s.to, s.closeCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forEachEnv runs fn as subtests for three realistic model configurations:
|
||||||
|
// pure transformer, transformer + sliding window (Mistral-style), and
|
||||||
|
// transformer + recurrent (Jamba-style). Leak checking runs automatically
|
||||||
|
// at the end of each subtest.
|
||||||
|
func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) {
|
||||||
|
t.Helper()
|
||||||
|
run := func(t *testing.T, env *testEnv) {
|
||||||
|
t.Cleanup(func() {
|
||||||
|
checkSnapshotLeaks(t, env.tracker, env.kvc.root)
|
||||||
|
})
|
||||||
|
fn(t, env)
|
||||||
|
}
|
||||||
|
t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) })
|
||||||
|
t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) })
|
||||||
|
t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) })
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle:
|
||||||
|
// two conversations share a prefix and diverge, creating a branch point.
|
||||||
|
// A third conversation extends the first. Verifies trie structure, cache
|
||||||
|
// hit lengths, and that semantic caches contain the correct token sequences.
|
||||||
|
func TestBranchCreationAndReuse(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
|
||||||
|
// Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss.
|
||||||
|
resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21})
|
||||||
|
if len(resA.remaining) != 8 {
|
||||||
|
t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining))
|
||||||
|
}
|
||||||
|
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||||
|
|
||||||
|
// Verify trie was populated by close().
|
||||||
|
_, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||||
|
if mA != 10 {
|
||||||
|
t.Fatalf("A findable: expected 10 matched, got %d", mA)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
||||||
|
// Partial match in A's edge triggers snapshotOffset.
|
||||||
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||||
|
if resB.snapshotOffset != 5 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset)
|
||||||
|
}
|
||||||
|
// Cache was rewound to 0 (partial match truncates path to root),
|
||||||
|
// so all tokens were re-evaluated.
|
||||||
|
if len(resB.remaining) != 8 {
|
||||||
|
t.Fatalf("B: remaining = %d, want 8", len(resB.remaining))
|
||||||
|
}
|
||||||
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||||
|
|
||||||
|
// Both A and B should be findable in the trie.
|
||||||
|
_, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||||
|
if mA2 < 5 {
|
||||||
|
t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2)
|
||||||
|
}
|
||||||
|
_, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||||
|
if mB < 5 {
|
||||||
|
t.Fatalf("B findable: expected >= 5 matched, got %d", mB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix.
|
||||||
|
// Should get a cache hit for the shared prefix.
|
||||||
|
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil)
|
||||||
|
if len(resC.remaining) >= 10 {
|
||||||
|
t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining))
|
||||||
|
}
|
||||||
|
env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41})
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact
|
||||||
|
// same prompt is requested twice, the cache does not overclaim cached work.
|
||||||
|
// The last token must be re-evaluated to seed generation.
|
||||||
|
func TestExactMatchSeedBehavior(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
|
||||||
|
// Request A: first time.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||||
|
|
||||||
|
// Request B: identical prompt. Holdback means matched=4, partial in
|
||||||
|
// the 5-token edge, so path truncates to root and all tokens are
|
||||||
|
// re-evaluated. snapshotOffset should be set at the holdback point.
|
||||||
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
||||||
|
if len(resB.remaining) != 5 {
|
||||||
|
t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining))
|
||||||
|
}
|
||||||
|
if resB.snapshotOffset != 4 {
|
||||||
|
t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset)
|
||||||
|
}
|
||||||
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConversationResumption tests the most common pattern: user sends a message,
|
||||||
|
// gets a response, then sends a follow-up. The follow-up should reuse the cached
|
||||||
|
// prefix (system prompt + first turn + assistant response).
|
||||||
|
func TestConversationResumption(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
|
||||||
|
// Turn 1: system prompt + user message, assistant generates response.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12})
|
||||||
|
env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12})
|
||||||
|
|
||||||
|
// Turn 2: full history + new user message. Should get a cache hit on
|
||||||
|
// the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21].
|
||||||
|
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30})
|
||||||
|
if len(resB.remaining) > 5 {
|
||||||
|
t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining))
|
||||||
|
}
|
||||||
|
env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30})
|
||||||
|
|
||||||
|
// Turn 3: even longer history.
|
||||||
|
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil)
|
||||||
|
if len(resC.remaining) > 5 {
|
||||||
|
t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining))
|
||||||
|
}
|
||||||
|
env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41})
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEvictionPreservesActiveConversations creates multiple conversations sharing
|
||||||
|
// a system prompt, triggers eviction via large snapshot sizes, and verifies the
|
||||||
|
// active path and shared prefix survive while memory stays bounded.
|
||||||
|
func TestEvictionPreservesActiveConversations(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
systemPrompt := []int32{1, 2, 3, 4, 5}
|
||||||
|
|
||||||
|
// Create 5 conversations with unique suffixes.
|
||||||
|
for i := range 5 {
|
||||||
|
suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)}
|
||||||
|
inputs := append(slices.Clone(systemPrompt), suffix...)
|
||||||
|
simulateRequest(t, kvc, inputs, []int32{int32(200 + i)})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inflate snapshot sizes to trigger eviction.
|
||||||
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||||
|
if !n.hasSnapshots() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
snaps := make([]cache.Snapshot, len(n.snapshots))
|
||||||
|
for i, s := range n.snapshots {
|
||||||
|
if s != nil {
|
||||||
|
snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Run eviction.
|
||||||
|
kvc.enforceEvictionPolicy()
|
||||||
|
|
||||||
|
// Memory should be within limits.
|
||||||
|
if kvc.pagedOutBytes > maxPagedOutBytes {
|
||||||
|
t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Active path should be untouched.
|
||||||
|
if len(kvc.activePath) < 2 {
|
||||||
|
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
|
||||||
|
}
|
||||||
|
|
||||||
|
// System prompt prefix should still be findable (evicting a
|
||||||
|
// multi-child branch point only drops snapshots, not the node).
|
||||||
|
_, matched := findBestMatch(kvc.root, systemPrompt)
|
||||||
|
if matched < len(systemPrompt) {
|
||||||
|
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots
|
||||||
|
// (snapshot(true)) resist structural changes that would destroy them:
|
||||||
|
// - A user node forces new tokens into a child instead of extending in-place
|
||||||
|
// - The snapshot remains restorable after other branches are added
|
||||||
|
func TestUserSnapshotPreservesRestorePoint(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
|
||||||
|
// Request A: user snapshot at offset 5, then generate.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5)
|
||||||
|
|
||||||
|
assertUserNodeExists(t, kvc, "after A")
|
||||||
|
|
||||||
|
// Request B: extends A's prefix. The user node at offset 5 should
|
||||||
|
// force tokens into a child rather than extending in-place.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil)
|
||||||
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21})
|
||||||
|
assertUserNodeExists(t, kvc, "after B")
|
||||||
|
|
||||||
|
// Request C: diverge from the user node.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40})
|
||||||
|
|
||||||
|
// Request D: switch back to A's branch — user snapshot still restorable.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil)
|
||||||
|
env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50})
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted,
|
||||||
|
// a user-marked parent node is not auto-merged with its remaining single child.
|
||||||
|
func TestUserSnapshotResistsAutoMerge(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
|
||||||
|
// Request A: user snapshot at offset 3, then continue to offset 5.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3)
|
||||||
|
|
||||||
|
// Request B: diverges at the user node, creating a second child.
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20})
|
||||||
|
|
||||||
|
userNode := findUserNode(t, kvc)
|
||||||
|
if len(userNode.children) != 2 {
|
||||||
|
t.Fatalf("user node children = %d, want 2", len(userNode.children))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inflate snapshot sizes and evict. The non-active branch should be
|
||||||
|
// evicted, leaving the user node with one child.
|
||||||
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||||
|
if !n.hasSnapshots() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
snaps := make([]cache.Snapshot, len(n.snapshots))
|
||||||
|
for i, s := range n.snapshots {
|
||||||
|
if s != nil {
|
||||||
|
snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
kvc.enforceEvictionPolicy()
|
||||||
|
|
||||||
|
// The user node should still exist (not auto-merged) even with one child.
|
||||||
|
assertUserNodeExists(t, kvc, "after eviction")
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func findUserNode(t *testing.T, kvc *kvCache) *trieNode {
|
||||||
|
t.Helper()
|
||||||
|
var found *trieNode
|
||||||
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||||
|
if n.user {
|
||||||
|
found = n
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if found == nil {
|
||||||
|
t.Fatal("no user-marked node found")
|
||||||
|
}
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) {
|
||||||
|
t.Helper()
|
||||||
|
var exists bool
|
||||||
|
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||||
|
if n.user {
|
||||||
|
exists = true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("%s: no user-marked node found", label)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBranchSwitchRestoresCorrectState exercises switching back to an older
|
||||||
|
// branch after working on a different one, verifying that the restored cache
|
||||||
|
// state contains the correct token sequence for both rewindable and
|
||||||
|
// non-rewindable caches.
|
||||||
|
func TestBranchSwitchRestoresCorrectState(t *testing.T) {
|
||||||
|
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||||
|
kvc := env.kvc
|
||||||
|
|
||||||
|
// Request A: [1,2,3,4,5] + generate [10,11]
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||||
|
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11})
|
||||||
|
|
||||||
|
// Request B: [1,2,3,6,7] — diverges at token 4
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13})
|
||||||
|
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13})
|
||||||
|
|
||||||
|
// Request C: switch back to A's branch [1,2,3,4,5,10,11,20]
|
||||||
|
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil)
|
||||||
|
env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20})
|
||||||
|
|
||||||
|
checkTrieInvariants(t, kvc.root)
|
||||||
|
})
|
||||||
|
}
|
||||||
296
x/mlxrunner/cache_trie.go
Normal file
296
x/mlxrunner/cache_trie.go
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
package mlxrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// trieNode represents a node in the compressed prefix trie for KV cache branching.
|
||||||
|
// Each node stores a compressed edge (multiple tokens) and optional paged-out
|
||||||
|
// snapshot data per cache layer.
|
||||||
|
type trieNode struct {
|
||||||
|
tokens []int32 // compressed edge — multiple tokens per node
|
||||||
|
endOffset int // cumulative tokens from root to end of this node
|
||||||
|
parent *trieNode
|
||||||
|
children []*trieNode
|
||||||
|
lastUsed time.Time // for LRU eviction
|
||||||
|
snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out)
|
||||||
|
user bool // true = explicit restore point (resist auto-merge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// startOffset returns the cumulative token offset at the start of this node's edge.
|
||||||
|
func (n *trieNode) startOffset() int {
|
||||||
|
return n.endOffset - len(n.tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// snapshotBytes returns the total bytes of paged-out snapshots on this node.
|
||||||
|
func (n *trieNode) snapshotBytes() int64 {
|
||||||
|
var total int64
|
||||||
|
for _, s := range n.snapshots {
|
||||||
|
if s != nil {
|
||||||
|
total += int64(s.Size())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSnapshots replaces this node's snapshots with snaps and closes the old ones.
|
||||||
|
// If counter is non-nil, the net byte delta is applied to it.
|
||||||
|
func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) {
|
||||||
|
old := n.swapSnapshots(snaps, counter)
|
||||||
|
for _, s := range old {
|
||||||
|
if s != nil {
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// swapSnapshots is like setSnapshots but returns the previous snapshots
|
||||||
|
// without closing them. Use this when the old snapshots will be consumed
|
||||||
|
// (e.g. by Split/Merge).
|
||||||
|
func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot {
|
||||||
|
old := n.snapshots
|
||||||
|
if counter != nil {
|
||||||
|
*counter -= n.snapshotBytes()
|
||||||
|
}
|
||||||
|
n.snapshots = snaps
|
||||||
|
if counter != nil {
|
||||||
|
*counter += n.snapshotBytes()
|
||||||
|
}
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasSnapshots returns true if any layer has snapshot data.
|
||||||
|
func (n *trieNode) hasSnapshots() bool {
|
||||||
|
return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil })
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasAllSnapshots returns true if every layer has snapshot data.
|
||||||
|
func (n *trieNode) hasAllSnapshots() bool {
|
||||||
|
return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findBestMatch walks the trie matching input tokens, returning the path of
|
||||||
|
// nodes traversed and the total number of tokens matched.
|
||||||
|
func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) {
|
||||||
|
if root == nil {
|
||||||
|
return nil, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
path = []*trieNode{root}
|
||||||
|
pos := 0
|
||||||
|
|
||||||
|
node := root
|
||||||
|
for pos < len(tokens) {
|
||||||
|
// When multiple children share the same first token (e.g. after
|
||||||
|
// a split), prefer the child whose full edge matches over one
|
||||||
|
// that only partially matches. This is just being defensive - it
|
||||||
|
// shouldn't actually happen.
|
||||||
|
var best *trieNode
|
||||||
|
bestMatched := 0
|
||||||
|
bestFull := false
|
||||||
|
for _, child := range node.children {
|
||||||
|
edge := child.tokens
|
||||||
|
if len(edge) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if edge[0] != tokens[pos] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Count matching tokens in this child's edge.
|
||||||
|
j := 0
|
||||||
|
for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] {
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
full := j == len(edge)
|
||||||
|
// Prefer full edge matches; among same type, prefer longer.
|
||||||
|
if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) {
|
||||||
|
best = child
|
||||||
|
bestMatched = j
|
||||||
|
bestFull = full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if best == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
pos += bestMatched
|
||||||
|
path = append(path, best)
|
||||||
|
|
||||||
|
if !bestFull {
|
||||||
|
// Partial match within this edge
|
||||||
|
break
|
||||||
|
}
|
||||||
|
node = best
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendTokens either creates a new child node or extends the leaf in place,
|
||||||
|
// returning the node that now holds the tokens.
|
||||||
|
func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode {
|
||||||
|
if n == root || len(n.children) > 0 || n.hasSnapshots() {
|
||||||
|
child := &trieNode{
|
||||||
|
tokens: make([]int32, len(tokens)),
|
||||||
|
endOffset: endOffset,
|
||||||
|
parent: n,
|
||||||
|
lastUsed: n.lastUsed,
|
||||||
|
}
|
||||||
|
copy(child.tokens, tokens)
|
||||||
|
n.children = append(n.children, child)
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
n.tokens = append(n.tokens, tokens...)
|
||||||
|
n.endOffset = endOffset
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeNode removes a leaf node from the trie.
|
||||||
|
func removeNode(node *trieNode, counter *int64) {
|
||||||
|
if node.parent == nil {
|
||||||
|
panic("removeNode called on root")
|
||||||
|
}
|
||||||
|
if len(node.children) != 0 {
|
||||||
|
panic("removeNode called on non-leaf node")
|
||||||
|
}
|
||||||
|
p := node.parent
|
||||||
|
for i, child := range p.children {
|
||||||
|
if child == node {
|
||||||
|
p.children = append(p.children[:i], p.children[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node.parent = nil
|
||||||
|
node.setSnapshots(nil, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitNode splits a node at the given token offset within its edge,
|
||||||
|
// creating a new parent node. Returns the new parent.
|
||||||
|
// `at` is relative to the node's edge (0-based index into node.tokens).
|
||||||
|
// If caches are provided, snapshots are split between parent and child
|
||||||
|
// using Cache.Split; otherwise snapshots are invalidated.
|
||||||
|
func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode {
|
||||||
|
if at <= 0 || at >= len(node.tokens) {
|
||||||
|
panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new parent with the prefix of the edge.
|
||||||
|
newParent := &trieNode{
|
||||||
|
tokens: make([]int32, at),
|
||||||
|
endOffset: node.startOffset() + at,
|
||||||
|
parent: node.parent,
|
||||||
|
children: []*trieNode{node},
|
||||||
|
lastUsed: node.lastUsed,
|
||||||
|
}
|
||||||
|
copy(newParent.tokens, node.tokens[:at])
|
||||||
|
|
||||||
|
// Update the original node to have only the suffix.
|
||||||
|
node.tokens = node.tokens[at:]
|
||||||
|
// endOffset stays the same for the original node.
|
||||||
|
|
||||||
|
// Split snapshots between parent and child using Cache.Split.
|
||||||
|
// Split consumes the old snapshots, so we remove them first (adjusting
|
||||||
|
// the counter), then assign the split halves (adjusting it back).
|
||||||
|
if node.hasSnapshots() {
|
||||||
|
oldSnaps := node.swapSnapshots(nil, counter)
|
||||||
|
parentSnaps := make([]cache.Snapshot, len(oldSnaps))
|
||||||
|
childSnaps := make([]cache.Snapshot, len(oldSnaps))
|
||||||
|
for i, snap := range oldSnaps {
|
||||||
|
if snap != nil {
|
||||||
|
parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newParent.setSnapshots(parentSnaps, counter)
|
||||||
|
node.setSnapshots(childSnaps, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reparent: replace node with newParent in the old parent's children.
|
||||||
|
if node.parent != nil {
|
||||||
|
for i, child := range node.parent.children {
|
||||||
|
if child == node {
|
||||||
|
node.parent.children[i] = newParent
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node.parent = newParent
|
||||||
|
|
||||||
|
return newParent
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeWithChild merges a node with its single child: concatenates tokens,
|
||||||
|
// merges snapshot data via Cache.Merge, and removes the child.
|
||||||
|
func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) {
|
||||||
|
if len(node.children) != 1 {
|
||||||
|
panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children)))
|
||||||
|
}
|
||||||
|
|
||||||
|
child := node.children[0]
|
||||||
|
|
||||||
|
// Concatenate tokens.
|
||||||
|
node.tokens = append(node.tokens, child.tokens...)
|
||||||
|
node.endOffset = child.endOffset
|
||||||
|
|
||||||
|
// Merge snapshots per layer. Merge consumes the old snapshots, so we
|
||||||
|
// remove them first (adjusting the counter), then assign the merged
|
||||||
|
// result (adjusting it back).
|
||||||
|
if len(node.snapshots) > 0 || len(child.snapshots) > 0 {
|
||||||
|
nodeSnaps := node.swapSnapshots(nil, counter)
|
||||||
|
childSnaps := child.swapSnapshots(nil, counter)
|
||||||
|
merged := make([]cache.Snapshot, len(caches))
|
||||||
|
for i := range caches {
|
||||||
|
var ps, cs cache.Snapshot
|
||||||
|
if nodeSnaps != nil {
|
||||||
|
ps = nodeSnaps[i]
|
||||||
|
}
|
||||||
|
if childSnaps != nil {
|
||||||
|
cs = childSnaps[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
merged[i] = caches[i].Merge(ps, cs)
|
||||||
|
}
|
||||||
|
node.setSnapshots(merged, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adopt grandchildren.
|
||||||
|
node.children = child.children
|
||||||
|
for _, gc := range node.children {
|
||||||
|
gc.parent = node
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inherit user flag from child if child was a user-created snapshot node.
|
||||||
|
node.user = child.user
|
||||||
|
|
||||||
|
// Update lastUsed to the more recent of the two.
|
||||||
|
if child.lastUsed.After(node.lastUsed) {
|
||||||
|
node.lastUsed = child.lastUsed
|
||||||
|
}
|
||||||
|
|
||||||
|
child.parent = nil
|
||||||
|
child.children = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// walkNodes calls fn for every node in the trie (depth-first).
|
||||||
|
// If fn returns false, the walk stops.
|
||||||
|
func walkNodes(root *trieNode, fn func(*trieNode) bool) {
|
||||||
|
if root == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var walk func(*trieNode) bool
|
||||||
|
walk = func(n *trieNode) bool {
|
||||||
|
if !fn(n) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, child := range n.children {
|
||||||
|
if !walk(child) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
walk(root)
|
||||||
|
}
|
||||||
455
x/mlxrunner/cache_trie_test.go
Normal file
455
x/mlxrunner/cache_trie_test.go
Normal file
@@ -0,0 +1,455 @@
|
|||||||
|
package mlxrunner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestTrie(tokens []int32) *trieNode {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
if len(tokens) > 0 {
|
||||||
|
child := &trieNode{
|
||||||
|
tokens: slices.Clone(tokens),
|
||||||
|
endOffset: len(tokens),
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{child}
|
||||||
|
}
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindBestMatchMultipleBranches(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
|
||||||
|
branch1 := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 3},
|
||||||
|
endOffset: 3,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
branch2 := &trieNode{
|
||||||
|
tokens: []int32{4, 5, 6},
|
||||||
|
endOffset: 3,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{branch1, branch2}
|
||||||
|
|
||||||
|
// Match branch 1.
|
||||||
|
path, matched := findBestMatch(root, []int32{1, 2, 3, 7})
|
||||||
|
if matched != 3 {
|
||||||
|
t.Fatalf("expected 3 matched, got %d", matched)
|
||||||
|
}
|
||||||
|
if len(path) != 2 || path[1] != branch1 {
|
||||||
|
t.Fatal("expected to match branch1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match branch 2.
|
||||||
|
path, matched = findBestMatch(root, []int32{4, 5, 6, 8})
|
||||||
|
if matched != 3 {
|
||||||
|
t.Fatalf("expected 3 matched, got %d", matched)
|
||||||
|
}
|
||||||
|
if len(path) != 2 || path[1] != branch2 {
|
||||||
|
t.Fatal("expected to match branch2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match neither.
|
||||||
|
_, matched = findBestMatch(root, []int32{7, 8, 9})
|
||||||
|
if matched != 0 {
|
||||||
|
t.Fatalf("expected 0 matched, got %d", matched)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindBestMatchPrefersFullEdge(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
|
||||||
|
shared := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 3},
|
||||||
|
endOffset: 3,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{shared}
|
||||||
|
|
||||||
|
longer := &trieNode{
|
||||||
|
tokens: []int32{10, 11, 12, 13, 14},
|
||||||
|
endOffset: 8,
|
||||||
|
parent: shared,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
shorter := &trieNode{
|
||||||
|
tokens: []int32{10, 11, 12},
|
||||||
|
endOffset: 6,
|
||||||
|
parent: shared,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
// Put longer first so naive first-match would pick it.
|
||||||
|
shared.children = []*trieNode{longer, shorter}
|
||||||
|
|
||||||
|
input := []int32{1, 2, 3, 10, 11, 12, 99, 100}
|
||||||
|
path, matched := findBestMatch(root, input)
|
||||||
|
|
||||||
|
if matched != 6 {
|
||||||
|
t.Fatalf("expected 6 matched, got %d", matched)
|
||||||
|
}
|
||||||
|
if len(path) != 3 {
|
||||||
|
t.Fatalf("expected 3 nodes in path, got %d", len(path))
|
||||||
|
}
|
||||||
|
if path[2] != shorter {
|
||||||
|
t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindBestMatchPrefersLongerPartial(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
|
||||||
|
child1 := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 3, 4, 5},
|
||||||
|
endOffset: 5,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
child2 := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 9},
|
||||||
|
endOffset: 3,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{child2, child1}
|
||||||
|
|
||||||
|
input := []int32{1, 2, 3, 7, 8}
|
||||||
|
path, matched := findBestMatch(root, input)
|
||||||
|
|
||||||
|
if matched != 3 {
|
||||||
|
t.Fatalf("expected 3 matched, got %d", matched)
|
||||||
|
}
|
||||||
|
if path[1] != child1 {
|
||||||
|
t.Fatal("expected findBestMatch to pick child1 (longer partial match)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitNodeWithSnapshots(t *testing.T) {
|
||||||
|
root := newTestTrie([]int32{1, 2, 3, 4, 5})
|
||||||
|
child := root.children[0]
|
||||||
|
|
||||||
|
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||||
|
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
|
||||||
|
child.user = true
|
||||||
|
|
||||||
|
caches := []cache.Cache{rc}
|
||||||
|
|
||||||
|
newParent := splitNode(child, 3, caches, nil)
|
||||||
|
|
||||||
|
if !newParent.hasSnapshots() {
|
||||||
|
t.Fatal("newParent should have snapshots after split")
|
||||||
|
}
|
||||||
|
if newParent.user {
|
||||||
|
t.Fatal("newParent should not be a user snapshot after splitNode")
|
||||||
|
}
|
||||||
|
if !child.hasSnapshots() {
|
||||||
|
t.Fatal("child should have snapshots after split")
|
||||||
|
}
|
||||||
|
if !child.user {
|
||||||
|
t.Fatal("child should remain a user snapshot")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindSplitAppendSequence(t *testing.T) {
|
||||||
|
root := newTestTrie([]int32{1, 2, 3, 4, 5})
|
||||||
|
|
||||||
|
path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||||
|
if matched != 3 {
|
||||||
|
t.Fatalf("expected 3 matched, got %d", matched)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastNode := path[len(path)-1]
|
||||||
|
matchedInEdge := matched - lastNode.startOffset()
|
||||||
|
split := splitNode(lastNode, matchedInEdge, nil, nil)
|
||||||
|
|
||||||
|
split.appendTokens(root, []int32{6, 7}, 5)
|
||||||
|
|
||||||
|
if len(root.children) != 1 {
|
||||||
|
t.Fatalf("root should have 1 child, got %d", len(root.children))
|
||||||
|
}
|
||||||
|
shared := root.children[0]
|
||||||
|
if !slices.Equal(shared.tokens, []int32{1, 2, 3}) {
|
||||||
|
t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens)
|
||||||
|
}
|
||||||
|
if len(shared.children) != 2 {
|
||||||
|
t.Fatalf("shared should have 2 children, got %d", len(shared.children))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5})
|
||||||
|
if m1 != 5 {
|
||||||
|
t.Fatalf("original branch: expected 5 matched, got %d", m1)
|
||||||
|
}
|
||||||
|
_, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||||
|
if m2 != 5 {
|
||||||
|
t.Fatalf("new branch: expected 5 matched, got %d", m2)
|
||||||
|
}
|
||||||
|
_, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9})
|
||||||
|
if m3 != 3 {
|
||||||
|
t.Fatalf("unrelated input: expected 3 matched, got %d", m3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepeatedBranching(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
|
||||||
|
root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5)
|
||||||
|
|
||||||
|
_, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||||
|
if matchedB != 3 {
|
||||||
|
t.Fatalf("B: expected 3 matched, got %d", matchedB)
|
||||||
|
}
|
||||||
|
nodeA := root.children[0]
|
||||||
|
split1 := splitNode(nodeA, 3, nil, nil)
|
||||||
|
split1.appendTokens(root, []int32{6, 7}, 5)
|
||||||
|
|
||||||
|
_, matchedC := findBestMatch(root, []int32{1, 2, 8, 9})
|
||||||
|
if matchedC != 2 {
|
||||||
|
t.Fatalf("C: expected 2 matched, got %d", matchedC)
|
||||||
|
}
|
||||||
|
split2 := splitNode(split1, 2, nil, nil)
|
||||||
|
split2.appendTokens(root, []int32{8, 9}, 4)
|
||||||
|
|
||||||
|
_, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5})
|
||||||
|
if mA != 5 {
|
||||||
|
t.Fatalf("A: expected 5 matched, got %d", mA)
|
||||||
|
}
|
||||||
|
_, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||||
|
if mB != 5 {
|
||||||
|
t.Fatalf("B: expected 5 matched, got %d", mB)
|
||||||
|
}
|
||||||
|
_, mC := findBestMatch(root, []int32{1, 2, 8, 9})
|
||||||
|
if mC != 4 {
|
||||||
|
t.Fatalf("C: expected 4 matched, got %d", mC)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkTrieInvariants(t, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeWithChild(t *testing.T) {
|
||||||
|
t.Run("Basic", func(t *testing.T) {
|
||||||
|
// root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]}
|
||||||
|
now := time.Now()
|
||||||
|
root := &trieNode{lastUsed: now}
|
||||||
|
a := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 3},
|
||||||
|
endOffset: 3,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: now,
|
||||||
|
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}},
|
||||||
|
}
|
||||||
|
b := &trieNode{
|
||||||
|
tokens: []int32{4, 5},
|
||||||
|
endOffset: 5,
|
||||||
|
parent: a,
|
||||||
|
lastUsed: now,
|
||||||
|
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}},
|
||||||
|
}
|
||||||
|
c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now}
|
||||||
|
d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now}
|
||||||
|
root.children = []*trieNode{a}
|
||||||
|
a.children = []*trieNode{b}
|
||||||
|
b.children = []*trieNode{c, d}
|
||||||
|
|
||||||
|
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||||
|
mergeWithChild(a, []cache.Cache{mc}, nil)
|
||||||
|
|
||||||
|
// Tokens concatenated.
|
||||||
|
if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) {
|
||||||
|
t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens)
|
||||||
|
}
|
||||||
|
if a.endOffset != 5 {
|
||||||
|
t.Fatalf("merged endOffset = %d, want 5", a.endOffset)
|
||||||
|
}
|
||||||
|
// Grandchildren reparented.
|
||||||
|
if len(a.children) != 2 {
|
||||||
|
t.Fatalf("merged children count = %d, want 2", len(a.children))
|
||||||
|
}
|
||||||
|
if c.parent != a || d.parent != a {
|
||||||
|
t.Fatal("grandchildren should be reparented to merged node")
|
||||||
|
}
|
||||||
|
// B detached.
|
||||||
|
if b.parent != nil || b.children != nil || b.snapshots != nil {
|
||||||
|
t.Fatal("child B should be fully detached after merge")
|
||||||
|
}
|
||||||
|
// Merged snapshot should cover [0,5).
|
||||||
|
if !a.hasSnapshots() {
|
||||||
|
t.Fatal("merged node should have snapshots")
|
||||||
|
}
|
||||||
|
ms := a.snapshots[0].(*fakeSnapshot)
|
||||||
|
if ms.from != 0 || ms.to != 5 {
|
||||||
|
t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkTrieInvariants(t, root)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UserFlag", func(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
parent := &trieNode{
|
||||||
|
tokens: []int32{1, 2}, endOffset: 2, parent: root,
|
||||||
|
lastUsed: time.Now(), user: false,
|
||||||
|
}
|
||||||
|
child := &trieNode{
|
||||||
|
tokens: []int32{3, 4}, endOffset: 4, parent: parent,
|
||||||
|
lastUsed: time.Now(), user: true,
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{parent}
|
||||||
|
parent.children = []*trieNode{child}
|
||||||
|
|
||||||
|
mergeWithChild(parent, nil, nil)
|
||||||
|
|
||||||
|
if !parent.user {
|
||||||
|
t.Fatal("merged node should inherit user=true from child")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LastUsed", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
root := &trieNode{lastUsed: now}
|
||||||
|
parent := &trieNode{
|
||||||
|
tokens: []int32{1}, endOffset: 1, parent: root,
|
||||||
|
lastUsed: now.Add(-1 * time.Hour),
|
||||||
|
}
|
||||||
|
child := &trieNode{
|
||||||
|
tokens: []int32{2}, endOffset: 2, parent: parent,
|
||||||
|
lastUsed: now.Add(1 * time.Hour),
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{parent}
|
||||||
|
parent.children = []*trieNode{child}
|
||||||
|
|
||||||
|
mergeWithChild(parent, nil, nil)
|
||||||
|
|
||||||
|
if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) {
|
||||||
|
t.Fatal("merged node should pick the more recent lastUsed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PanicOnMultipleChildren", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal("expected panic on node with 2 children")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
node := &trieNode{
|
||||||
|
tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(),
|
||||||
|
children: []*trieNode{
|
||||||
|
{tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()},
|
||||||
|
{tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{node}
|
||||||
|
mergeWithChild(node, nil, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitMergeRoundTrip(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
leaf := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 3, 4, 5},
|
||||||
|
endOffset: 5,
|
||||||
|
parent: root,
|
||||||
|
lastUsed: time.Now(),
|
||||||
|
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}},
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{leaf}
|
||||||
|
|
||||||
|
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||||
|
caches := []cache.Cache{mc}
|
||||||
|
|
||||||
|
// Split at 3: [1,2,3] -> [4,5]
|
||||||
|
newParent := splitNode(leaf, 3, caches, nil)
|
||||||
|
if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) {
|
||||||
|
t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens)
|
||||||
|
}
|
||||||
|
if !slices.Equal(leaf.tokens, []int32{4, 5}) {
|
||||||
|
t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens)
|
||||||
|
}
|
||||||
|
checkTrieInvariants(t, root)
|
||||||
|
|
||||||
|
// Merge back: should restore [1,2,3,4,5]
|
||||||
|
mergeWithChild(newParent, caches, nil)
|
||||||
|
if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) {
|
||||||
|
t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens)
|
||||||
|
}
|
||||||
|
if newParent.endOffset != 5 {
|
||||||
|
t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset)
|
||||||
|
}
|
||||||
|
if len(newParent.children) != 0 {
|
||||||
|
t.Fatalf("after merge: children count = %d, want 0", len(newParent.children))
|
||||||
|
}
|
||||||
|
// Merged snapshot should cover [0,5).
|
||||||
|
if !newParent.hasSnapshots() {
|
||||||
|
t.Fatal("after merge: should have snapshots")
|
||||||
|
}
|
||||||
|
ms := newParent.snapshots[0].(*fakeSnapshot)
|
||||||
|
if ms.from != 0 || ms.to != 5 {
|
||||||
|
t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkTrieInvariants(t, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveNode(t *testing.T) {
|
||||||
|
t.Run("Leaf", func(t *testing.T) {
|
||||||
|
root := &trieNode{lastUsed: time.Now()}
|
||||||
|
shared := &trieNode{
|
||||||
|
tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(),
|
||||||
|
}
|
||||||
|
leafA := &trieNode{
|
||||||
|
tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(),
|
||||||
|
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
|
||||||
|
}
|
||||||
|
leafB := &trieNode{
|
||||||
|
tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(),
|
||||||
|
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
|
||||||
|
}
|
||||||
|
root.children = []*trieNode{shared}
|
||||||
|
shared.children = []*trieNode{leafA, leafB}
|
||||||
|
|
||||||
|
removeNode(leafA, nil)
|
||||||
|
|
||||||
|
if len(shared.children) != 1 {
|
||||||
|
t.Fatalf("parent should have 1 child, got %d", len(shared.children))
|
||||||
|
}
|
||||||
|
if shared.children[0] != leafB {
|
||||||
|
t.Fatal("remaining child should be leafB")
|
||||||
|
}
|
||||||
|
if leafA.parent != nil {
|
||||||
|
t.Fatal("removed node parent should be nil")
|
||||||
|
}
|
||||||
|
if leafA.snapshots != nil {
|
||||||
|
t.Fatal("removed node snapshots should be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
checkTrieInvariants(t, root)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PanicOnRoot", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal("expected panic when removing root")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
removeNode(&trieNode{}, nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("PanicOnNonLeaf", func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal("expected panic when removing non-leaf")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
parent := &trieNode{parent: &trieNode{}}
|
||||||
|
parent.children = []*trieNode{{}}
|
||||||
|
removeNode(parent, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -106,6 +106,7 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
|||||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||||
type completionRequest struct {
|
type completionRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Images []llm.ImageData `json:"images,omitempty"`
|
||||||
Options *completionOpts `json:"options,omitempty"`
|
Options *completionOpts `json:"options,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,6 +156,7 @@ func (c *Client) Close() error {
|
|||||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
creq := completionRequest{
|
creq := completionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
|
Images: req.Images,
|
||||||
}
|
}
|
||||||
if req.Options != nil {
|
if req.Options != nil {
|
||||||
creq.Options = &completionOpts{
|
creq.Options = &completionOpts{
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ func Version() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func doEval(outputs []*Array, async bool) {
|
func doEval(outputs []*Array, async bool) {
|
||||||
|
if len(outputs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
vector := C.mlx_vector_array_new()
|
vector := C.mlx_vector_array_new()
|
||||||
defer C.mlx_vector_array_free(vector)
|
defer C.mlx_vector_array_free(vector)
|
||||||
|
|
||||||
|
|||||||
@@ -304,6 +304,18 @@ func Exp(a *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Sin(a *Array) *Array {
|
||||||
|
out := New("SIN")
|
||||||
|
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Cos(a *Array) *Array {
|
||||||
|
out := New("COS")
|
||||||
|
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func Log(a *Array) *Array {
|
func Log(a *Array) *Array {
|
||||||
out := New("LOG")
|
out := New("LOG")
|
||||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
|||||||
@@ -4,10 +4,14 @@ package mlx
|
|||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"math"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// End is a sentinel value meaning "to the end of the dimension",
|
||||||
|
// equivalent to an omitted stop in Python (e.g. a[i:]).
|
||||||
|
const End = math.MaxInt32
|
||||||
|
|
||||||
type slice struct {
|
type slice struct {
|
||||||
args []int
|
args []int
|
||||||
}
|
}
|
||||||
@@ -16,6 +20,16 @@ func Slice(args ...int) slice {
|
|||||||
return slice{args: args}
|
return slice{args: args}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolve(val, dim int) C.int {
|
||||||
|
if val == End {
|
||||||
|
return C.int(dim)
|
||||||
|
}
|
||||||
|
if val < 0 {
|
||||||
|
return C.int(dim + val)
|
||||||
|
}
|
||||||
|
return C.int(val)
|
||||||
|
}
|
||||||
|
|
||||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||||
if len(slices) != len(dims) {
|
if len(slices) != len(dims) {
|
||||||
panic("number of slice arguments must match number of tensor dimensions")
|
panic("number of slice arguments must match number of tensor dimensions")
|
||||||
@@ -28,26 +42,28 @@ func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, s := range slices {
|
for i, s := range slices {
|
||||||
|
dim := dims[i]
|
||||||
switch len(s.args) {
|
switch len(s.args) {
|
||||||
case 0:
|
case 0:
|
||||||
// slice[:]
|
// slice[:]
|
||||||
args[0][i] = C.int(0)
|
args[0][i] = C.int(0)
|
||||||
args[1][i] = C.int(dims[i])
|
args[1][i] = C.int(dim)
|
||||||
args[2][i] = C.int(1)
|
args[2][i] = C.int(1)
|
||||||
case 1:
|
case 1:
|
||||||
// slice[i]
|
// slice[i]
|
||||||
args[0][i] = C.int(s.args[0])
|
start := resolve(s.args[0], dim)
|
||||||
args[1][i] = C.int(s.args[0] + 1)
|
args[0][i] = start
|
||||||
|
args[1][i] = start + 1
|
||||||
args[2][i] = C.int(1)
|
args[2][i] = C.int(1)
|
||||||
case 2:
|
case 2:
|
||||||
// slice[i:j]
|
// slice[i:j]
|
||||||
args[0][i] = C.int(s.args[0])
|
args[0][i] = resolve(s.args[0], dim)
|
||||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
args[1][i] = resolve(s.args[1], dim)
|
||||||
args[2][i] = C.int(1)
|
args[2][i] = C.int(1)
|
||||||
case 3:
|
case 3:
|
||||||
// slice[i:j:k]
|
// slice[i:j:k]
|
||||||
args[0][i] = C.int(s.args[0])
|
args[0][i] = resolve(s.args[0], dim)
|
||||||
args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
|
args[1][i] = resolve(s.args[1], dim)
|
||||||
args[2][i] = C.int(s.args[2])
|
args[2][i] = C.int(s.args[2])
|
||||||
default:
|
default:
|
||||||
panic("invalid slice arguments")
|
panic("invalid slice arguments")
|
||||||
|
|||||||
32
x/mlxrunner/model/base/multimodal.go
Normal file
32
x/mlxrunner/model/base/multimodal.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package base
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImageInput is a single image attached to a prompt.
|
||||||
|
type ImageInput struct {
|
||||||
|
ID int
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptTokenization contains tokenized prompt IDs plus optional request-scoped
|
||||||
|
// model metadata needed during forward.
|
||||||
|
type PromptTokenization struct {
|
||||||
|
Tokens []int32
|
||||||
|
State any
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultimodalPromptTokenizerWithState is an optional model interface used by
|
||||||
|
// mlxrunner to expand tagged multimodal prompts into token IDs, returning
|
||||||
|
// request-scoped state to be attached to the forward pass.
|
||||||
|
type MultimodalPromptTokenizerWithState interface {
|
||||||
|
TokenizePromptWithImagesState(prompt string, images []ImageInput) (*PromptTokenization, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardWithStateModel is an optional model interface for request-scoped
|
||||||
|
// forward metadata that should not be stored in shared caches.
|
||||||
|
type ForwardWithStateModel interface {
|
||||||
|
ForwardWithState(inputs *mlx.Array, cache []cache.Cache, state any) *mlx.Array
|
||||||
|
}
|
||||||
@@ -12,12 +12,42 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
)
|
)
|
||||||
|
|
||||||
func prefillChunkSize() int {
|
func prefillChunkSize() int {
|
||||||
return 2 << 10
|
return 2 << 10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Runner) tokenizeRequest(request Request) ([]int32, any, error) {
|
||||||
|
if len(request.Images) > 0 {
|
||||||
|
// The shared trie cache keys only on token IDs today, so multimodal
|
||||||
|
// prompts must not reuse snapshots across distinct image inputs.
|
||||||
|
r.cache.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
if multimodalTokenizer, ok := r.Model.(base.MultimodalPromptTokenizerWithState); ok && len(request.Images) > 0 {
|
||||||
|
images := make([]base.ImageInput, len(request.Images))
|
||||||
|
for i := range request.Images {
|
||||||
|
images[i] = base.ImageInput{
|
||||||
|
ID: request.Images[i].ID,
|
||||||
|
Data: request.Images[i].Data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := multimodalTokenizer.TokenizePromptWithImagesState(request.Prompt, images)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if out == nil {
|
||||||
|
return nil, nil, errors.New("empty multimodal tokenization result")
|
||||||
|
}
|
||||||
|
return out.Tokens, out.State, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Tokenizer.Encode(request.Prompt, true), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Runner) TextGenerationPipeline(request Request) error {
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||||
if r.Model == nil {
|
if r.Model == nil {
|
||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
@@ -50,12 +80,15 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||||
mlx.LogArrays()
|
mlx.LogArrays()
|
||||||
r.cache.log()
|
r.cache.dumpTree()
|
||||||
}
|
}
|
||||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs, promptState, err := r.tokenizeRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if len(inputs) == 0 {
|
if len(inputs) == 0 {
|
||||||
return errors.New("empty prompt")
|
return errors.New("empty prompt")
|
||||||
}
|
}
|
||||||
@@ -83,10 +116,17 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
tokens := session.remaining
|
tokens := session.remaining
|
||||||
prefillChunk := prefillChunkSize()
|
prefillChunk := prefillChunkSize()
|
||||||
|
|
||||||
|
modelForward := func(tokens *mlx.Array) *mlx.Array {
|
||||||
|
if withState, ok := r.Model.(base.ForwardWithStateModel); ok {
|
||||||
|
return withState.ForwardWithState(tokens, caches, promptState)
|
||||||
|
}
|
||||||
|
return r.Model.Forward(tokens, caches)
|
||||||
|
}
|
||||||
|
|
||||||
materializeCaches := func() {
|
materializeCaches := func() {
|
||||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||||
for _, c := range caches {
|
for _, c := range caches {
|
||||||
state = appendCacheState(state, c)
|
state = append(state, c.State()...)
|
||||||
}
|
}
|
||||||
if len(state) == 0 {
|
if len(state) == 0 {
|
||||||
return
|
return
|
||||||
@@ -102,16 +142,37 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
n := min(prefillChunk, total-processed-1)
|
n := min(prefillChunk, total-processed-1)
|
||||||
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
|
||||||
|
// If there's a pending intermediate snapshot, split the batch
|
||||||
|
// so we can capture it at the exact offset. The cache offset
|
||||||
|
// after this batch will be: baseOffset + processed + n.
|
||||||
|
if session.snapshotOffset > 0 {
|
||||||
|
baseOffset := len(session.inputs) - len(tokens)
|
||||||
|
tokensUntilSnapshot := session.snapshotOffset - (baseOffset + processed)
|
||||||
|
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
||||||
|
n = tokensUntilSnapshot
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelForward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0))
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
materializeCaches()
|
materializeCaches()
|
||||||
processed += n
|
processed += n
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
|
|
||||||
|
// Create snapshot at branch point for future diverging requests.
|
||||||
|
if session.snapshotOffset > 0 {
|
||||||
|
baseOffset := len(session.inputs) - len(tokens)
|
||||||
|
if baseOffset+processed >= session.snapshotOffset {
|
||||||
|
session.snapshot(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
fwd := modelForward(token.ExpandDims(0))
|
||||||
logits := r.Model.Unembed(fwd)
|
logits := r.Model.Unembed(fwd)
|
||||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
@@ -30,6 +31,7 @@ type Request struct {
|
|||||||
|
|
||||||
type TextCompletionsRequest struct {
|
type TextCompletionsRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Images []llm.ImageData `json:"images,omitempty"`
|
||||||
Options struct {
|
Options struct {
|
||||||
Temperature float32 `json:"temperature"`
|
Temperature float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
|||||||
return logprobs
|
return logprobs
|
||||||
}
|
}
|
||||||
|
|
||||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, 0))
|
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
354
x/models/qwen3_5/multimodal.go
Normal file
354
x/models/qwen3_5/multimodal.go
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
package qwen3_5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
)
|
||||||
|
|
||||||
|
var imageTagRE = regexp.MustCompile(`\[img-(\d+)\]`)
|
||||||
|
|
||||||
|
type promptVisionSpan struct {
|
||||||
|
Start int32
|
||||||
|
End int32
|
||||||
|
|
||||||
|
Main *mlx.Array
|
||||||
|
Grid *VisionGrid
|
||||||
|
}
|
||||||
|
|
||||||
|
type promptVisionState struct {
|
||||||
|
Spans []promptVisionSpan
|
||||||
|
PositionCache []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptStartPosFromCaches(caches []cache.Cache) int32 {
|
||||||
|
offset := -1
|
||||||
|
for _, c := range caches {
|
||||||
|
if c == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
off := c.Offset()
|
||||||
|
if offset < 0 || off < offset {
|
||||||
|
offset = off
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if offset < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int32(offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptVisionStateFromState(state any) *promptVisionState {
|
||||||
|
typed, _ := state.(*promptVisionState)
|
||||||
|
return typed
|
||||||
|
}
|
||||||
|
|
||||||
|
func overlapRange(chunkStart, chunkLen, spanStart, spanEnd int32) (int32, int32, int32, int32, bool) {
|
||||||
|
chunkEnd := chunkStart + chunkLen
|
||||||
|
overlapStart := max(chunkStart, spanStart)
|
||||||
|
overlapEnd := min(chunkEnd, spanEnd)
|
||||||
|
if overlapStart >= overlapEnd {
|
||||||
|
return 0, 0, 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkLo := overlapStart - chunkStart
|
||||||
|
chunkHi := overlapEnd - chunkStart
|
||||||
|
spanLo := overlapStart - spanStart
|
||||||
|
spanHi := overlapEnd - spanStart
|
||||||
|
return chunkLo, chunkHi, spanLo, spanHi, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) applyPromptVisionEmbeddings(h *mlx.Array, startPos int32, state *promptVisionState) *mlx.Array {
|
||||||
|
if m == nil || h == nil || state == nil || len(state.Spans) == 0 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
dims := h.Dims()
|
||||||
|
if len(dims) != 3 {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
L := int32(dims[1])
|
||||||
|
for _, span := range state.Spans {
|
||||||
|
chunkLo, chunkHi, spanLo, spanHi, ok := overlapRange(startPos, L, span.Start, span.End)
|
||||||
|
if !ok || span.Main == nil || !span.Main.Valid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
repl := span.Main.Slice(
|
||||||
|
mlx.Slice(),
|
||||||
|
mlx.Slice(int(spanLo), int(spanHi)),
|
||||||
|
mlx.Slice(),
|
||||||
|
)
|
||||||
|
repl = repl.AsType(h.DType())
|
||||||
|
h = h.SliceUpdate(
|
||||||
|
repl,
|
||||||
|
mlx.Slice(),
|
||||||
|
mlx.Slice(int(chunkLo), int(chunkHi)),
|
||||||
|
mlx.Slice(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func findImageByID(images []base.ImageInput, id int) (base.ImageInput, bool) {
|
||||||
|
for i := range images {
|
||||||
|
if images[i].ID == id {
|
||||||
|
return images[i], true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return base.ImageInput{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPromptPosition(state *promptVisionState, id int32) int32 {
|
||||||
|
if state == nil {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if id < int32(len(state.PositionCache)) {
|
||||||
|
return state.PositionCache[id]
|
||||||
|
}
|
||||||
|
if len(state.PositionCache) > 0 {
|
||||||
|
return id - int32(len(state.PositionCache)) + state.PositionCache[len(state.PositionCache)-1] + 1
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptVisionGridSpan(grid *VisionGrid, merge int32, fallback int32) int32 {
|
||||||
|
if fallback <= 0 {
|
||||||
|
fallback = 1
|
||||||
|
}
|
||||||
|
if grid == nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
if merge <= 0 {
|
||||||
|
merge = 1
|
||||||
|
}
|
||||||
|
return max(max(int32(1), grid.Width/merge), max(int32(1), grid.Height/merge))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMRoPESections(sections []int32) [4]int32 {
|
||||||
|
var out [4]int32
|
||||||
|
for i := range min(4, len(sections)) {
|
||||||
|
if sections[i] > 0 {
|
||||||
|
out[i] = sections[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mropePairComponent(pair int32, sections [4]int32, interleaved bool) int {
|
||||||
|
if interleaved {
|
||||||
|
if pair%3 == 1 && pair < 1+3*sections[1] {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if pair%3 == 2 && pair < 2+3*sections[2] {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
if pair%3 == 0 && pair < 3*sections[0] {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
|
||||||
|
secW := sections[0] + sections[1]
|
||||||
|
secE := secW + sections[2]
|
||||||
|
switch {
|
||||||
|
case pair < sections[0]:
|
||||||
|
return 0
|
||||||
|
case pair < secW:
|
||||||
|
return 1
|
||||||
|
case pair < secE:
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildPromptMRoPEPositions(state *promptVisionState, startPos, chunkLen int32) [4][]int32 {
|
||||||
|
var positions [4][]int32
|
||||||
|
for i := range positions {
|
||||||
|
positions[i] = make([]int32, chunkLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// positions[3] stays zero — it covers RoPE dims beyond the 3 MRoPE sections.
|
||||||
|
for i := range chunkLen {
|
||||||
|
p := mapPromptPosition(state, startPos+i)
|
||||||
|
positions[0][i] = p
|
||||||
|
positions[1][i] = p
|
||||||
|
positions[2][i] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
merge := int32(1)
|
||||||
|
if m != nil && m.Config != nil && m.Config.Vision != nil {
|
||||||
|
merge = m.Config.Vision.SpatialMergeSize
|
||||||
|
}
|
||||||
|
for _, span := range state.Spans {
|
||||||
|
if span.Grid == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkLo, chunkHi, spanLo, _, ok := overlapRange(startPos, chunkLen, span.Start, span.End)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
w := max(int32(1), span.Grid.Width/merge)
|
||||||
|
for i := chunkLo; i < chunkHi; i++ {
|
||||||
|
rel := spanLo + (i - chunkLo)
|
||||||
|
positions[1][i] += rel / w
|
||||||
|
positions[2][i] += rel % w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return positions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) buildPromptMRoPECosSin(state *promptVisionState, startPos, chunkLen int32, dtype mlx.DType) (*mlx.Array, *mlx.Array) {
|
||||||
|
if m == nil || m.Config == nil || state == nil || chunkLen <= 0 || len(m.Config.MRoPESections) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ropeDim := m.Config.RopeDim
|
||||||
|
if ropeDim%2 != 0 {
|
||||||
|
ropeDim--
|
||||||
|
}
|
||||||
|
if ropeDim <= 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
half := ropeDim / 2
|
||||||
|
positions := m.buildPromptMRoPEPositions(state, startPos, chunkLen)
|
||||||
|
sections := normalizeMRoPESections(m.Config.MRoPESections)
|
||||||
|
theta := m.Config.RopeTheta
|
||||||
|
if theta <= 0 {
|
||||||
|
theta = 100000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
freqs := make([]float64, half)
|
||||||
|
for j := range half {
|
||||||
|
freqs[j] = math.Pow(float64(theta), -2.0*float64(j)/float64(ropeDim))
|
||||||
|
}
|
||||||
|
|
||||||
|
angles := make([]float32, chunkLen*ropeDim)
|
||||||
|
for i := range chunkLen {
|
||||||
|
base := i * ropeDim
|
||||||
|
for j := range half {
|
||||||
|
component := mropePairComponent(j, sections, m.Config.MRoPEInterleaved)
|
||||||
|
angle := float32(float64(positions[component][i]) * freqs[j])
|
||||||
|
angles[base+j] = angle
|
||||||
|
angles[base+half+j] = angle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := mlx.FromValues(angles, 1, 1, int(chunkLen), int(ropeDim))
|
||||||
|
cos := mlx.Cos(arr)
|
||||||
|
sin := mlx.Sin(arr)
|
||||||
|
if dtype != 0 {
|
||||||
|
cos = cos.AsType(dtype)
|
||||||
|
sin = sin.AsType(dtype)
|
||||||
|
}
|
||||||
|
return cos, sin
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) tokenizePromptWithResolvedImages(
|
||||||
|
prompt string,
|
||||||
|
images []base.ImageInput,
|
||||||
|
resolve func([]byte) (*VisionEmbeddings, error),
|
||||||
|
) ([]int32, *promptVisionState, error) {
|
||||||
|
if m == nil || m.tok == nil {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: tokenizer not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Vision == nil || m.ImageProcessor == nil || resolve == nil {
|
||||||
|
return m.tok.Encode(prompt, true), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := imageTagRE.Split(prompt, -1)
|
||||||
|
matches := imageTagRE.FindAllStringSubmatch(prompt, -1)
|
||||||
|
|
||||||
|
resolved := make(map[int]*VisionEmbeddings, len(images))
|
||||||
|
var out []int32
|
||||||
|
state := &promptVisionState{}
|
||||||
|
var p int32
|
||||||
|
appendToken := func(tok, pos int32) {
|
||||||
|
out = append(out, tok)
|
||||||
|
state.PositionCache = append(state.PositionCache, pos)
|
||||||
|
}
|
||||||
|
for i, part := range parts {
|
||||||
|
for _, tok := range m.tok.Encode(part, i == 0) {
|
||||||
|
appendToken(tok, p)
|
||||||
|
p++
|
||||||
|
}
|
||||||
|
|
||||||
|
if i >= len(matches) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
imageID, err := strconv.Atoi(matches[i][1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: invalid image tag %q: %w", matches[i][0], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, ok := findImageByID(images, imageID)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("invalid image index: %d", imageID)
|
||||||
|
}
|
||||||
|
|
||||||
|
embeds := resolved[imageID]
|
||||||
|
if embeds == nil {
|
||||||
|
embeds, err = resolve(img.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
resolved[imageID] = embeds
|
||||||
|
}
|
||||||
|
if embeds == nil || embeds.Main == nil || !embeds.Main.Valid() || embeds.Main.NumDims() < 2 {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: invalid vision embeddings")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensPerImage := int32(embeds.Main.Dim(1))
|
||||||
|
if tokensPerImage <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("qwen3_5: invalid image token count: %d", tokensPerImage)
|
||||||
|
}
|
||||||
|
|
||||||
|
appendToken(m.VisionStartToken, p)
|
||||||
|
p++
|
||||||
|
basePos := p
|
||||||
|
spanStart := int32(len(out))
|
||||||
|
for range tokensPerImage {
|
||||||
|
appendToken(m.ImageTokenID, basePos)
|
||||||
|
}
|
||||||
|
spanEnd := int32(len(out))
|
||||||
|
merge := int32(1)
|
||||||
|
if m.Config != nil && m.Config.Vision != nil {
|
||||||
|
merge = m.Config.Vision.SpatialMergeSize
|
||||||
|
}
|
||||||
|
gridSpan := promptVisionGridSpan(embeds.Grid, merge, tokensPerImage)
|
||||||
|
p += gridSpan
|
||||||
|
appendToken(m.VisionEndToken, p)
|
||||||
|
p++
|
||||||
|
|
||||||
|
state.Spans = append(state.Spans, promptVisionSpan{
|
||||||
|
Start: spanStart,
|
||||||
|
End: spanEnd,
|
||||||
|
Main: embeds.Main,
|
||||||
|
Grid: embeds.Grid,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) TokenizePromptWithImagesState(prompt string, images []base.ImageInput) (*base.PromptTokenization, error) {
|
||||||
|
tokens, state, err := m.tokenizePromptWithResolvedImages(prompt, images, m.EncodeVisionImage)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &base.PromptTokenization{Tokens: tokens, State: state}, nil
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
package qwen3_5
|
package qwen3_5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
@@ -22,16 +23,26 @@ func init() {
|
|||||||
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
base.Register("Qwen3NextForConditionalGeneration", NewModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ base.MultimodalPromptTokenizerWithState = (*Model)(nil)
|
||||||
|
_ base.ForwardWithStateModel = (*Model)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
// RopeParameters carries optional rope metadata embedded under rope_parameters.
|
||||||
type RopeParameters struct {
|
type RopeParameters struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
RopeType string `json:"rope_type"`
|
RopeType string `json:"rope_type"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
|
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||||
|
MRoPESection []int32 `json:"mrope_section"`
|
||||||
|
DimensionSections []int32 `json:"dimension_sections"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config holds Qwen 3.5 text config (top-level or nested text_config).
|
// TextConfig holds the Qwen 3.5 text-model architecture fields.
|
||||||
type Config struct {
|
// In VLM configs these live under the "text_config" key; in text-only
|
||||||
|
// configs they appear at the top level.
|
||||||
|
type TextConfig struct {
|
||||||
ModelType string `json:"model_type"`
|
ModelType string `json:"model_type"`
|
||||||
HiddenSize int32 `json:"hidden_size"`
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
IntermediateSize int32 `json:"intermediate_size"`
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
@@ -67,6 +78,19 @@ type Config struct {
|
|||||||
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
|
||||||
RopeScaling map[string]any `json:"rope_scaling"`
|
RopeScaling map[string]any `json:"rope_scaling"`
|
||||||
RopeParameters *RopeParameters `json:"rope_parameters"`
|
RopeParameters *RopeParameters `json:"rope_parameters"`
|
||||||
|
MRoPESections []int32 `json:"mrope_sections"`
|
||||||
|
MRoPEInterleaved bool `json:"mrope_interleaved"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the full model config. It embeds TextConfig for the text-model
|
||||||
|
// fields and adds top-level-only fields (vision, token IDs, quantization).
|
||||||
|
type Config struct {
|
||||||
|
TextConfig
|
||||||
|
|
||||||
|
Vision *VisionConfig `json:"vision_config"`
|
||||||
|
ImageTokenID int32 `json:"image_token_id"`
|
||||||
|
VisionStartToken int32 `json:"vision_start_token_id"`
|
||||||
|
VisionEndToken int32 `json:"vision_end_token_id"`
|
||||||
|
|
||||||
// Quantization metadata.
|
// Quantization metadata.
|
||||||
QuantGroupSize int `json:"-"`
|
QuantGroupSize int `json:"-"`
|
||||||
@@ -90,6 +114,9 @@ type Model struct {
|
|||||||
*Config
|
*Config
|
||||||
|
|
||||||
weightPrefix string
|
weightPrefix string
|
||||||
|
|
||||||
|
Vision *VisionModel
|
||||||
|
ImageProcessor *VisionImageProcessor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Layer is a transformer decoder layer.
|
// Layer is a transformer decoder layer.
|
||||||
@@ -190,17 +217,24 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
|
|
||||||
var cfg Config
|
var cfg Config
|
||||||
activeRaw := rawTop
|
activeRaw := rawTop
|
||||||
|
|
||||||
|
// First pass: unmarshal the full config to pick up top-level fields
|
||||||
|
// (vision_config, image_token_id, etc.) and text fields for text-only models.
|
||||||
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||||
|
return Config{}, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: if text_config exists, unmarshal it into TextConfig so
|
||||||
|
// text-model fields from text_config take priority over any top-level
|
||||||
|
// duplicates. Top-level-only fields (Vision, token IDs) are unaffected
|
||||||
|
// because they live on Config, not TextConfig.
|
||||||
if textRaw, ok := rawTop["text_config"]; ok {
|
if textRaw, ok := rawTop["text_config"]; ok {
|
||||||
if err := json.Unmarshal(textRaw, &cfg); err != nil {
|
if err := json.Unmarshal(textRaw, &cfg.TextConfig); err != nil {
|
||||||
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
return Config{}, fmt.Errorf("parse text_config: %w", err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
if err := json.Unmarshal(textRaw, &activeRaw); err != nil {
|
||||||
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
return Config{}, fmt.Errorf("parse text_config envelope: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if err := json.Unmarshal(configData, &cfg); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("parse config: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.HiddenSize <= 0 {
|
if cfg.HiddenSize <= 0 {
|
||||||
@@ -225,12 +259,8 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
return Config{}, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.RMSNormEps == 0 {
|
cfg.RMSNormEps = cmp.Or(cfg.RMSNormEps, 1e-6)
|
||||||
cfg.RMSNormEps = 1e-6
|
cfg.LinearConvKernelDim = cmp.Or(cfg.LinearConvKernelDim, 4)
|
||||||
}
|
|
||||||
if cfg.LinearConvKernelDim <= 0 {
|
|
||||||
cfg.LinearConvKernelDim = 4
|
|
||||||
}
|
|
||||||
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
if cfg.LinearNumKeyHeads <= 0 || cfg.LinearNumValueHeads <= 0 || cfg.LinearKeyHeadDim <= 0 || cfg.LinearValueHeadDim <= 0 {
|
||||||
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
return Config{}, fmt.Errorf("invalid linear attention config (k_heads=%d v_heads=%d k_dim=%d v_dim=%d)",
|
||||||
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
cfg.LinearNumKeyHeads, cfg.LinearNumValueHeads, cfg.LinearKeyHeadDim, cfg.LinearValueHeadDim)
|
||||||
@@ -246,14 +276,21 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
if cfg.RopeParameters.PartialRotaryFactor > 0 {
|
||||||
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
cfg.PartialRotaryFactor = cfg.RopeParameters.PartialRotaryFactor
|
||||||
}
|
}
|
||||||
|
if len(cfg.MRoPESections) == 0 {
|
||||||
|
switch {
|
||||||
|
case len(cfg.RopeParameters.MRoPESection) > 0:
|
||||||
|
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.MRoPESection...)
|
||||||
|
case len(cfg.RopeParameters.DimensionSections) > 0:
|
||||||
|
cfg.MRoPESections = append([]int32(nil), cfg.RopeParameters.DimensionSections...)
|
||||||
}
|
}
|
||||||
if cfg.RopeTheta == 0 {
|
|
||||||
cfg.RopeTheta = 100000.0
|
|
||||||
}
|
}
|
||||||
if cfg.PartialRotaryFactor == 0 {
|
cfg.MRoPEInterleaved = cmp.Or(cfg.MRoPEInterleaved, cfg.RopeParameters.MRoPEInterleaved)
|
||||||
cfg.PartialRotaryFactor = 0.25
|
|
||||||
}
|
}
|
||||||
if cfg.PartialRotaryFactor < 0 {
|
if len(cfg.MRoPESections) > 4 {
|
||||||
|
cfg.MRoPESections = cfg.MRoPESections[:4]
|
||||||
|
}
|
||||||
|
cfg.RopeTheta = cmp.Or(cfg.RopeTheta, 100000.0)
|
||||||
|
if cfg.PartialRotaryFactor <= 0 {
|
||||||
cfg.PartialRotaryFactor = 0.25
|
cfg.PartialRotaryFactor = 0.25
|
||||||
}
|
}
|
||||||
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
ropeDim := int32(float32(cfg.HeadDim) * cfg.PartialRotaryFactor)
|
||||||
@@ -281,24 +318,23 @@ func parseConfig(configData []byte) (Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cfg.NumExperts > 0 {
|
if cfg.NumExperts > 0 {
|
||||||
if cfg.NumExpertsPerTok <= 0 {
|
cfg.NumExpertsPerTok = cmp.Or(cfg.NumExpertsPerTok, int32(1))
|
||||||
cfg.NumExpertsPerTok = 1
|
cfg.MoeIntermediateSize = cmp.Or(cfg.MoeIntermediateSize, cfg.IntermediateSize)
|
||||||
}
|
cfg.SharedExpertIntermediateSize = cmp.Or(cfg.SharedExpertIntermediateSize, cfg.IntermediateSize)
|
||||||
if cfg.MoeIntermediateSize <= 0 {
|
|
||||||
cfg.MoeIntermediateSize = cfg.IntermediateSize
|
|
||||||
}
|
|
||||||
if cfg.SharedExpertIntermediateSize <= 0 {
|
|
||||||
cfg.SharedExpertIntermediateSize = cfg.IntermediateSize
|
|
||||||
}
|
|
||||||
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
if _, ok := activeRaw["norm_topk_prob"]; !ok {
|
||||||
cfg.NormTopKProb = true
|
cfg.NormTopKProb = true
|
||||||
}
|
}
|
||||||
if cfg.DecoderSparseStep <= 0 {
|
cfg.DecoderSparseStep = cmp.Or(cfg.DecoderSparseStep, int32(1))
|
||||||
cfg.DecoderSparseStep = 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
|
||||||
|
if cfg.Vision != nil {
|
||||||
|
cfg.Vision.applyDefaults()
|
||||||
|
}
|
||||||
|
cfg.ImageTokenID = cmp.Or(cfg.ImageTokenID, int32(151655))
|
||||||
|
cfg.VisionStartToken = cmp.Or(cfg.VisionStartToken, int32(151652))
|
||||||
|
cfg.VisionEndToken = cmp.Or(cfg.VisionEndToken, int32(151653))
|
||||||
return cfg, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,6 +400,11 @@ func NewModel(root *model.Root) (base.Model, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if cfg.Vision != nil {
|
||||||
|
if preprocessorData, err := root.Manifest.ReadConfig("preprocessor_config.json"); err == nil {
|
||||||
|
cfg.Vision.applyPreprocessorConfig(preprocessorData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if qt := root.QuantType(); qt != "" {
|
if qt := root.QuantType(); qt != "" {
|
||||||
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||||
@@ -1060,6 +1101,15 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
m.Layers[i] = layer
|
m.Layers[i] = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.Vision != nil && cfg.Vision.Depth > 0 {
|
||||||
|
vision, processor, err := loadVisionComponents(tensors, linears, cfg, m.weightPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Vision = vision
|
||||||
|
m.ImageProcessor = processor
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1117,7 +1167,51 @@ func splitQKVZBA(mixedQKVZ, mixedBA *mlx.Array, cfg *Config, B, L int32) (q, k,
|
|||||||
return q, k, v, z, b, a
|
return q, k, v, z, b, a
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
func rotateHalf(x *mlx.Array) *mlx.Array {
|
||||||
|
shape := x.Dims()
|
||||||
|
last := int32(shape[len(shape)-1])
|
||||||
|
half := last / 2
|
||||||
|
if half <= 0 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
x1 := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), half})
|
||||||
|
x2 := mlx.SliceStartStop(x, []int32{0, 0, 0, half}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||||
|
return mlx.Concatenate([]*mlx.Array{mlx.Neg(x2), x1}, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyTextRoPE(x, cos, sin *mlx.Array, ropeDim int32) *mlx.Array {
|
||||||
|
if x == nil || cos == nil || sin == nil || ropeDim <= 0 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 4 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
last := int32(shape[len(shape)-1])
|
||||||
|
if ropeDim > last {
|
||||||
|
ropeDim = last
|
||||||
|
}
|
||||||
|
if ropeDim%2 != 0 {
|
||||||
|
ropeDim--
|
||||||
|
}
|
||||||
|
if ropeDim <= 0 {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
rot := mlx.SliceStartStop(x, []int32{0, 0, 0, 0}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), ropeDim})
|
||||||
|
rot = mlx.Add(mlx.Mul(rot, cos), mlx.Mul(rotateHalf(rot), sin))
|
||||||
|
if ropeDim == last {
|
||||||
|
return rot
|
||||||
|
}
|
||||||
|
|
||||||
|
tail := mlx.SliceStartStop(x, []int32{0, 0, 0, ropeDim}, []int32{int32(shape[0]), int32(shape[1]), int32(shape[2]), last})
|
||||||
|
return mlx.Concatenate([]*mlx.Array{rot, tail}, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||||
qg := a.QProj.Forward(x)
|
qg := a.QProj.Forward(x)
|
||||||
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
qg = mlx.Reshape(qg, B, L, cfg.NumAttentionHeads, cfg.HeadDim*2)
|
||||||
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
q := mlx.SliceStartStop(qg, []int32{0, 0, 0, 0}, []int32{B, L, cfg.NumAttentionHeads, cfg.HeadDim})
|
||||||
@@ -1140,8 +1234,13 @@ func (a *FullAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
|||||||
if c != nil {
|
if c != nil {
|
||||||
offset = c.Offset()
|
offset = c.Offset()
|
||||||
}
|
}
|
||||||
|
if ropeCos != nil && ropeSin != nil {
|
||||||
|
q = applyTextRoPE(q, ropeCos, ropeSin, cfg.RopeDim)
|
||||||
|
k = applyTextRoPE(k, ropeCos, ropeSin, cfg.RopeDim)
|
||||||
|
} else {
|
||||||
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
q = mlx.RoPEWithBase(q, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
k = mlx.RoPEWithBase(k, int(cfg.RopeDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
|
}
|
||||||
|
|
||||||
if c != nil {
|
if c != nil {
|
||||||
k, v = c.Update(k, v)
|
k, v = c.Update(k, v)
|
||||||
@@ -1323,13 +1422,13 @@ func (m *SparseMoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
|
|||||||
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
return mlx.Reshape(y, B, L, cfg.HiddenSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config, ropeCos, ropeSin *mlx.Array) *mlx.Array {
|
||||||
var r *mlx.Array
|
var r *mlx.Array
|
||||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||||
if l.IsLinear {
|
if l.IsLinear {
|
||||||
r = l.Linear.Forward(normed, c, B, L, cfg)
|
r = l.Linear.Forward(normed, c, B, L, cfg)
|
||||||
} else {
|
} else {
|
||||||
r = l.FullAttn.Forward(normed, c, B, L, cfg)
|
r = l.FullAttn.Forward(normed, c, B, L, cfg, ropeCos, ropeSin)
|
||||||
}
|
}
|
||||||
h := mlx.Add(x, r)
|
h := mlx.Add(x, r)
|
||||||
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
r = l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps), cfg)
|
||||||
@@ -1337,16 +1436,27 @@ func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||||
|
return m.ForwardWithState(tokens, caches, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) ForwardWithState(tokens *mlx.Array, caches []cache.Cache, state any) *mlx.Array {
|
||||||
dims := tokens.Dims()
|
dims := tokens.Dims()
|
||||||
B, L := int32(dims[0]), int32(dims[1])
|
B, L := int32(dims[0]), int32(dims[1])
|
||||||
|
|
||||||
|
startPos := promptStartPosFromCaches(caches)
|
||||||
|
promptState := promptVisionStateFromState(state)
|
||||||
h := m.EmbedTokens.Forward(tokens)
|
h := m.EmbedTokens.Forward(tokens)
|
||||||
|
h = m.applyPromptVisionEmbeddings(h, startPos, promptState)
|
||||||
|
var ropeCos, ropeSin *mlx.Array
|
||||||
|
if len(m.MRoPESections) > 0 {
|
||||||
|
ropeCos, ropeSin = m.buildPromptMRoPECosSin(promptState, startPos, L, h.DType())
|
||||||
|
}
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
var c cache.Cache
|
var c cache.Cache
|
||||||
if caches != nil && i < len(caches) {
|
if caches != nil && i < len(caches) {
|
||||||
c = caches[i]
|
c = caches[i]
|
||||||
}
|
}
|
||||||
h = layer.Forward(h, c, B, L, m.Config)
|
h = layer.Forward(h, c, B, L, m.Config, ropeCos, ropeSin)
|
||||||
}
|
}
|
||||||
out := m.Norm.Forward(h, m.RMSNormEps)
|
out := m.Norm.Forward(h, m.RMSNormEps)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
package qwen3_5
|
package qwen3_5
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func skipIfNoMLX(t *testing.T) {
|
func skipIfNoMLX(t *testing.T) {
|
||||||
@@ -60,13 +64,13 @@ func TestParseConfigNestedDefaults(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLayerSelectionHelpers(t *testing.T) {
|
func TestLayerSelectionHelpers(t *testing.T) {
|
||||||
cfg := &Config{
|
cfg := &Config{TextConfig: TextConfig{
|
||||||
NumHiddenLayers: 6,
|
NumHiddenLayers: 6,
|
||||||
FullAttentionInterval: 3,
|
FullAttentionInterval: 3,
|
||||||
NumExperts: 8,
|
NumExperts: 8,
|
||||||
DecoderSparseStep: 2,
|
DecoderSparseStep: 2,
|
||||||
MLPOnlyLayers: []int32{1},
|
MLPOnlyLayers: []int32{1},
|
||||||
}
|
}}
|
||||||
|
|
||||||
if !layerIsLinear(cfg, 0) {
|
if !layerIsLinear(cfg, 0) {
|
||||||
t.Fatalf("layer 0 should be linear")
|
t.Fatalf("layer 0 should be linear")
|
||||||
@@ -133,13 +137,13 @@ func TestResolveTensorPathLayout(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewCachesLayout(t *testing.T) {
|
func TestNewCachesLayout(t *testing.T) {
|
||||||
m := &Model{
|
m := &Model{
|
||||||
Config: &Config{
|
Config: &Config{TextConfig: TextConfig{
|
||||||
LinearConvKernelDim: 4,
|
LinearConvKernelDim: 4,
|
||||||
LinearNumKeyHeads: 2,
|
LinearNumKeyHeads: 2,
|
||||||
LinearKeyHeadDim: 8,
|
LinearKeyHeadDim: 8,
|
||||||
LinearNumValueHeads: 4,
|
LinearNumValueHeads: 4,
|
||||||
LinearValueHeadDim: 16,
|
LinearValueHeadDim: 16,
|
||||||
},
|
}},
|
||||||
Layers: []*Layer{
|
Layers: []*Layer{
|
||||||
{IsLinear: true},
|
{IsLinear: true},
|
||||||
{IsLinear: false},
|
{IsLinear: false},
|
||||||
@@ -166,7 +170,7 @@ func TestNewCachesLayout(t *testing.T) {
|
|||||||
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
||||||
skipIfNoMLX(t)
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
cfg := &Config{
|
cfg := &Config{TextConfig: TextConfig{
|
||||||
HiddenSize: 4,
|
HiddenSize: 4,
|
||||||
IntermediateSize: 8,
|
IntermediateSize: 8,
|
||||||
NumHiddenLayers: 2,
|
NumHiddenLayers: 2,
|
||||||
@@ -182,7 +186,7 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
|||||||
LinearValueHeadDim: 2,
|
LinearValueHeadDim: 2,
|
||||||
LinearConvKernelDim: 4,
|
LinearConvKernelDim: 4,
|
||||||
FullAttentionInterval: 2,
|
FullAttentionInterval: 2,
|
||||||
}
|
}}
|
||||||
|
|
||||||
m := &Model{
|
m := &Model{
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
@@ -343,3 +347,389 @@ func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
|
|||||||
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
t.Fatalf("k norm dtype = %v, want %v", got, f32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseConfigVisionFields(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"linear_num_value_heads": 64,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4,
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 10000000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"vision_config": {
|
||||||
|
"depth": 2,
|
||||||
|
"hidden_size": 256,
|
||||||
|
"num_heads": 8,
|
||||||
|
"in_channels": 3,
|
||||||
|
"patch_size": 14,
|
||||||
|
"spatial_merge_size": 2,
|
||||||
|
"layer_norm_epsilon": 0.000001,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"num_position_embeddings": 2304
|
||||||
|
},
|
||||||
|
"image_token_id": 111,
|
||||||
|
"vision_start_token_id": 112,
|
||||||
|
"vision_end_token_id": 113
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Vision == nil {
|
||||||
|
t.Fatal("vision config should be parsed")
|
||||||
|
}
|
||||||
|
if cfg.Vision.Depth != 2 {
|
||||||
|
t.Fatalf("vision.depth mismatch: got %d", cfg.Vision.Depth)
|
||||||
|
}
|
||||||
|
if cfg.Vision.GridPerSide != 48 {
|
||||||
|
t.Fatalf("vision grid-per-side mismatch: got %d want 48", cfg.Vision.GridPerSide)
|
||||||
|
}
|
||||||
|
if cfg.Vision.RopeTheta != 10000 {
|
||||||
|
t.Fatalf("vision rope_theta should default to 10000, got %v", cfg.Vision.RopeTheta)
|
||||||
|
}
|
||||||
|
if cfg.RopeTheta != 10000000 {
|
||||||
|
t.Fatalf("text rope_theta mismatch: got %v", cfg.RopeTheta)
|
||||||
|
}
|
||||||
|
if cfg.ImageTokenID != 111 || cfg.VisionStartToken != 112 || cfg.VisionEndToken != 113 {
|
||||||
|
t.Fatalf("vision token ids mismatch: got image=%d start=%d end=%d", cfg.ImageTokenID, cfg.VisionStartToken, cfg.VisionEndToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigMRoPEFromRopeParameters(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 2048,
|
||||||
|
"intermediate_size": 8192,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"head_dim": 256,
|
||||||
|
"linear_num_value_heads": 32,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4,
|
||||||
|
"rope_parameters": {
|
||||||
|
"rope_theta": 10000000,
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"mrope_interleaved": true,
|
||||||
|
"mrope_section": [11, 11, 10]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.MRoPEInterleaved {
|
||||||
|
t.Fatal("mrope_interleaved should be parsed from rope_parameters")
|
||||||
|
}
|
||||||
|
if !slices.Equal(cfg.MRoPESections, []int32{11, 11, 10}) {
|
||||||
|
t.Fatalf("mrope sections mismatch: got %v", cfg.MRoPESections)
|
||||||
|
}
|
||||||
|
if cfg.RopeDim != 64 {
|
||||||
|
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseConfigVisionTokenDefaults(t *testing.T) {
|
||||||
|
data := []byte(`{
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"head_dim": 128,
|
||||||
|
"linear_num_value_heads": 64,
|
||||||
|
"linear_num_key_heads": 16,
|
||||||
|
"linear_key_head_dim": 128,
|
||||||
|
"linear_value_head_dim": 128,
|
||||||
|
"linear_conv_kernel_dim": 4
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.ImageTokenID != 151655 {
|
||||||
|
t.Fatalf("default image token mismatch: got %d", cfg.ImageTokenID)
|
||||||
|
}
|
||||||
|
if cfg.VisionStartToken != 151652 {
|
||||||
|
t.Fatalf("default vision start token mismatch: got %d", cfg.VisionStartToken)
|
||||||
|
}
|
||||||
|
if cfg.VisionEndToken != 151653 {
|
||||||
|
t.Fatalf("default vision end token mismatch: got %d", cfg.VisionEndToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveVisionPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tensors map[string]*mlx.Array
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "legacy visual prefix",
|
||||||
|
tensors: map[string]*mlx.Array{
|
||||||
|
"model.visual.patch_embed.proj.weight": mlx.New("patch"),
|
||||||
|
},
|
||||||
|
want: "model.visual",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "imported vision tower prefix",
|
||||||
|
tensors: map[string]*mlx.Array{
|
||||||
|
"vision_tower.blocks.0.attn.qkv.weight": mlx.New("qkv"),
|
||||||
|
},
|
||||||
|
want: "vision_tower",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveVisionPrefix(tt.tensors, "language_model."); got != tt.want {
|
||||||
|
t.Fatalf("resolveVisionPrefix() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionPreprocessorOverridesDefaults(t *testing.T) {
|
||||||
|
v := &VisionConfig{}
|
||||||
|
v.applyDefaults()
|
||||||
|
|
||||||
|
v.applyPreprocessorConfig([]byte(`{
|
||||||
|
"patch_size": 16,
|
||||||
|
"temporal_patch_size": 3,
|
||||||
|
"merge_size": 4,
|
||||||
|
"size": {
|
||||||
|
"shortest_edge": 1024,
|
||||||
|
"longest_edge": 8192
|
||||||
|
},
|
||||||
|
"image_mean": [0.1, 0.2, 0.3],
|
||||||
|
"image_std": [0.9, 0.8, 0.7]
|
||||||
|
}`))
|
||||||
|
|
||||||
|
if v.PatchSize != 16 {
|
||||||
|
t.Fatalf("patch_size mismatch: got %d want 16", v.PatchSize)
|
||||||
|
}
|
||||||
|
if v.TemporalPatchSize != 3 {
|
||||||
|
t.Fatalf("temporal_patch_size mismatch: got %d want 3", v.TemporalPatchSize)
|
||||||
|
}
|
||||||
|
if v.SpatialMergeSize != 4 {
|
||||||
|
t.Fatalf("merge_size mismatch: got %d want 4", v.SpatialMergeSize)
|
||||||
|
}
|
||||||
|
if v.Size.ShortestEdge != 1024 || v.Size.LongestEdge != 8192 {
|
||||||
|
t.Fatalf("size mismatch: got shortest=%d longest=%d", v.Size.ShortestEdge, v.Size.LongestEdge)
|
||||||
|
}
|
||||||
|
if v.ImageMean[0] != 0.1 || v.ImageStd[2] != 0.7 {
|
||||||
|
t.Fatalf("image preprocessing stats mismatch: mean=%v std=%v", v.ImageMean, v.ImageStd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVisionImageProcessorUsesPreprocessorSize(t *testing.T) {
|
||||||
|
v := &VisionConfig{}
|
||||||
|
v.applyDefaults()
|
||||||
|
|
||||||
|
v.applyPreprocessorConfig([]byte(`{
|
||||||
|
"size": {
|
||||||
|
"shortest_edge": 65536,
|
||||||
|
"longest_edge": 16777216
|
||||||
|
},
|
||||||
|
"patch_size": 16,
|
||||||
|
"temporal_patch_size": 2,
|
||||||
|
"merge_size": 2,
|
||||||
|
"image_mean": [0.5, 0.5, 0.5],
|
||||||
|
"image_std": [0.5, 0.5, 0.5]
|
||||||
|
}`))
|
||||||
|
|
||||||
|
p := newVisionImageProcessor(v)
|
||||||
|
if p == nil {
|
||||||
|
t.Fatal("newVisionImageProcessor returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.shortestEdge != 65536 || p.longestEdge != 16777216 {
|
||||||
|
t.Fatalf("processor size mismatch: shortest=%d longest=%d", p.shortestEdge, p.longestEdge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTokenizer(t *testing.T) *tokenizer.Tokenizer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tok, err := tokenizer.LoadFromBytes([]byte(`{
|
||||||
|
"model": {
|
||||||
|
"type": "BPE",
|
||||||
|
"vocab": {"a": 0},
|
||||||
|
"merges": []
|
||||||
|
}
|
||||||
|
}`))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load test tokenizer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tok
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenizePromptWithResolvedImagesStoresVisionSpans(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
tok: testTokenizer(t),
|
||||||
|
Config: &Config{
|
||||||
|
ImageTokenID: 101,
|
||||||
|
VisionStartToken: 102,
|
||||||
|
VisionEndToken: 103,
|
||||||
|
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||||
|
},
|
||||||
|
Vision: &VisionModel{},
|
||||||
|
ImageProcessor: &VisionImageProcessor{},
|
||||||
|
}
|
||||||
|
|
||||||
|
main := mlx.FromValues([]float32{
|
||||||
|
10, 11,
|
||||||
|
20, 21,
|
||||||
|
}, 1, 2, 2)
|
||||||
|
|
||||||
|
resolveCalls := 0
|
||||||
|
got, state, err := m.tokenizePromptWithResolvedImages(
|
||||||
|
"a[img-7][img-7]a",
|
||||||
|
[]base.ImageInput{{ID: 7, Data: []byte("img7")}},
|
||||||
|
func(data []byte) (*VisionEmbeddings, error) {
|
||||||
|
if string(data) != "img7" {
|
||||||
|
return nil, fmt.Errorf("unexpected data: %q", string(data))
|
||||||
|
}
|
||||||
|
resolveCalls++
|
||||||
|
return &VisionEmbeddings{
|
||||||
|
Main: main,
|
||||||
|
Grid: &VisionGrid{Height: 2, Width: 2, Temporal: 1},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tokenizePromptWithResolvedImages returned error: %v", err)
|
||||||
|
}
|
||||||
|
if resolveCalls != 1 {
|
||||||
|
t.Fatalf("resolve calls mismatch: got %d want 1", resolveCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []int32{
|
||||||
|
0,
|
||||||
|
102, 101, 101, 103,
|
||||||
|
102, 101, 101, 103,
|
||||||
|
0,
|
||||||
|
}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("expanded tokens mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == nil {
|
||||||
|
t.Fatal("expected prompt vision state")
|
||||||
|
}
|
||||||
|
if len(state.Spans) != 2 {
|
||||||
|
t.Fatalf("prompt span count mismatch: got %d want 2", len(state.Spans))
|
||||||
|
}
|
||||||
|
if state.Spans[0].Start != 2 || state.Spans[0].End != 4 {
|
||||||
|
t.Fatalf("first span mismatch: got [%d,%d)", state.Spans[0].Start, state.Spans[0].End)
|
||||||
|
}
|
||||||
|
if state.Spans[1].Start != 6 || state.Spans[1].End != 8 {
|
||||||
|
t.Fatalf("second span mismatch: got [%d,%d)", state.Spans[1].Start, state.Spans[1].End)
|
||||||
|
}
|
||||||
|
wantPos := []int32{0, 1, 2, 2, 3, 4, 5, 5, 6, 7}
|
||||||
|
if !slices.Equal(state.PositionCache, wantPos) {
|
||||||
|
t.Fatalf("position cache mismatch: got %v want %v", state.PositionCache, wantPos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPromptMRoPEPositions(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Config: &Config{
|
||||||
|
Vision: &VisionConfig{SpatialMergeSize: 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
state := &promptVisionState{
|
||||||
|
PositionCache: []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6},
|
||||||
|
Spans: []promptVisionSpan{
|
||||||
|
{
|
||||||
|
Start: 2,
|
||||||
|
End: 8,
|
||||||
|
Grid: &VisionGrid{Height: 4, Width: 6, Temporal: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pos := m.buildPromptMRoPEPositions(state, 0, 10)
|
||||||
|
if got, want := pos[0], []int32{0, 1, 2, 2, 2, 2, 2, 2, 5, 6}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("time positions mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := pos[1], []int32{0, 1, 2, 2, 2, 3, 3, 3, 5, 6}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("height positions mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := pos[2], []int32{0, 1, 2, 3, 4, 2, 3, 4, 5, 6}; !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("width positions mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapPromptPositionContinuesAfterCache(t *testing.T) {
|
||||||
|
state := &promptVisionState{PositionCache: []int32{0, 1, 2, 2, 3}}
|
||||||
|
|
||||||
|
if got := mapPromptPosition(state, 3); got != 2 {
|
||||||
|
t.Fatalf("mapPromptPosition(3) = %d, want 2", got)
|
||||||
|
}
|
||||||
|
if got := mapPromptPosition(state, 5); got != 4 {
|
||||||
|
t.Fatalf("mapPromptPosition(5) = %d, want 4", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyPromptVisionEmbeddings(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
m := &Model{}
|
||||||
|
state := &promptVisionState{
|
||||||
|
Spans: []promptVisionSpan{
|
||||||
|
{
|
||||||
|
Start: 1,
|
||||||
|
End: 3,
|
||||||
|
Main: mlx.FromValues([]float32{
|
||||||
|
10, 11,
|
||||||
|
20, 21,
|
||||||
|
}, 1, 2, 2),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
h := mlx.FromValues([]float32{
|
||||||
|
0, 1,
|
||||||
|
2, 3,
|
||||||
|
4, 5,
|
||||||
|
6, 7,
|
||||||
|
}, 1, 4, 2)
|
||||||
|
|
||||||
|
got := m.applyPromptVisionEmbeddings(h, 0, state)
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
want := []float32{
|
||||||
|
0, 1,
|
||||||
|
10, 11,
|
||||||
|
20, 21,
|
||||||
|
6, 7,
|
||||||
|
}
|
||||||
|
if !slices.Equal(got.Floats(), want) {
|
||||||
|
t.Fatalf("embedding replacement mismatch: got %v want %v", got.Floats(), want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
854
x/models/qwen3_5/vision.go
Normal file
854
x/models/qwen3_5/vision.go
Normal file
@@ -0,0 +1,854 @@
|
|||||||
|
package qwen3_5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
_ "image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/model/imageproc"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
mlxmodel "github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
|
"github.com/ollama/ollama/x/models/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoVisionModel = errors.New("qwen3_5: no vision model")
|
||||||
|
|
||||||
|
// VisionConfig mirrors Qwen3.5/Qwen3-Next vision_config.
|
||||||
|
type VisionConfig struct {
|
||||||
|
Depth int32 `json:"depth"`
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHeads int32 `json:"num_heads"`
|
||||||
|
InChannels int32 `json:"in_channels"`
|
||||||
|
PatchSize int32 `json:"patch_size"`
|
||||||
|
SpatialMergeSize int32 `json:"spatial_merge_size"`
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||||
|
NumPositionEmbeddings int32 `json:"num_position_embeddings"`
|
||||||
|
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge int32 `json:"shortest_edge"`
|
||||||
|
LongestEdge int32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
|
||||||
|
GridPerSide int32 `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VisionConfig) applyDefaults() {
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.HiddenSize <= 0 {
|
||||||
|
v.HiddenSize = 1280
|
||||||
|
}
|
||||||
|
if v.NumHeads <= 0 {
|
||||||
|
v.NumHeads = 16
|
||||||
|
}
|
||||||
|
if v.InChannels <= 0 {
|
||||||
|
v.InChannels = 3
|
||||||
|
}
|
||||||
|
if v.PatchSize <= 0 {
|
||||||
|
v.PatchSize = 14
|
||||||
|
}
|
||||||
|
if v.SpatialMergeSize <= 0 {
|
||||||
|
v.SpatialMergeSize = 2
|
||||||
|
}
|
||||||
|
if v.LayerNormEpsilon == 0 {
|
||||||
|
v.LayerNormEpsilon = 1e-6
|
||||||
|
}
|
||||||
|
if v.RopeTheta == 0 {
|
||||||
|
v.RopeTheta = 10000
|
||||||
|
}
|
||||||
|
if v.TemporalPatchSize <= 0 {
|
||||||
|
v.TemporalPatchSize = 2
|
||||||
|
}
|
||||||
|
if v.NumPositionEmbeddings <= 0 {
|
||||||
|
v.NumPositionEmbeddings = 2304
|
||||||
|
}
|
||||||
|
if len(v.ImageMean) < 3 {
|
||||||
|
v.ImageMean = []float32{0.5, 0.5, 0.5}
|
||||||
|
}
|
||||||
|
if len(v.ImageStd) < 3 {
|
||||||
|
v.ImageStd = []float32{0.5, 0.5, 0.5}
|
||||||
|
}
|
||||||
|
if v.Size.ShortestEdge <= 0 {
|
||||||
|
v.Size.ShortestEdge = 64 << 10
|
||||||
|
}
|
||||||
|
if v.Size.LongestEdge <= 0 {
|
||||||
|
v.Size.LongestEdge = 2 << 20
|
||||||
|
}
|
||||||
|
|
||||||
|
grid := int32(math.Sqrt(float64(v.NumPositionEmbeddings)))
|
||||||
|
if grid <= 0 {
|
||||||
|
grid = 48
|
||||||
|
}
|
||||||
|
v.GridPerSide = grid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *VisionConfig) applyPreprocessorConfig(data []byte) {
|
||||||
|
if v == nil || len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var pre struct {
|
||||||
|
Size struct {
|
||||||
|
ShortestEdge int32 `json:"shortest_edge"`
|
||||||
|
LongestEdge int32 `json:"longest_edge"`
|
||||||
|
} `json:"size"`
|
||||||
|
PatchSize int32 `json:"patch_size"`
|
||||||
|
TemporalPatchSize int32 `json:"temporal_patch_size"`
|
||||||
|
MergeSize int32 `json:"merge_size"`
|
||||||
|
ImageMean []float32 `json:"image_mean"`
|
||||||
|
ImageStd []float32 `json:"image_std"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &pre); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if pre.PatchSize > 0 {
|
||||||
|
v.PatchSize = pre.PatchSize
|
||||||
|
}
|
||||||
|
if pre.TemporalPatchSize > 0 {
|
||||||
|
v.TemporalPatchSize = pre.TemporalPatchSize
|
||||||
|
}
|
||||||
|
if pre.MergeSize > 0 {
|
||||||
|
v.SpatialMergeSize = pre.MergeSize
|
||||||
|
}
|
||||||
|
if pre.Size.ShortestEdge > 0 {
|
||||||
|
v.Size.ShortestEdge = pre.Size.ShortestEdge
|
||||||
|
}
|
||||||
|
if pre.Size.LongestEdge > 0 {
|
||||||
|
v.Size.LongestEdge = pre.Size.LongestEdge
|
||||||
|
}
|
||||||
|
if len(pre.ImageMean) >= 3 {
|
||||||
|
v.ImageMean = pre.ImageMean
|
||||||
|
}
|
||||||
|
if len(pre.ImageStd) >= 3 {
|
||||||
|
v.ImageStd = pre.ImageStd
|
||||||
|
}
|
||||||
|
v.applyDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionGrid tracks patch-grid dimensions for an image.
|
||||||
|
type VisionGrid struct {
|
||||||
|
Height int32
|
||||||
|
Width int32
|
||||||
|
Temporal int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionImageProcessor reproduces qwen3vl image preprocessing.
|
||||||
|
type VisionImageProcessor struct {
|
||||||
|
numChannels int32
|
||||||
|
patchSize int32
|
||||||
|
temporalPatchSize int32
|
||||||
|
mergeSize int32
|
||||||
|
shortestEdge int32
|
||||||
|
longestEdge int32
|
||||||
|
factor int32
|
||||||
|
imageMean [3]float32
|
||||||
|
imageStd [3]float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newVisionImageProcessor(cfg *VisionConfig) *VisionImageProcessor {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &VisionImageProcessor{
|
||||||
|
numChannels: cfg.InChannels,
|
||||||
|
patchSize: cfg.PatchSize,
|
||||||
|
temporalPatchSize: cfg.TemporalPatchSize,
|
||||||
|
mergeSize: cfg.SpatialMergeSize,
|
||||||
|
shortestEdge: cfg.Size.ShortestEdge,
|
||||||
|
longestEdge: cfg.Size.LongestEdge,
|
||||||
|
factor: cfg.PatchSize * cfg.SpatialMergeSize,
|
||||||
|
imageMean: [3]float32{cfg.ImageMean[0], cfg.ImageMean[1], cfg.ImageMean[2]},
|
||||||
|
imageStd: [3]float32{cfg.ImageStd[0], cfg.ImageStd[1], cfg.ImageStd[2]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *VisionImageProcessor) smartResize(height, width int) (int, int, error) {
|
||||||
|
factor := int(p.factor)
|
||||||
|
if factor <= 0 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid factor: %d", factor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if height < factor || width < factor {
|
||||||
|
return 0, 0, fmt.Errorf("height (%d) or width (%d) must be >= factor (%d)", height, width, factor)
|
||||||
|
}
|
||||||
|
if min(height, width) == 0 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid dimensions: %dx%d", width, height)
|
||||||
|
}
|
||||||
|
if max(height, width)/min(height, width) > 200 {
|
||||||
|
return 0, 0, fmt.Errorf("aspect ratio too large: %dx%d", width, height)
|
||||||
|
}
|
||||||
|
|
||||||
|
roundEven := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||||
|
|
||||||
|
hBar := roundEven(float64(height)/float64(factor)) * factor
|
||||||
|
wBar := roundEven(float64(width)/float64(factor)) * factor
|
||||||
|
|
||||||
|
if hBar*wBar > int(p.longestEdge) {
|
||||||
|
beta := math.Sqrt(float64(height*width) / float64(p.longestEdge))
|
||||||
|
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||||
|
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||||
|
} else if hBar*wBar < int(p.shortestEdge) {
|
||||||
|
beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width))
|
||||||
|
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||||
|
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||||
|
}
|
||||||
|
|
||||||
|
return hBar, wBar, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *VisionImageProcessor) ProcessImage(img image.Image) (*mlx.Array, *VisionGrid, error) {
|
||||||
|
if p == nil {
|
||||||
|
return nil, nil, errNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
img = imageproc.Composite(img)
|
||||||
|
origW := img.Bounds().Dx()
|
||||||
|
origH := img.Bounds().Dy()
|
||||||
|
|
||||||
|
resizedH, resizedW, err := p.smartResize(origH, origW)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resized := imageproc.Resize(
|
||||||
|
img,
|
||||||
|
image.Point{X: resizedW, Y: resizedH},
|
||||||
|
imageproc.ResizeBilinear,
|
||||||
|
)
|
||||||
|
pixels := imageproc.Normalize(resized, p.imageMean, p.imageStd, true, true)
|
||||||
|
|
||||||
|
grid := &VisionGrid{
|
||||||
|
Height: int32(resizedH / int(p.patchSize)),
|
||||||
|
Width: int32(resizedW / int(p.patchSize)),
|
||||||
|
Temporal: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
patches := p.createPatches(pixels, resizedH, resizedW, grid)
|
||||||
|
|
||||||
|
patchDim := int(p.numChannels * p.temporalPatchSize * p.patchSize * p.patchSize)
|
||||||
|
numPatches := int(grid.Height * grid.Width)
|
||||||
|
pixelValues := mlx.FromValues(patches, numPatches, patchDim).ExpandDims(0)
|
||||||
|
return pixelValues, grid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *VisionImageProcessor) createPatches(pixels []float32, height, width int, grid *VisionGrid) []float32 {
|
||||||
|
channels := int(p.numChannels)
|
||||||
|
patchSize := int(p.patchSize)
|
||||||
|
mergeSize := int(p.mergeSize)
|
||||||
|
temporalPatchSize := int(p.temporalPatchSize)
|
||||||
|
|
||||||
|
// Temporal is always 1 for static images; only spatial patches are created.
|
||||||
|
numPatches := int(grid.Height * grid.Width)
|
||||||
|
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||||
|
result := make([]float32, numPatches*patchDim)
|
||||||
|
|
||||||
|
patchIndex := 0
|
||||||
|
for h := 0; h < int(grid.Height); h += mergeSize {
|
||||||
|
for w := 0; w < int(grid.Width); w += mergeSize {
|
||||||
|
for mh := range mergeSize {
|
||||||
|
for mw := range mergeSize {
|
||||||
|
baseOffset := patchIndex * patchDim
|
||||||
|
|
||||||
|
for c := range channels {
|
||||||
|
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||||
|
for py := range patchSize {
|
||||||
|
for px := range patchSize {
|
||||||
|
y := (h+mh)*patchSize + py
|
||||||
|
x := (w+mw)*patchSize + px
|
||||||
|
srcIdx := c*height*width + y*width + x
|
||||||
|
dstIdx := channelOffset + py*patchSize + px
|
||||||
|
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||||
|
result[dstIdx] = pixels[srcIdx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if temporalPatchSize > 1 {
|
||||||
|
for c := range channels {
|
||||||
|
channelOffset := baseOffset + c*temporalPatchSize*patchSize*patchSize
|
||||||
|
frameSize := patchSize * patchSize
|
||||||
|
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||||
|
cur := channelOffset + tp*frameSize
|
||||||
|
copy(result[cur:cur+frameSize], result[channelOffset:channelOffset+frameSize])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
patchIndex++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionAttention runs one self-attention block inside the vision encoder.
|
||||||
|
type VisionAttention struct {
|
||||||
|
QKV nn.LinearLayer
|
||||||
|
Query nn.LinearLayer
|
||||||
|
Key nn.LinearLayer
|
||||||
|
Value nn.LinearLayer
|
||||||
|
Output nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyVisionRoPE(x, cos, sin *mlx.Array) *mlx.Array {
|
||||||
|
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(rotateHalf(x), sin))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *VisionAttention) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return nil, fmt.Errorf("vision attention expects [B,L,D], got %v", shape)
|
||||||
|
}
|
||||||
|
B, L, hidden := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||||
|
headDim := cfg.HiddenSize / cfg.NumHeads
|
||||||
|
if headDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
var q, k, v *mlx.Array
|
||||||
|
if a.QKV != nil {
|
||||||
|
qkv := a.QKV.Forward(x)
|
||||||
|
qkv = mlx.Reshape(qkv, B, L, 3, cfg.NumHeads, headDim)
|
||||||
|
q = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, cfg.NumHeads, headDim}), 2)
|
||||||
|
k = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, cfg.NumHeads, headDim}), 2)
|
||||||
|
v = mlx.Squeeze(mlx.SliceStartStop(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, cfg.NumHeads, headDim}), 2)
|
||||||
|
} else {
|
||||||
|
if a.Query == nil || a.Key == nil || a.Value == nil {
|
||||||
|
return nil, errors.New("vision attention is missing q/k/v projections")
|
||||||
|
}
|
||||||
|
q = mlx.Reshape(a.Query.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||||
|
k = mlx.Reshape(a.Key.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||||
|
v = mlx.Reshape(a.Value.Forward(x), B, L, cfg.NumHeads, headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
q = applyVisionRoPE(q, cos, sin)
|
||||||
|
k = applyVisionRoPE(k, cos, sin)
|
||||||
|
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
scale := float32(1.0 / math.Sqrt(float64(headDim)))
|
||||||
|
attn := mlx.ScaledDotProductAttentionCausal(q, k, v, scale, false)
|
||||||
|
attn = mlx.Reshape(mlx.Transpose(attn, 0, 2, 1, 3), B, L, hidden)
|
||||||
|
if a.Output == nil {
|
||||||
|
return nil, errors.New("vision attention is missing output projection")
|
||||||
|
}
|
||||||
|
return a.Output.Forward(attn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionMLP is the vision feed-forward block.
|
||||||
|
type VisionMLP struct {
|
||||||
|
FC1 nn.LinearLayer
|
||||||
|
FC2 nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionMLP) Forward(x *mlx.Array) (*mlx.Array, error) {
|
||||||
|
if m.FC1 == nil || m.FC2 == nil {
|
||||||
|
return nil, errors.New("vision mlp is missing fc1/fc2")
|
||||||
|
}
|
||||||
|
return m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionEncoderLayer is one transformer block in the vision encoder.
|
||||||
|
type VisionEncoderLayer struct {
|
||||||
|
Norm1 *nn.LayerNorm
|
||||||
|
Attn *VisionAttention
|
||||||
|
Norm2 *nn.LayerNorm
|
||||||
|
MLP *VisionMLP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *VisionEncoderLayer) Forward(x, cos, sin *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||||
|
if l.Norm1 == nil || l.Norm2 == nil || l.Attn == nil || l.MLP == nil {
|
||||||
|
return nil, errors.New("vision layer is incomplete")
|
||||||
|
}
|
||||||
|
|
||||||
|
r := x
|
||||||
|
a, err := l.Attn.Forward(l.Norm1.Forward(x), cos, sin, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
x = mlx.Add(r, a)
|
||||||
|
|
||||||
|
r = x
|
||||||
|
m, err := l.MLP.Forward(l.Norm2.Forward(x))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mlx.Add(r, m), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionPatchMerger projects merged spatial groups into language embedding space.
|
||||||
|
type VisionPatchMerger struct {
|
||||||
|
Norm *nn.LayerNorm
|
||||||
|
FC1 nn.LinearLayer
|
||||||
|
FC2 nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupMergedTokens(x *mlx.Array, merge int32) (*mlx.Array, error) {
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return nil, fmt.Errorf("expected [B,L,D], got %v", shape)
|
||||||
|
}
|
||||||
|
if merge <= 0 {
|
||||||
|
merge = 1
|
||||||
|
}
|
||||||
|
B, L, D := int32(shape[0]), int32(shape[1]), int32(shape[2])
|
||||||
|
group := merge * merge
|
||||||
|
if group <= 0 || L%group != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid merge layout: L=%d merge=%d", L, merge)
|
||||||
|
}
|
||||||
|
|
||||||
|
x = mlx.Reshape(x, B, L/group, group, D)
|
||||||
|
x = mlx.Reshape(x, B, L/group, group*D)
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionPatchMerger) Forward(x *mlx.Array, cfg *VisionConfig) (*mlx.Array, error) {
|
||||||
|
if m == nil || m.Norm == nil || m.FC1 == nil || m.FC2 == nil {
|
||||||
|
return nil, errors.New("vision patch merger is incomplete")
|
||||||
|
}
|
||||||
|
|
||||||
|
x = m.Norm.Forward(x)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
x, err = groupMergedTokens(x, cfg.SpatialMergeSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
x = m.FC2.Forward(mlx.GELUApprox(m.FC1.Forward(x)))
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VisionModel contains the full Qwen vision tower.
|
||||||
|
type VisionModel struct {
|
||||||
|
PatchProjection nn.LinearLayer
|
||||||
|
PositionEmbed *nn.Embedding
|
||||||
|
Layers []*VisionEncoderLayer
|
||||||
|
PatchMerger *VisionPatchMerger
|
||||||
|
|
||||||
|
cfg *VisionConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergedPatchCoordinates(grid *VisionGrid, merge int32) [][2]int32 {
|
||||||
|
if merge <= 0 {
|
||||||
|
merge = 1
|
||||||
|
}
|
||||||
|
// Temporal is always 1 for static images; only spatial coordinates are generated.
|
||||||
|
coords := make([][2]int32, 0, grid.Height*grid.Width)
|
||||||
|
for h := int32(0); h < grid.Height; h += merge {
|
||||||
|
for w := int32(0); w < grid.Width; w += merge {
|
||||||
|
for mh := range merge {
|
||||||
|
for mw := range merge {
|
||||||
|
coords = append(coords, [2]int32{h + mh, w + mw})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return coords
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) addPositionEmbedding(x *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||||
|
if m.PositionEmbed == nil {
|
||||||
|
return x, nil
|
||||||
|
}
|
||||||
|
shape := x.Dims()
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return nil, fmt.Errorf("vision embeddings expect [B,L,D], got %v", shape)
|
||||||
|
}
|
||||||
|
B, D := int32(shape[0]), int32(shape[2])
|
||||||
|
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||||
|
L := int32(len(coords))
|
||||||
|
if L != int32(shape[1]) {
|
||||||
|
return nil, fmt.Errorf("vision sequence mismatch: hidden L=%d coords=%d", shape[1], L)
|
||||||
|
}
|
||||||
|
|
||||||
|
stepH := float32(0)
|
||||||
|
if grid.Height > 1 {
|
||||||
|
stepH = float32(m.cfg.GridPerSide-1) / float32(grid.Height-1)
|
||||||
|
}
|
||||||
|
stepW := float32(0)
|
||||||
|
if grid.Width > 1 {
|
||||||
|
stepW = float32(m.cfg.GridPerSide-1) / float32(grid.Width-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
indices := make([]int32, 0, L*4)
|
||||||
|
weights := make([]float32, 0, L*4)
|
||||||
|
for _, c := range coords {
|
||||||
|
y := float32(c[0]) * stepH
|
||||||
|
x0 := float32(c[1]) * stepW
|
||||||
|
|
||||||
|
fy := int32(y)
|
||||||
|
fx := int32(x0)
|
||||||
|
cy := min(fy+1, m.cfg.GridPerSide-1)
|
||||||
|
cx := min(fx+1, m.cfg.GridPerSide-1)
|
||||||
|
|
||||||
|
indices = append(indices,
|
||||||
|
fy*m.cfg.GridPerSide+fx,
|
||||||
|
fy*m.cfg.GridPerSide+cx,
|
||||||
|
cy*m.cfg.GridPerSide+fx,
|
||||||
|
cy*m.cfg.GridPerSide+cx,
|
||||||
|
)
|
||||||
|
|
||||||
|
dy := y - float32(fy)
|
||||||
|
dx := x0 - float32(fx)
|
||||||
|
weights = append(weights,
|
||||||
|
(1-dy)*(1-dx),
|
||||||
|
(1-dy)*dx,
|
||||||
|
dy*(1-dx),
|
||||||
|
dy*dx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
idxArr := mlx.FromValues(indices, int(L), 4)
|
||||||
|
wArr := mlx.FromValues(weights, int(L), 4, 1)
|
||||||
|
|
||||||
|
pos := m.PositionEmbed.Forward(idxArr)
|
||||||
|
wArr = wArr.AsType(pos.DType())
|
||||||
|
pos = mlx.Sum(mlx.Mul(pos, wArr), 1, false)
|
||||||
|
if D != int32(pos.Dim(1)) {
|
||||||
|
return nil, fmt.Errorf("position embedding dim mismatch: hidden=%d pos=%d", D, pos.Dim(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
pos = mlx.ExpandDims(pos, 0)
|
||||||
|
if B > 1 {
|
||||||
|
pos = mlx.Tile(pos, []int32{B, 1, 1})
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlx.Add(x, pos), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) rotaryEmbeddings(grid *VisionGrid) (*mlx.Array, *mlx.Array, error) {
|
||||||
|
headDim := m.cfg.HiddenSize / m.cfg.NumHeads
|
||||||
|
if headDim <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid vision head dim: %d", headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
coords := mergedPatchCoordinates(grid, m.cfg.SpatialMergeSize)
|
||||||
|
L := int32(len(coords))
|
||||||
|
half := headDim / 2
|
||||||
|
quarter := half / 2
|
||||||
|
if quarter <= 0 {
|
||||||
|
return nil, nil, fmt.Errorf("invalid vision rotary layout: head_dim=%d", headDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
angles := make([]float32, L*headDim)
|
||||||
|
for i, c := range coords {
|
||||||
|
base := int32(i) * headDim
|
||||||
|
for j := range quarter {
|
||||||
|
freq := 1.0 / math.Pow(float64(m.cfg.RopeTheta), float64(2*j)/float64(half))
|
||||||
|
angles[base+j] = float32(float64(c[0]) * freq)
|
||||||
|
angles[base+quarter+j] = float32(float64(c[1]) * freq)
|
||||||
|
}
|
||||||
|
for j := range half {
|
||||||
|
angles[base+half+j] = angles[base+j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := mlx.FromValues(angles, int(L), int(headDim))
|
||||||
|
cos := mlx.ExpandDims(mlx.ExpandDims(mlx.Cos(arr), 0), 2)
|
||||||
|
sin := mlx.ExpandDims(mlx.ExpandDims(mlx.Sin(arr), 0), 2)
|
||||||
|
return cos, sin, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VisionModel) Forward(pixelValues *mlx.Array, grid *VisionGrid) (*mlx.Array, error) {
|
||||||
|
if m == nil || pixelValues == nil || grid == nil {
|
||||||
|
return nil, errNoVisionModel
|
||||||
|
}
|
||||||
|
if m.PatchProjection == nil || m.PatchMerger == nil {
|
||||||
|
return nil, errors.New("vision model is missing required projections")
|
||||||
|
}
|
||||||
|
|
||||||
|
x := m.PatchProjection.Forward(pixelValues)
|
||||||
|
var err error
|
||||||
|
x, err = m.addPositionEmbedding(x, grid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cos, sin, err := m.rotaryEmbeddings(grid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
x, err = layer.Forward(x, cos, sin, m.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("vision layer %d: %w", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
main, err := m.PatchMerger.Forward(x, m.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("vision patch merger: %w", err)
|
||||||
|
}
|
||||||
|
return main, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type VisionEmbeddings struct {
|
||||||
|
Main *mlx.Array
|
||||||
|
Grid *VisionGrid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) EncodeVisionImage(multimodalData []byte) (*VisionEmbeddings, error) {
|
||||||
|
if m == nil || m.Vision == nil || m.ImageProcessor == nil {
|
||||||
|
return nil, errNoVisionModel
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pixelValues, grid, err := m.ImageProcessor.ProcessImage(img)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
main, err := m.Vision.Forward(pixelValues, grid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &VisionEmbeddings{Main: main, Grid: grid}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveVisionPrefix(tensors map[string]*mlx.Array, weightPrefix string) string {
|
||||||
|
candidates := []string{
|
||||||
|
"vision_tower",
|
||||||
|
weightPrefix + "vision_tower",
|
||||||
|
"model.visual",
|
||||||
|
"visual",
|
||||||
|
weightPrefix + "model.visual",
|
||||||
|
weightPrefix + "visual",
|
||||||
|
}
|
||||||
|
|
||||||
|
hasTensor := func(prefix string) bool {
|
||||||
|
for _, suffix := range []string{
|
||||||
|
".patch_embed.proj.weight",
|
||||||
|
".patch_embed.weight",
|
||||||
|
".pos_embed.weight",
|
||||||
|
".blocks.0.attn.qkv.weight",
|
||||||
|
".merger.linear_fc1.weight",
|
||||||
|
".merger.mlp.0.weight",
|
||||||
|
} {
|
||||||
|
if tensors[prefix+suffix] != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range candidates {
|
||||||
|
if hasTensor(prefix) {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstLinear(linears mlxmodel.LinearFactory, paths ...string) nn.LinearLayer {
|
||||||
|
for _, p := range paths {
|
||||||
|
if l := linears.Make(p); l != nil {
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadLayerNorm(tensors map[string]*mlx.Array, eps float32, bases ...string) *nn.LayerNorm {
|
||||||
|
for _, base := range bases {
|
||||||
|
if w := tensors[base+".weight"]; w != nil {
|
||||||
|
return &nn.LayerNorm{Weight: w, Bias: tensors[base+".bias"], Eps: eps}
|
||||||
|
}
|
||||||
|
if w := tensors[base]; w != nil {
|
||||||
|
return &nn.LayerNorm{Weight: w, Bias: tensors[base+"_bias"], Eps: eps}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadVisionPatchMerger(
|
||||||
|
tensors map[string]*mlx.Array,
|
||||||
|
linears mlxmodel.LinearFactory,
|
||||||
|
eps float32,
|
||||||
|
bases ...string,
|
||||||
|
) *VisionPatchMerger {
|
||||||
|
for _, base := range bases {
|
||||||
|
norm := loadLayerNorm(tensors, eps, base+".norm", base+".ln_q")
|
||||||
|
fc1 := firstLinear(linears, base+".linear_fc1", base+".mlp.0")
|
||||||
|
fc2 := firstLinear(linears, base+".linear_fc2", base+".mlp.2")
|
||||||
|
if norm != nil && fc1 != nil && fc2 != nil {
|
||||||
|
return &VisionPatchMerger{Norm: norm, FC1: fc1, FC2: fc2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenPatchEmbeddingWeight(w *mlx.Array) (*mlx.Array, error) {
|
||||||
|
if w == nil || !w.Valid() {
|
||||||
|
return nil, errors.New("missing patch embedding weight")
|
||||||
|
}
|
||||||
|
if w.NumDims() < 2 {
|
||||||
|
return nil, fmt.Errorf("patch embedding weight must be >=2D, got %dD", w.NumDims())
|
||||||
|
}
|
||||||
|
if w.NumDims() == 2 {
|
||||||
|
return w, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := int32(w.Dim(0))
|
||||||
|
in := int32(w.Size() / w.Dim(0))
|
||||||
|
return mlx.Reshape(w, out, in), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadVisionComponents(
|
||||||
|
tensors map[string]*mlx.Array,
|
||||||
|
linears mlxmodel.LinearFactory,
|
||||||
|
cfg *Config,
|
||||||
|
weightPrefix string,
|
||||||
|
) (*VisionModel, *VisionImageProcessor, error) {
|
||||||
|
if cfg == nil || cfg.Vision == nil || cfg.Vision.Depth <= 0 {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
cfg.Vision.applyDefaults()
|
||||||
|
|
||||||
|
visionPrefix := resolveVisionPrefix(tensors, weightPrefix)
|
||||||
|
if visionPrefix == "" {
|
||||||
|
return nil, nil, errors.New("vision enabled in config but vision tensors were not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
patchW, _ := tensorAny(
|
||||||
|
tensors,
|
||||||
|
visionPrefix+".patch_embed.proj.weight",
|
||||||
|
visionPrefix+".patch_embed.weight",
|
||||||
|
)
|
||||||
|
if patchW == nil {
|
||||||
|
return nil, nil, fmt.Errorf("missing vision patch embedding weight under %s", visionPrefix)
|
||||||
|
}
|
||||||
|
patchW, err := flattenPatchEmbeddingWeight(patchW)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
patchB, _ := tensorAny(
|
||||||
|
tensors,
|
||||||
|
visionPrefix+".patch_embed.proj.bias",
|
||||||
|
visionPrefix+".patch_embed.bias",
|
||||||
|
)
|
||||||
|
|
||||||
|
patchProj := nn.NewLinear(patchW, patchB)
|
||||||
|
if got := int32(patchW.Dim(1)); got != cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize {
|
||||||
|
return nil, nil, fmt.Errorf(
|
||||||
|
"vision patch embedding input dim mismatch: got %d expected %d",
|
||||||
|
got,
|
||||||
|
cfg.Vision.InChannels*cfg.Vision.TemporalPatchSize*cfg.Vision.PatchSize*cfg.Vision.PatchSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
posW, _ := tensorAny(
|
||||||
|
tensors,
|
||||||
|
visionPrefix+".pos_embed.weight",
|
||||||
|
visionPrefix+".position_embedding.weight",
|
||||||
|
)
|
||||||
|
if posW == nil {
|
||||||
|
return nil, nil, fmt.Errorf("missing vision position embedding under %s", visionPrefix)
|
||||||
|
}
|
||||||
|
cfg.Vision.NumPositionEmbeddings = int32(posW.Dim(0))
|
||||||
|
cfg.Vision.applyDefaults()
|
||||||
|
|
||||||
|
vm := &VisionModel{
|
||||||
|
PatchProjection: patchProj,
|
||||||
|
PositionEmbed: nn.NewEmbedding(posW),
|
||||||
|
Layers: make([]*VisionEncoderLayer, cfg.Vision.Depth),
|
||||||
|
cfg: cfg.Vision,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range cfg.Vision.Depth {
|
||||||
|
layerPrefix := fmt.Sprintf("%s.blocks.%d", visionPrefix, i)
|
||||||
|
layer := &VisionEncoderLayer{
|
||||||
|
Norm1: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm1"),
|
||||||
|
Norm2: loadLayerNorm(tensors, cfg.Vision.LayerNormEpsilon, layerPrefix+".norm2"),
|
||||||
|
Attn: &VisionAttention{
|
||||||
|
QKV: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.qkv",
|
||||||
|
layerPrefix+".attn_qkv",
|
||||||
|
),
|
||||||
|
Query: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.q_proj",
|
||||||
|
layerPrefix+".attn_q",
|
||||||
|
),
|
||||||
|
Key: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.k_proj",
|
||||||
|
layerPrefix+".attn_k",
|
||||||
|
),
|
||||||
|
Value: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.v_proj",
|
||||||
|
layerPrefix+".attn_v",
|
||||||
|
),
|
||||||
|
Output: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".attn.proj",
|
||||||
|
layerPrefix+".attn_out",
|
||||||
|
layerPrefix+".attn.o_proj",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
MLP: &VisionMLP{
|
||||||
|
FC1: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".mlp.fc1",
|
||||||
|
layerPrefix+".mlp.linear_fc1",
|
||||||
|
),
|
||||||
|
FC2: firstLinear(
|
||||||
|
linears,
|
||||||
|
layerPrefix+".mlp.fc2",
|
||||||
|
layerPrefix+".mlp.linear_fc2",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if layer.Norm1 == nil || layer.Norm2 == nil {
|
||||||
|
return nil, nil, fmt.Errorf("vision layer %d: missing norm1/norm2", i)
|
||||||
|
}
|
||||||
|
if layer.Attn.Output == nil || (layer.Attn.QKV == nil && (layer.Attn.Query == nil || layer.Attn.Key == nil || layer.Attn.Value == nil)) {
|
||||||
|
return nil, nil, fmt.Errorf("vision layer %d: missing attention projections", i)
|
||||||
|
}
|
||||||
|
if layer.MLP.FC1 == nil || layer.MLP.FC2 == nil {
|
||||||
|
return nil, nil, fmt.Errorf("vision layer %d: missing mlp projections", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
vm.Layers[i] = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
vm.PatchMerger = loadVisionPatchMerger(
|
||||||
|
tensors,
|
||||||
|
linears,
|
||||||
|
cfg.Vision.LayerNormEpsilon,
|
||||||
|
visionPrefix+".merger",
|
||||||
|
)
|
||||||
|
if vm.PatchMerger == nil {
|
||||||
|
return nil, nil, fmt.Errorf("missing vision patch merger under %s", visionPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return vm, newVisionImageProcessor(cfg.Vision), nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user