Compare commits

...

16 Commits

Author SHA1 Message Date
Eva Ho
7a3ed0a1b4 adding test 2026-04-14 15:28:40 -07:00
Eva Ho
03f9e57274 add test 2026-04-14 15:28:40 -07:00
Eva Ho
30d9100fff launch: add thinking capability detection to opencode 2026-04-14 15:28:40 -07:00
Eva H
698e04a14b launch: OpenCode inline config (#15586) 2026-04-14 15:08:42 -07:00
Eva H
1d9537bc33 launch/openclaw: fix --yes flag behaviour to skip channels configuration (#15589) 2026-04-14 13:57:35 -07:00
Eva H
120424d832 Revert "launch/opencode: use inline config (#15462)" (#15568) 2026-04-13 18:40:17 -07:00
Eva H
5818001610 launch: skip unchanged integration rewrite configration (#15491) 2026-04-13 17:18:56 -07:00
Daniel Hiltgen
2cba7756c5 Gemma4 on MLX (#15244)
* gemma4: implement Gemma 4 model for MLX (text-only runtime)

* gemma4: two MoE + SWA prefill perf fixes

Two performance optimizations in the gemma4 forward pass

1. Memoize the sliding-window prefill mask across layers.
2. Softmax only over the selected experts in Router.Forward.

* review comments
2026-04-13 16:36:51 -07:00
Devon Rifkin
bf2a421727 gemma4: restore e2b-style nothink prompt (#15560)
Gemma 4 prompts differ when thinking is disabled for different sized
models: 26b/31b emit an empty thought block, while e2b/e4b do not.

Before #15490, our shared Gemma 4 renderer effectively matched the
e2b behavior. #15490 changed it to always emit the empty thought block,
which regressed e2b/e4b nothink behavior and led to #15536 (and possibly

This change restores the previous shared behavior by removing the empty
trailing thought block. It also renames the checked-in upstream chat
templates so the e2b and 31b fixtures are tracked separately.

A follow-up will split Gemma 4 rendering by model size.

Fixes: #15536
2026-04-13 14:26:15 -07:00
Eva H
f3cf6b75fb launch/opencode: use inline config (#15462) 2026-04-13 13:41:31 -07:00
Devon Rifkin
5dfac387a6 Revert "gemma4: fix nothink case renderer (#15553)" (#15556)
This reverts commit 4d75f5da03.
2026-04-13 13:12:18 -07:00
Daniel Hiltgen
a99e5d9c22 mac: prevent generate on cross-compiles (#15120)
For some versions of Xcode, cmake builds are failing due to header problems in
cross-compiling during the generate phase.  Since generate is producing arch
independent generated output, we can skip this during cross-compiling.
2026-04-13 13:04:58 -07:00
Daniel Hiltgen
0abf3aca36 cgo: suppress deprecated warning to quiet down go build (#15438) 2026-04-13 13:04:11 -07:00
Devon Rifkin
ee0266462a Revert "gemma4: add nothink renderer tests (#15554)" (#15555)
This reverts commit 1b70bb8a10.
2026-04-13 13:00:59 -07:00
Daniel Hiltgen
c88fb286ec mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA (#14913)
* mlx: add op wrappers for Conv2d, Pad, activations, trig, and masked SDPA

Add Conv2d, flexible Pad (with axes/mode), PadConstant, Maximum,
Minimum, Softplus, ReLU, GLU, Clamp, Sin, Cos, Clip,
ScaledDotProductAttentionMasked, and RoPEWithFreqs. Refactor
RoPEWithBase to delegate to RoPEWithFreqs.

* review comments

* mlx: fix ScaledDotProductAttentionMasked to consult the mask argument
2026-04-13 11:43:24 -07:00
Daniel Hiltgen
d3da29cbfc mlx: mixed-precision quant and capability detection improvements (#15409)
Improve the MLX model creation pipeline with several model-agnostic changes:

- Rewrite supportsVision to use vision_config instead of architecture name
- Add supportsAudio for audio encoder detection
- Add alignment checking (isAligned) for quantization group sizes
- Support per-projection mixed quantization in MoE expert packing
- Record per-tensor quant metadata in safetensors blobs
- Parse per-tensor quant metadata at model load time
- Validate quantize output is non-empty before storing
- Fix pin/unpin cleanup in expert group quantization
- Promote v_proj/k_proj/down_proj to INT8 for INT4 base quant
- Add MetalIsAvailable() utility
- Skip audio encoder tensors from quantization
2026-04-13 11:43:07 -07:00
28 changed files with 4760 additions and 1070 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
)
type stubEditorRunner struct {
@@ -722,6 +723,59 @@ func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t
}
}
func TestSavedMatchesModels(t *testing.T) {
tests := []struct {
name string
saved *config.IntegrationConfig
models []string
want bool
}{
{
name: "nil saved",
saved: nil,
models: []string{"llama3.2"},
want: false,
},
{
name: "identical order",
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
models: []string{"llama3.2", "qwen3:8b"},
want: true,
},
{
name: "different order",
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
models: []string{"qwen3:8b", "llama3.2"},
want: false,
},
{
name: "subset",
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
models: []string{"llama3.2"},
want: false,
},
{
name: "nil models in saved with non-nil models",
saved: &config.IntegrationConfig{Models: nil},
models: []string{"llama3.2"},
want: false,
},
{
name: "empty both",
saved: &config.IntegrationConfig{Models: nil},
models: nil,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := savedMatchesModels(tt.saved, tt.models); got != tt.want {
t.Fatalf("savedMatchesModels = %v, want %v", got, tt.want)
}
})
}
}
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"os"
"slices"
"strings"
"github.com/ollama/ollama/api"
@@ -500,7 +501,7 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
return nil
}
if needsConfigure || req.ModelOverride != "" {
if (needsConfigure || req.ModelOverride != "") && !savedMatchesModels(saved, models) {
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
return err
}
@@ -846,6 +847,13 @@ func firstModel(models []string) string {
return models[0]
}
func savedMatchesModels(saved *config.IntegrationConfig, models []string) bool {
if saved == nil {
return false
}
return slices.Equal(saved.Models, models)
}
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
if override == "" {
if saved == nil {

View File

@@ -186,6 +186,11 @@ func (c *Openclaw) runChannelSetupPreflight(bin string) error {
if !isInteractiveSession() {
return nil
}
// --yes is headless; channel setup spawns an interactive picker we can't
// auto-answer, so skip it. Users can run `openclaw channels add` later.
if currentLaunchConfirmPolicy.yes {
return nil
}
for {
if c.channelsConfigured() {

View File

@@ -1304,6 +1304,46 @@ func TestOpenclawChannelSetupPreflight(t *testing.T) {
}
})
t.Run("--yes skips preflight without channels configured", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", tmpDir)
configDir := filepath.Join(tmpDir, ".openclaw")
if err := os.MkdirAll(configDir, 0o755); err != nil {
t.Fatal(err)
}
// Empty config = no channels configured. Without the --yes skip, the
// preflight would prompt and (on confirm) spawn `openclaw channels add`.
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644); err != nil {
t.Fatal(err)
}
bin := filepath.Join(tmpDir, "openclaw")
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
t.Fatal(err)
}
oldInteractive := isInteractiveSession
isInteractiveSession = func() bool { return true }
defer func() { isInteractiveSession = oldInteractive }()
restore := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
defer restore()
oldConfirmPrompt := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect prompt in --yes mode: %s", prompt)
return false, nil
}
defer func() { DefaultConfirmPrompt = oldConfirmPrompt }()
if err := c.runChannelSetupPreflight("openclaw"); err != nil {
t.Fatalf("runChannelSetupPreflight() error = %v", err)
}
if _, err := os.Stat(filepath.Join(tmpDir, "invocations.log")); !os.IsNotExist(err) {
t.Fatalf("expected no channels add invocation in --yes mode, got err=%v", err)
}
})
t.Run("set up later prompts once and exits", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)

View File

@@ -1,9 +1,10 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"maps"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -11,12 +12,18 @@ import (
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
modeltype "github.com/ollama/ollama/types/model"
)
// OpenCode implements Runner and Editor for OpenCode integration
type OpenCode struct{}
// OpenCode implements Runner and Editor for OpenCode integration.
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
// instead of writing to opencode's config files.
type OpenCode struct {
configContent string // JSON config built by Edit, passed to Run via env var
}
func (o *OpenCode) String() string { return "OpenCode" }
@@ -51,25 +58,51 @@ func (o *OpenCode) Run(model string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = os.Environ()
if content := o.resolveContent(model); content != "" {
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
}
return cmd.Run()
}
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
// Returns content built by Edit if available, otherwise builds from model.json
// with the requested model as primary (e.g. re-launch with saved config).
func (o *OpenCode) resolveContent(model string) string {
if o.configContent != "" {
return o.configContent
}
models := readModelJSONModels()
if !slices.Contains(models, model) {
models = append([]string{model}, models...)
}
content, err := buildInlineConfig(model, models)
if err != nil {
return ""
}
return content
}
func (o *OpenCode) Paths() []string {
home, err := os.UserHomeDir()
sp, err := openCodeStatePath()
if err != nil {
return nil
}
var paths []string
p := filepath.Join(home, ".config", "opencode", "opencode.json")
if _, err := os.Stat(p); err == nil {
paths = append(paths, p)
}
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
if _, err := os.Stat(sp); err == nil {
paths = append(paths, sp)
return []string{sp}
}
return paths
return nil
}
// openCodeStatePath returns the path to opencode's model state file.
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
func openCodeStatePath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
}
func (o *OpenCode) Edit(modelList []string) error {
@@ -77,110 +110,17 @@ func (o *OpenCode) Edit(modelList []string) error {
return nil
}
home, err := os.UserHomeDir()
content, err := buildInlineConfig(modelList[0], modelList)
if err != nil {
return err
}
o.configContent = content
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
}
config["$schema"] = "https://opencode.ai/config.json"
provider, ok := config["provider"].(map[string]any)
if !ok {
provider = make(map[string]any)
}
ollama, ok := provider["ollama"].(map[string]any)
if !ok {
ollama = map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
}
}
// Migrate legacy provider name
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
ollama["name"] = "Ollama"
}
models, ok := ollama["models"].(map[string]any)
if !ok {
models = make(map[string]any)
}
selectedSet := make(map[string]bool)
for _, m := range modelList {
selectedSet[m] = true
}
for name, cfg := range models {
if cfgMap, ok := cfg.(map[string]any); ok {
if isOllamaModel(cfgMap) && !selectedSet[name] {
delete(models, name)
}
}
}
for _, model := range modelList {
if existing, ok := models[model].(map[string]any); ok {
// migrate existing models without _launch marker
if isOllamaModel(existing) {
existing["_launch"] = true
if name, ok := existing["name"].(string); ok {
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
}
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
existing["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
continue
}
entry := map[string]any{
"name": model,
"_launch": true,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
models[model] = entry
}
ollama["models"] = models
provider["ollama"] = ollama
config["provider"] = provider
config["model"] = "ollama/" + modelList[0]
configData, err := json.MarshalIndent(config, "", " ")
// Write model state file so models appear in OpenCode's model picker
statePath, err := openCodeStatePath()
if err != nil {
return err
}
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
return err
}
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
return err
}
@@ -232,33 +172,127 @@ func (o *OpenCode) Edit(modelList []string) error {
}
func (o *OpenCode) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
if err != nil {
return nil
}
provider, _ := config["provider"].(map[string]any)
ollama, _ := provider["ollama"].(map[string]any)
models, _ := ollama["models"].(map[string]any)
if len(models) == 0 {
return nil
}
keys := slices.Collect(maps.Keys(models))
slices.Sort(keys)
return keys
return nil
}
// isOllamaModel reports whether a model config entry is managed by us
func isOllamaModel(cfg map[string]any) bool {
if v, ok := cfg["_launch"].(bool); ok && v {
return true
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
// primary is the model to launch with, models is the full list of available models.
func buildInlineConfig(primary string, models []string) (string, error) {
if primary == "" || len(models) == 0 {
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
}
// previously used [Ollama] as a suffix for the model managed by ollama launch
if name, ok := cfg["name"].(string); ok {
return strings.HasSuffix(name, "[Ollama]")
config := map[string]any{
"$schema": "https://opencode.ai/config.json",
"provider": map[string]any{
"ollama": map[string]any{
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": map[string]any{
"baseURL": envconfig.Host().String() + "/v1",
},
"models": buildModelEntries(models),
},
},
"model": "ollama/" + primary,
}
data, err := json.Marshal(config)
if err != nil {
return "", err
}
return string(data), nil
}
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
func readModelJSONModels() []string {
statePath, err := openCodeStatePath()
if err != nil {
return nil
}
data, err := os.ReadFile(statePath)
if err != nil {
return nil
}
var state map[string]any
if err := json.Unmarshal(data, &state); err != nil {
return nil
}
recent, _ := state["recent"].([]any)
var models []string
for _, entry := range recent {
e, ok := entry.(map[string]any)
if !ok {
continue
}
if e["providerID"] != "ollama" {
continue
}
if id, ok := e["modelID"].(string); ok && id != "" {
models = append(models, id)
}
}
return models
}
func buildModelEntries(modelList []string) map[string]any {
client := api.NewClient(envconfig.Host(), http.DefaultClient)
ctx := context.Background()
models := make(map[string]any)
for _, model := range modelList {
entry := map[string]any{
"name": model,
}
if isCloudModelName(model) {
if l, ok := lookupCloudModelLimit(model); ok {
entry["limit"] = map[string]any{
"context": l.Context,
"output": l.Output,
}
}
}
applyOpenCodeReasoning(ctx, client, model, entry)
models[model] = entry
}
return models
}
// applyOpenCodeReasoning detects thinking capability and sets reasoning config
// on the model entry. When the model supports thinking, it sets "reasoning": true
// and configures variants for the OpenCode TUI:
// - GPT-OSS: supports variable effort levels (low/medium/high) and defaults to
// medium via options. Thinking cannot be turned off.
// - Other models: only support on/off. Disables built-in low/medium/high variants
// and adds a "none" variant so users can toggle thinking off via Ctrl+T.
//
// When the model does not support thinking, no reasoning config is set.
func applyOpenCodeReasoning(ctx context.Context, client *api.Client, modelName string, entry map[string]any) {
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return
}
if slices.Contains(resp.Capabilities, modeltype.CapabilityThinking) {
entry["reasoning"] = true
if strings.Contains(modelName, "gpt-oss") {
// GPT-OSS models support variable thinking effort levels
// and cannot turn thinking off. Keep the built-in
// low/medium/high variants as-is and default to medium.
options, ok := entry["options"].(map[string]any)
if !ok {
options = make(map[string]any)
}
options["reasoningEffort"] = "medium"
entry["options"] = options
} else {
// Most models only support thinking on or off.
// Disable the built-in low/medium/high variants and add none.
entry["variants"] = map[string]any{
"none": map[string]any{"reasoningEffort": "none"},
"low": map[string]any{"disabled": true},
"medium": map[string]any{"disabled": true},
"high": map[string]any{"disabled": true},
}
}
}
return false
}

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
binary: "opencode",
runner: &OpenCode{},
checkPath: func(home string) string {
return filepath.Join(home, ".config", "opencode", "opencode.json")
return filepath.Join(home, ".local", "state", "opencode", "model.json")
},
},
{

View File

@@ -28,79 +28,4 @@ To configure without launching:
ollama launch opencode --config
```
### Manual setup
Add a configuration block to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"qwen3-coder": {
"name": "qwen3-coder"
}
}
}
}
}
```
## Cloud Models
`glm-4.7:cloud` is the recommended model for use with OpenCode.
Add the cloud configuration to `~/.config/opencode/opencode.json`:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama",
"options": {
"baseURL": "http://localhost:11434/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
## Connecting to ollama.com
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
```json
{
"$schema": "https://opencode.ai/config.json",
"provider": {
"ollama": {
"npm": "@ai-sdk/openai-compatible",
"name": "Ollama Cloud",
"options": {
"baseURL": "https://ollama.com/v1"
},
"models": {
"glm-4.7:cloud": {
"name": "glm-4.7:cloud"
}
}
}
}
}
```
Run `opencode` in a new terminal to load the new settings.
<Note>`ollama launch opencode` passes its configuration to OpenCode inline via the `OPENCODE_CONFIG_CONTENT` environment variable. OpenCode deep-merges its config sources on startup, so anything you declare in `~/.config/opencode/opencode.json` is still respected and available inside OpenCode. Models declared only in `opencode.json` won't appear in `ollama launch`'s model-selection menu.</Note>

View File

@@ -1,6 +1,6 @@
package common
// #cgo CXXFLAGS: -std=c++17
// #cgo CXXFLAGS: -std=c++17 -Wno-deprecated-declarations
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
import "C"

View File

@@ -40,7 +40,6 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
// Emit system turn if there's a system/developer role, tools, or thinking.
hasThink := thinkValue != nil && thinkValue.Bool()
thinkingExplicitlyDisabled := thinkValue != nil && thinkValue.IsBool() && !thinkValue.Bool()
if hasSystemRole || len(tools) > 0 || hasThink {
sb.WriteString("<|turn>system\n")
if hasThink {
@@ -125,9 +124,6 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
// Generation prompt.
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
sb.WriteString("<|turn>model\n")
if !hasThink && !thinkingExplicitlyDisabled {
sb.WriteString("<|channel>thought\n<channel|>")
}
}
return sb.String(), nil

View File

@@ -1,14 +1,18 @@
package renderers
// TestGemma4RendererMatchesReference verifies our renderer matches the HF
// Jinja2 chat template exactly.
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
// Gemma 4 reference template.
//
// To regenerate expected values, save gemma4Jinja2Template (below) to
// gemma4_chat_template.jinja2 and run:
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
// reference intentionally uses the shared baseline without an empty generation-time
// thought channel until renderer selection is split by size.
//
// To regenerate expected values, save the E2B template to
// gemma4_e2b_chat_template.jinja2 and run:
//
// python3 -c "
// from jinja2 import Environment; import json
// tmpl = Environment().from_string(open('gemma4_chat_template.jinja2').read())
// tmpl = Environment().from_string(open('gemma4_e2b_chat_template.jinja2').read())
// msgs = [{'role':'user','content':'Hello'}]
// print(repr(tmpl.render(messages=msgs, bos_token='<bos>', add_generation_prompt=True)))
// "
@@ -26,8 +30,13 @@ import (
"github.com/stretchr/testify/assert"
)
// The full Jinja2 template is committed as testdata/gemma4_chat_template.jinja2.
// Run with VERIFY_JINJA2=1 to verify expected values against the template using uv + Python.
const (
gemma4E2BTemplate = "testdata/gemma4_e2b_chat_template.jinja2"
gemma431BTemplate = "testdata/gemma4_31b_chat_template.jinja2"
)
// The upstream Gemma 4 chat templates are committed by size under testdata/.
// Run with VERIFY_JINJA2=1 to verify expected values against the E2B template using uv + Python.
func bashRefTool() []api.Tool {
return []api.Tool{{
@@ -665,7 +674,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{
name: "user_only",
messages: []api.Message{{Role: "user", Content: "Hello"}},
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "system_user",
@@ -673,7 +682,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "developer_user",
@@ -681,13 +690,13 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{Role: "developer", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "tools_no_system",
messages: []api.Message{{Role: "user", Content: "Hi"}},
tools: bashRefTool(),
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "system_tools",
@@ -696,7 +705,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
{Role: "user", Content: "Hi"},
},
tools: bashRefTool(),
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "thinking_no_system",
@@ -704,13 +713,6 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
think: thinkTrue(),
expected: "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "nothink_no_system",
messages: []api.Message{{Role: "user", Content: "Hi"}},
think: thinkFalse(),
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
skipJinja2: true,
},
{
name: "thinking_system",
messages: []api.Message{
@@ -737,6 +739,12 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
think: thinkTrue(),
expected: "<bos><|turn>system\n<|think|>\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
name: "thinking_explicitly_disabled",
messages: []api.Message{{Role: "user", Content: "Hi"}},
think: thinkFalse(),
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
},
// === Message loop paths ===
{
@@ -751,7 +759,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
"<|turn>user\nHi<turn|>\n" +
"<|turn>model\nHello!<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Tool call with structured args → tool response as separate <|turn>tool turn
@@ -813,7 +821,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
"<|tool_response>response:bash{value:" + q + "file1.txt\nfile2.txt" + q + "}<tool_response|>" +
"Here are the files.<turn|>\n" +
"<|turn>user\nRead file1.txt<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Multiple tool calls + multiple tool responses
@@ -848,7 +856,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
expected: "<bos><|turn>user\nWhat is 2+2?<turn|>\n" +
"<|turn>model\n4<turn|>\n" +
"<|turn>user\nAnd 3+3?<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
// === Additional edge cases ported from original tests ===
{
@@ -906,17 +914,17 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
messages: []api.Message{{Role: "user", Content: "Test"}},
tools: modeTool(),
expected: "<bos><|turn>system\n" + modeDeclRef + "<turn|>\n" +
"<|turn>user\nTest<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nTest<turn|>\n<|turn>model\n",
},
{
name: "unicode_content",
messages: []api.Message{{Role: "user", Content: "こんにちは"}},
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n",
},
{
name: "newlines_in_content",
messages: []api.Message{{Role: "user", Content: "Line 1\nLine 2\nLine 3"}},
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n",
},
{
// Tool response (raw JSON) followed by user message
@@ -935,7 +943,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
"<|turn>model\n<|tool_call>call:get_weather{city:" + q + "Tokyo" + q + "}<tool_call|>" +
"<|tool_response>response:get_weather{value:" + q + `{"temperature": 15, "weather": "sunny"}` + q + "}<tool_response|>" +
"<|turn>user\nThanks!<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
// === Ordering and whitespace edge cases ===
{
@@ -958,7 +966,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
// User content with whitespace is trimmed
name: "user_content_trimmed",
messages: []api.Message{{Role: "user", Content: " hello "}},
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n",
},
{
// Empty tool call arguments
@@ -982,7 +990,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
messages: []api.Message{{Role: "user", Content: "Create"}},
tools: nestedTool(),
expected: "<bos><|turn>system\n" + nestedDeclRef + "<turn|>\n" +
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
},
{
// Array type in tool declaration
@@ -990,7 +998,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
messages: []api.Message{{Role: "user", Content: "Batch"}},
tools: arrayTool(),
expected: "<bos><|turn>system\n" + arrayDeclRef + "<turn|>\n" +
"<|turn>user\nBatch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nBatch<turn|>\n<|turn>model\n",
},
{
// Top-level typed union follows the template's odd stringified-list form.
@@ -1002,8 +1010,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
<|turn>user
Hi<turn|>
<|turn>model
<|channel>thought
<channel|>`,
`,
},
{
// Assistant whitespace is trimmed (strip_thinking includes | trim)
@@ -1016,7 +1023,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\nspaced<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Three sequential tool responses
@@ -1071,7 +1078,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\nMiddleDone<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Property with no description — just type
@@ -1079,7 +1086,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Count"}},
tools: countTool(),
expected: "<bos><|turn>system\n" + countDeclRef + "<turn|>\n" +
"<|turn>user\nCount<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nCount<turn|>\n<|turn>model\n",
},
{
// System message with leading/trailing whitespace is trimmed
@@ -1089,7 +1096,7 @@ Hi<turn|>
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n" +
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
// Deeply nested map in tool call arguments (3 levels)
@@ -1151,7 +1158,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Set"}},
tools: enumNoDescTool(),
expected: "<bos><|turn>system\n" + enumNoDescDeclRef + "<turn|>\n" +
"<|turn>user\nSet<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nSet<turn|>\n<|turn>model\n",
},
{
// System message that is only whitespace (trims to empty)
@@ -1161,7 +1168,7 @@ Hi<turn|>
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\n<turn|>\n" +
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
// Empty assistant content (empty string, not nil)
@@ -1174,7 +1181,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\n<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Map argument with string keys (keys NOT escaped with <|"|>)
@@ -1200,7 +1207,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Search"}},
tools: searchTool(),
expected: "<bos><|turn>system\n" + searchDeclRef + "<turn|>\n" +
"<|turn>user\nSearch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nSearch<turn|>\n<|turn>model\n",
},
// === Round 3 coverage gaps ===
@@ -1228,7 +1235,7 @@ Hi<turn|>
{Role: "user", Content: "Hi"},
},
expected: "<bos><|turn>system\n<turn|>\n" +
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nHi<turn|>\n<|turn>model\n",
},
{
// Nested OBJECT property with required field
@@ -1236,7 +1243,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Create"}},
tools: nestedRequiredTool(),
expected: "<bos><|turn>system\n" + nestedRequiredDeclRef + "<turn|>\n" +
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
},
{
// Non-integer float in tool call argument
@@ -1263,7 +1270,7 @@ Hi<turn|>
},
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\nResult<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Tool content with newlines and leading/trailing whitespace trimmed
@@ -1287,7 +1294,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Raw"}},
tools: rawTool(),
expected: "<bos><|turn>system\n" + rawDeclRef + "<turn|>\n" +
"<|turn>user\nRaw<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nRaw<turn|>\n<|turn>model\n",
},
{
// Multiple required fields at top level
@@ -1295,7 +1302,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Move"}},
tools: moveTool(),
expected: "<bos><|turn>system\n" + moveDeclRef + "<turn|>\n" +
"<|turn>user\nMove<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nMove<turn|>\n<|turn>model\n",
},
{
// Assistant content that is ONLY thinking (strips to empty)
@@ -1308,7 +1315,7 @@ Hi<turn|>
expected: "<bos><|turn>user\nHi<turn|>\n" +
"<|turn>model\n<turn|>\n" +
"<|turn>user\nMore<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
// === Round 4: final coverage gaps ===
@@ -1341,7 +1348,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Tag"}},
tools: arrayNoItemsTool(),
expected: "<bos><|turn>system\n" + arrayNoItemsDeclRef + "<turn|>\n" +
"<|turn>user\nTag<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nTag<turn|>\n<|turn>model\n",
},
{
// OBJECT property without description but with nested properties
@@ -1349,7 +1356,7 @@ Hi<turn|>
messages: []api.Message{{Role: "user", Content: "Update"}},
tools: objectNoDescTool(),
expected: "<bos><|turn>system\n" + objectNoDescDeclRef + "<turn|>\n" +
"<|turn>user\nUpdate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>user\nUpdate<turn|>\n<|turn>model\n",
},
// === Round 5: coding agent patterns ===
@@ -1379,7 +1386,7 @@ Hi<turn|>
"<|tool_response>response:bash{value:" + q + q + "}<tool_response|>" +
"Done.<turn|>\n" +
"<|turn>user\nThanks<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Tool call with thinking that strips to real remaining content
@@ -1399,7 +1406,7 @@ Hi<turn|>
"<|tool_response>response:bash{value:" + q + "main.go\ngo.mod" + q + "}<tool_response|>" +
"Let me list the files.<turn|>\n" +
"<|turn>user\nOK<turn|>\n" +
"<|turn>model\n<|channel>thought\n<channel|>",
"<|turn>model\n",
},
{
// Argument value containing newlines (multi-line script)
@@ -1635,7 +1642,6 @@ func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
name string
messages []api.Message
tools []api.Tool
think *api.ThinkValue
wantJinjaFrag string
wantRenderFrag string
}{
@@ -1684,22 +1690,15 @@ func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
wantJinjaFrag: `response:read{value:<|"|>payload<|"|>}`,
wantRenderFrag: `response:unknown{value:<|"|>payload<|"|>}`,
},
{
name: "explicit_nothink_skips_empty_thought_channel",
messages: []api.Message{{Role: "user", Content: "Hi"}},
think: thinkFalse(),
wantJinjaFrag: "<|turn>model\n<|channel>thought\n<channel|>",
wantRenderFrag: "<|turn>model\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
got, err := renderer.Render(tt.messages, tt.tools, nil)
assert.NoError(t, err)
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, nil)
assert.NotEqual(t, jinja2Output, got, "case no longer differs from Jinja2 output")
assert.Contains(t, jinja2Output, tt.wantJinjaFrag)
assert.Contains(t, got, tt.wantRenderFrag)
@@ -1735,12 +1734,35 @@ func TestGemma4RendererToolResponseWithoutNameOrIDUsesUnknown(t *testing.T) {
assert.NotContains(t, got, `response:read{value:<|"|>payload<|"|>}`)
}
func TestGemma4SizeTemplateFixturesDifferAtGenerationPrompt(t *testing.T) {
e2b, err := os.ReadFile(gemma4E2BTemplate)
if err != nil {
t.Fatalf("failed to read %s: %v", gemma4E2BTemplate, err)
}
thirtyOneB, err := os.ReadFile(gemma431BTemplate)
if err != nil {
t.Fatalf("failed to read %s: %v", gemma431BTemplate, err)
}
assert.Contains(t, string(e2b), "{{- '<|turn>model\\n' -}}")
assert.NotContains(t, string(e2b), "{{- '<|channel>thought\\n<channel|>' -}}")
assert.Contains(t, string(thirtyOneB), "{{- '<|turn>model\\n' -}}")
assert.Contains(t, string(thirtyOneB), "{{- '<|channel>thought\\n<channel|>' -}}")
}
// renderWithJinja2 shells out to uv + Python to render messages through the
// Jinja2 chat template. Returns the rendered string.
// E2B Jinja2 chat template. Returns the rendered string.
func renderWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
return renderWithJinja2Template(t, gemma4E2BTemplate, messages, tools, think)
}
// renderWithJinja2Template shells out to uv + Python to render messages through
// the named Jinja2 chat template. Returns the rendered string.
func renderWithJinja2Template(t *testing.T, templateRelPath string, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
t.Helper()
templatePath, err := filepath.Abs("testdata/gemma4_chat_template.jinja2")
templatePath, err := filepath.Abs(templateRelPath)
if err != nil {
t.Fatalf("failed to get template path: %v", err)
}

View File

@@ -0,0 +1,344 @@
{%- macro format_parameters(properties, required) -%}
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in properties | dictsort -%}
{%- set add_comma = false -%}
{%- if key not in standard_keys -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{{ key }}:{
{%- if value['description'] -%}
description:<|"|>{{ value['description'] }}<|"|>
{%- set add_comma = true -%}
{%- endif -%}
{%- if value['type'] | upper == 'STRING' -%}
{%- if value['enum'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
enum:{{ format_argument(value['enum']) }}
{%- endif -%}
{%- elif value['type'] | upper == 'ARRAY' -%}
{%- if value['items'] is mapping and value['items'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
items:{
{%- set ns_items = namespace(found_first=false) -%}
{%- for item_key, item_value in value['items'] | dictsort -%}
{%- if item_value is not none -%}
{%- if ns_items.found_first %},{% endif -%}
{%- set ns_items.found_first = true -%}
{%- if item_key == 'properties' -%}
properties:{
{%- if item_value is mapping -%}
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
{%- endif -%}
}
{%- elif item_key == 'required' -%}
required:[
{%- for req_item in item_value -%}
<|"|>{{- req_item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- elif item_key == 'type' -%}
{%- if item_value is string -%}
type:{{ format_argument(item_value | upper) }}
{%- else -%}
type:{{ format_argument(item_value | map('upper') | list) }}
{%- endif -%}
{%- else -%}
{{ item_key }}:{{ format_argument(item_value) }}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
}
{%- endif -%}
{%- endif -%}
{%- if value['nullable'] %}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
nullable:true
{%- endif -%}
{%- if value['type'] | upper == 'OBJECT' -%}
{%- if value['properties'] is defined and value['properties'] is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
}
{%- elif value is mapping -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
properties:{
{{- format_parameters(value, value['required'] | default([])) -}}
}
{%- endif -%}
{%- if value['required'] -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
required:[
{%- for item in value['required'] | default([]) -%}
<|"|>{{- item -}}<|"|>
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
]
{%- endif -%}
{%- endif -%}
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
type:<|"|>{{ value['type'] | upper }}<|"|>}
{%- endif -%}
{%- endfor -%}
{%- endmacro -%}
{%- macro format_function_declaration(tool_data) -%}
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
{%- set params = tool_data['function']['parameters'] -%}
{%- if params -%}
,parameters:{
{%- if params['properties'] -%}
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
{%- endif -%}
{%- if params['required'] -%}
required:[
{%- for item in params['required'] -%}
<|"|>{{- item -}}<|"|>
{{- ',' if not loop.last -}}
{%- endfor -%}
],
{%- endif -%}
{%- if params['type'] -%}
type:<|"|>{{- params['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
{%- if 'response' in tool_data['function'] -%}
{%- set response_declaration = tool_data['function']['response'] -%}
,response:{
{%- if response_declaration['description'] -%}
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
{%- endif -%}
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
{%- endif -%}
{%- endif -%}
}
{%- endmacro -%}
{%- macro format_argument(argument, escape_keys=True) -%}
{%- if argument is string -%}
{{- '<|"|>' + argument + '<|"|>' -}}
{%- elif argument is boolean -%}
{{- 'true' if argument else 'false' -}}
{%- elif argument is mapping -%}
{{- '{' -}}
{%- set ns = namespace(found_first=false) -%}
{%- for key, value in argument | dictsort -%}
{%- if ns.found_first %},{% endif -%}
{%- set ns.found_first = true -%}
{%- if escape_keys -%}
{{- '<|"|>' + key + '<|"|>' -}}
{%- else -%}
{{- key -}}
{%- endif -%}
:{{- format_argument(value, escape_keys=escape_keys) -}}
{%- endfor -%}
{{- '}' -}}
{%- elif argument is sequence -%}
{{- '[' -}}
{%- for item in argument -%}
{{- format_argument(item, escape_keys=escape_keys) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- ']' -}}
{%- else -%}
{{- argument -}}
{%- endif -%}
{%- endmacro -%}
{%- macro strip_thinking(text) -%}
{%- set ns = namespace(result='') -%}
{%- for part in text.split('<channel|>') -%}
{%- if '<|channel>' in part -%}
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
{%- else -%}
{%- set ns.result = ns.result + part -%}
{%- endif -%}
{%- endfor -%}
{{- ns.result | trim -}}
{%- endmacro -%}
{%- macro format_tool_response_block(tool_name, response) -%}
{{- '<|tool_response>' -}}
{%- if response is mapping -%}
{{- 'response:' + tool_name + '{' -}}
{%- for key, value in response | dictsort -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- if not loop.last %},{% endif -%}
{%- endfor -%}
{{- '}' -}}
{%- else -%}
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
{%- endif -%}
{{- '<tool_response|>' -}}
{%- endmacro -%}
{%- set ns = namespace(prev_message_type=None) -%}
{%- set loop_messages = messages -%}
{{- bos_token -}}
{#- Handle System/Tool Definitions Block -#}
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
{{- '<|turn>system\n' -}}
{#- Inject Thinking token at the very top of the FIRST system turn -#}
{%- if enable_thinking is defined and enable_thinking -%}
{{- '<|think|>\n' -}}
{%- set ns.prev_message_type = 'think' -%}
{%- endif -%}
{%- if messages[0]['role'] in ['system', 'developer'] -%}
{{- messages[0]['content'] | trim -}}
{%- set loop_messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- for tool in tools %}
{{- '<|tool>' -}}
{{- format_function_declaration(tool) | trim -}}
{{- '<tool|>' -}}
{%- endfor %}
{%- set ns.prev_message_type = 'tool' -%}
{%- endif -%}
{{- '<turn|>\n' -}}
{%- endif %}
{#- Pre-scan: find last user message index for reasoning guard -#}
{%- set ns_turn = namespace(last_user_idx=-1) -%}
{%- for i in range(loop_messages | length) -%}
{%- if loop_messages[i]['role'] == 'user' -%}
{%- set ns_turn.last_user_idx = i -%}
{%- endif -%}
{%- endfor -%}
{#- Loop through messages -#}
{%- for message in loop_messages -%}
{%- if message['role'] != 'tool' -%}
{%- set ns.prev_message_type = None -%}
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
{%- set prev_nt = namespace(role=None, found=false) -%}
{%- if loop.index0 > 0 -%}
{%- for j in range(loop.index0 - 1, -1, -1) -%}
{%- if not prev_nt.found -%}
{%- if loop_messages[j]['role'] != 'tool' -%}
{%- set prev_nt.role = loop_messages[j]['role'] -%}
{%- set prev_nt.found = true -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
{%- if not continue_same_model_turn -%}
{{- '<|turn>' + role + '\n' }}
{%- endif -%}
{#- Render reasoning/reasoning_content as thinking channel -#}
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{%- for tool_call in message['tool_calls'] -%}
{%- set function = tool_call['function'] -%}
{{- '<|tool_call>call:' + function['name'] + '{' -}}
{%- if function['arguments'] is mapping -%}
{%- set ns_args = namespace(found_first=false) -%}
{%- for key, value in function['arguments'] | dictsort -%}
{%- if ns_args.found_first %},{% endif -%}
{%- set ns_args.found_first = true -%}
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
{%- endfor -%}
{%- elif function['arguments'] is string -%}
{{- function['arguments'] -}}
{%- endif -%}
{{- '}<tool_call|>' -}}
{%- endfor -%}
{%- set ns.prev_message_type = 'tool_call' -%}
{%- endif -%}
{%- set ns_tr_out = namespace(flag=false) -%}
{%- if message.get('tool_responses') -%}
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
{%- for tool_response in message['tool_responses'] -%}
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endfor -%}
{%- elif message.get('tool_calls') -%}
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
{%- set ns_tool_scan = namespace(stopped=false) -%}
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
{%- if ns_tool_scan.stopped -%}
{%- elif loop_messages[k]['role'] != 'tool' -%}
{%- set ns_tool_scan.stopped = true -%}
{%- else -%}
{%- set follow = loop_messages[k] -%}
{#- Resolve tool_call_id to function name -#}
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
{%- for tc in message['tool_calls'] -%}
{%- if tc.get('id') == follow.get('tool_call_id') -%}
{%- set ns_tname.name = tc['function']['name'] -%}
{%- endif -%}
{%- endfor -%}
{#- Handle content as string or content-parts array -#}
{%- set tool_body = follow.get('content') -%}
{%- if tool_body is string -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- elif tool_body is sequence and tool_body is not string -%}
{%- set ns_txt = namespace(s='') -%}
{%- for part in tool_body -%}
{%- if part.get('type') == 'text' -%}
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
{%- endif -%}
{%- endfor -%}
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
{%- else -%}
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
{%- endif -%}
{%- set ns_tr_out.flag = true -%}
{%- set ns.prev_message_type = 'tool_response' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if message['content'] is string -%}
{%- if role == 'model' -%}
{{- strip_thinking(message['content']) -}}
{%- else -%}
{{- message['content'] | trim -}}
{%- endif -%}
{%- elif message['content'] is sequence -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'text' -%}
{%- if role == 'model' -%}
{{- strip_thinking(item['text']) -}}
{%- else -%}
{{- item['text'] | trim -}}
{%- endif -%}
{%- elif item['type'] == 'image' -%}
{{- '<|image|>' -}}
{%- set ns.prev_message_type = 'image' -%}
{%- elif item['type'] == 'audio' -%}
{{- '<|audio|>' -}}
{%- set ns.prev_message_type = 'audio' -%}
{%- elif item['type'] == 'video' -%}
{{- '<|video|>' -}}
{%- set ns.prev_message_type = 'video' -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
{{- '<|tool_response>' -}}
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
{{- '<turn|>\n' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
{{- '<|turn>model\n' -}}
{%- endif -%}
{%- endif -%}

View File

@@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string {
capabilities = append(capabilities, "vision")
}
if supportsAudio(modelDir) {
capabilities = append(capabilities, "audio")
}
if supportsThinking(modelDir) {
capabilities = append(capabilities, "thinking")
}
@@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool {
return false
}
// supportsVision checks if the model supports image input based on its architecture.
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
// supportsVision checks if the model has a vision encoder by looking for
// vision_config in config.json.
func supportsVision(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
VisionConfig *map[string]any `json:"vision_config"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
return true
}
return cfg.VisionConfig != nil
}
func supportsAudio(modelDir string) bool {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return false
}
typeLower := strings.ToLower(cfg.ModelType)
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
var cfg struct {
AudioConfig *map[string]any `json:"audio_config"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
return cfg.AudioConfig != nil
}
// getParserName returns the parser name for a model based on its architecture.
@@ -550,6 +560,9 @@ func getParserName(modelDir string) string {
if strings.Contains(archLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
if strings.Contains(archLower, "qwen3") {
return "qwen3"
}
@@ -564,6 +577,9 @@ func getParserName(modelDir string) string {
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3"
}
@@ -592,6 +608,9 @@ func getRendererName(modelDir string) string {
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "gemma4") {
return "gemma4"
}
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
return "glm-4.7"
}
@@ -606,6 +625,9 @@ func getRendererName(modelDir string) string {
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}

View File

@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
name: "qwen3.5 multimodal model",
configJSON: `{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3"
"model_type": "qwen3",
"vision_config": {"hidden_size": 1024}
}`,
want: []string{"completion", "vision", "thinking"},
},
{
name: "model with audio config",
configJSON: `{
"architectures": ["Gemma4ForConditionalGeneration"],
"model_type": "gemma4",
"vision_config": {"hidden_size": 1024},
"audio_config": {"num_mel_bins": 128}
}`,
want: []string{"completion", "vision", "audio"},
},
{
name: "model with audio but no vision",
configJSON: `{
"architectures": ["SomeAudioModel"],
"model_type": "other",
"audio_config": {"num_mel_bins": 128}
}`,
want: []string{"completion", "audio"},
},
{
name: "non-qwen conditional generation model",
configJSON: `{
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
}
}
func TestParsePerExpertInputs(t *testing.T) {
makeInput := func(name, quantize string) create.PackedTensorInput {
return create.PackedTensorInput{Name: name, Quantize: quantize}
}
t.Run("uniform quant across projections", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
}
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups == nil {
t.Fatal("expected non-nil groups")
}
if len(groups) != 2 {
t.Fatalf("expected 2 projection groups, got %d", len(groups))
}
if projQ["gate_proj.weight"] != "int4" {
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
}
if projQ["down_proj.weight"] != "int4" {
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
}
})
t.Run("mixed quant across projections", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
}
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups == nil {
t.Fatal("expected non-nil groups for mixed cross-projection quant")
}
if projQ["gate_proj.weight"] != "int4" {
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
}
if projQ["down_proj.weight"] != "int8" {
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
}
})
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
}
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
if groups != nil {
t.Fatal("expected nil for mixed quant within same projection")
}
})
t.Run("non-experts group rejected", func(t *testing.T) {
inputs := []create.PackedTensorInput{
makeInput("layer.mlp.gate_proj.weight", "int4"),
}
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
if groups != nil {
t.Fatal("expected nil for non-experts group")
}
})
}
func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx)

View File

@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
groupSize, bits, mode := model.QuantizationParams(quantize)
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
// Validate quantization produced non-empty output. MLX quantize may return
// empty arrays for unsupported mode/bits combinations without raising an error.
mlx.Eval(qweight, scales)
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
name, quantize, groupSize, bits, mode)
}
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
st.Free()
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
name, quantize, groupSize, bits, mode)
}
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
arrays[name] = qweight
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
// Returns the blob bytes.
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
// Check if inputs are per-expert tensors that should be stacked into 3D
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
}
allArrays := make(map[string]*mlx.Array)
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
mlx.Pin(finalArrays...)
pinned = append(pinned, finalArrays...)
// Record per-tensor quant type so the model can resolve params at load time.
if input.Quantize != "" {
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
if metadata == nil {
metadata = make(map[string]string)
}
metadata[input.Name+".quant_type"] = input.Quantize
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
}
}
if st != nil {
st.Free()
}
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
}
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
// and returns the uniform quantization type shared by all inputs.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
// or if the inputs have mixed quantization types.
// and returns per-projection quantization types. Different projections may use
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
// a projection must share the same type.
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
if !strings.HasSuffix(groupName, ".experts") {
return nil, ""
return nil, nil
}
quantize := inputs[0].Quantize
groups := make(map[string][]expertTensorInfo)
projQuantize := make(map[string]string) // projection -> quant type
for _, input := range inputs {
if input.Quantize != quantize {
return nil, "" // mixed quantization types
}
suffix := strings.TrimPrefix(input.Name, groupName)
m := perExpertSuffix.FindStringSubmatch(suffix)
if m == nil {
return nil, "" // not a per-expert pattern
return nil, nil // not a per-expert pattern
}
index, err := strconv.Atoi(m[1])
if err != nil {
return nil, ""
return nil, nil
}
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
proj := m[2]
if existing, ok := projQuantize[proj]; ok {
if input.Quantize != existing {
return nil, nil // mixed quant within same projection
}
} else {
projQuantize[proj] = input.Quantize
}
groups[proj] = append(groups[proj], expertTensorInfo{
index: index,
proj: m[2],
proj: proj,
input: input,
})
}
if len(groups) == 0 {
return nil, ""
return nil, nil
}
return groups, quantize
return groups, projQuantize
}
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
// projQuantize maps projection name to its quantization type (may differ per projection).
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
groupBase := strings.TrimSuffix(groupName, ".experts")
allArrays := make(map[string]*mlx.Array)
var pinned []*mlx.Array
var metadata map[string]string
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
metadata = map[string]string{
"quant_type": quantize,
"group_size": strconv.Itoa(groupSize),
}
}
// Build metadata: if all projections use the same quant type, set global metadata.
// Otherwise record per-tensor quant info.
metadata := make(map[string]string)
// Sort projection names for deterministic output
projNames := make([]string, 0, len(projGroups))
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
sort.Strings(projNames)
cleanup := func() {
mlx.Unpin(pinned...)
for _, p := range pinned {
if p != nil {
mlx.Unpin(p)
}
}
mlx.Sweep()
}
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
mlx.Pin(stacked)
pinned = append(pinned, stacked)
// Free individual decoded arrays
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
for i, p := range pinned {
for _, d := range decoded {
if p == d {
pinned[i] = nil
}
}
}
mlx.Unpin(decoded...)
mlx.Sweep()
stackedName := groupBase + ".switch_mlp." + proj
quantize := projQuantize[proj]
// Record per-tensor quant metadata so the model can resolve params at load time.
if quantize != "" {
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
metadata[stackedName+".quant_type"] = quantize
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
}
}
// Quantize the stacked tensor
if quantize != "" {
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
// Validate quantization produced non-empty output.
mlx.Eval(qweight, scales)
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
cleanup()
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
stackedName, quantize, groupSize, bits, mode)
}
qweight = mlx.Contiguous(qweight, false)
scales = mlx.Contiguous(scales, false)
allArrays[stackedName] = qweight
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
mlx.Pin(toEval...)
pinned = append(pinned, toEval...)
// Free stacked source array
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
for i, p := range pinned {
if p == stacked {
pinned[i] = nil
}
}
mlx.Unpin(stacked)
mlx.Sweep()
} else {
stacked = mlx.Contiguous(stacked, false)
mlx.Eval(stacked)
mlx.Pin(stacked)
pinned = append(pinned, stacked)
allArrays[stackedName] = stacked
}
}
@@ -529,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
padBottom := blockRows*scaleShape[0] - rows
padSide := blockCols*scaleShape[1] - cols
if padBottom > 0 || padSide > 0 {
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
}
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))

View File

@@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool {
return false
}
// Skip audio encoder tensors (highly sensitive to quantization)
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
return false
}
// Skip embeddings
if strings.Contains(name, "embed") {
return false
@@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string {
}
}
// isAligned checks if a tensor's last dimension is divisible by the
// group size required for the given quantization type.
func isAligned(shape []int32, quantType string) bool {
if len(shape) == 0 {
return false
}
groupSize := int32(32)
switch normalizeQuantType(quantType) {
case "nvfp4":
groupSize = 16
case "int4", "int8":
groupSize = 64
}
return shape[len(shape)-1]%groupSize == 0
}
func isStackedExpertWeight(name string) bool {
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
// or "...proj" (pre-stacked packed tensor).
@@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool {
return strings.Contains(name, ".mlp.switch_mlp.") ||
strings.Contains(name, ".mlp.experts.") ||
strings.Contains(name, ".mlp.shared_experts.")
strings.Contains(name, ".mlp.shared_experts.") ||
strings.Contains(name, ".moe.experts.")
}
// GetTensorQuantization returns the appropriate quantization type for a tensor.
// Returns "" if the tensor should not be quantized.
// This implements mixed-precision quantization:
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
// - Output projection, gate/up weights: int4 (less sensitive)
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4
// - Norms, embeddings, biases, routing gates: no quantization
// - All other eligible weights: use requested quantization type
func GetTensorQuantization(name string, shape []int32, quantize string) string {
stackedExpert := isStackedExpertWeight(name)
@@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Normalize quantization type to canonical form
quantNorm := normalizeQuantType(quantize)
// MLX quantization requires last dimension to be divisible by group size
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
groupSize := int32(32)
switch quantNorm {
case "nvfp4":
groupSize = 16
case "int4", "int8":
groupSize = 64
}
if shape[len(shape)-1]%groupSize != 0 {
return ""
}
// Skip routing gate weights (should stay high precision)
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
return ""
}
// MLX quantization requires last dimension to be divisible by group size.
if !isAligned(shape, quantNorm) {
return ""
}
// For non-affine modes, use the same quantization for all eligible tensors.
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
return quantNorm
}
// Attention MLA weights - keep unquantized (bf16)
// These are highly sensitive: errors accumulate in the KV cache over time
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
if strings.Contains(name, "q_a_proj") ||
strings.Contains(name, "q_b_proj") ||
strings.Contains(name, "kv_a_proj") ||
strings.Contains(name, "kv_b_proj") {
return "" // No quantization - keep bf16
// Value projection weights directly determine attention output quality.
// Down projection weights feed directly into the residual stream where
// errors accumulate across layers. Both benefit from higher precision.
// Promote to INT8 when base is INT4 (same affine mode, compatible with
// GatherQMM for MoE expert tensors).
if quantNorm == "int4" {
if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") {
if isAligned(shape, "int8") {
return "int8"
}
}
}
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
if strings.Contains(name, "down_proj") {
return "int8"
}
// Output projection, gate/up weights - use requested quantization (INT4)
// o_proj, gate_proj, up_proj
if strings.Contains(name, "o_proj") ||
strings.Contains(name, "gate_proj") ||
strings.Contains(name, "up_proj") {
return quantNorm
}
// LM head - use requested quantization
if strings.Contains(name, "lm_head") {
return quantNorm
}
// Default to requested quantization for other weights
return quantNorm
}
@@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string {
".mlp.experts.",
".mlp.shared_experts.",
".mlp.switch_mlp.",
".moe.experts.",
} {
idx := strings.Index(tensorName, marker)
if idx == -1 {
@@ -637,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
"Gemma4ForCausalLM": newGemma4ImportTransform,
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
}
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {

View File

@@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) {
{"ln prefix", "ln_1.weight", "", false},
{"layernorm in name", "input_layernorm.weight", "", false},
// Audio encoder tensors should not be quantized
{"audio tower weight", "model.audio_tower.layers.0.weight", "", false},
{"audio tower norm", "model.audio_tower.norm.weight", "", false},
{"embed audio weight", "embed_audio.weight", "", false},
// Biases should not be quantized
{"bias tensor", "attention.bias", "", false},
{"proj bias", "o_proj.bias", "", false},
@@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) {
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
// MoE expert tensors (Gemma-style .moe.experts.)
{"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"},
{"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"},
{"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"},
// Expert tensors with language_model prefix should also match
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
@@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
}
}
func TestIsAligned(t *testing.T) {
tests := []struct {
name string
shape []int32
quantType string
want bool
}{
// int4/int8: group_size=64
{"int4 aligned", []int32{1024, 4096}, "int4", true},
{"int4 unaligned", []int32{1024, 48}, "int4", false},
{"int8 aligned", []int32{1024, 128}, "int8", true},
{"int8 unaligned", []int32{1024, 32}, "int8", false},
// nvfp4: group_size=16
{"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true},
{"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false},
{"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true},
// mxfp4/mxfp8: group_size=32
{"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true},
{"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false},
{"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true},
{"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false},
// Edge cases
{"empty shape", []int32{}, "int4", false},
{"1D tensor", []int32{4096}, "int4", true},
{"3D stacked expert", []int32{128, 4096, 2816}, "int4", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isAligned(tt.shape, tt.quantType)
if got != tt.want {
t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want)
}
})
}
}
func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) {
aligned := []int32{4096, 4096} // divisible by 64
tests := []struct {
name string
tensor string
shape []int32
quantize string
want string
}{
// int4 → int8 promotion for sensitive tensors
{"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
{"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
{"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
// Non-sensitive int4 tensors stay int4
{"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
{"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
{"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
{"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
// nvfp4/mxfp4/mxfp8: no promotion (uniform quantization)
{"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
{"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
{"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
// int8: already 8-bit, no promotion
{"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
// Expert tensors: down_proj also promoted for int4
{"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
{"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
// Unaligned: falls back to bf16 (empty string)
{"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q",
tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
dir := t.TempDir()

264
x/create/gemma4.go Normal file
View File

@@ -0,0 +1,264 @@
package create
import (
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/x/safetensors"
)
type gemma4ImportTransform struct {
numLayers int
numExperts int
}
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
type gemma4Config struct {
NumHiddenLayers int `json:"num_hidden_layers"`
NumExperts int `json:"num_experts"`
TextConfig struct {
NumHiddenLayers int `json:"num_hidden_layers"`
NumExperts int `json:"num_experts"`
} `json:"text_config"`
}
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
}
var cfg gemma4Config
if err := json.Unmarshal(data, &cfg); err != nil {
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
}
numLayers := cfg.NumHiddenLayers
if numLayers == 0 {
numLayers = cfg.TextConfig.NumHiddenLayers
}
numExperts := cfg.NumExperts
if numExperts == 0 {
numExperts = cfg.TextConfig.NumExperts
}
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
}
func (t gemma4ImportTransform) skipTensor(name string) bool {
return false
}
// layerIndexRe extracts the layer index from tensor names like
// "model.language_model.layers.5.self_attn.v_proj.weight" or
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
// useMoreBits returns true for layers where quantization-sensitive tensors
// should use higher precision: the first and last 1/8 of layers (which handle
// input grounding and final output refinement), plus every 3rd layer in between
// to limit error accumulation through the residual stream.
func useMoreBits(layerIdx, numLayers int) bool {
return layerIdx < numLayers/8 ||
layerIdx >= 7*numLayers/8 ||
(layerIdx-numLayers/8)%3 == 2
}
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
quantNorm := normalizeQuantType(quantize)
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
if isEmbedTokensWeight(name) {
switch quantNorm {
case "int4", "int8":
if isAligned(shape, "int8") {
return "int8"
}
case "mxfp4", "nvfp4", "mxfp8":
if isAligned(shape, "mxfp8") {
return "mxfp8"
}
}
if isAligned(shape, quantNorm) {
return quantNorm
}
return ""
}
// Mixed-precision quantization: sensitive tensors get higher precision.
//
// Value projections (v_proj) directly determine attention output quality.
// Down projections (down_proj) are the final MLP output and errors there
// propagate directly to the residual stream. Both benefit from higher
// precision at early layers, late layers, and periodically in between
// (the "useMoreBits" heuristic).
//
// For int4: promote → int8 (same affine family, GatherQMM compatible).
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
// quant metadata and the kernel dispatches per-tensor.
if t.numLayers > 0 {
layerIdx := -1
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
if idx, err := strconv.Atoi(m[1]); err == nil {
layerIdx = idx
}
}
// Determine promotion target for sensitive tensors.
// "int8" = int4 base → int8 (affine family)
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
// "" = no promotion (int8/mxfp8, already 8-bit)
promote := ""
switch quantNorm {
case "int4":
promote = "int8"
case "mxfp4", "nvfp4":
promote = "mxfp8"
}
// Only apply to language model tensors — audio/vision tower tensors
// should pass through to GetTensorQuantization which skips them.
isModelTensor := !strings.Contains(name, "audio_tower") &&
!strings.Contains(name, "vision_tower")
isSensitive := isModelTensor &&
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
if promote != "" && (isSensitive || isSensitiveK) {
shouldPromote := false
// 8-expert models: v_proj and k_proj share very few KV heads,
// so quantization errors are amplified. Always promote.
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
shouldPromote = true
}
// Layer-position heuristic for v_proj and down_proj.
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
shouldPromote = true
}
if shouldPromote && isAligned(shape, promote) {
return promote
}
// Sensitive tensor at a non-promoted layer: use base quant type.
// Return directly to bypass GetTensorQuantization's uniform
// promotion — the layer-position heuristic is authoritative here.
if !isAligned(shape, quantNorm) {
return ""
}
return quantNorm
}
}
return GetTensorQuantization(name, shape, quantize)
}
// isEmbedTokensWeight returns true for the main token embedding weight.
func isEmbedTokensWeight(name string) bool {
return strings.HasSuffix(name, "embed_tokens.weight") &&
!strings.Contains(name, "per_layer")
}
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil {
return nil, nil
}
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
// [out, in] tensors so they go through the standard expert packing and
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
if isGemma4StackedMoETensor(td.Name, td.Shape) {
return splitStackedMoETensor(td)
}
return []*safetensors.TensorData{td}, nil
}
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
// Gemma 4 HF weights come in two layouts depending on the model version:
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
//
// The newer layout has gate+up already fused. We keep it fused (no splitting)
// so the tensors flow through the standard expert packing and quantization path.
func isGemma4StackedMoETensor(name string, shape []int32) bool {
if len(shape) != 3 {
return false
}
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
}
return false
}
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
// N individual [out, in] tensors named with the per-expert convention that
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
raw, err := io.ReadAll(td.Reader())
if err != nil {
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
}
numExperts := int(td.Shape[0])
rows := int(td.Shape[1]) // out_features in HF layout
cols := int(td.Shape[2]) // in_features in HF layout
elemSize, err := DTypeSize(td.Dtype)
if err != nil {
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
}
perExpertBytes := rows * cols * elemSize
if len(raw) != numExperts*perExpertBytes {
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
td.Name, len(raw), td.Shape, td.Dtype)
}
// Determine the per-expert name pattern.
// Two source layouts:
// Old: model.language_model.layers.N.moe.gate_proj
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
// New: model.language_model.layers.N.experts.gate_up_proj
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
baseName := td.Name
baseName = strings.TrimSuffix(baseName, ".weight")
lastDot := strings.LastIndex(baseName, ".")
if lastDot < 0 {
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
}
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
// so we get a consistent "layers.N.moe.experts.E" pattern.
var moePrefix string
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
moePrefix = cut + ".moe"
} else {
moePrefix = parentPrefix
}
transposedShape := []int32{td.Shape[1], td.Shape[2]}
results := make([]*safetensors.TensorData, numExperts)
for e := range numExperts {
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
start := e * perExpertBytes
end := start + perExpertBytes
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
}
return results, nil
}

191
x/create/gemma4_test.go Normal file
View File

@@ -0,0 +1,191 @@
package create
import (
"testing"
)
func TestGemma4QuantizationType(t *testing.T) {
// 26B MoE: 30 layers, 128 experts
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
// 8-expert model (hypothetical)
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
tests := []struct {
name string
transform gemma4ImportTransform
tensor string
shape []int32
quantize string
want string
}{
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
// === v_proj: layer-position heuristic for int4/nvfp4 ===
// Layer 0 is in first 1/8 (30/8=3) → promoted
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
// Layer 4 is NOT in useMoreBits → base quant
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
// Layer 29 is in last 1/8 → promoted
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
// int8/mxfp8: no promotion (already 8-bit)
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
// === down_proj (dense MLP): same heuristic as v_proj ===
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
// so GatherQMM sees uniform quant per projection per layer)
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
// === k_proj: promoted only for 8-expert models ===
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
// === Non-quantizable tensors: always bf16 ===
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
// These contain .v_proj and down_proj but should NOT be intercepted by
// the sensitive-tensor promotion logic.
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
// but not intercepted by gemma4's layer-position heuristic.
// Falls through to GetTensorQuantization which applies uniform promotion.
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
// Audio tower down_proj
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
if got != tt.want {
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
tt.tensor, tt.shape, tt.quantize, got, tt.want)
}
})
}
}
func TestUseMoreBits(t *testing.T) {
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
// In between: every 3rd from offset (i - n/8) % 3 == 2
n := 30
promoted := map[int]bool{}
for i := range n {
if useMoreBits(i, n) {
promoted[i] = true
}
}
// First 1/8 (30/8 = 3): layers 0, 1, 2
for _, i := range []int{0, 1, 2} {
if !promoted[i] {
t.Errorf("layer %d should be promoted (first 1/8)", i)
}
}
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
for _, i := range []int{26, 27, 28, 29} {
if !promoted[i] {
t.Errorf("layer %d should be promoted (last 1/8)", i)
}
}
// Some middle layers should NOT be promoted
for _, i := range []int{3, 4, 6, 7} {
if promoted[i] {
t.Errorf("layer %d should NOT be promoted", i)
}
}
// Layer 5 should be promoted: (5 - 3) % 3 == 2
if !promoted[5] {
t.Errorf("layer 5 should be promoted (periodic)")
}
}
func TestIsGemma4StackedMoETensor(t *testing.T) {
tests := []struct {
label string
tensorName string
shape []int32
want bool
}{
// New-style: .experts.gate_up_proj
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
// Old-style: .moe.gate_proj
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
// Not stacked: 2D
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
// Not expert
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
// Not a projection
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
}
for _, tt := range tests {
t.Run(tt.label, func(t *testing.T) {
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
if got != tt.want {
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
tt.tensorName, tt.shape, got, tt.want)
}
})
}
}

View File

@@ -2,6 +2,7 @@ package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/gemma4"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"

View File

@@ -4,16 +4,57 @@ package mlx
import "C"
import "math"
func GELUApprox(t *Array) *Array {
return t.Multiply(
FromValue[float32](0.5),
).Multiply(
t.Add(
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
).Multiply(
FromValue(float32(math.Sqrt(2 / math.Pi))),
).Tanh().Add(FromValue[float32](1.0)),
).AsType(t.DType())
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox matches mlx.nn.gelu_approx:
//
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func GELUApprox(x *Array) *Array {
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
half := scalarWithDtype(0.5, x)
defer C.mlx_array_free(half)
coeff := scalarWithDtype(geluCoeff, x)
defer C.mlx_array_free(coeff)
c := scalarWithDtype(0.044715, x)
defer C.mlx_array_free(c)
// x^3 via x*x*x (avoids general Power which is slower)
x3 := New("GELU_X3")
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
tmp := New("GELU_X3b")
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
x3 = tmp
// 0.044715 * x^3
cx3 := New("GELU_CX3")
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
// x + 0.044715 * x^3
inner := New("GELU_INNER")
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := New("GELU_SCALED")
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
// tanh(...)
th := New("GELU_TANH")
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
// 1 + tanh(...)
one := scalarWithDtype(1.0, x)
defer C.mlx_array_free(one)
onePlusTanh := New("GELU_1PT")
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
// 0.5 * x
halfX := New("GELU_HALFX")
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
// 0.5 * x * (1 + tanh(...))
out := New("GELU_APPROX")
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
return out
}
func SILU(t *Array) *Array {

View File

@@ -90,3 +90,10 @@ func AsyncEval(outputs ...*Array) {
func Eval(outputs ...*Array) {
doEval(outputs, false)
}
// MetalIsAvailable returns true if a Metal GPU is available.
func MetalIsAvailable() bool {
var available C._Bool
C.mlx_metal_is_available(&available)
return bool(available)
}

View File

@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
return out
}
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
axes := make([]C.int, numAxes)
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := range numAxes {
axes[i] = C.int(i)
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
// MLX uses NHWC layout.
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
out := New("CONV2D")
C.mlx_conv2d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(strideH), C.int(strideW),
C.int(padH), C.int(padW),
C.int(dilationH), C.int(dilationW),
C.int(groups),
DefaultStream().ctx,
)
return out
}
// Pad pads array a along the given axes with specified low/high pad sizes.
// mode should be "constant", "edge", or "reflect".
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
cAxes := make([]C.int, len(axes))
cLow := make([]C.int, len(lowPad))
cHigh := make([]C.int, len(highPad))
for i := range axes {
cAxes[i] = C.int(axes[i])
cLow[i] = C.int(lowPad[i])
cHigh[i] = C.int(highPad[i])
}
padValue := C.mlx_array_new_float(C.float(0))
defer C.mlx_array_free(padValue)
cMode := C.CString("constant")
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(axes),
C.size_t(len(axes)),
unsafe.SliceData(lowPad),
C.size_t(len(lowPad)),
unsafe.SliceData(highPad),
C.size_t(len(highPad)),
padValue,
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
unsafe.SliceData(cLow), C.size_t(len(cLow)),
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
padValue.ctx,
cMode,
DefaultStream().ctx,
)
return out
}
// PadConstant pads with zeros along the given axes.
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
zero := NewScalarArray(float32(0))
return Pad(a, axes, lowPad, highPad, zero, "constant")
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Maximum returns element-wise maximum of two arrays.
func Maximum(a, b *Array) *Array {
out := New("MAXIMUM")
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise minimum of two arrays.
func Minimum(a, b *Array) *Array {
out := New("MINIMUM")
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
func Softplus(a *Array) *Array {
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
}
// ReLU computes max(0, x).
func ReLU(a *Array) *Array {
return Maximum(a, NewScalarArray(float32(0)))
}
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
// returns first * sigmoid(second).
func GLU(a *Array) *Array {
lastDim := a.NumDims() - 1
halfSize := a.Dim(lastDim) / 2
first := SliceStartStop(a,
make([]int32, lastDim+1), // all zeros for start
appendDims(a, lastDim, int32(halfSize)),
)
second := SliceStartStop(a,
appendDimsStart(a, lastDim, int32(halfSize)),
appendDims(a, lastDim, int32(a.Dim(lastDim))),
)
return first.Multiply(second.Sigmoid())
}
// helper: builds stop array for SliceStartStop where the target axis = val
func appendDims(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
} else {
out[i] = int32(a.Dim(i))
}
}
return out
}
// helper: builds start array for SliceStartStop where the target axis = val
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
}
}
return out
}
// Clamp clamps array values to [min, max].
func Clamp(a *Array, minVal, maxVal float32) *Array {
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
@@ -323,20 +410,37 @@ func SiLU(a *Array) *Array {
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
freqs := New("")
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
}
// RoPEWithFreqs applies RoPE with optional custom frequencies.
// When freqs is non-nil, it is used instead of computing from base.
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
var freqsCtx C.mlx_array
var optBase C.mlx_optional_float
if freqs != nil {
freqsCtx = freqs.ctx
optBase = C.mlx_optional_float{has_value: C.bool(false)}
} else {
empty := New("")
freqsCtx = empty.ctx
optBase = C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
}
}
out := New("FAST_ROPE")
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
},
optBase,
C.float(scale),
C.int(offset),
freqs.ctx,
freqsCtx,
DefaultStream().ctx,
)
return out
@@ -358,6 +462,24 @@ func Log(a *Array) *Array {
return out
}
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Clip(a, aMin, aMax *Array) *Array {
out := New("CLIP")
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
return out
}
func Logaddexp(a, b *Array) *Array {
out := New("LOGADDEXP")
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
@@ -385,6 +507,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
return out
}
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
// empty string is "no mask" and silently ignores the array argument.
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
sinks := New("")
cMode := C.CString("array")
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
var w, b C.mlx_array

View File

@@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
globalQuantType = strings.ToUpper(globalQuantType)
// Parse full metadata for per-tensor quant info
var metaMap map[string]string
if metaRaw, ok := header["__metadata__"]; ok {
json.Unmarshal(metaRaw, &metaMap)
}
mainNames := mainTensorNames(header)
infos := make(map[string]*TensorQuantInfo)
for _, name := range mainNames {
@@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
quantType := globalQuantType
groupSize := globalGroupSize
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
if metaMap != nil {
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
quantType = strings.ToUpper(qt)
}
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
if v, err := strconv.Atoi(gs); err == nil {
groupSize = v
}
}
}
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
if quantType == "" {
quantType = inferredType

1514
x/models/gemma4/gemma4.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// onesLike creates a tensor of the given shape filled with a small constant.
func onesLike(shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
}
func TestMoEForward(t *testing.T) {
skipIfNoMLX(t)
// Small config matching 26b architecture pattern.
cfg := &TextConfig{
HiddenSize: 16, // tiny for testing
NumAttentionHeads: 2,
NumKeyValueHeads: 1,
NumGlobalKeyValueHeads: 1,
HeadDim: 8,
GlobalHeadDim: 8,
NumExperts: 4,
TopKExperts: 2,
ExpertIntermediateSize: 8,
EnableMoeBlock: true,
AttentionKEqV: false,
RMSNormEps: 1e-6,
SlidingScale: 1.0,
FullScale: 1.0,
}
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
// Test Router.Forward.
router := &Router{
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
Scale: onesLike(int(cfg.HiddenSize)),
}
t.Run("Router", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
sDims := scores.Dims()
iDims := inds.Dims()
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
}
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
}
})
// Test MoEBlock.Forward.
moe := &MoEBlock{
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
PerExpertScale: onesLike(int(cfg.NumExperts)),
}
t.Run("MoEBlock", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
})
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
t.Run("MoEBlock_sorted", func(t *testing.T) {
bigB, bigL := int32(1), int32(128)
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
scores, inds := router.Forward(bigX, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(bigX, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE sorted output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
}
})
}
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
// which takes the top-k of the raw logits and softmaxes only the selected
// values — produces the same indices and (within tolerance) the same
// normalized scores as the legacy path that softmaxes over every expert
// first, gathers the top-k probabilities, then renormalizes.
func TestRouterForwardMatchesLegacy(t *testing.T) {
skipIfNoMLX(t)
cfg := &TextConfig{
HiddenSize: 8,
NumExperts: 4,
TopKExperts: 2,
RMSNormEps: 1e-6,
RouterScale: 0.5,
}
// Distinct per-expert weight rows so top-k has a well-defined ordering
// (tied scores would let argpartition pick either tied expert and make
// the index comparison below flaky).
projWeight := mlx.FromValues([]float32{
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
}, int(cfg.NumExperts), int(cfg.HiddenSize))
scale := mlx.FromValues([]float32{
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
}, int(cfg.HiddenSize))
r := &Router{
Proj: linearFromWeight(projWeight),
Scale: scale,
}
// Varied x so different positions potentially hit different top-k.
x := mlx.FromValues([]float32{
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
}, 1, 3, int(cfg.HiddenSize))
gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg)
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
}
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
}
}
// legacyRouterForward implements the pre-optimization router: full softmax
// over every expert, gather the top-k probabilities, then renormalize them
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
dims := x.Dims()
BL := int32(dims[0]) * int32(dims[1])
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
normed = mlx.MulScalar(normed, cfg.RouterScale)
normed = mlx.Mul(normed, r.Scale)
expertScores := r.Proj.Forward(normed)
probs := mlx.SoftmaxAxis(expertScores, -1, true)
neg := mlx.Neg(expertScores)
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
inds = mlx.SliceStartStop(inds,
[]int32{0, 0},
[]int32{BL, cfg.TopKExperts},
)
scores := mlx.TakeAlongAxis(probs, inds, -1)
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
return scores, inds
}
func intSlicesEqual(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func floatSlicesClose(a, b []float32, tol float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
d := a[i] - b[i]
if d < 0 {
d = -d
}
if d > tol {
return false
}
}
return true
}
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
func linearFromWeight(w *mlx.Array) *simpleLinear {
return &simpleLinear{weight: w}
}
type simpleLinear struct {
weight *mlx.Array
}
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
}
func (l *simpleLinear) OutputDim() int32 {
return int32(l.weight.Dims()[0])
}

View File

@@ -0,0 +1,503 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseTextConfigE2B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 1536,
"num_hidden_layers": 35,
"intermediate_size": 6144,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"sliding_window_pattern": 5,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": true,
"num_kv_shared_layers": 20,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
// Basic fields.
if cfg.HiddenSize != 1536 {
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 35 {
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
}
if cfg.GlobalHeadDim != 512 {
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
}
if cfg.FinalLogitSoftcapping != 30.0 {
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
}
if cfg.NumKVSharedLayers != 20 {
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
}
// RoPE settings.
if cfg.SlidingRopeDims != 256 {
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
}
if cfg.FullRopeDims != 512 {
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
}
if cfg.SlidingRopeBase != 10000 {
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
}
if cfg.FullRopeBase != 1000000 {
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
}
// Attention scale.
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
t.Error("attention scales should be non-zero")
}
// KV sharing map.
// First shared layer is 35 - 20 = 15.
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
}
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
// Layer 14 should not be shared.
if _, ok := cfg.KVShareMap[14]; ok {
t.Error("layer 14 should not be in KVShareMap (non-shared)")
}
// Donors.
if !cfg.KVDonors[13] {
t.Error("layer 13 should be a KV donor")
}
if !cfg.KVDonors[14] {
t.Error("layer 14 should be a KV donor")
}
}
func TestParseTextConfig26B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2816,
"num_hidden_layers": 30,
"intermediate_size": 2112,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"num_global_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"enable_moe_block": true,
"num_experts": 128,
"top_k_experts": 8,
"moe_intermediate_size": 704,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2816 {
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 2 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
}
if !cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be true")
}
if cfg.NumExperts != 128 {
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
}
if cfg.TopKExperts != 8 {
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
}
if cfg.ExpertIntermediateSize != 704 {
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
}
func TestParseTextConfig31B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 5376,
"num_hidden_layers": 60,
"intermediate_size": 21504,
"num_attention_heads": 32,
"num_key_value_heads": 16,
"num_global_key_value_heads": 4,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 5376 {
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 60 {
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 4 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
}
if cfg.NumKeyValueHeads != 16 {
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
}
if cfg.NumKVSharedLayers != 0 {
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
if cfg.SlidingWindow != 1024 {
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
}
// KV sharing should be empty (no shared layers).
if len(cfg.KVShareMap) != 0 {
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
}
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(59, &cfg) {
t.Error("layer 59 should be full attention")
}
}
func TestParseTextConfigE4B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2560,
"num_hidden_layers": 42,
"intermediate_size": 10240,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 18,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"enable_moe_block": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2560 {
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 42 {
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
}
if cfg.IntermediateSize != 10240 {
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
}
if cfg.NumKeyValueHeads != 2 {
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
}
if cfg.UseDoubleWideMLP {
t.Error("UseDoubleWideMLP should be false")
}
if cfg.NumKVSharedLayers != 18 {
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
}
if cfg.AttentionKEqV {
t.Error("AttentionKEqV should be false")
}
if cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be false")
}
if cfg.SlidingWindow != 512 {
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
}
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(41, &cfg) {
t.Error("layer 41 should be full attention")
}
// KV sharing: first shared = 42 - 18 = 24.
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
if donor, ok := cfg.KVShareMap[24]; !ok {
t.Error("layer 24 should be in KVShareMap")
} else {
t.Logf("layer 24 donor = %d", donor)
}
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
// Non-shared full layers: 5, 11, 17, 23.
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
}
// Layer 23 should NOT be shared (it's the last non-shared layer).
if _, ok := cfg.KVShareMap[23]; ok {
t.Error("layer 23 should not be in KVShareMap (non-shared)")
}
}
func TestLayerTypeDetection(t *testing.T) {
cfg := &TextConfig{
LayerTypes: []string{
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
},
}
if !isLayerSliding(0, cfg) {
t.Error("layer 0 should be sliding")
}
if !isLayerSliding(3, cfg) {
t.Error("layer 3 should be sliding")
}
if isLayerSliding(4, cfg) {
t.Error("layer 4 should be full attention")
}
}
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: 0},
{IsSliding: false, KVShareDonor: 1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), 2; got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: -1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), len(m.Layers); got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestResolveWeightPrefix(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
tests := []struct {
name string
key string
wantPfx string
}{
{"bare", "embed_tokens.weight", ""},
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
{"with_model", "model.embed_tokens.weight", "model."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dummy := mlx.FromValue(float32(1.0))
mlx.Eval(dummy)
tensors := map[string]*mlx.Array{tt.key: dummy}
got := resolveWeightPrefix(tensors)
if got != tt.wantPfx {
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
}
})
}
}
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}