mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 16:25:42 +02:00
Compare commits
16 Commits
v0.20.7-rc
...
hoyyeva/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a3ed0a1b4 | ||
|
|
03f9e57274 | ||
|
|
30d9100fff | ||
|
|
698e04a14b | ||
|
|
1d9537bc33 | ||
|
|
120424d832 | ||
|
|
5818001610 | ||
|
|
2cba7756c5 | ||
|
|
bf2a421727 | ||
|
|
f3cf6b75fb | ||
|
|
5dfac387a6 | ||
|
|
a99e5d9c22 | ||
|
|
0abf3aca36 | ||
|
|
ee0266462a | ||
|
|
c88fb286ec | ||
|
|
d3da29cbfc |
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/config"
|
||||
)
|
||||
|
||||
type stubEditorRunner struct {
|
||||
@@ -722,6 +723,59 @@ func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t
|
||||
}
|
||||
}
|
||||
|
||||
func TestSavedMatchesModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
saved *config.IntegrationConfig
|
||||
models []string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil saved",
|
||||
saved: nil,
|
||||
models: []string{"llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "identical order",
|
||||
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||
models: []string{"llama3.2", "qwen3:8b"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different order",
|
||||
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||
models: []string{"qwen3:8b", "llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "subset",
|
||||
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||
models: []string{"llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil models in saved with non-nil models",
|
||||
saved: &config.IntegrationConfig{Models: nil},
|
||||
models: []string{"llama3.2"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty both",
|
||||
saved: &config.IntegrationConfig{Models: nil},
|
||||
models: nil,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := savedMatchesModels(tt.saved, tt.models); got != tt.want {
|
||||
t.Fatalf("savedMatchesModels = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -500,7 +501,7 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
||||
return nil
|
||||
}
|
||||
|
||||
if needsConfigure || req.ModelOverride != "" {
|
||||
if (needsConfigure || req.ModelOverride != "") && !savedMatchesModels(saved, models) {
|
||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -846,6 +847,13 @@ func firstModel(models []string) string {
|
||||
return models[0]
|
||||
}
|
||||
|
||||
func savedMatchesModels(saved *config.IntegrationConfig, models []string) bool {
|
||||
if saved == nil {
|
||||
return false
|
||||
}
|
||||
return slices.Equal(saved.Models, models)
|
||||
}
|
||||
|
||||
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||
if override == "" {
|
||||
if saved == nil {
|
||||
|
||||
@@ -186,6 +186,11 @@ func (c *Openclaw) runChannelSetupPreflight(bin string) error {
|
||||
if !isInteractiveSession() {
|
||||
return nil
|
||||
}
|
||||
// --yes is headless; channel setup spawns an interactive picker we can't
|
||||
// auto-answer, so skip it. Users can run `openclaw channels add` later.
|
||||
if currentLaunchConfirmPolicy.yes {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
if c.channelsConfigured() {
|
||||
|
||||
@@ -1304,6 +1304,46 @@ func TestOpenclawChannelSetupPreflight(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("--yes skips preflight without channels configured", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
t.Setenv("PATH", tmpDir)
|
||||
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Empty config = no channels configured. Without the --yes skip, the
|
||||
// preflight would prompt and (on confirm) spawn `openclaw channels add`.
|
||||
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bin := filepath.Join(tmpDir, "openclaw")
|
||||
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
oldInteractive := isInteractiveSession
|
||||
isInteractiveSession = func() bool { return true }
|
||||
defer func() { isInteractiveSession = oldInteractive }()
|
||||
|
||||
restore := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||
defer restore()
|
||||
|
||||
oldConfirmPrompt := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("did not expect prompt in --yes mode: %s", prompt)
|
||||
return false, nil
|
||||
}
|
||||
defer func() { DefaultConfirmPrompt = oldConfirmPrompt }()
|
||||
|
||||
if err := c.runChannelSetupPreflight("openclaw"); err != nil {
|
||||
t.Fatalf("runChannelSetupPreflight() error = %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(tmpDir, "invocations.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected no channels add invocation in --yes mode, got err=%v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set up later prompts once and exits", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -11,12 +12,18 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
modeltype "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// OpenCode implements Runner and Editor for OpenCode integration
|
||||
type OpenCode struct{}
|
||||
// OpenCode implements Runner and Editor for OpenCode integration.
|
||||
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
|
||||
// instead of writing to opencode's config files.
|
||||
type OpenCode struct {
|
||||
configContent string // JSON config built by Edit, passed to Run via env var
|
||||
}
|
||||
|
||||
func (o *OpenCode) String() string { return "OpenCode" }
|
||||
|
||||
@@ -51,25 +58,51 @@ func (o *OpenCode) Run(model string, args []string) error {
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Env = os.Environ()
|
||||
if content := o.resolveContent(model); content != "" {
|
||||
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
|
||||
}
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
|
||||
// Returns content built by Edit if available, otherwise builds from model.json
|
||||
// with the requested model as primary (e.g. re-launch with saved config).
|
||||
func (o *OpenCode) resolveContent(model string) string {
|
||||
if o.configContent != "" {
|
||||
return o.configContent
|
||||
}
|
||||
models := readModelJSONModels()
|
||||
if !slices.Contains(models, model) {
|
||||
models = append([]string{model}, models...)
|
||||
}
|
||||
content, err := buildInlineConfig(model, models)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func (o *OpenCode) Paths() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
sp, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var paths []string
|
||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
paths = append(paths, p)
|
||||
}
|
||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if _, err := os.Stat(sp); err == nil {
|
||||
paths = append(paths, sp)
|
||||
return []string{sp}
|
||||
}
|
||||
return paths
|
||||
return nil
|
||||
}
|
||||
|
||||
// openCodeStatePath returns the path to opencode's model state file.
|
||||
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
|
||||
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
|
||||
func openCodeStatePath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
|
||||
}
|
||||
|
||||
func (o *OpenCode) Edit(modelList []string) error {
|
||||
@@ -77,110 +110,17 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
content, err := buildInlineConfig(modelList[0], modelList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.configContent = content
|
||||
|
||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config := make(map[string]any)
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
||||
}
|
||||
|
||||
config["$schema"] = "https://opencode.ai/config.json"
|
||||
|
||||
provider, ok := config["provider"].(map[string]any)
|
||||
if !ok {
|
||||
provider = make(map[string]any)
|
||||
}
|
||||
|
||||
ollama, ok := provider["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
ollama = map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate legacy provider name
|
||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
||||
ollama["name"] = "Ollama"
|
||||
}
|
||||
|
||||
models, ok := ollama["models"].(map[string]any)
|
||||
if !ok {
|
||||
models = make(map[string]any)
|
||||
}
|
||||
|
||||
selectedSet := make(map[string]bool)
|
||||
for _, m := range modelList {
|
||||
selectedSet[m] = true
|
||||
}
|
||||
|
||||
for name, cfg := range models {
|
||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
||||
delete(models, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, model := range modelList {
|
||||
if existing, ok := models[model].(map[string]any); ok {
|
||||
// migrate existing models without _launch marker
|
||||
if isOllamaModel(existing) {
|
||||
existing["_launch"] = true
|
||||
if name, ok := existing["name"].(string); ok {
|
||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
||||
}
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
existing["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
"_launch": true,
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
models[model] = entry
|
||||
}
|
||||
|
||||
ollama["models"] = models
|
||||
provider["ollama"] = ollama
|
||||
config["provider"] = provider
|
||||
config["model"] = "ollama/" + modelList[0]
|
||||
|
||||
configData, err := json.MarshalIndent(config, "", " ")
|
||||
// Write model state file so models appear in OpenCode's model picker
|
||||
statePath, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -232,33 +172,127 @@ func (o *OpenCode) Edit(modelList []string) error {
|
||||
}
|
||||
|
||||
func (o *OpenCode) Models() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
provider, _ := config["provider"].(map[string]any)
|
||||
ollama, _ := provider["ollama"].(map[string]any)
|
||||
models, _ := ollama["models"].(map[string]any)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := slices.Collect(maps.Keys(models))
|
||||
slices.Sort(keys)
|
||||
return keys
|
||||
return nil
|
||||
}
|
||||
|
||||
// isOllamaModel reports whether a model config entry is managed by us
|
||||
func isOllamaModel(cfg map[string]any) bool {
|
||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
||||
return true
|
||||
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
|
||||
// primary is the model to launch with, models is the full list of available models.
|
||||
func buildInlineConfig(primary string, models []string) (string, error) {
|
||||
if primary == "" || len(models) == 0 {
|
||||
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
|
||||
}
|
||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
||||
if name, ok := cfg["name"].(string); ok {
|
||||
return strings.HasSuffix(name, "[Ollama]")
|
||||
config := map[string]any{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": map[string]any{
|
||||
"ollama": map[string]any{
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": map[string]any{
|
||||
"baseURL": envconfig.Host().String() + "/v1",
|
||||
},
|
||||
"models": buildModelEntries(models),
|
||||
},
|
||||
},
|
||||
"model": "ollama/" + primary,
|
||||
}
|
||||
data, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
|
||||
func readModelJSONModels() []string {
|
||||
statePath, err := openCodeStatePath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
data, err := os.ReadFile(statePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil
|
||||
}
|
||||
recent, _ := state["recent"].([]any)
|
||||
var models []string
|
||||
for _, entry := range recent {
|
||||
e, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if e["providerID"] != "ollama" {
|
||||
continue
|
||||
}
|
||||
if id, ok := e["modelID"].(string); ok && id != "" {
|
||||
models = append(models, id)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func buildModelEntries(modelList []string) map[string]any {
|
||||
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||
ctx := context.Background()
|
||||
|
||||
models := make(map[string]any)
|
||||
for _, model := range modelList {
|
||||
entry := map[string]any{
|
||||
"name": model,
|
||||
}
|
||||
if isCloudModelName(model) {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
entry["limit"] = map[string]any{
|
||||
"context": l.Context,
|
||||
"output": l.Output,
|
||||
}
|
||||
}
|
||||
}
|
||||
applyOpenCodeReasoning(ctx, client, model, entry)
|
||||
models[model] = entry
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// applyOpenCodeReasoning detects thinking capability and sets reasoning config
|
||||
// on the model entry. When the model supports thinking, it sets "reasoning": true
|
||||
// and configures variants for the OpenCode TUI:
|
||||
// - GPT-OSS: supports variable effort levels (low/medium/high) and defaults to
|
||||
// medium via options. Thinking cannot be turned off.
|
||||
// - Other models: only support on/off. Disables built-in low/medium/high variants
|
||||
// and adds a "none" variant so users can toggle thinking off via Ctrl+T.
|
||||
//
|
||||
// When the model does not support thinking, no reasoning config is set.
|
||||
func applyOpenCodeReasoning(ctx context.Context, client *api.Client, modelName string, entry map[string]any) {
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if slices.Contains(resp.Capabilities, modeltype.CapabilityThinking) {
|
||||
entry["reasoning"] = true
|
||||
|
||||
if strings.Contains(modelName, "gpt-oss") {
|
||||
// GPT-OSS models support variable thinking effort levels
|
||||
// and cannot turn thinking off. Keep the built-in
|
||||
// low/medium/high variants as-is and default to medium.
|
||||
options, ok := entry["options"].(map[string]any)
|
||||
if !ok {
|
||||
options = make(map[string]any)
|
||||
}
|
||||
options["reasoningEffort"] = "medium"
|
||||
entry["options"] = options
|
||||
} else {
|
||||
// Most models only support thinking on or off.
|
||||
// Disable the built-in low/medium/high variants and add none.
|
||||
entry["variants"] = map[string]any{
|
||||
"none": map[string]any{"reasoningEffort": "none"},
|
||||
"low": map[string]any{"disabled": true},
|
||||
"medium": map[string]any{"disabled": true},
|
||||
"high": map[string]any{"disabled": true},
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
binary: "opencode",
|
||||
runner: &OpenCode{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
||||
return filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -28,79 +28,4 @@ To configure without launching:
|
||||
ollama launch opencode --config
|
||||
```
|
||||
|
||||
### Manual setup
|
||||
|
||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"qwen3-coder": {
|
||||
"name": "qwen3-coder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cloud Models
|
||||
|
||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
||||
|
||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama",
|
||||
"options": {
|
||||
"baseURL": "http://localhost:11434/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Connecting to ollama.com
|
||||
|
||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
||||
|
||||
```json
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"provider": {
|
||||
"ollama": {
|
||||
"npm": "@ai-sdk/openai-compatible",
|
||||
"name": "Ollama Cloud",
|
||||
"options": {
|
||||
"baseURL": "https://ollama.com/v1"
|
||||
},
|
||||
"models": {
|
||||
"glm-4.7:cloud": {
|
||||
"name": "glm-4.7:cloud"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Run `opencode` in a new terminal to load the new settings.
|
||||
<Note>`ollama launch opencode` passes its configuration to OpenCode inline via the `OPENCODE_CONFIG_CONTENT` environment variable. OpenCode deep-merges its config sources on startup, so anything you declare in `~/.config/opencode/opencode.json` is still respected and available inside OpenCode. Models declared only in `opencode.json` won't appear in `ollama launch`'s model-selection menu.</Note>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package common
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CXXFLAGS: -std=c++17 -Wno-deprecated-declarations
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
||||
import "C"
|
||||
|
||||
@@ -40,7 +40,6 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
|
||||
// Emit system turn if there's a system/developer role, tools, or thinking.
|
||||
hasThink := thinkValue != nil && thinkValue.Bool()
|
||||
thinkingExplicitlyDisabled := thinkValue != nil && thinkValue.IsBool() && !thinkValue.Bool()
|
||||
if hasSystemRole || len(tools) > 0 || hasThink {
|
||||
sb.WriteString("<|turn>system\n")
|
||||
if hasThink {
|
||||
@@ -125,9 +124,6 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||
// Generation prompt.
|
||||
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
||||
sb.WriteString("<|turn>model\n")
|
||||
if !hasThink && !thinkingExplicitlyDisabled {
|
||||
sb.WriteString("<|channel>thought\n<channel|>")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package renderers
|
||||
|
||||
// TestGemma4RendererMatchesReference verifies our renderer matches the HF
|
||||
// Jinja2 chat template exactly.
|
||||
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
||||
// Gemma 4 reference template.
|
||||
//
|
||||
// To regenerate expected values, save gemma4Jinja2Template (below) to
|
||||
// gemma4_chat_template.jinja2 and run:
|
||||
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
|
||||
// reference intentionally uses the shared baseline without an empty generation-time
|
||||
// thought channel until renderer selection is split by size.
|
||||
//
|
||||
// To regenerate expected values, save the E2B template to
|
||||
// gemma4_e2b_chat_template.jinja2 and run:
|
||||
//
|
||||
// python3 -c "
|
||||
// from jinja2 import Environment; import json
|
||||
// tmpl = Environment().from_string(open('gemma4_chat_template.jinja2').read())
|
||||
// tmpl = Environment().from_string(open('gemma4_e2b_chat_template.jinja2').read())
|
||||
// msgs = [{'role':'user','content':'Hello'}]
|
||||
// print(repr(tmpl.render(messages=msgs, bos_token='<bos>', add_generation_prompt=True)))
|
||||
// "
|
||||
@@ -26,8 +30,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// The full Jinja2 template is committed as testdata/gemma4_chat_template.jinja2.
|
||||
// Run with VERIFY_JINJA2=1 to verify expected values against the template using uv + Python.
|
||||
const (
|
||||
gemma4E2BTemplate = "testdata/gemma4_e2b_chat_template.jinja2"
|
||||
gemma431BTemplate = "testdata/gemma4_31b_chat_template.jinja2"
|
||||
)
|
||||
|
||||
// The upstream Gemma 4 chat templates are committed by size under testdata/.
|
||||
// Run with VERIFY_JINJA2=1 to verify expected values against the E2B template using uv + Python.
|
||||
|
||||
func bashRefTool() []api.Tool {
|
||||
return []api.Tool{{
|
||||
@@ -665,7 +674,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{
|
||||
name: "user_only",
|
||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "system_user",
|
||||
@@ -673,7 +682,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "developer_user",
|
||||
@@ -681,13 +690,13 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{Role: "developer", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "tools_no_system",
|
||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||
tools: bashRefTool(),
|
||||
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "system_tools",
|
||||
@@ -696,7 +705,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
tools: bashRefTool(),
|
||||
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "thinking_no_system",
|
||||
@@ -704,13 +713,6 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
think: thinkTrue(),
|
||||
expected: "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "nothink_no_system",
|
||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||
think: thinkFalse(),
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
skipJinja2: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_system",
|
||||
messages: []api.Message{
|
||||
@@ -737,6 +739,12 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
think: thinkTrue(),
|
||||
expected: "<bos><|turn>system\n<|think|>\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "thinking_explicitly_disabled",
|
||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||
think: thinkFalse(),
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Message loop paths ===
|
||||
{
|
||||
@@ -751,7 +759,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
"<|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nHello!<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool call with structured args → tool response as separate <|turn>tool turn
|
||||
@@ -813,7 +821,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
"<|tool_response>response:bash{value:" + q + "file1.txt\nfile2.txt" + q + "}<tool_response|>" +
|
||||
"Here are the files.<turn|>\n" +
|
||||
"<|turn>user\nRead file1.txt<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Multiple tool calls + multiple tool responses
|
||||
@@ -848,7 +856,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
expected: "<bos><|turn>user\nWhat is 2+2?<turn|>\n" +
|
||||
"<|turn>model\n4<turn|>\n" +
|
||||
"<|turn>user\nAnd 3+3?<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
// === Additional edge cases ported from original tests ===
|
||||
{
|
||||
@@ -906,17 +914,17 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
messages: []api.Message{{Role: "user", Content: "Test"}},
|
||||
tools: modeTool(),
|
||||
expected: "<bos><|turn>system\n" + modeDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nTest<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nTest<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "unicode_content",
|
||||
messages: []api.Message{{Role: "user", Content: "こんにちは"}},
|
||||
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
name: "newlines_in_content",
|
||||
messages: []api.Message{{Role: "user", Content: "Line 1\nLine 2\nLine 3"}},
|
||||
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool response (raw JSON) followed by user message
|
||||
@@ -935,7 +943,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
"<|turn>model\n<|tool_call>call:get_weather{city:" + q + "Tokyo" + q + "}<tool_call|>" +
|
||||
"<|tool_response>response:get_weather{value:" + q + `{"temperature": 15, "weather": "sunny"}` + q + "}<tool_response|>" +
|
||||
"<|turn>user\nThanks!<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
// === Ordering and whitespace edge cases ===
|
||||
{
|
||||
@@ -958,7 +966,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
// User content with whitespace is trimmed
|
||||
name: "user_content_trimmed",
|
||||
messages: []api.Message{{Role: "user", Content: " hello "}},
|
||||
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Empty tool call arguments
|
||||
@@ -982,7 +990,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
messages: []api.Message{{Role: "user", Content: "Create"}},
|
||||
tools: nestedTool(),
|
||||
expected: "<bos><|turn>system\n" + nestedDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Array type in tool declaration
|
||||
@@ -990,7 +998,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
messages: []api.Message{{Role: "user", Content: "Batch"}},
|
||||
tools: arrayTool(),
|
||||
expected: "<bos><|turn>system\n" + arrayDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nBatch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nBatch<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Top-level typed union follows the template's odd stringified-list form.
|
||||
@@ -1002,8 +1010,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
||||
<|turn>user
|
||||
Hi<turn|>
|
||||
<|turn>model
|
||||
<|channel>thought
|
||||
<channel|>`,
|
||||
`,
|
||||
},
|
||||
{
|
||||
// Assistant whitespace is trimmed (strip_thinking includes | trim)
|
||||
@@ -1016,7 +1023,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nspaced<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Three sequential tool responses
|
||||
@@ -1071,7 +1078,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nMiddleDone<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Property with no description — just type
|
||||
@@ -1079,7 +1086,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Count"}},
|
||||
tools: countTool(),
|
||||
expected: "<bos><|turn>system\n" + countDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nCount<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nCount<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// System message with leading/trailing whitespace is trimmed
|
||||
@@ -1089,7 +1096,7 @@ Hi<turn|>
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n" +
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Deeply nested map in tool call arguments (3 levels)
|
||||
@@ -1151,7 +1158,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Set"}},
|
||||
tools: enumNoDescTool(),
|
||||
expected: "<bos><|turn>system\n" + enumNoDescDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nSet<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nSet<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// System message that is only whitespace (trims to empty)
|
||||
@@ -1161,7 +1168,7 @@ Hi<turn|>
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\n<turn|>\n" +
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Empty assistant content (empty string, not nil)
|
||||
@@ -1174,7 +1181,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\n<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Map argument with string keys (keys NOT escaped with <|"|>)
|
||||
@@ -1200,7 +1207,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Search"}},
|
||||
tools: searchTool(),
|
||||
expected: "<bos><|turn>system\n" + searchDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nSearch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nSearch<turn|>\n<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Round 3 coverage gaps ===
|
||||
@@ -1228,7 +1235,7 @@ Hi<turn|>
|
||||
{Role: "user", Content: "Hi"},
|
||||
},
|
||||
expected: "<bos><|turn>system\n<turn|>\n" +
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Nested OBJECT property with required field
|
||||
@@ -1236,7 +1243,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Create"}},
|
||||
tools: nestedRequiredTool(),
|
||||
expected: "<bos><|turn>system\n" + nestedRequiredDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Non-integer float in tool call argument
|
||||
@@ -1263,7 +1270,7 @@ Hi<turn|>
|
||||
},
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\nResult<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool content with newlines and leading/trailing whitespace trimmed
|
||||
@@ -1287,7 +1294,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Raw"}},
|
||||
tools: rawTool(),
|
||||
expected: "<bos><|turn>system\n" + rawDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nRaw<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nRaw<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Multiple required fields at top level
|
||||
@@ -1295,7 +1302,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Move"}},
|
||||
tools: moveTool(),
|
||||
expected: "<bos><|turn>system\n" + moveDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nMove<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nMove<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Assistant content that is ONLY thinking (strips to empty)
|
||||
@@ -1308,7 +1315,7 @@ Hi<turn|>
|
||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||
"<|turn>model\n<turn|>\n" +
|
||||
"<|turn>user\nMore<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Round 4: final coverage gaps ===
|
||||
@@ -1341,7 +1348,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Tag"}},
|
||||
tools: arrayNoItemsTool(),
|
||||
expected: "<bos><|turn>system\n" + arrayNoItemsDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nTag<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nTag<turn|>\n<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// OBJECT property without description but with nested properties
|
||||
@@ -1349,7 +1356,7 @@ Hi<turn|>
|
||||
messages: []api.Message{{Role: "user", Content: "Update"}},
|
||||
tools: objectNoDescTool(),
|
||||
expected: "<bos><|turn>system\n" + objectNoDescDeclRef + "<turn|>\n" +
|
||||
"<|turn>user\nUpdate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>user\nUpdate<turn|>\n<|turn>model\n",
|
||||
},
|
||||
|
||||
// === Round 5: coding agent patterns ===
|
||||
@@ -1379,7 +1386,7 @@ Hi<turn|>
|
||||
"<|tool_response>response:bash{value:" + q + q + "}<tool_response|>" +
|
||||
"Done.<turn|>\n" +
|
||||
"<|turn>user\nThanks<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Tool call with thinking that strips to real remaining content
|
||||
@@ -1399,7 +1406,7 @@ Hi<turn|>
|
||||
"<|tool_response>response:bash{value:" + q + "main.go\ngo.mod" + q + "}<tool_response|>" +
|
||||
"Let me list the files.<turn|>\n" +
|
||||
"<|turn>user\nOK<turn|>\n" +
|
||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
||||
"<|turn>model\n",
|
||||
},
|
||||
{
|
||||
// Argument value containing newlines (multi-line script)
|
||||
@@ -1635,7 +1642,6 @@ func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
think *api.ThinkValue
|
||||
wantJinjaFrag string
|
||||
wantRenderFrag string
|
||||
}{
|
||||
@@ -1684,22 +1690,15 @@ func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
|
||||
wantJinjaFrag: `response:read{value:<|"|>payload<|"|>}`,
|
||||
wantRenderFrag: `response:unknown{value:<|"|>payload<|"|>}`,
|
||||
},
|
||||
{
|
||||
name: "explicit_nothink_skips_empty_thought_channel",
|
||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||
think: thinkFalse(),
|
||||
wantJinjaFrag: "<|turn>model\n<|channel>thought\n<channel|>",
|
||||
wantRenderFrag: "<|turn>model\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
||||
got, err := renderer.Render(tt.messages, tt.tools, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
|
||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, nil)
|
||||
assert.NotEqual(t, jinja2Output, got, "case no longer differs from Jinja2 output")
|
||||
assert.Contains(t, jinja2Output, tt.wantJinjaFrag)
|
||||
assert.Contains(t, got, tt.wantRenderFrag)
|
||||
@@ -1735,12 +1734,35 @@ func TestGemma4RendererToolResponseWithoutNameOrIDUsesUnknown(t *testing.T) {
|
||||
assert.NotContains(t, got, `response:read{value:<|"|>payload<|"|>}`)
|
||||
}
|
||||
|
||||
func TestGemma4SizeTemplateFixturesDifferAtGenerationPrompt(t *testing.T) {
|
||||
e2b, err := os.ReadFile(gemma4E2BTemplate)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", gemma4E2BTemplate, err)
|
||||
}
|
||||
|
||||
thirtyOneB, err := os.ReadFile(gemma431BTemplate)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", gemma431BTemplate, err)
|
||||
}
|
||||
|
||||
assert.Contains(t, string(e2b), "{{- '<|turn>model\\n' -}}")
|
||||
assert.NotContains(t, string(e2b), "{{- '<|channel>thought\\n<channel|>' -}}")
|
||||
assert.Contains(t, string(thirtyOneB), "{{- '<|turn>model\\n' -}}")
|
||||
assert.Contains(t, string(thirtyOneB), "{{- '<|channel>thought\\n<channel|>' -}}")
|
||||
}
|
||||
|
||||
// renderWithJinja2 shells out to uv + Python to render messages through the
|
||||
// Jinja2 chat template. Returns the rendered string.
|
||||
// E2B Jinja2 chat template. Returns the rendered string.
|
||||
func renderWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||
return renderWithJinja2Template(t, gemma4E2BTemplate, messages, tools, think)
|
||||
}
|
||||
|
||||
// renderWithJinja2Template shells out to uv + Python to render messages through
|
||||
// the named Jinja2 chat template. Returns the rendered string.
|
||||
func renderWithJinja2Template(t *testing.T, templateRelPath string, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||
t.Helper()
|
||||
|
||||
templatePath, err := filepath.Abs("testdata/gemma4_chat_template.jinja2")
|
||||
templatePath, err := filepath.Abs(templateRelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get template path: %v", err)
|
||||
}
|
||||
|
||||
344
model/renderers/testdata/gemma4_e2b_chat_template.jinja2
vendored
Normal file
344
model/renderers/testdata/gemma4_e2b_chat_template.jinja2
vendored
Normal file
@@ -0,0 +1,344 @@
|
||||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
}
|
||||
{%- elif value is mapping -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
properties:{
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- if value['required'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if response is mapping -%}
|
||||
{{- 'response:' + tool_name + '{' -}}
|
||||
{%- for key, value in response | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{- bos_token -}}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>\n' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||
{%- for i in range(loop_messages | length) -%}
|
||||
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||
{%- set ns_turn.last_user_idx = i -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- if message['role'] != 'tool' -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||
{%- if loop.index0 > 0 -%}
|
||||
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||
{%- if not prev_nt.found -%}
|
||||
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||
{%- set prev_nt.found = true -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||
{%- if not continue_same_model_turn -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||
{%- if message.get('tool_responses') -%}
|
||||
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endfor -%}
|
||||
{%- elif message.get('tool_calls') -%}
|
||||
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||
{%- if ns_tool_scan.stopped -%}
|
||||
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||
{%- set ns_tool_scan.stopped = true -%}
|
||||
{%- else -%}
|
||||
{%- set follow = loop_messages[k] -%}
|
||||
{#- Resolve tool_call_id to function name -#}
|
||||
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||
{%- for tc in message['tool_calls'] -%}
|
||||
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Handle content as string or content-parts array -#}
|
||||
{%- set tool_body = follow.get('content') -%}
|
||||
{%- if tool_body is string -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||
{%- set ns_txt = namespace(s='') -%}
|
||||
{%- for part in tool_body -%}
|
||||
{%- if part.get('type') == 'text' -%}
|
||||
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||
{%- else -%}
|
||||
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||
{%- endif -%}
|
||||
{%- set ns_tr_out.flag = true -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '<|image|>' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '<|video|>' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
@@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string {
|
||||
capabilities = append(capabilities, "vision")
|
||||
}
|
||||
|
||||
if supportsAudio(modelDir) {
|
||||
capabilities = append(capabilities, "audio")
|
||||
}
|
||||
|
||||
if supportsThinking(modelDir) {
|
||||
capabilities = append(capabilities, "thinking")
|
||||
}
|
||||
@@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// supportsVision checks if the model supports image input based on its architecture.
|
||||
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
|
||||
// supportsVision checks if the model has a vision encoder by looking for
|
||||
// vision_config in config.json.
|
||||
func supportsVision(modelDir string) bool {
|
||||
configPath := filepath.Join(modelDir, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
VisionConfig *map[string]any `json:"vision_config"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
|
||||
return true
|
||||
}
|
||||
return cfg.VisionConfig != nil
|
||||
}
|
||||
|
||||
func supportsAudio(modelDir string) bool {
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
|
||||
var cfg struct {
|
||||
AudioConfig *map[string]any `json:"audio_config"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return cfg.AudioConfig != nil
|
||||
}
|
||||
|
||||
// getParserName returns the parser name for a model based on its architecture.
|
||||
@@ -550,6 +560,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(archLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(archLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -564,6 +577,9 @@ func getParserName(modelDir string) string {
|
||||
if strings.Contains(typeLower, "deepseek") {
|
||||
return "deepseek3"
|
||||
}
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(typeLower, "qwen3") {
|
||||
return "qwen3"
|
||||
}
|
||||
@@ -592,6 +608,9 @@ func getRendererName(modelDir string) string {
|
||||
// Check architectures for known renderers
|
||||
for _, arch := range cfg.Architectures {
|
||||
archLower := strings.ToLower(arch)
|
||||
if strings.Contains(archLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
@@ -606,6 +625,9 @@ func getRendererName(modelDir string) string {
|
||||
// Also check model_type
|
||||
if cfg.ModelType != "" {
|
||||
typeLower := strings.ToLower(cfg.ModelType)
|
||||
if strings.Contains(typeLower, "gemma4") {
|
||||
return "gemma4"
|
||||
}
|
||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||
return "glm-4.7"
|
||||
}
|
||||
|
||||
@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
name: "qwen3.5 multimodal model",
|
||||
configJSON: `{
|
||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||
"model_type": "qwen3"
|
||||
"model_type": "qwen3",
|
||||
"vision_config": {"hidden_size": 1024}
|
||||
}`,
|
||||
want: []string{"completion", "vision", "thinking"},
|
||||
},
|
||||
{
|
||||
name: "model with audio config",
|
||||
configJSON: `{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"model_type": "gemma4",
|
||||
"vision_config": {"hidden_size": 1024},
|
||||
"audio_config": {"num_mel_bins": 128}
|
||||
}`,
|
||||
want: []string{"completion", "vision", "audio"},
|
||||
},
|
||||
{
|
||||
name: "model with audio but no vision",
|
||||
configJSON: `{
|
||||
"architectures": ["SomeAudioModel"],
|
||||
"model_type": "other",
|
||||
"audio_config": {"num_mel_bins": 128}
|
||||
}`,
|
||||
want: []string{"completion", "audio"},
|
||||
},
|
||||
{
|
||||
name: "non-qwen conditional generation model",
|
||||
configJSON: `{
|
||||
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePerExpertInputs(t *testing.T) {
|
||||
makeInput := func(name, quantize string) create.PackedTensorInput {
|
||||
return create.PackedTensorInput{Name: name, Quantize: quantize}
|
||||
}
|
||||
|
||||
t.Run("uniform quant across projections", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
|
||||
}
|
||||
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups == nil {
|
||||
t.Fatal("expected non-nil groups")
|
||||
}
|
||||
if len(groups) != 2 {
|
||||
t.Fatalf("expected 2 projection groups, got %d", len(groups))
|
||||
}
|
||||
if projQ["gate_proj.weight"] != "int4" {
|
||||
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||
}
|
||||
if projQ["down_proj.weight"] != "int4" {
|
||||
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed quant across projections", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||
}
|
||||
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups == nil {
|
||||
t.Fatal("expected non-nil groups for mixed cross-projection quant")
|
||||
}
|
||||
if projQ["gate_proj.weight"] != "int4" {
|
||||
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||
}
|
||||
if projQ["down_proj.weight"] != "int8" {
|
||||
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||
}
|
||||
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||
if groups != nil {
|
||||
t.Fatal("expected nil for mixed quant within same projection")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-experts group rejected", func(t *testing.T) {
|
||||
inputs := []create.PackedTensorInput{
|
||||
makeInput("layer.mlp.gate_proj.weight", "int4"),
|
||||
}
|
||||
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
|
||||
if groups != nil {
|
||||
t.Fatal("expected nil for non-experts group")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestQuantizeSupported(t *testing.T) {
|
||||
// This just verifies the function exists and returns a boolean
|
||||
// The actual value depends on build tags (mlx vs non-mlx)
|
||||
|
||||
@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||
|
||||
// Validate quantization produced non-empty output. MLX quantize may return
|
||||
// empty arrays for unsupported mode/bits combinations without raising an error.
|
||||
mlx.Eval(qweight, scales)
|
||||
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
name, quantize, groupSize, bits, mode)
|
||||
}
|
||||
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
|
||||
st.Free()
|
||||
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
name, quantize, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
arrays[name] = qweight
|
||||
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
||||
// Returns the blob bytes.
|
||||
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
||||
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
|
||||
}
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
|
||||
mlx.Pin(finalArrays...)
|
||||
pinned = append(pinned, finalArrays...)
|
||||
|
||||
// Record per-tensor quant type so the model can resolve params at load time.
|
||||
if input.Quantize != "" {
|
||||
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
|
||||
if metadata == nil {
|
||||
metadata = make(map[string]string)
|
||||
}
|
||||
metadata[input.Name+".quant_type"] = input.Quantize
|
||||
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
|
||||
}
|
||||
}
|
||||
|
||||
if st != nil {
|
||||
st.Free()
|
||||
}
|
||||
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
|
||||
}
|
||||
|
||||
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||
// and returns the uniform quantization type shared by all inputs.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
||||
// or if the inputs have mixed quantization types.
|
||||
// and returns per-projection quantization types. Different projections may use
|
||||
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
|
||||
// a projection must share the same type.
|
||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
|
||||
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
|
||||
if !strings.HasSuffix(groupName, ".experts") {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
quantize := inputs[0].Quantize
|
||||
groups := make(map[string][]expertTensorInfo)
|
||||
projQuantize := make(map[string]string) // projection -> quant type
|
||||
for _, input := range inputs {
|
||||
if input.Quantize != quantize {
|
||||
return nil, "" // mixed quantization types
|
||||
}
|
||||
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||
if m == nil {
|
||||
return nil, "" // not a per-expert pattern
|
||||
return nil, nil // not a per-expert pattern
|
||||
}
|
||||
index, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
||||
proj := m[2]
|
||||
if existing, ok := projQuantize[proj]; ok {
|
||||
if input.Quantize != existing {
|
||||
return nil, nil // mixed quant within same projection
|
||||
}
|
||||
} else {
|
||||
projQuantize[proj] = input.Quantize
|
||||
}
|
||||
groups[proj] = append(groups[proj], expertTensorInfo{
|
||||
index: index,
|
||||
proj: m[2],
|
||||
proj: proj,
|
||||
input: input,
|
||||
})
|
||||
}
|
||||
if len(groups) == 0 {
|
||||
return nil, ""
|
||||
return nil, nil
|
||||
}
|
||||
return groups, quantize
|
||||
return groups, projQuantize
|
||||
}
|
||||
|
||||
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
||||
// projQuantize maps projection name to its quantization type (may differ per projection).
|
||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
|
||||
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||
|
||||
allArrays := make(map[string]*mlx.Array)
|
||||
var pinned []*mlx.Array
|
||||
|
||||
var metadata map[string]string
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
||||
metadata = map[string]string{
|
||||
"quant_type": quantize,
|
||||
"group_size": strconv.Itoa(groupSize),
|
||||
}
|
||||
}
|
||||
// Build metadata: if all projections use the same quant type, set global metadata.
|
||||
// Otherwise record per-tensor quant info.
|
||||
metadata := make(map[string]string)
|
||||
|
||||
// Sort projection names for deterministic output
|
||||
projNames := make([]string, 0, len(projGroups))
|
||||
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
sort.Strings(projNames)
|
||||
|
||||
cleanup := func() {
|
||||
mlx.Unpin(pinned...)
|
||||
for _, p := range pinned {
|
||||
if p != nil {
|
||||
mlx.Unpin(p)
|
||||
}
|
||||
}
|
||||
mlx.Sweep()
|
||||
}
|
||||
|
||||
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
|
||||
// Free individual decoded arrays
|
||||
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
|
||||
for i, p := range pinned {
|
||||
for _, d := range decoded {
|
||||
if p == d {
|
||||
pinned[i] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
mlx.Unpin(decoded...)
|
||||
mlx.Sweep()
|
||||
|
||||
stackedName := groupBase + ".switch_mlp." + proj
|
||||
quantize := projQuantize[proj]
|
||||
|
||||
// Record per-tensor quant metadata so the model can resolve params at load time.
|
||||
if quantize != "" {
|
||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
|
||||
metadata[stackedName+".quant_type"] = quantize
|
||||
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Quantize the stacked tensor
|
||||
if quantize != "" {
|
||||
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
|
||||
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||
|
||||
// Validate quantization produced non-empty output.
|
||||
mlx.Eval(qweight, scales)
|
||||
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||
cleanup()
|
||||
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||
stackedName, quantize, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
qweight = mlx.Contiguous(qweight, false)
|
||||
scales = mlx.Contiguous(scales, false)
|
||||
allArrays[stackedName] = qweight
|
||||
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
||||
mlx.Pin(toEval...)
|
||||
pinned = append(pinned, toEval...)
|
||||
|
||||
// Free stacked source array
|
||||
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
|
||||
for i, p := range pinned {
|
||||
if p == stacked {
|
||||
pinned[i] = nil
|
||||
}
|
||||
}
|
||||
mlx.Unpin(stacked)
|
||||
mlx.Sweep()
|
||||
} else {
|
||||
stacked = mlx.Contiguous(stacked, false)
|
||||
mlx.Eval(stacked)
|
||||
mlx.Pin(stacked)
|
||||
pinned = append(pinned, stacked)
|
||||
allArrays[stackedName] = stacked
|
||||
}
|
||||
}
|
||||
@@ -529,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
||||
padBottom := blockRows*scaleShape[0] - rows
|
||||
padSide := blockCols*scaleShape[1] - cols
|
||||
if padBottom > 0 || padSide > 0 {
|
||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
||||
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
|
||||
}
|
||||
|
||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||
|
||||
@@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip audio encoder tensors (highly sensitive to quantization)
|
||||
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip embeddings
|
||||
if strings.Contains(name, "embed") {
|
||||
return false
|
||||
@@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// isAligned checks if a tensor's last dimension is divisible by the
|
||||
// group size required for the given quantization type.
|
||||
func isAligned(shape []int32, quantType string) bool {
|
||||
if len(shape) == 0 {
|
||||
return false
|
||||
}
|
||||
groupSize := int32(32)
|
||||
switch normalizeQuantType(quantType) {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
|
||||
func isStackedExpertWeight(name string) bool {
|
||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||
// or "...proj" (pre-stacked packed tensor).
|
||||
@@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool {
|
||||
|
||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||
strings.Contains(name, ".mlp.experts.") ||
|
||||
strings.Contains(name, ".mlp.shared_experts.")
|
||||
strings.Contains(name, ".mlp.shared_experts.") ||
|
||||
strings.Contains(name, ".moe.experts.")
|
||||
}
|
||||
|
||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||
// Returns "" if the tensor should not be quantized.
|
||||
// This implements mixed-precision quantization:
|
||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
||||
// - Output projection, gate/up weights: int4 (less sensitive)
|
||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
||||
// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4
|
||||
// - Norms, embeddings, biases, routing gates: no quantization
|
||||
// - All other eligible weights: use requested quantization type
|
||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
stackedExpert := isStackedExpertWeight(name)
|
||||
|
||||
@@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||
// Normalize quantization type to canonical form
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size
|
||||
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
||||
groupSize := int32(32)
|
||||
switch quantNorm {
|
||||
case "nvfp4":
|
||||
groupSize = 16
|
||||
case "int4", "int8":
|
||||
groupSize = 64
|
||||
}
|
||||
if shape[len(shape)-1]%groupSize != 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Skip routing gate weights (should stay high precision)
|
||||
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
||||
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
||||
return ""
|
||||
}
|
||||
|
||||
// MLX quantization requires last dimension to be divisible by group size.
|
||||
if !isAligned(shape, quantNorm) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Attention MLA weights - keep unquantized (bf16)
|
||||
// These are highly sensitive: errors accumulate in the KV cache over time
|
||||
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
|
||||
if strings.Contains(name, "q_a_proj") ||
|
||||
strings.Contains(name, "q_b_proj") ||
|
||||
strings.Contains(name, "kv_a_proj") ||
|
||||
strings.Contains(name, "kv_b_proj") {
|
||||
return "" // No quantization - keep bf16
|
||||
// Value projection weights directly determine attention output quality.
|
||||
// Down projection weights feed directly into the residual stream where
|
||||
// errors accumulate across layers. Both benefit from higher precision.
|
||||
// Promote to INT8 when base is INT4 (same affine mode, compatible with
|
||||
// GatherQMM for MoE expert tensors).
|
||||
if quantNorm == "int4" {
|
||||
if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") {
|
||||
if isAligned(shape, "int8") {
|
||||
return "int8"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
||||
if strings.Contains(name, "down_proj") {
|
||||
return "int8"
|
||||
}
|
||||
|
||||
// Output projection, gate/up weights - use requested quantization (INT4)
|
||||
// o_proj, gate_proj, up_proj
|
||||
if strings.Contains(name, "o_proj") ||
|
||||
strings.Contains(name, "gate_proj") ||
|
||||
strings.Contains(name, "up_proj") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// LM head - use requested quantization
|
||||
if strings.Contains(name, "lm_head") {
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
// Default to requested quantization for other weights
|
||||
return quantNorm
|
||||
}
|
||||
|
||||
@@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string {
|
||||
".mlp.experts.",
|
||||
".mlp.shared_experts.",
|
||||
".mlp.switch_mlp.",
|
||||
".moe.experts.",
|
||||
} {
|
||||
idx := strings.Index(tensorName, marker)
|
||||
if idx == -1 {
|
||||
@@ -637,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
||||
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
|
||||
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
|
||||
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||
}
|
||||
|
||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||
|
||||
@@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) {
|
||||
{"ln prefix", "ln_1.weight", "", false},
|
||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||
|
||||
// Audio encoder tensors should not be quantized
|
||||
{"audio tower weight", "model.audio_tower.layers.0.weight", "", false},
|
||||
{"audio tower norm", "model.audio_tower.norm.weight", "", false},
|
||||
{"embed audio weight", "embed_audio.weight", "", false},
|
||||
|
||||
// Biases should not be quantized
|
||||
{"bias tensor", "attention.bias", "", false},
|
||||
{"proj bias", "o_proj.bias", "", false},
|
||||
@@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) {
|
||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||
|
||||
// MoE expert tensors (Gemma-style .moe.experts.)
|
||||
{"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"},
|
||||
{"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"},
|
||||
{"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"},
|
||||
|
||||
// Expert tensors with language_model prefix should also match
|
||||
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||
@@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAligned(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shape []int32
|
||||
quantType string
|
||||
want bool
|
||||
}{
|
||||
// int4/int8: group_size=64
|
||||
{"int4 aligned", []int32{1024, 4096}, "int4", true},
|
||||
{"int4 unaligned", []int32{1024, 48}, "int4", false},
|
||||
{"int8 aligned", []int32{1024, 128}, "int8", true},
|
||||
{"int8 unaligned", []int32{1024, 32}, "int8", false},
|
||||
|
||||
// nvfp4: group_size=16
|
||||
{"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true},
|
||||
{"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false},
|
||||
{"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true},
|
||||
|
||||
// mxfp4/mxfp8: group_size=32
|
||||
{"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true},
|
||||
{"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false},
|
||||
{"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true},
|
||||
{"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false},
|
||||
|
||||
// Edge cases
|
||||
{"empty shape", []int32{}, "int4", false},
|
||||
{"1D tensor", []int32{4096}, "int4", true},
|
||||
{"3D stacked expert", []int32{128, 4096, 2816}, "int4", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isAligned(tt.shape, tt.quantType)
|
||||
if got != tt.want {
|
||||
t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) {
|
||||
aligned := []int32{4096, 4096} // divisible by 64
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
// int4 → int8 promotion for sensitive tensors
|
||||
{"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
{"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||
{"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||
|
||||
// Non-sensitive int4 tensors stay int4
|
||||
{"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||
{"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||
{"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||
{"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||
|
||||
// nvfp4/mxfp4/mxfp8: no promotion (uniform quantization)
|
||||
{"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
{"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// int8: already 8-bit, no promotion
|
||||
{"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||
|
||||
// Expert tensors: down_proj also promoted for int4
|
||||
{"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||
{"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||
|
||||
// Unaligned: falls back to bf16 (empty string)
|
||||
{"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q",
|
||||
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
264
x/create/gemma4.go
Normal file
264
x/create/gemma4.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
type gemma4ImportTransform struct {
|
||||
numLayers int
|
||||
numExperts int
|
||||
}
|
||||
|
||||
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
|
||||
type gemma4Config struct {
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumExperts int `json:"num_experts"`
|
||||
TextConfig struct {
|
||||
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||
NumExperts int `json:"num_experts"`
|
||||
} `json:"text_config"`
|
||||
}
|
||||
|
||||
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||
if err != nil {
|
||||
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||
}
|
||||
var cfg gemma4Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||
}
|
||||
|
||||
numLayers := cfg.NumHiddenLayers
|
||||
if numLayers == 0 {
|
||||
numLayers = cfg.TextConfig.NumHiddenLayers
|
||||
}
|
||||
numExperts := cfg.NumExperts
|
||||
if numExperts == 0 {
|
||||
numExperts = cfg.TextConfig.NumExperts
|
||||
}
|
||||
|
||||
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) skipTensor(name string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// layerIndexRe extracts the layer index from tensor names like
|
||||
// "model.language_model.layers.5.self_attn.v_proj.weight" or
|
||||
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
|
||||
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
|
||||
|
||||
// useMoreBits returns true for layers where quantization-sensitive tensors
|
||||
// should use higher precision: the first and last 1/8 of layers (which handle
|
||||
// input grounding and final output refinement), plus every 3rd layer in between
|
||||
// to limit error accumulation through the residual stream.
|
||||
func useMoreBits(layerIdx, numLayers int) bool {
|
||||
return layerIdx < numLayers/8 ||
|
||||
layerIdx >= 7*numLayers/8 ||
|
||||
(layerIdx-numLayers/8)%3 == 2
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||
quantNorm := normalizeQuantType(quantize)
|
||||
|
||||
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
|
||||
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
|
||||
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
|
||||
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
|
||||
if isEmbedTokensWeight(name) {
|
||||
switch quantNorm {
|
||||
case "int4", "int8":
|
||||
if isAligned(shape, "int8") {
|
||||
return "int8"
|
||||
}
|
||||
case "mxfp4", "nvfp4", "mxfp8":
|
||||
if isAligned(shape, "mxfp8") {
|
||||
return "mxfp8"
|
||||
}
|
||||
}
|
||||
if isAligned(shape, quantNorm) {
|
||||
return quantNorm
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Mixed-precision quantization: sensitive tensors get higher precision.
|
||||
//
|
||||
// Value projections (v_proj) directly determine attention output quality.
|
||||
// Down projections (down_proj) are the final MLP output and errors there
|
||||
// propagate directly to the residual stream. Both benefit from higher
|
||||
// precision at early layers, late layers, and periodically in between
|
||||
// (the "useMoreBits" heuristic).
|
||||
//
|
||||
// For int4: promote → int8 (same affine family, GatherQMM compatible).
|
||||
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
|
||||
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
|
||||
// quant metadata and the kernel dispatches per-tensor.
|
||||
if t.numLayers > 0 {
|
||||
layerIdx := -1
|
||||
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
|
||||
if idx, err := strconv.Atoi(m[1]); err == nil {
|
||||
layerIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
// Determine promotion target for sensitive tensors.
|
||||
// "int8" = int4 base → int8 (affine family)
|
||||
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
|
||||
// "" = no promotion (int8/mxfp8, already 8-bit)
|
||||
promote := ""
|
||||
switch quantNorm {
|
||||
case "int4":
|
||||
promote = "int8"
|
||||
case "mxfp4", "nvfp4":
|
||||
promote = "mxfp8"
|
||||
}
|
||||
|
||||
// Only apply to language model tensors — audio/vision tower tensors
|
||||
// should pass through to GetTensorQuantization which skips them.
|
||||
isModelTensor := !strings.Contains(name, "audio_tower") &&
|
||||
!strings.Contains(name, "vision_tower")
|
||||
isSensitive := isModelTensor &&
|
||||
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
|
||||
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
|
||||
|
||||
if promote != "" && (isSensitive || isSensitiveK) {
|
||||
shouldPromote := false
|
||||
|
||||
// 8-expert models: v_proj and k_proj share very few KV heads,
|
||||
// so quantization errors are amplified. Always promote.
|
||||
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
|
||||
shouldPromote = true
|
||||
}
|
||||
|
||||
// Layer-position heuristic for v_proj and down_proj.
|
||||
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
|
||||
shouldPromote = true
|
||||
}
|
||||
|
||||
if shouldPromote && isAligned(shape, promote) {
|
||||
return promote
|
||||
}
|
||||
|
||||
// Sensitive tensor at a non-promoted layer: use base quant type.
|
||||
// Return directly to bypass GetTensorQuantization's uniform
|
||||
// promotion — the layer-position heuristic is authoritative here.
|
||||
if !isAligned(shape, quantNorm) {
|
||||
return ""
|
||||
}
|
||||
return quantNorm
|
||||
}
|
||||
}
|
||||
|
||||
return GetTensorQuantization(name, shape, quantize)
|
||||
}
|
||||
|
||||
// isEmbedTokensWeight returns true for the main token embedding weight.
|
||||
func isEmbedTokensWeight(name string) bool {
|
||||
return strings.HasSuffix(name, "embed_tokens.weight") &&
|
||||
!strings.Contains(name, "per_layer")
|
||||
}
|
||||
|
||||
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
if td == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
|
||||
// [out, in] tensors so they go through the standard expert packing and
|
||||
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
|
||||
if isGemma4StackedMoETensor(td.Name, td.Shape) {
|
||||
return splitStackedMoETensor(td)
|
||||
}
|
||||
|
||||
return []*safetensors.TensorData{td}, nil
|
||||
}
|
||||
|
||||
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
|
||||
// Gemma 4 HF weights come in two layouts depending on the model version:
|
||||
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
|
||||
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
|
||||
//
|
||||
// The newer layout has gate+up already fused. We keep it fused (no splitting)
|
||||
// so the tensors flow through the standard expert packing and quantization path.
|
||||
func isGemma4StackedMoETensor(name string, shape []int32) bool {
|
||||
if len(shape) != 3 {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
|
||||
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
|
||||
// N individual [out, in] tensors named with the per-expert convention that
|
||||
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
|
||||
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||
raw, err := io.ReadAll(td.Reader())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
|
||||
}
|
||||
|
||||
numExperts := int(td.Shape[0])
|
||||
rows := int(td.Shape[1]) // out_features in HF layout
|
||||
cols := int(td.Shape[2]) // in_features in HF layout
|
||||
|
||||
elemSize, err := DTypeSize(td.Dtype)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
|
||||
}
|
||||
|
||||
perExpertBytes := rows * cols * elemSize
|
||||
if len(raw) != numExperts*perExpertBytes {
|
||||
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
|
||||
td.Name, len(raw), td.Shape, td.Dtype)
|
||||
}
|
||||
|
||||
// Determine the per-expert name pattern.
|
||||
// Two source layouts:
|
||||
// Old: model.language_model.layers.N.moe.gate_proj
|
||||
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
|
||||
// New: model.language_model.layers.N.experts.gate_up_proj
|
||||
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
|
||||
baseName := td.Name
|
||||
baseName = strings.TrimSuffix(baseName, ".weight")
|
||||
lastDot := strings.LastIndex(baseName, ".")
|
||||
if lastDot < 0 {
|
||||
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
|
||||
}
|
||||
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
|
||||
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
|
||||
|
||||
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
|
||||
// so we get a consistent "layers.N.moe.experts.E" pattern.
|
||||
var moePrefix string
|
||||
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
|
||||
moePrefix = cut + ".moe"
|
||||
} else {
|
||||
moePrefix = parentPrefix
|
||||
}
|
||||
|
||||
transposedShape := []int32{td.Shape[1], td.Shape[2]}
|
||||
|
||||
results := make([]*safetensors.TensorData, numExperts)
|
||||
for e := range numExperts {
|
||||
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
|
||||
start := e * perExpertBytes
|
||||
end := start + perExpertBytes
|
||||
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
191
x/create/gemma4_test.go
Normal file
191
x/create/gemma4_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGemma4QuantizationType(t *testing.T) {
|
||||
// 26B MoE: 30 layers, 128 experts
|
||||
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
|
||||
// 8-expert model (hypothetical)
|
||||
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
|
||||
|
||||
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
transform gemma4ImportTransform
|
||||
tensor string
|
||||
shape []int32
|
||||
quantize string
|
||||
want string
|
||||
}{
|
||||
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
|
||||
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
|
||||
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
|
||||
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// === v_proj: layer-position heuristic for int4/nvfp4 ===
|
||||
// Layer 0 is in first 1/8 (30/8=3) → promoted
|
||||
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
// Layer 4 is NOT in useMoreBits → base quant
|
||||
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
|
||||
// Layer 29 is in last 1/8 → promoted
|
||||
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
|
||||
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
|
||||
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
// int8/mxfp8: no promotion (already 8-bit)
|
||||
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||
|
||||
// === down_proj (dense MLP): same heuristic as v_proj ===
|
||||
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
|
||||
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
|
||||
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
|
||||
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
|
||||
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
|
||||
// so GatherQMM sees uniform quant per projection per layer)
|
||||
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
|
||||
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
|
||||
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
|
||||
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||
|
||||
// === k_proj: promoted only for 8-expert models ===
|
||||
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
||||
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||
|
||||
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
|
||||
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||
|
||||
// === Non-quantizable tensors: always bf16 ===
|
||||
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
|
||||
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
|
||||
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
|
||||
|
||||
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
|
||||
// These contain .v_proj and down_proj but should NOT be intercepted by
|
||||
// the sensitive-tensor promotion logic.
|
||||
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
|
||||
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
|
||||
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
|
||||
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
|
||||
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
|
||||
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
|
||||
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
|
||||
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
|
||||
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
|
||||
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
|
||||
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
|
||||
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
|
||||
// but not intercepted by gemma4's layer-position heuristic.
|
||||
// Falls through to GetTensorQuantization which applies uniform promotion.
|
||||
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
|
||||
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
|
||||
// Audio tower down_proj
|
||||
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
|
||||
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
|
||||
if got != tt.want {
|
||||
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
|
||||
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseMoreBits(t *testing.T) {
|
||||
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
|
||||
// In between: every 3rd from offset (i - n/8) % 3 == 2
|
||||
n := 30
|
||||
promoted := map[int]bool{}
|
||||
for i := range n {
|
||||
if useMoreBits(i, n) {
|
||||
promoted[i] = true
|
||||
}
|
||||
}
|
||||
|
||||
// First 1/8 (30/8 = 3): layers 0, 1, 2
|
||||
for _, i := range []int{0, 1, 2} {
|
||||
if !promoted[i] {
|
||||
t.Errorf("layer %d should be promoted (first 1/8)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
|
||||
for _, i := range []int{26, 27, 28, 29} {
|
||||
if !promoted[i] {
|
||||
t.Errorf("layer %d should be promoted (last 1/8)", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Some middle layers should NOT be promoted
|
||||
for _, i := range []int{3, 4, 6, 7} {
|
||||
if promoted[i] {
|
||||
t.Errorf("layer %d should NOT be promoted", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Layer 5 should be promoted: (5 - 3) % 3 == 2
|
||||
if !promoted[5] {
|
||||
t.Errorf("layer 5 should be promoted (periodic)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemma4StackedMoETensor(t *testing.T) {
|
||||
tests := []struct {
|
||||
label string
|
||||
tensorName string
|
||||
shape []int32
|
||||
want bool
|
||||
}{
|
||||
// New-style: .experts.gate_up_proj
|
||||
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
|
||||
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
|
||||
// Old-style: .moe.gate_proj
|
||||
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
|
||||
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
|
||||
// Not stacked: 2D
|
||||
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
|
||||
// Not expert
|
||||
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
|
||||
// Not a projection
|
||||
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.label, func(t *testing.T) {
|
||||
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
|
||||
if got != tt.want {
|
||||
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
|
||||
tt.tensorName, tt.shape, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/gemma4"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
|
||||
@@ -4,16 +4,57 @@ package mlx
|
||||
import "C"
|
||||
import "math"
|
||||
|
||||
func GELUApprox(t *Array) *Array {
|
||||
return t.Multiply(
|
||||
FromValue[float32](0.5),
|
||||
).Multiply(
|
||||
t.Add(
|
||||
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
||||
).Multiply(
|
||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
||||
).Tanh().Add(FromValue[float32](1.0)),
|
||||
).AsType(t.DType())
|
||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||
|
||||
// GELUApprox matches mlx.nn.gelu_approx:
|
||||
//
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func GELUApprox(x *Array) *Array {
|
||||
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
|
||||
half := scalarWithDtype(0.5, x)
|
||||
defer C.mlx_array_free(half)
|
||||
coeff := scalarWithDtype(geluCoeff, x)
|
||||
defer C.mlx_array_free(coeff)
|
||||
c := scalarWithDtype(0.044715, x)
|
||||
defer C.mlx_array_free(c)
|
||||
|
||||
// x^3 via x*x*x (avoids general Power which is slower)
|
||||
x3 := New("GELU_X3")
|
||||
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
|
||||
tmp := New("GELU_X3b")
|
||||
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
|
||||
x3 = tmp
|
||||
|
||||
// 0.044715 * x^3
|
||||
cx3 := New("GELU_CX3")
|
||||
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
|
||||
|
||||
// x + 0.044715 * x^3
|
||||
inner := New("GELU_INNER")
|
||||
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
|
||||
|
||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
scaled := New("GELU_SCALED")
|
||||
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
|
||||
|
||||
// tanh(...)
|
||||
th := New("GELU_TANH")
|
||||
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
|
||||
|
||||
// 1 + tanh(...)
|
||||
one := scalarWithDtype(1.0, x)
|
||||
defer C.mlx_array_free(one)
|
||||
onePlusTanh := New("GELU_1PT")
|
||||
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
|
||||
|
||||
// 0.5 * x
|
||||
halfX := New("GELU_HALFX")
|
||||
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
|
||||
|
||||
// 0.5 * x * (1 + tanh(...))
|
||||
out := New("GELU_APPROX")
|
||||
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SILU(t *Array) *Array {
|
||||
|
||||
@@ -90,3 +90,10 @@ func AsyncEval(outputs ...*Array) {
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
|
||||
// MetalIsAvailable returns true if a Metal GPU is available.
|
||||
func MetalIsAvailable() bool {
|
||||
var available C._Bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
|
||||
@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Pad(a *Array, paddings []int32) *Array {
|
||||
numAxes := len(paddings) / 2
|
||||
axes := make([]C.int, numAxes)
|
||||
lowPad := make([]C.int, numAxes)
|
||||
highPad := make([]C.int, numAxes)
|
||||
for i := range numAxes {
|
||||
axes[i] = C.int(i)
|
||||
lowPad[i] = C.int(paddings[i*2])
|
||||
highPad[i] = C.int(paddings[i*2+1])
|
||||
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
|
||||
// MLX uses NHWC layout.
|
||||
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
|
||||
out := New("CONV2D")
|
||||
C.mlx_conv2d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(strideH), C.int(strideW),
|
||||
C.int(padH), C.int(padW),
|
||||
C.int(dilationH), C.int(dilationW),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// Pad pads array a along the given axes with specified low/high pad sizes.
|
||||
// mode should be "constant", "edge", or "reflect".
|
||||
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
cLow := make([]C.int, len(lowPad))
|
||||
cHigh := make([]C.int, len(highPad))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
cLow[i] = C.int(lowPad[i])
|
||||
cHigh[i] = C.int(highPad[i])
|
||||
}
|
||||
|
||||
padValue := C.mlx_array_new_float(C.float(0))
|
||||
defer C.mlx_array_free(padValue)
|
||||
|
||||
cMode := C.CString("constant")
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(axes),
|
||||
C.size_t(len(axes)),
|
||||
unsafe.SliceData(lowPad),
|
||||
C.size_t(len(lowPad)),
|
||||
unsafe.SliceData(highPad),
|
||||
C.size_t(len(highPad)),
|
||||
padValue,
|
||||
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
|
||||
unsafe.SliceData(cLow), C.size_t(len(cLow)),
|
||||
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
|
||||
padValue.ctx,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// PadConstant pads with zeros along the given axes.
|
||||
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
|
||||
zero := NewScalarArray(float32(0))
|
||||
return Pad(a, axes, lowPad, highPad, zero, "constant")
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Maximum returns element-wise maximum of two arrays.
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAXIMUM")
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise minimum of two arrays.
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MINIMUM")
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
|
||||
func Softplus(a *Array) *Array {
|
||||
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
|
||||
}
|
||||
|
||||
// ReLU computes max(0, x).
|
||||
func ReLU(a *Array) *Array {
|
||||
return Maximum(a, NewScalarArray(float32(0)))
|
||||
}
|
||||
|
||||
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
|
||||
// returns first * sigmoid(second).
|
||||
func GLU(a *Array) *Array {
|
||||
lastDim := a.NumDims() - 1
|
||||
halfSize := a.Dim(lastDim) / 2
|
||||
first := SliceStartStop(a,
|
||||
make([]int32, lastDim+1), // all zeros for start
|
||||
appendDims(a, lastDim, int32(halfSize)),
|
||||
)
|
||||
second := SliceStartStop(a,
|
||||
appendDimsStart(a, lastDim, int32(halfSize)),
|
||||
appendDims(a, lastDim, int32(a.Dim(lastDim))),
|
||||
)
|
||||
return first.Multiply(second.Sigmoid())
|
||||
}
|
||||
|
||||
// helper: builds stop array for SliceStartStop where the target axis = val
|
||||
func appendDims(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
} else {
|
||||
out[i] = int32(a.Dim(i))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// helper: builds start array for SliceStartStop where the target axis = val
|
||||
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Clamp clamps array values to [min, max].
|
||||
func Clamp(a *Array, minVal, maxVal float32) *Array {
|
||||
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
@@ -323,20 +410,37 @@ func SiLU(a *Array) *Array {
|
||||
}
|
||||
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||
}
|
||||
|
||||
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||
// When freqs is non-nil, it is used instead of computing from base.
|
||||
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
|
||||
var freqsCtx C.mlx_array
|
||||
var optBase C.mlx_optional_float
|
||||
if freqs != nil {
|
||||
freqsCtx = freqs.ctx
|
||||
optBase = C.mlx_optional_float{has_value: C.bool(false)}
|
||||
} else {
|
||||
empty := New("")
|
||||
freqsCtx = empty.ctx
|
||||
optBase = C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
}
|
||||
}
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
},
|
||||
optBase,
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
freqsCtx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
@@ -358,6 +462,24 @@ func Log(a *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Clip(a, aMin, aMax *Array) *Array {
|
||||
out := New("CLIP")
|
||||
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Logaddexp(a, b *Array) *Array {
|
||||
out := New("LOGADDEXP")
|
||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
@@ -385,6 +507,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
|
||||
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
|
||||
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
|
||||
// empty string is "no mask" and silently ignores the array argument.
|
||||
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString("array")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
var w, b C.mlx_array
|
||||
|
||||
@@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
// Parse full metadata for per-tensor quant info
|
||||
var metaMap map[string]string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
json.Unmarshal(metaRaw, &metaMap)
|
||||
}
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
@@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
|
||||
if metaMap != nil {
|
||||
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
|
||||
if v, err := strconv.Atoi(gs); err == nil {
|
||||
groupSize = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
|
||||
1514
x/models/gemma4/gemma4.go
Normal file
1514
x/models/gemma4/gemma4.go
Normal file
File diff suppressed because it is too large
Load Diff
228
x/models/gemma4/gemma4_moe_test.go
Normal file
228
x/models/gemma4/gemma4_moe_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// onesLike creates a tensor of the given shape filled with a small constant.
|
||||
func onesLike(shape ...int) *mlx.Array {
|
||||
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
|
||||
}
|
||||
|
||||
func TestMoEForward(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Small config matching 26b architecture pattern.
|
||||
cfg := &TextConfig{
|
||||
HiddenSize: 16, // tiny for testing
|
||||
NumAttentionHeads: 2,
|
||||
NumKeyValueHeads: 1,
|
||||
NumGlobalKeyValueHeads: 1,
|
||||
HeadDim: 8,
|
||||
GlobalHeadDim: 8,
|
||||
NumExperts: 4,
|
||||
TopKExperts: 2,
|
||||
ExpertIntermediateSize: 8,
|
||||
EnableMoeBlock: true,
|
||||
AttentionKEqV: false,
|
||||
RMSNormEps: 1e-6,
|
||||
SlidingScale: 1.0,
|
||||
FullScale: 1.0,
|
||||
}
|
||||
|
||||
B, L := int32(1), int32(3)
|
||||
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
|
||||
|
||||
// Test Router.Forward.
|
||||
router := &Router{
|
||||
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
|
||||
Scale: onesLike(int(cfg.HiddenSize)),
|
||||
}
|
||||
|
||||
t.Run("Router", func(t *testing.T) {
|
||||
scores, inds := router.Forward(x, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
sDims := scores.Dims()
|
||||
iDims := inds.Dims()
|
||||
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
|
||||
|
||||
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
|
||||
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
|
||||
}
|
||||
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
|
||||
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
|
||||
}
|
||||
})
|
||||
|
||||
// Test MoEBlock.Forward.
|
||||
moe := &MoEBlock{
|
||||
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
|
||||
PerExpertScale: onesLike(int(cfg.NumExperts)),
|
||||
}
|
||||
|
||||
t.Run("MoEBlock", func(t *testing.T) {
|
||||
scores, inds := router.Forward(x, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
out := moe.Forward(x, scores, inds, cfg)
|
||||
mlx.Eval(out)
|
||||
|
||||
outDims := out.Dims()
|
||||
t.Logf("MoE output shape: %v", outDims)
|
||||
|
||||
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
|
||||
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
|
||||
}
|
||||
})
|
||||
|
||||
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
|
||||
t.Run("MoEBlock_sorted", func(t *testing.T) {
|
||||
bigB, bigL := int32(1), int32(128)
|
||||
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
|
||||
|
||||
scores, inds := router.Forward(bigX, cfg)
|
||||
mlx.Eval(scores, inds)
|
||||
|
||||
out := moe.Forward(bigX, scores, inds, cfg)
|
||||
mlx.Eval(out)
|
||||
|
||||
outDims := out.Dims()
|
||||
t.Logf("MoE sorted output shape: %v", outDims)
|
||||
|
||||
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
|
||||
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
|
||||
// which takes the top-k of the raw logits and softmaxes only the selected
|
||||
// values — produces the same indices and (within tolerance) the same
|
||||
// normalized scores as the legacy path that softmaxes over every expert
|
||||
// first, gathers the top-k probabilities, then renormalizes.
|
||||
func TestRouterForwardMatchesLegacy(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cfg := &TextConfig{
|
||||
HiddenSize: 8,
|
||||
NumExperts: 4,
|
||||
TopKExperts: 2,
|
||||
RMSNormEps: 1e-6,
|
||||
RouterScale: 0.5,
|
||||
}
|
||||
|
||||
// Distinct per-expert weight rows so top-k has a well-defined ordering
|
||||
// (tied scores would let argpartition pick either tied expert and make
|
||||
// the index comparison below flaky).
|
||||
projWeight := mlx.FromValues([]float32{
|
||||
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
|
||||
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
|
||||
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
|
||||
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
|
||||
}, int(cfg.NumExperts), int(cfg.HiddenSize))
|
||||
|
||||
scale := mlx.FromValues([]float32{
|
||||
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
|
||||
}, int(cfg.HiddenSize))
|
||||
|
||||
r := &Router{
|
||||
Proj: linearFromWeight(projWeight),
|
||||
Scale: scale,
|
||||
}
|
||||
|
||||
// Varied x so different positions potentially hit different top-k.
|
||||
x := mlx.FromValues([]float32{
|
||||
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
|
||||
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
|
||||
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
|
||||
}, 1, 3, int(cfg.HiddenSize))
|
||||
|
||||
gotScores, gotInds := r.Forward(x, cfg)
|
||||
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
||||
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
||||
|
||||
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
||||
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
|
||||
}
|
||||
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
|
||||
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// legacyRouterForward implements the pre-optimization router: full softmax
|
||||
// over every expert, gather the top-k probabilities, then renormalize them
|
||||
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
|
||||
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
|
||||
dims := x.Dims()
|
||||
BL := int32(dims[0]) * int32(dims[1])
|
||||
|
||||
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
|
||||
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
|
||||
normed = mlx.MulScalar(normed, cfg.RouterScale)
|
||||
normed = mlx.Mul(normed, r.Scale)
|
||||
|
||||
expertScores := r.Proj.Forward(normed)
|
||||
probs := mlx.SoftmaxAxis(expertScores, -1, true)
|
||||
|
||||
neg := mlx.Neg(expertScores)
|
||||
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
|
||||
inds = mlx.SliceStartStop(inds,
|
||||
[]int32{0, 0},
|
||||
[]int32{BL, cfg.TopKExperts},
|
||||
)
|
||||
|
||||
scores := mlx.TakeAlongAxis(probs, inds, -1)
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
scores = mlx.Div(scores, sumScores)
|
||||
return scores, inds
|
||||
}
|
||||
|
||||
func intSlicesEqual(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func floatSlicesClose(a, b []float32, tol float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
d := a[i] - b[i]
|
||||
if d < 0 {
|
||||
d = -d
|
||||
}
|
||||
if d > tol {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
|
||||
func linearFromWeight(w *mlx.Array) *simpleLinear {
|
||||
return &simpleLinear{weight: w}
|
||||
}
|
||||
|
||||
type simpleLinear struct {
|
||||
weight *mlx.Array
|
||||
}
|
||||
|
||||
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
|
||||
}
|
||||
|
||||
func (l *simpleLinear) OutputDim() int32 {
|
||||
return int32(l.weight.Dims()[0])
|
||||
}
|
||||
503
x/models/gemma4/gemma4_test.go
Normal file
503
x/models/gemma4/gemma4_test.go
Normal file
@@ -0,0 +1,503 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestParseTextConfigE2B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 1536,
|
||||
"num_hidden_layers": 35,
|
||||
"intermediate_size": 6144,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 512,
|
||||
"sliding_window_pattern": 5,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": true,
|
||||
"num_kv_shared_layers": 20,
|
||||
"hidden_size_per_layer_input": 256,
|
||||
"vocab_size_per_layer_input": 262144,
|
||||
"attention_k_eq_v": false,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
// Basic fields.
|
||||
if cfg.HiddenSize != 1536 {
|
||||
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 35 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.GlobalHeadDim != 512 {
|
||||
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
|
||||
}
|
||||
if cfg.FinalLogitSoftcapping != 30.0 {
|
||||
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 20 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 256 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
|
||||
// RoPE settings.
|
||||
if cfg.SlidingRopeDims != 256 {
|
||||
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
|
||||
}
|
||||
if cfg.FullRopeDims != 512 {
|
||||
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
|
||||
}
|
||||
if cfg.SlidingRopeBase != 10000 {
|
||||
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
|
||||
}
|
||||
if cfg.FullRopeBase != 1000000 {
|
||||
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
|
||||
}
|
||||
|
||||
// Attention scale.
|
||||
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
|
||||
t.Error("attention scales should be non-zero")
|
||||
}
|
||||
|
||||
// KV sharing map.
|
||||
// First shared layer is 35 - 20 = 15.
|
||||
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
|
||||
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
|
||||
}
|
||||
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
|
||||
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||
}
|
||||
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
|
||||
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||
}
|
||||
// Layer 14 should not be shared.
|
||||
if _, ok := cfg.KVShareMap[14]; ok {
|
||||
t.Error("layer 14 should not be in KVShareMap (non-shared)")
|
||||
}
|
||||
|
||||
// Donors.
|
||||
if !cfg.KVDonors[13] {
|
||||
t.Error("layer 13 should be a KV donor")
|
||||
}
|
||||
if !cfg.KVDonors[14] {
|
||||
t.Error("layer 14 should be a KV donor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfig26B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 2816,
|
||||
"num_hidden_layers": 30,
|
||||
"intermediate_size": 2112,
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 8,
|
||||
"num_global_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 1024,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 0,
|
||||
"hidden_size_per_layer_input": null,
|
||||
"attention_k_eq_v": true,
|
||||
"enable_moe_block": true,
|
||||
"num_experts": 128,
|
||||
"top_k_experts": 8,
|
||||
"moe_intermediate_size": 704,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 2816 {
|
||||
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
|
||||
}
|
||||
if !cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be true")
|
||||
}
|
||||
if cfg.NumGlobalKeyValueHeads != 2 {
|
||||
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
|
||||
}
|
||||
if !cfg.EnableMoeBlock {
|
||||
t.Error("EnableMoeBlock should be true")
|
||||
}
|
||||
if cfg.NumExperts != 128 {
|
||||
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
|
||||
}
|
||||
if cfg.TopKExperts != 8 {
|
||||
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
|
||||
}
|
||||
if cfg.ExpertIntermediateSize != 704 {
|
||||
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 0 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfig31B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 5376,
|
||||
"num_hidden_layers": 60,
|
||||
"intermediate_size": 21504,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 16,
|
||||
"num_global_key_value_heads": 4,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 1024,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 0,
|
||||
"hidden_size_per_layer_input": null,
|
||||
"attention_k_eq_v": true,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 5376 {
|
||||
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 60 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
|
||||
}
|
||||
if !cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be true")
|
||||
}
|
||||
if cfg.NumGlobalKeyValueHeads != 4 {
|
||||
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
|
||||
}
|
||||
if cfg.NumKeyValueHeads != 16 {
|
||||
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 0 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 0 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
if cfg.SlidingWindow != 1024 {
|
||||
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
|
||||
}
|
||||
|
||||
// KV sharing should be empty (no shared layers).
|
||||
if len(cfg.KVShareMap) != 0 {
|
||||
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
|
||||
}
|
||||
|
||||
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
|
||||
if !isLayerSliding(0, &cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if isLayerSliding(5, &cfg) {
|
||||
t.Error("layer 5 should be full attention")
|
||||
}
|
||||
if !isLayerSliding(6, &cfg) {
|
||||
t.Error("layer 6 should be sliding")
|
||||
}
|
||||
if isLayerSliding(59, &cfg) {
|
||||
t.Error("layer 59 should be full attention")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTextConfigE4B(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
data := []byte(`{
|
||||
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||
"text_config": {
|
||||
"hidden_size": 2560,
|
||||
"num_hidden_layers": 42,
|
||||
"intermediate_size": 10240,
|
||||
"num_attention_heads": 8,
|
||||
"num_key_value_heads": 2,
|
||||
"head_dim": 256,
|
||||
"global_head_dim": 512,
|
||||
"vocab_size": 262144,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"max_position_embeddings": 131072,
|
||||
"sliding_window": 512,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"use_double_wide_mlp": false,
|
||||
"num_kv_shared_layers": 18,
|
||||
"hidden_size_per_layer_input": 256,
|
||||
"vocab_size_per_layer_input": 262144,
|
||||
"attention_k_eq_v": false,
|
||||
"enable_moe_block": false,
|
||||
"tie_word_embeddings": true,
|
||||
"layer_types": [
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||
],
|
||||
"rope_parameters": {
|
||||
"full_attention": {
|
||||
"partial_rotary_factor": 0.25,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "proportional"
|
||||
},
|
||||
"sliding_attention": {
|
||||
"rope_theta": 10000.0,
|
||||
"rope_type": "default"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
cfg, err := parseTextConfig(data)
|
||||
if err != nil {
|
||||
t.Fatalf("parseTextConfig failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.HiddenSize != 2560 {
|
||||
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
|
||||
}
|
||||
if cfg.NumHiddenLayers != 42 {
|
||||
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
|
||||
}
|
||||
if cfg.IntermediateSize != 10240 {
|
||||
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
|
||||
}
|
||||
if cfg.NumKeyValueHeads != 2 {
|
||||
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
|
||||
}
|
||||
if cfg.UseDoubleWideMLP {
|
||||
t.Error("UseDoubleWideMLP should be false")
|
||||
}
|
||||
if cfg.NumKVSharedLayers != 18 {
|
||||
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
|
||||
}
|
||||
if cfg.HiddenSizePerLayer != 256 {
|
||||
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
|
||||
}
|
||||
if cfg.AttentionKEqV {
|
||||
t.Error("AttentionKEqV should be false")
|
||||
}
|
||||
if cfg.EnableMoeBlock {
|
||||
t.Error("EnableMoeBlock should be false")
|
||||
}
|
||||
if cfg.SlidingWindow != 512 {
|
||||
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
|
||||
}
|
||||
|
||||
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
|
||||
if !isLayerSliding(0, &cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if isLayerSliding(5, &cfg) {
|
||||
t.Error("layer 5 should be full attention")
|
||||
}
|
||||
if !isLayerSliding(6, &cfg) {
|
||||
t.Error("layer 6 should be sliding")
|
||||
}
|
||||
if isLayerSliding(41, &cfg) {
|
||||
t.Error("layer 41 should be full attention")
|
||||
}
|
||||
|
||||
// KV sharing: first shared = 42 - 18 = 24.
|
||||
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
|
||||
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
|
||||
if donor, ok := cfg.KVShareMap[24]; !ok {
|
||||
t.Error("layer 24 should be in KVShareMap")
|
||||
} else {
|
||||
t.Logf("layer 24 donor = %d", donor)
|
||||
}
|
||||
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
|
||||
// Non-shared full layers: 5, 11, 17, 23.
|
||||
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
|
||||
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
|
||||
}
|
||||
// Layer 23 should NOT be shared (it's the last non-shared layer).
|
||||
if _, ok := cfg.KVShareMap[23]; ok {
|
||||
t.Error("layer 23 should not be in KVShareMap (non-shared)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLayerTypeDetection(t *testing.T) {
|
||||
cfg := &TextConfig{
|
||||
LayerTypes: []string{
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
},
|
||||
}
|
||||
|
||||
if !isLayerSliding(0, cfg) {
|
||||
t.Error("layer 0 should be sliding")
|
||||
}
|
||||
if !isLayerSliding(3, cfg) {
|
||||
t.Error("layer 3 should be sliding")
|
||||
}
|
||||
if isLayerSliding(4, cfg) {
|
||||
t.Error("layer 4 should be full attention")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
{IsSliding: false, KVShareDonor: -1},
|
||||
{IsSliding: true, KVShareDonor: 0},
|
||||
{IsSliding: false, KVShareDonor: 1},
|
||||
},
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if got, want := len(caches), 2; got != want {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
|
||||
m := &Model{
|
||||
Layers: []*DecoderLayer{
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
{IsSliding: false, KVShareDonor: -1},
|
||||
{IsSliding: true, KVShareDonor: -1},
|
||||
},
|
||||
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||
}
|
||||
|
||||
caches := m.NewCaches()
|
||||
if got, want := len(caches), len(m.Layers); got != want {
|
||||
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWeightPrefix(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
wantPfx string
|
||||
}{
|
||||
{"bare", "embed_tokens.weight", ""},
|
||||
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
|
||||
{"with_model", "model.embed_tokens.weight", "model."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dummy := mlx.FromValue(float32(1.0))
|
||||
mlx.Eval(dummy)
|
||||
tensors := map[string]*mlx.Array{tt.key: dummy}
|
||||
got := resolveWeightPrefix(tensors)
|
||||
if got != tt.wantPfx {
|
||||
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user