mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 08:45:53 +02:00
Compare commits
16 Commits
v0.20.7-rc
...
hoyyeva/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a3ed0a1b4 | ||
|
|
03f9e57274 | ||
|
|
30d9100fff | ||
|
|
698e04a14b | ||
|
|
1d9537bc33 | ||
|
|
120424d832 | ||
|
|
5818001610 | ||
|
|
2cba7756c5 | ||
|
|
bf2a421727 | ||
|
|
f3cf6b75fb | ||
|
|
5dfac387a6 | ||
|
|
a99e5d9c22 | ||
|
|
0abf3aca36 | ||
|
|
ee0266462a | ||
|
|
c88fb286ec | ||
|
|
d3da29cbfc |
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stubEditorRunner struct {
|
type stubEditorRunner struct {
|
||||||
@@ -722,6 +723,59 @@ func TestLauncherClientFilterDisabledCloudModels_ChecksStatusOncePerInvocation(t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSavedMatchesModels(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
saved *config.IntegrationConfig
|
||||||
|
models []string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil saved",
|
||||||
|
saved: nil,
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "identical order",
|
||||||
|
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||||
|
models: []string{"llama3.2", "qwen3:8b"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different order",
|
||||||
|
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||||
|
models: []string{"qwen3:8b", "llama3.2"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subset",
|
||||||
|
saved: &config.IntegrationConfig{Models: []string{"llama3.2", "qwen3:8b"}},
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil models in saved with non-nil models",
|
||||||
|
saved: &config.IntegrationConfig{Models: nil},
|
||||||
|
models: []string{"llama3.2"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty both",
|
||||||
|
saved: &config.IntegrationConfig{Models: nil},
|
||||||
|
models: nil,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := savedMatchesModels(tt.saved, tt.models); got != tt.want {
|
||||||
|
t.Fatalf("savedMatchesModels = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
|
func TestPrepareEditorIntegration_SavesOnlyAfterSuccessfulEdit(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
@@ -500,7 +501,7 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if needsConfigure || req.ModelOverride != "" {
|
if (needsConfigure || req.ModelOverride != "") && !savedMatchesModels(saved, models) {
|
||||||
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
if err := prepareEditorIntegration(name, runner, editor, models); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -846,6 +847,13 @@ func firstModel(models []string) string {
|
|||||||
return models[0]
|
return models[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func savedMatchesModels(saved *config.IntegrationConfig, models []string) bool {
|
||||||
|
if saved == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return slices.Equal(saved.Models, models)
|
||||||
|
}
|
||||||
|
|
||||||
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
func editorPreCheckedModels(saved *config.IntegrationConfig, override string) []string {
|
||||||
if override == "" {
|
if override == "" {
|
||||||
if saved == nil {
|
if saved == nil {
|
||||||
|
|||||||
@@ -186,6 +186,11 @@ func (c *Openclaw) runChannelSetupPreflight(bin string) error {
|
|||||||
if !isInteractiveSession() {
|
if !isInteractiveSession() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// --yes is headless; channel setup spawns an interactive picker we can't
|
||||||
|
// auto-answer, so skip it. Users can run `openclaw channels add` later.
|
||||||
|
if currentLaunchConfirmPolicy.yes {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if c.channelsConfigured() {
|
if c.channelsConfigured() {
|
||||||
|
|||||||
@@ -1304,6 +1304,46 @@ func TestOpenclawChannelSetupPreflight(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("--yes skips preflight without channels configured", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
configDir := filepath.Join(tmpDir, ".openclaw")
|
||||||
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Empty config = no channels configured. Without the --yes skip, the
|
||||||
|
// preflight would prompt and (on confirm) spawn `openclaw channels add`.
|
||||||
|
if err := os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bin := filepath.Join(tmpDir, "openclaw")
|
||||||
|
if err := os.WriteFile(bin, []byte("#!/bin/sh\nprintf '%s\\n' \"$*\" >> \"$HOME/invocations.log\"\n"), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oldInteractive := isInteractiveSession
|
||||||
|
isInteractiveSession = func() bool { return true }
|
||||||
|
defer func() { isInteractiveSession = oldInteractive }()
|
||||||
|
|
||||||
|
restore := withLaunchConfirmPolicy(launchConfirmPolicy{yes: true})
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
oldConfirmPrompt := DefaultConfirmPrompt
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
t.Fatalf("did not expect prompt in --yes mode: %s", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
defer func() { DefaultConfirmPrompt = oldConfirmPrompt }()
|
||||||
|
|
||||||
|
if err := c.runChannelSetupPreflight("openclaw"); err != nil {
|
||||||
|
t.Fatalf("runChannelSetupPreflight() error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(tmpDir, "invocations.log")); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("expected no channels add invocation in --yes mode, got err=%v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("set up later prompts once and exits", func(t *testing.T) {
|
t.Run("set up later prompts once and exits", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setTestHome(t, tmpDir)
|
setTestHome(t, tmpDir)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package launch
|
package launch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -11,12 +12,18 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/cmd/internal/fileutil"
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
modeltype "github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenCode implements Runner and Editor for OpenCode integration
|
// OpenCode implements Runner and Editor for OpenCode integration.
|
||||||
type OpenCode struct{}
|
// Config is passed via OPENCODE_CONFIG_CONTENT env var at launch time
|
||||||
|
// instead of writing to opencode's config files.
|
||||||
|
type OpenCode struct {
|
||||||
|
configContent string // JSON config built by Edit, passed to Run via env var
|
||||||
|
}
|
||||||
|
|
||||||
func (o *OpenCode) String() string { return "OpenCode" }
|
func (o *OpenCode) String() string { return "OpenCode" }
|
||||||
|
|
||||||
@@ -51,25 +58,51 @@ func (o *OpenCode) Run(model string, args []string) error {
|
|||||||
cmd.Stdin = os.Stdin
|
cmd.Stdin = os.Stdin
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
cmd.Env = os.Environ()
|
||||||
|
if content := o.resolveContent(model); content != "" {
|
||||||
|
cmd.Env = append(cmd.Env, "OPENCODE_CONFIG_CONTENT="+content)
|
||||||
|
}
|
||||||
return cmd.Run()
|
return cmd.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveContent returns the inline config to send via OPENCODE_CONFIG_CONTENT.
|
||||||
|
// Returns content built by Edit if available, otherwise builds from model.json
|
||||||
|
// with the requested model as primary (e.g. re-launch with saved config).
|
||||||
|
func (o *OpenCode) resolveContent(model string) string {
|
||||||
|
if o.configContent != "" {
|
||||||
|
return o.configContent
|
||||||
|
}
|
||||||
|
models := readModelJSONModels()
|
||||||
|
if !slices.Contains(models, model) {
|
||||||
|
models = append([]string{model}, models...)
|
||||||
|
}
|
||||||
|
content, err := buildInlineConfig(model, models)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
func (o *OpenCode) Paths() []string {
|
func (o *OpenCode) Paths() []string {
|
||||||
home, err := os.UserHomeDir()
|
sp, err := openCodeStatePath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var paths []string
|
|
||||||
p := filepath.Join(home, ".config", "opencode", "opencode.json")
|
|
||||||
if _, err := os.Stat(p); err == nil {
|
|
||||||
paths = append(paths, p)
|
|
||||||
}
|
|
||||||
sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
|
||||||
if _, err := os.Stat(sp); err == nil {
|
if _, err := os.Stat(sp); err == nil {
|
||||||
paths = append(paths, sp)
|
return []string{sp}
|
||||||
}
|
}
|
||||||
return paths
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// openCodeStatePath returns the path to opencode's model state file.
|
||||||
|
// TODO: this hardcodes the Linux/macOS XDG path. On Windows, opencode stores
|
||||||
|
// state under %LOCALAPPDATA% (or similar) — verify and branch on runtime.GOOS.
|
||||||
|
func openCodeStatePath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".local", "state", "opencode", "model.json"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenCode) Edit(modelList []string) error {
|
func (o *OpenCode) Edit(modelList []string) error {
|
||||||
@@ -77,110 +110,17 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
home, err := os.UserHomeDir()
|
content, err := buildInlineConfig(modelList[0], modelList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
o.configContent = content
|
||||||
|
|
||||||
configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
|
// Write model state file so models appear in OpenCode's model picker
|
||||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
statePath, err := openCodeStatePath()
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config := make(map[string]any)
|
|
||||||
if data, err := os.ReadFile(configPath); err == nil {
|
|
||||||
_ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
|
|
||||||
}
|
|
||||||
|
|
||||||
config["$schema"] = "https://opencode.ai/config.json"
|
|
||||||
|
|
||||||
provider, ok := config["provider"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
provider = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama, ok := provider["ollama"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
ollama = map[string]any{
|
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
|
||||||
"name": "Ollama",
|
|
||||||
"options": map[string]any{
|
|
||||||
"baseURL": envconfig.Host().String() + "/v1",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Migrate legacy provider name
|
|
||||||
if name, _ := ollama["name"].(string); name == "Ollama (local)" {
|
|
||||||
ollama["name"] = "Ollama"
|
|
||||||
}
|
|
||||||
|
|
||||||
models, ok := ollama["models"].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
models = make(map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedSet := make(map[string]bool)
|
|
||||||
for _, m := range modelList {
|
|
||||||
selectedSet[m] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, cfg := range models {
|
|
||||||
if cfgMap, ok := cfg.(map[string]any); ok {
|
|
||||||
if isOllamaModel(cfgMap) && !selectedSet[name] {
|
|
||||||
delete(models, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, model := range modelList {
|
|
||||||
if existing, ok := models[model].(map[string]any); ok {
|
|
||||||
// migrate existing models without _launch marker
|
|
||||||
if isOllamaModel(existing) {
|
|
||||||
existing["_launch"] = true
|
|
||||||
if name, ok := existing["name"].(string); ok {
|
|
||||||
existing["name"] = strings.TrimSuffix(name, " [Ollama]")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isCloudModelName(model) {
|
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
|
||||||
existing["limit"] = map[string]any{
|
|
||||||
"context": l.Context,
|
|
||||||
"output": l.Output,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
entry := map[string]any{
|
|
||||||
"name": model,
|
|
||||||
"_launch": true,
|
|
||||||
}
|
|
||||||
if isCloudModelName(model) {
|
|
||||||
if l, ok := lookupCloudModelLimit(model); ok {
|
|
||||||
entry["limit"] = map[string]any{
|
|
||||||
"context": l.Context,
|
|
||||||
"output": l.Output,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
models[model] = entry
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama["models"] = models
|
|
||||||
provider["ollama"] = ollama
|
|
||||||
config["provider"] = provider
|
|
||||||
config["model"] = "ollama/" + modelList[0]
|
|
||||||
|
|
||||||
configData, err := json.MarshalIndent(config, "", " ")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := fileutil.WriteWithBackup(configPath, configData); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
|
|
||||||
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -232,33 +172,127 @@ func (o *OpenCode) Edit(modelList []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenCode) Models() []string {
|
func (o *OpenCode) Models() []string {
|
||||||
home, err := os.UserHomeDir()
|
return nil
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
config, err := fileutil.ReadJSON(filepath.Join(home, ".config", "opencode", "opencode.json"))
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
provider, _ := config["provider"].(map[string]any)
|
|
||||||
ollama, _ := provider["ollama"].(map[string]any)
|
|
||||||
models, _ := ollama["models"].(map[string]any)
|
|
||||||
if len(models) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
keys := slices.Collect(maps.Keys(models))
|
|
||||||
slices.Sort(keys)
|
|
||||||
return keys
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// isOllamaModel reports whether a model config entry is managed by us
|
// buildInlineConfig produces the JSON string for OPENCODE_CONFIG_CONTENT.
|
||||||
func isOllamaModel(cfg map[string]any) bool {
|
// primary is the model to launch with, models is the full list of available models.
|
||||||
if v, ok := cfg["_launch"].(bool); ok && v {
|
func buildInlineConfig(primary string, models []string) (string, error) {
|
||||||
return true
|
if primary == "" || len(models) == 0 {
|
||||||
|
return "", fmt.Errorf("buildInlineConfig: primary and models are required")
|
||||||
}
|
}
|
||||||
// previously used [Ollama] as a suffix for the model managed by ollama launch
|
config := map[string]any{
|
||||||
if name, ok := cfg["name"].(string); ok {
|
"$schema": "https://opencode.ai/config.json",
|
||||||
return strings.HasSuffix(name, "[Ollama]")
|
"provider": map[string]any{
|
||||||
|
"ollama": map[string]any{
|
||||||
|
"npm": "@ai-sdk/openai-compatible",
|
||||||
|
"name": "Ollama",
|
||||||
|
"options": map[string]any{
|
||||||
|
"baseURL": envconfig.Host().String() + "/v1",
|
||||||
|
},
|
||||||
|
"models": buildModelEntries(models),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model": "ollama/" + primary,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(config)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readModelJSONModels reads ollama model IDs from the opencode model.json state file
|
||||||
|
func readModelJSONModels() []string {
|
||||||
|
statePath, err := openCodeStatePath()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(statePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var state map[string]any
|
||||||
|
if err := json.Unmarshal(data, &state); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
recent, _ := state["recent"].([]any)
|
||||||
|
var models []string
|
||||||
|
for _, entry := range recent {
|
||||||
|
e, ok := entry.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if e["providerID"] != "ollama" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if id, ok := e["modelID"].(string); ok && id != "" {
|
||||||
|
models = append(models, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildModelEntries(modelList []string) map[string]any {
|
||||||
|
client := api.NewClient(envconfig.Host(), http.DefaultClient)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
models := make(map[string]any)
|
||||||
|
for _, model := range modelList {
|
||||||
|
entry := map[string]any{
|
||||||
|
"name": model,
|
||||||
|
}
|
||||||
|
if isCloudModelName(model) {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
entry["limit"] = map[string]any{
|
||||||
|
"context": l.Context,
|
||||||
|
"output": l.Output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
applyOpenCodeReasoning(ctx, client, model, entry)
|
||||||
|
models[model] = entry
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyOpenCodeReasoning detects thinking capability and sets reasoning config
|
||||||
|
// on the model entry. When the model supports thinking, it sets "reasoning": true
|
||||||
|
// and configures variants for the OpenCode TUI:
|
||||||
|
// - GPT-OSS: supports variable effort levels (low/medium/high) and defaults to
|
||||||
|
// medium via options. Thinking cannot be turned off.
|
||||||
|
// - Other models: only support on/off. Disables built-in low/medium/high variants
|
||||||
|
// and adds a "none" variant so users can toggle thinking off via Ctrl+T.
|
||||||
|
//
|
||||||
|
// When the model does not support thinking, no reasoning config is set.
|
||||||
|
func applyOpenCodeReasoning(ctx context.Context, client *api.Client, modelName string, entry map[string]any) {
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(resp.Capabilities, modeltype.CapabilityThinking) {
|
||||||
|
entry["reasoning"] = true
|
||||||
|
|
||||||
|
if strings.Contains(modelName, "gpt-oss") {
|
||||||
|
// GPT-OSS models support variable thinking effort levels
|
||||||
|
// and cannot turn thinking off. Keep the built-in
|
||||||
|
// low/medium/high variants as-is and default to medium.
|
||||||
|
options, ok := entry["options"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
options = make(map[string]any)
|
||||||
|
}
|
||||||
|
options["reasoningEffort"] = "medium"
|
||||||
|
entry["options"] = options
|
||||||
|
} else {
|
||||||
|
// Most models only support thinking on or off.
|
||||||
|
// Disable the built-in low/medium/high variants and add none.
|
||||||
|
entry["variants"] = map[string]any{
|
||||||
|
"none": map[string]any{"reasoningEffort": "none"},
|
||||||
|
"low": map[string]any{"disabled": true},
|
||||||
|
"medium": map[string]any{"disabled": true},
|
||||||
|
"high": map[string]any{"disabled": true},
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
|||||||
binary: "opencode",
|
binary: "opencode",
|
||||||
runner: &OpenCode{},
|
runner: &OpenCode{},
|
||||||
checkPath: func(home string) string {
|
checkPath: func(home string) string {
|
||||||
return filepath.Join(home, ".config", "opencode", "opencode.json")
|
return filepath.Join(home, ".local", "state", "opencode", "model.json")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -28,79 +28,4 @@ To configure without launching:
|
|||||||
ollama launch opencode --config
|
ollama launch opencode --config
|
||||||
```
|
```
|
||||||
|
|
||||||
### Manual setup
|
<Note>`ollama launch opencode` passes its configuration to OpenCode inline via the `OPENCODE_CONFIG_CONTENT` environment variable. OpenCode deep-merges its config sources on startup, so anything you declare in `~/.config/opencode/opencode.json` is still respected and available inside OpenCode. Models declared only in `opencode.json` won't appear in `ollama launch`'s model-selection menu.</Note>
|
||||||
|
|
||||||
Add a configuration block to `~/.config/opencode/opencode.json`:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"$schema": "https://opencode.ai/config.json",
|
|
||||||
"provider": {
|
|
||||||
"ollama": {
|
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
|
||||||
"name": "Ollama",
|
|
||||||
"options": {
|
|
||||||
"baseURL": "http://localhost:11434/v1"
|
|
||||||
},
|
|
||||||
"models": {
|
|
||||||
"qwen3-coder": {
|
|
||||||
"name": "qwen3-coder"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Cloud Models
|
|
||||||
|
|
||||||
`glm-4.7:cloud` is the recommended model for use with OpenCode.
|
|
||||||
|
|
||||||
Add the cloud configuration to `~/.config/opencode/opencode.json`:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"$schema": "https://opencode.ai/config.json",
|
|
||||||
"provider": {
|
|
||||||
"ollama": {
|
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
|
||||||
"name": "Ollama",
|
|
||||||
"options": {
|
|
||||||
"baseURL": "http://localhost:11434/v1"
|
|
||||||
},
|
|
||||||
"models": {
|
|
||||||
"glm-4.7:cloud": {
|
|
||||||
"name": "glm-4.7:cloud"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Connecting to ollama.com
|
|
||||||
|
|
||||||
1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
|
|
||||||
2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"$schema": "https://opencode.ai/config.json",
|
|
||||||
"provider": {
|
|
||||||
"ollama": {
|
|
||||||
"npm": "@ai-sdk/openai-compatible",
|
|
||||||
"name": "Ollama Cloud",
|
|
||||||
"options": {
|
|
||||||
"baseURL": "https://ollama.com/v1"
|
|
||||||
},
|
|
||||||
"models": {
|
|
||||||
"glm-4.7:cloud": {
|
|
||||||
"name": "glm-4.7:cloud"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Run `opencode` in a new terminal to load the new settings.
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
// #cgo CXXFLAGS: -std=c++17
|
// #cgo CXXFLAGS: -std=c++17 -Wno-deprecated-declarations
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
|
// #cgo CPPFLAGS: -I${SRCDIR}/../include -I${SRCDIR}/../vendor
|
||||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
// #cgo CPPFLAGS: -I${SRCDIR}/../../../ml/backend/ggml/ggml/include
|
||||||
import "C"
|
import "C"
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||||||
|
|
||||||
// Emit system turn if there's a system/developer role, tools, or thinking.
|
// Emit system turn if there's a system/developer role, tools, or thinking.
|
||||||
hasThink := thinkValue != nil && thinkValue.Bool()
|
hasThink := thinkValue != nil && thinkValue.Bool()
|
||||||
thinkingExplicitlyDisabled := thinkValue != nil && thinkValue.IsBool() && !thinkValue.Bool()
|
|
||||||
if hasSystemRole || len(tools) > 0 || hasThink {
|
if hasSystemRole || len(tools) > 0 || hasThink {
|
||||||
sb.WriteString("<|turn>system\n")
|
sb.WriteString("<|turn>system\n")
|
||||||
if hasThink {
|
if hasThink {
|
||||||
@@ -125,9 +124,6 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||||||
// Generation prompt.
|
// Generation prompt.
|
||||||
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
||||||
sb.WriteString("<|turn>model\n")
|
sb.WriteString("<|turn>model\n")
|
||||||
if !hasThink && !thinkingExplicitlyDisabled {
|
|
||||||
sb.WriteString("<|channel>thought\n<channel|>")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
package renderers
|
package renderers
|
||||||
|
|
||||||
// TestGemma4RendererMatchesReference verifies our renderer matches the HF
|
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
||||||
// Jinja2 chat template exactly.
|
// Gemma 4 reference template.
|
||||||
//
|
//
|
||||||
// To regenerate expected values, save gemma4Jinja2Template (below) to
|
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
|
||||||
// gemma4_chat_template.jinja2 and run:
|
// reference intentionally uses the shared baseline without an empty generation-time
|
||||||
|
// thought channel until renderer selection is split by size.
|
||||||
|
//
|
||||||
|
// To regenerate expected values, save the E2B template to
|
||||||
|
// gemma4_e2b_chat_template.jinja2 and run:
|
||||||
//
|
//
|
||||||
// python3 -c "
|
// python3 -c "
|
||||||
// from jinja2 import Environment; import json
|
// from jinja2 import Environment; import json
|
||||||
// tmpl = Environment().from_string(open('gemma4_chat_template.jinja2').read())
|
// tmpl = Environment().from_string(open('gemma4_e2b_chat_template.jinja2').read())
|
||||||
// msgs = [{'role':'user','content':'Hello'}]
|
// msgs = [{'role':'user','content':'Hello'}]
|
||||||
// print(repr(tmpl.render(messages=msgs, bos_token='<bos>', add_generation_prompt=True)))
|
// print(repr(tmpl.render(messages=msgs, bos_token='<bos>', add_generation_prompt=True)))
|
||||||
// "
|
// "
|
||||||
@@ -26,8 +30,13 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The full Jinja2 template is committed as testdata/gemma4_chat_template.jinja2.
|
const (
|
||||||
// Run with VERIFY_JINJA2=1 to verify expected values against the template using uv + Python.
|
gemma4E2BTemplate = "testdata/gemma4_e2b_chat_template.jinja2"
|
||||||
|
gemma431BTemplate = "testdata/gemma4_31b_chat_template.jinja2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The upstream Gemma 4 chat templates are committed by size under testdata/.
|
||||||
|
// Run with VERIFY_JINJA2=1 to verify expected values against the E2B template using uv + Python.
|
||||||
|
|
||||||
func bashRefTool() []api.Tool {
|
func bashRefTool() []api.Tool {
|
||||||
return []api.Tool{{
|
return []api.Tool{{
|
||||||
@@ -665,7 +674,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "user_only",
|
name: "user_only",
|
||||||
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||||
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "system_user",
|
name: "system_user",
|
||||||
@@ -673,7 +682,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
{Role: "system", Content: "You are helpful."},
|
{Role: "system", Content: "You are helpful."},
|
||||||
{Role: "user", Content: "Hi"},
|
{Role: "user", Content: "Hi"},
|
||||||
},
|
},
|
||||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "developer_user",
|
name: "developer_user",
|
||||||
@@ -681,13 +690,13 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
{Role: "developer", Content: "You are helpful."},
|
{Role: "developer", Content: "You are helpful."},
|
||||||
{Role: "user", Content: "Hi"},
|
{Role: "user", Content: "Hi"},
|
||||||
},
|
},
|
||||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "tools_no_system",
|
name: "tools_no_system",
|
||||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||||
tools: bashRefTool(),
|
tools: bashRefTool(),
|
||||||
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>system\n" + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "system_tools",
|
name: "system_tools",
|
||||||
@@ -696,7 +705,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
{Role: "user", Content: "Hi"},
|
{Role: "user", Content: "Hi"},
|
||||||
},
|
},
|
||||||
tools: bashRefTool(),
|
tools: bashRefTool(),
|
||||||
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>system\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "thinking_no_system",
|
name: "thinking_no_system",
|
||||||
@@ -704,13 +713,6 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
think: thinkTrue(),
|
think: thinkTrue(),
|
||||||
expected: "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
expected: "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "nothink_no_system",
|
|
||||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
|
||||||
think: thinkFalse(),
|
|
||||||
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
|
|
||||||
skipJinja2: true,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "thinking_system",
|
name: "thinking_system",
|
||||||
messages: []api.Message{
|
messages: []api.Message{
|
||||||
@@ -737,6 +739,12 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
think: thinkTrue(),
|
think: thinkTrue(),
|
||||||
expected: "<bos><|turn>system\n<|think|>\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
expected: "<bos><|turn>system\n<|think|>\nYou are helpful." + bashDeclRef + "<turn|>\n<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "thinking_explicitly_disabled",
|
||||||
|
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
||||||
|
think: thinkFalse(),
|
||||||
|
expected: "<bos><|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
|
||||||
// === Message loop paths ===
|
// === Message loop paths ===
|
||||||
{
|
{
|
||||||
@@ -751,7 +759,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
"<|turn>user\nHi<turn|>\n" +
|
"<|turn>user\nHi<turn|>\n" +
|
||||||
"<|turn>model\nHello!<turn|>\n" +
|
"<|turn>model\nHello!<turn|>\n" +
|
||||||
"<|turn>user\nMore<turn|>\n" +
|
"<|turn>user\nMore<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Tool call with structured args → tool response as separate <|turn>tool turn
|
// Tool call with structured args → tool response as separate <|turn>tool turn
|
||||||
@@ -813,7 +821,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
"<|tool_response>response:bash{value:" + q + "file1.txt\nfile2.txt" + q + "}<tool_response|>" +
|
"<|tool_response>response:bash{value:" + q + "file1.txt\nfile2.txt" + q + "}<tool_response|>" +
|
||||||
"Here are the files.<turn|>\n" +
|
"Here are the files.<turn|>\n" +
|
||||||
"<|turn>user\nRead file1.txt<turn|>\n" +
|
"<|turn>user\nRead file1.txt<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Multiple tool calls + multiple tool responses
|
// Multiple tool calls + multiple tool responses
|
||||||
@@ -848,7 +856,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
expected: "<bos><|turn>user\nWhat is 2+2?<turn|>\n" +
|
expected: "<bos><|turn>user\nWhat is 2+2?<turn|>\n" +
|
||||||
"<|turn>model\n4<turn|>\n" +
|
"<|turn>model\n4<turn|>\n" +
|
||||||
"<|turn>user\nAnd 3+3?<turn|>\n" +
|
"<|turn>user\nAnd 3+3?<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
// === Additional edge cases ported from original tests ===
|
// === Additional edge cases ported from original tests ===
|
||||||
{
|
{
|
||||||
@@ -906,17 +914,17 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
messages: []api.Message{{Role: "user", Content: "Test"}},
|
messages: []api.Message{{Role: "user", Content: "Test"}},
|
||||||
tools: modeTool(),
|
tools: modeTool(),
|
||||||
expected: "<bos><|turn>system\n" + modeDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + modeDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nTest<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nTest<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unicode_content",
|
name: "unicode_content",
|
||||||
messages: []api.Message{{Role: "user", Content: "こんにちは"}},
|
messages: []api.Message{{Role: "user", Content: "こんにちは"}},
|
||||||
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>user\nこんにちは<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "newlines_in_content",
|
name: "newlines_in_content",
|
||||||
messages: []api.Message{{Role: "user", Content: "Line 1\nLine 2\nLine 3"}},
|
messages: []api.Message{{Role: "user", Content: "Line 1\nLine 2\nLine 3"}},
|
||||||
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>user\nLine 1\nLine 2\nLine 3<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Tool response (raw JSON) followed by user message
|
// Tool response (raw JSON) followed by user message
|
||||||
@@ -935,7 +943,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
"<|turn>model\n<|tool_call>call:get_weather{city:" + q + "Tokyo" + q + "}<tool_call|>" +
|
"<|turn>model\n<|tool_call>call:get_weather{city:" + q + "Tokyo" + q + "}<tool_call|>" +
|
||||||
"<|tool_response>response:get_weather{value:" + q + `{"temperature": 15, "weather": "sunny"}` + q + "}<tool_response|>" +
|
"<|tool_response>response:get_weather{value:" + q + `{"temperature": 15, "weather": "sunny"}` + q + "}<tool_response|>" +
|
||||||
"<|turn>user\nThanks!<turn|>\n" +
|
"<|turn>user\nThanks!<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
// === Ordering and whitespace edge cases ===
|
// === Ordering and whitespace edge cases ===
|
||||||
{
|
{
|
||||||
@@ -958,7 +966,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
// User content with whitespace is trimmed
|
// User content with whitespace is trimmed
|
||||||
name: "user_content_trimmed",
|
name: "user_content_trimmed",
|
||||||
messages: []api.Message{{Role: "user", Content: " hello "}},
|
messages: []api.Message{{Role: "user", Content: " hello "}},
|
||||||
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
expected: "<bos><|turn>user\nhello<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Empty tool call arguments
|
// Empty tool call arguments
|
||||||
@@ -982,7 +990,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
messages: []api.Message{{Role: "user", Content: "Create"}},
|
messages: []api.Message{{Role: "user", Content: "Create"}},
|
||||||
tools: nestedTool(),
|
tools: nestedTool(),
|
||||||
expected: "<bos><|turn>system\n" + nestedDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + nestedDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Array type in tool declaration
|
// Array type in tool declaration
|
||||||
@@ -990,7 +998,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
messages: []api.Message{{Role: "user", Content: "Batch"}},
|
messages: []api.Message{{Role: "user", Content: "Batch"}},
|
||||||
tools: arrayTool(),
|
tools: arrayTool(),
|
||||||
expected: "<bos><|turn>system\n" + arrayDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + arrayDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nBatch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nBatch<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Top-level typed union follows the template's odd stringified-list form.
|
// Top-level typed union follows the template's odd stringified-list form.
|
||||||
@@ -1002,8 +1010,7 @@ func TestGemma4RendererMatchesReference(t *testing.T) {
|
|||||||
<|turn>user
|
<|turn>user
|
||||||
Hi<turn|>
|
Hi<turn|>
|
||||||
<|turn>model
|
<|turn>model
|
||||||
<|channel>thought
|
`,
|
||||||
<channel|>`,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Assistant whitespace is trimmed (strip_thinking includes | trim)
|
// Assistant whitespace is trimmed (strip_thinking includes | trim)
|
||||||
@@ -1016,7 +1023,7 @@ Hi<turn|>
|
|||||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||||
"<|turn>model\nspaced<turn|>\n" +
|
"<|turn>model\nspaced<turn|>\n" +
|
||||||
"<|turn>user\nMore<turn|>\n" +
|
"<|turn>user\nMore<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Three sequential tool responses
|
// Three sequential tool responses
|
||||||
@@ -1071,7 +1078,7 @@ Hi<turn|>
|
|||||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||||
"<|turn>model\nMiddleDone<turn|>\n" +
|
"<|turn>model\nMiddleDone<turn|>\n" +
|
||||||
"<|turn>user\nMore<turn|>\n" +
|
"<|turn>user\nMore<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Property with no description — just type
|
// Property with no description — just type
|
||||||
@@ -1079,7 +1086,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Count"}},
|
messages: []api.Message{{Role: "user", Content: "Count"}},
|
||||||
tools: countTool(),
|
tools: countTool(),
|
||||||
expected: "<bos><|turn>system\n" + countDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + countDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nCount<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nCount<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// System message with leading/trailing whitespace is trimmed
|
// System message with leading/trailing whitespace is trimmed
|
||||||
@@ -1089,7 +1096,7 @@ Hi<turn|>
|
|||||||
{Role: "user", Content: "Hi"},
|
{Role: "user", Content: "Hi"},
|
||||||
},
|
},
|
||||||
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n" +
|
expected: "<bos><|turn>system\nYou are helpful.<turn|>\n" +
|
||||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Deeply nested map in tool call arguments (3 levels)
|
// Deeply nested map in tool call arguments (3 levels)
|
||||||
@@ -1151,7 +1158,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Set"}},
|
messages: []api.Message{{Role: "user", Content: "Set"}},
|
||||||
tools: enumNoDescTool(),
|
tools: enumNoDescTool(),
|
||||||
expected: "<bos><|turn>system\n" + enumNoDescDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + enumNoDescDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nSet<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nSet<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// System message that is only whitespace (trims to empty)
|
// System message that is only whitespace (trims to empty)
|
||||||
@@ -1161,7 +1168,7 @@ Hi<turn|>
|
|||||||
{Role: "user", Content: "Hi"},
|
{Role: "user", Content: "Hi"},
|
||||||
},
|
},
|
||||||
expected: "<bos><|turn>system\n<turn|>\n" +
|
expected: "<bos><|turn>system\n<turn|>\n" +
|
||||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Empty assistant content (empty string, not nil)
|
// Empty assistant content (empty string, not nil)
|
||||||
@@ -1174,7 +1181,7 @@ Hi<turn|>
|
|||||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||||
"<|turn>model\n<turn|>\n" +
|
"<|turn>model\n<turn|>\n" +
|
||||||
"<|turn>user\nMore<turn|>\n" +
|
"<|turn>user\nMore<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Map argument with string keys (keys NOT escaped with <|"|>)
|
// Map argument with string keys (keys NOT escaped with <|"|>)
|
||||||
@@ -1200,7 +1207,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Search"}},
|
messages: []api.Message{{Role: "user", Content: "Search"}},
|
||||||
tools: searchTool(),
|
tools: searchTool(),
|
||||||
expected: "<bos><|turn>system\n" + searchDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + searchDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nSearch<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nSearch<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
|
|
||||||
// === Round 3 coverage gaps ===
|
// === Round 3 coverage gaps ===
|
||||||
@@ -1228,7 +1235,7 @@ Hi<turn|>
|
|||||||
{Role: "user", Content: "Hi"},
|
{Role: "user", Content: "Hi"},
|
||||||
},
|
},
|
||||||
expected: "<bos><|turn>system\n<turn|>\n" +
|
expected: "<bos><|turn>system\n<turn|>\n" +
|
||||||
"<|turn>user\nHi<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nHi<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Nested OBJECT property with required field
|
// Nested OBJECT property with required field
|
||||||
@@ -1236,7 +1243,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Create"}},
|
messages: []api.Message{{Role: "user", Content: "Create"}},
|
||||||
tools: nestedRequiredTool(),
|
tools: nestedRequiredTool(),
|
||||||
expected: "<bos><|turn>system\n" + nestedRequiredDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + nestedRequiredDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nCreate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nCreate<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Non-integer float in tool call argument
|
// Non-integer float in tool call argument
|
||||||
@@ -1263,7 +1270,7 @@ Hi<turn|>
|
|||||||
},
|
},
|
||||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||||
"<|turn>model\nResult<turn|>\n" +
|
"<|turn>model\nResult<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Tool content with newlines and leading/trailing whitespace trimmed
|
// Tool content with newlines and leading/trailing whitespace trimmed
|
||||||
@@ -1287,7 +1294,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Raw"}},
|
messages: []api.Message{{Role: "user", Content: "Raw"}},
|
||||||
tools: rawTool(),
|
tools: rawTool(),
|
||||||
expected: "<bos><|turn>system\n" + rawDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + rawDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nRaw<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nRaw<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Multiple required fields at top level
|
// Multiple required fields at top level
|
||||||
@@ -1295,7 +1302,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Move"}},
|
messages: []api.Message{{Role: "user", Content: "Move"}},
|
||||||
tools: moveTool(),
|
tools: moveTool(),
|
||||||
expected: "<bos><|turn>system\n" + moveDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + moveDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nMove<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nMove<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Assistant content that is ONLY thinking (strips to empty)
|
// Assistant content that is ONLY thinking (strips to empty)
|
||||||
@@ -1308,7 +1315,7 @@ Hi<turn|>
|
|||||||
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
expected: "<bos><|turn>user\nHi<turn|>\n" +
|
||||||
"<|turn>model\n<turn|>\n" +
|
"<|turn>model\n<turn|>\n" +
|
||||||
"<|turn>user\nMore<turn|>\n" +
|
"<|turn>user\nMore<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
|
|
||||||
// === Round 4: final coverage gaps ===
|
// === Round 4: final coverage gaps ===
|
||||||
@@ -1341,7 +1348,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Tag"}},
|
messages: []api.Message{{Role: "user", Content: "Tag"}},
|
||||||
tools: arrayNoItemsTool(),
|
tools: arrayNoItemsTool(),
|
||||||
expected: "<bos><|turn>system\n" + arrayNoItemsDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + arrayNoItemsDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nTag<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nTag<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// OBJECT property without description but with nested properties
|
// OBJECT property without description but with nested properties
|
||||||
@@ -1349,7 +1356,7 @@ Hi<turn|>
|
|||||||
messages: []api.Message{{Role: "user", Content: "Update"}},
|
messages: []api.Message{{Role: "user", Content: "Update"}},
|
||||||
tools: objectNoDescTool(),
|
tools: objectNoDescTool(),
|
||||||
expected: "<bos><|turn>system\n" + objectNoDescDeclRef + "<turn|>\n" +
|
expected: "<bos><|turn>system\n" + objectNoDescDeclRef + "<turn|>\n" +
|
||||||
"<|turn>user\nUpdate<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>user\nUpdate<turn|>\n<|turn>model\n",
|
||||||
},
|
},
|
||||||
|
|
||||||
// === Round 5: coding agent patterns ===
|
// === Round 5: coding agent patterns ===
|
||||||
@@ -1379,7 +1386,7 @@ Hi<turn|>
|
|||||||
"<|tool_response>response:bash{value:" + q + q + "}<tool_response|>" +
|
"<|tool_response>response:bash{value:" + q + q + "}<tool_response|>" +
|
||||||
"Done.<turn|>\n" +
|
"Done.<turn|>\n" +
|
||||||
"<|turn>user\nThanks<turn|>\n" +
|
"<|turn>user\nThanks<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Tool call with thinking that strips to real remaining content
|
// Tool call with thinking that strips to real remaining content
|
||||||
@@ -1399,7 +1406,7 @@ Hi<turn|>
|
|||||||
"<|tool_response>response:bash{value:" + q + "main.go\ngo.mod" + q + "}<tool_response|>" +
|
"<|tool_response>response:bash{value:" + q + "main.go\ngo.mod" + q + "}<tool_response|>" +
|
||||||
"Let me list the files.<turn|>\n" +
|
"Let me list the files.<turn|>\n" +
|
||||||
"<|turn>user\nOK<turn|>\n" +
|
"<|turn>user\nOK<turn|>\n" +
|
||||||
"<|turn>model\n<|channel>thought\n<channel|>",
|
"<|turn>model\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Argument value containing newlines (multi-line script)
|
// Argument value containing newlines (multi-line script)
|
||||||
@@ -1635,7 +1642,6 @@ func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
messages []api.Message
|
messages []api.Message
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
think *api.ThinkValue
|
|
||||||
wantJinjaFrag string
|
wantJinjaFrag string
|
||||||
wantRenderFrag string
|
wantRenderFrag string
|
||||||
}{
|
}{
|
||||||
@@ -1684,22 +1690,15 @@ func TestGemma4RendererKnownJinja2Differences(t *testing.T) {
|
|||||||
wantJinjaFrag: `response:read{value:<|"|>payload<|"|>}`,
|
wantJinjaFrag: `response:read{value:<|"|>payload<|"|>}`,
|
||||||
wantRenderFrag: `response:unknown{value:<|"|>payload<|"|>}`,
|
wantRenderFrag: `response:unknown{value:<|"|>payload<|"|>}`,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "explicit_nothink_skips_empty_thought_channel",
|
|
||||||
messages: []api.Message{{Role: "user", Content: "Hi"}},
|
|
||||||
think: thinkFalse(),
|
|
||||||
wantJinjaFrag: "<|turn>model\n<|channel>thought\n<channel|>",
|
|
||||||
wantRenderFrag: "<|turn>model\n",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
got, err := renderer.Render(tt.messages, tt.tools, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
|
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, nil)
|
||||||
assert.NotEqual(t, jinja2Output, got, "case no longer differs from Jinja2 output")
|
assert.NotEqual(t, jinja2Output, got, "case no longer differs from Jinja2 output")
|
||||||
assert.Contains(t, jinja2Output, tt.wantJinjaFrag)
|
assert.Contains(t, jinja2Output, tt.wantJinjaFrag)
|
||||||
assert.Contains(t, got, tt.wantRenderFrag)
|
assert.Contains(t, got, tt.wantRenderFrag)
|
||||||
@@ -1735,12 +1734,35 @@ func TestGemma4RendererToolResponseWithoutNameOrIDUsesUnknown(t *testing.T) {
|
|||||||
assert.NotContains(t, got, `response:read{value:<|"|>payload<|"|>}`)
|
assert.NotContains(t, got, `response:read{value:<|"|>payload<|"|>}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGemma4SizeTemplateFixturesDifferAtGenerationPrompt(t *testing.T) {
|
||||||
|
e2b, err := os.ReadFile(gemma4E2BTemplate)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read %s: %v", gemma4E2BTemplate, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
thirtyOneB, err := os.ReadFile(gemma431BTemplate)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read %s: %v", gemma431BTemplate, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, string(e2b), "{{- '<|turn>model\\n' -}}")
|
||||||
|
assert.NotContains(t, string(e2b), "{{- '<|channel>thought\\n<channel|>' -}}")
|
||||||
|
assert.Contains(t, string(thirtyOneB), "{{- '<|turn>model\\n' -}}")
|
||||||
|
assert.Contains(t, string(thirtyOneB), "{{- '<|channel>thought\\n<channel|>' -}}")
|
||||||
|
}
|
||||||
|
|
||||||
// renderWithJinja2 shells out to uv + Python to render messages through the
|
// renderWithJinja2 shells out to uv + Python to render messages through the
|
||||||
// Jinja2 chat template. Returns the rendered string.
|
// E2B Jinja2 chat template. Returns the rendered string.
|
||||||
func renderWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
func renderWithJinja2(t *testing.T, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||||
|
return renderWithJinja2Template(t, gemma4E2BTemplate, messages, tools, think)
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderWithJinja2Template shells out to uv + Python to render messages through
|
||||||
|
// the named Jinja2 chat template. Returns the rendered string.
|
||||||
|
func renderWithJinja2Template(t *testing.T, templateRelPath string, messages []api.Message, tools []api.Tool, think *api.ThinkValue) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
templatePath, err := filepath.Abs("testdata/gemma4_chat_template.jinja2")
|
templatePath, err := filepath.Abs(templateRelPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to get template path: %v", err)
|
t.Fatalf("failed to get template path: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
344
model/renderers/testdata/gemma4_e2b_chat_template.jinja2
vendored
Normal file
344
model/renderers/testdata/gemma4_e2b_chat_template.jinja2
vendored
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
{%- macro format_parameters(properties, required) -%}
|
||||||
|
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in properties | dictsort -%}
|
||||||
|
{%- set add_comma = false -%}
|
||||||
|
{%- if key not in standard_keys -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{{ key }}:{
|
||||||
|
{%- if value['description'] -%}
|
||||||
|
description:<|"|>{{ value['description'] }}<|"|>
|
||||||
|
{%- set add_comma = true -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if value['type'] | upper == 'STRING' -%}
|
||||||
|
{%- if value['enum'] -%}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
enum:{{ format_argument(value['enum']) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||||
|
{%- if value['items'] is mapping and value['items'] -%}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
items:{
|
||||||
|
{%- set ns_items = namespace(found_first=false) -%}
|
||||||
|
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||||
|
{%- if item_value is not none -%}
|
||||||
|
{%- if ns_items.found_first %},{% endif -%}
|
||||||
|
{%- set ns_items.found_first = true -%}
|
||||||
|
{%- if item_key == 'properties' -%}
|
||||||
|
properties:{
|
||||||
|
{%- if item_value is mapping -%}
|
||||||
|
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- elif item_key == 'required' -%}
|
||||||
|
required:[
|
||||||
|
{%- for req_item in item_value -%}
|
||||||
|
<|"|>{{- req_item -}}<|"|>
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
]
|
||||||
|
{%- elif item_key == 'type' -%}
|
||||||
|
{%- if item_value is string -%}
|
||||||
|
type:{{ format_argument(item_value | upper) }}
|
||||||
|
{%- else -%}
|
||||||
|
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ item_key }}:{{ format_argument(item_value) }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if value['nullable'] %}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
nullable:true
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if value['type'] | upper == 'OBJECT' -%}
|
||||||
|
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
properties:{
|
||||||
|
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||||
|
}
|
||||||
|
{%- elif value is mapping -%}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
properties:{
|
||||||
|
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||||
|
}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if value['required'] -%}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
required:[
|
||||||
|
{%- for item in value['required'] | default([]) -%}
|
||||||
|
<|"|>{{- item -}}<|"|>
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
]
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||||
|
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{%- macro format_function_declaration(tool_data) -%}
|
||||||
|
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||||
|
{%- set params = tool_data['function']['parameters'] -%}
|
||||||
|
{%- if params -%}
|
||||||
|
,parameters:{
|
||||||
|
{%- if params['properties'] -%}
|
||||||
|
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if params['required'] -%}
|
||||||
|
required:[
|
||||||
|
{%- for item in params['required'] -%}
|
||||||
|
<|"|>{{- item -}}<|"|>
|
||||||
|
{{- ',' if not loop.last -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
],
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if params['type'] -%}
|
||||||
|
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if 'response' in tool_data['function'] -%}
|
||||||
|
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||||
|
,response:{
|
||||||
|
{%- if response_declaration['description'] -%}
|
||||||
|
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||||
|
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||||
|
{%- if argument is string -%}
|
||||||
|
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||||
|
{%- elif argument is boolean -%}
|
||||||
|
{{- 'true' if argument else 'false' -}}
|
||||||
|
{%- elif argument is mapping -%}
|
||||||
|
{{- '{' -}}
|
||||||
|
{%- set ns = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in argument | dictsort -%}
|
||||||
|
{%- if ns.found_first %},{% endif -%}
|
||||||
|
{%- set ns.found_first = true -%}
|
||||||
|
{%- if escape_keys -%}
|
||||||
|
{{- '<|"|>' + key + '<|"|>' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- key -}}
|
||||||
|
{%- endif -%}
|
||||||
|
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}' -}}
|
||||||
|
{%- elif argument is sequence -%}
|
||||||
|
{{- '[' -}}
|
||||||
|
{%- for item in argument -%}
|
||||||
|
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- ']' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- argument -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endmacro -%}
|
||||||
|
{%- macro strip_thinking(text) -%}
|
||||||
|
{%- set ns = namespace(result='') -%}
|
||||||
|
{%- for part in text.split('<channel|>') -%}
|
||||||
|
{%- if '<|channel>' in part -%}
|
||||||
|
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set ns.result = ns.result + part -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- ns.result | trim -}}
|
||||||
|
{%- endmacro -%}
|
||||||
|
|
||||||
|
{%- macro format_tool_response_block(tool_name, response) -%}
|
||||||
|
{{- '<|tool_response>' -}}
|
||||||
|
{%- if response is mapping -%}
|
||||||
|
{{- 'response:' + tool_name + '{' -}}
|
||||||
|
{%- for key, value in response | dictsort -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- if not loop.last %},{% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- '}' -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '<tool_response|>' -}}
|
||||||
|
{%- endmacro -%}
|
||||||
|
|
||||||
|
{%- set ns = namespace(prev_message_type=None) -%}
|
||||||
|
{%- set loop_messages = messages -%}
|
||||||
|
{{- bos_token -}}
|
||||||
|
{#- Handle System/Tool Definitions Block -#}
|
||||||
|
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||||
|
{{- '<|turn>system\n' -}}
|
||||||
|
|
||||||
|
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||||
|
{%- if enable_thinking is defined and enable_thinking -%}
|
||||||
|
{{- '<|think|>\n' -}}
|
||||||
|
{%- set ns.prev_message_type = 'think' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||||
|
{{- messages[0]['content'] | trim -}}
|
||||||
|
{%- set loop_messages = messages[1:] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if tools -%}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- '<|tool>' -}}
|
||||||
|
{{- format_function_declaration(tool) | trim -}}
|
||||||
|
{{- '<tool|>' -}}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- set ns.prev_message_type = 'tool' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{{- '<turn|>\n' -}}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{#- Pre-scan: find last user message index for reasoning guard -#}
|
||||||
|
{%- set ns_turn = namespace(last_user_idx=-1) -%}
|
||||||
|
{%- for i in range(loop_messages | length) -%}
|
||||||
|
{%- if loop_messages[i]['role'] == 'user' -%}
|
||||||
|
{%- set ns_turn.last_user_idx = i -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{#- Loop through messages -#}
|
||||||
|
{%- for message in loop_messages -%}
|
||||||
|
{%- if message['role'] != 'tool' -%}
|
||||||
|
{%- set ns.prev_message_type = None -%}
|
||||||
|
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||||
|
{#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}
|
||||||
|
{%- set prev_nt = namespace(role=None, found=false) -%}
|
||||||
|
{%- if loop.index0 > 0 -%}
|
||||||
|
{%- for j in range(loop.index0 - 1, -1, -1) -%}
|
||||||
|
{%- if not prev_nt.found -%}
|
||||||
|
{%- if loop_messages[j]['role'] != 'tool' -%}
|
||||||
|
{%- set prev_nt.role = loop_messages[j]['role'] -%}
|
||||||
|
{%- set prev_nt.found = true -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}
|
||||||
|
{%- if not continue_same_model_turn -%}
|
||||||
|
{{- '<|turn>' + role + '\n' }}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{#- Render reasoning/reasoning_content as thinking channel -#}
|
||||||
|
{%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}
|
||||||
|
{%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}
|
||||||
|
{{- '<|channel>thought\n' + thinking_text + '\n<channel|>' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if message['tool_calls'] -%}
|
||||||
|
{%- for tool_call in message['tool_calls'] -%}
|
||||||
|
{%- set function = tool_call['function'] -%}
|
||||||
|
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||||
|
{%- if function['arguments'] is mapping -%}
|
||||||
|
{%- set ns_args = namespace(found_first=false) -%}
|
||||||
|
{%- for key, value in function['arguments'] | dictsort -%}
|
||||||
|
{%- if ns_args.found_first %},{% endif -%}
|
||||||
|
{%- set ns_args.found_first = true -%}
|
||||||
|
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- elif function['arguments'] is string -%}
|
||||||
|
{{- function['arguments'] -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{{- '}<tool_call|>' -}}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- set ns_tr_out = namespace(flag=false) -%}
|
||||||
|
{%- if message.get('tool_responses') -%}
|
||||||
|
{#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}
|
||||||
|
{%- for tool_response in message['tool_responses'] -%}
|
||||||
|
{{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}
|
||||||
|
{%- set ns_tr_out.flag = true -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- elif message.get('tool_calls') -%}
|
||||||
|
{#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}
|
||||||
|
{%- set ns_tool_scan = namespace(stopped=false) -%}
|
||||||
|
{%- for k in range(loop.index0 + 1, loop_messages | length) -%}
|
||||||
|
{%- if ns_tool_scan.stopped -%}
|
||||||
|
{%- elif loop_messages[k]['role'] != 'tool' -%}
|
||||||
|
{%- set ns_tool_scan.stopped = true -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set follow = loop_messages[k] -%}
|
||||||
|
{#- Resolve tool_call_id to function name -#}
|
||||||
|
{%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}
|
||||||
|
{%- for tc in message['tool_calls'] -%}
|
||||||
|
{%- if tc.get('id') == follow.get('tool_call_id') -%}
|
||||||
|
{%- set ns_tname.name = tc['function']['name'] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{#- Handle content as string or content-parts array -#}
|
||||||
|
{%- set tool_body = follow.get('content') -%}
|
||||||
|
{%- if tool_body is string -%}
|
||||||
|
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||||
|
{%- elif tool_body is sequence and tool_body is not string -%}
|
||||||
|
{%- set ns_txt = namespace(s='') -%}
|
||||||
|
{%- for part in tool_body -%}
|
||||||
|
{%- if part.get('type') == 'text' -%}
|
||||||
|
{%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- format_tool_response_block(ns_tname.name, tool_body) -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set ns_tr_out.flag = true -%}
|
||||||
|
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if message['content'] is string -%}
|
||||||
|
{%- if role == 'model' -%}
|
||||||
|
{{- strip_thinking(message['content']) -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- message['content'] | trim -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif message['content'] is sequence -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item['type'] == 'text' -%}
|
||||||
|
{%- if role == 'model' -%}
|
||||||
|
{{- strip_thinking(item['text']) -}}
|
||||||
|
{%- else -%}
|
||||||
|
{{- item['text'] | trim -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- elif item['type'] == 'image' -%}
|
||||||
|
{{- '<|image|>' -}}
|
||||||
|
{%- set ns.prev_message_type = 'image' -%}
|
||||||
|
{%- elif item['type'] == 'audio' -%}
|
||||||
|
{{- '<|audio|>' -}}
|
||||||
|
{%- set ns.prev_message_type = 'audio' -%}
|
||||||
|
{%- elif item['type'] == 'video' -%}
|
||||||
|
{{- '<|video|>' -}}
|
||||||
|
{%- set ns.prev_message_type = 'video' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
|
||||||
|
{%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%}
|
||||||
|
{{- '<|tool_response>' -}}
|
||||||
|
{%- elif not (ns_tr_out.flag and not message.get('content')) -%}
|
||||||
|
{{- '<turn|>\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%}
|
||||||
|
{{- '<|turn>model\n' -}}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endif -%}
|
||||||
@@ -191,6 +191,10 @@ func inferSafetensorsCapabilities(modelDir string) []string {
|
|||||||
capabilities = append(capabilities, "vision")
|
capabilities = append(capabilities, "vision")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if supportsAudio(modelDir) {
|
||||||
|
capabilities = append(capabilities, "audio")
|
||||||
|
}
|
||||||
|
|
||||||
if supportsThinking(modelDir) {
|
if supportsThinking(modelDir) {
|
||||||
capabilities = append(capabilities, "thinking")
|
capabilities = append(capabilities, "thinking")
|
||||||
}
|
}
|
||||||
@@ -496,32 +500,38 @@ func supportsThinking(modelDir string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// supportsVision checks if the model supports image input based on its architecture.
|
// supportsVision checks if the model has a vision encoder by looking for
|
||||||
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
|
// vision_config in config.json.
|
||||||
func supportsVision(modelDir string) bool {
|
func supportsVision(modelDir string) bool {
|
||||||
configPath := filepath.Join(modelDir, "config.json")
|
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg struct {
|
var cfg struct {
|
||||||
Architectures []string `json:"architectures"`
|
VisionConfig *map[string]any `json:"vision_config"`
|
||||||
ModelType string `json:"model_type"`
|
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, arch := range cfg.Architectures {
|
return cfg.VisionConfig != nil
|
||||||
archLower := strings.ToLower(arch)
|
}
|
||||||
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
|
|
||||||
return true
|
func supportsAudio(modelDir string) bool {
|
||||||
}
|
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
var cfg struct {
|
||||||
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
|
AudioConfig *map[string]any `json:"audio_config"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg.AudioConfig != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getParserName returns the parser name for a model based on its architecture.
|
// getParserName returns the parser name for a model based on its architecture.
|
||||||
@@ -550,6 +560,9 @@ func getParserName(modelDir string) string {
|
|||||||
if strings.Contains(archLower, "deepseek") {
|
if strings.Contains(archLower, "deepseek") {
|
||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
}
|
||||||
|
if strings.Contains(archLower, "gemma4") {
|
||||||
|
return "gemma4"
|
||||||
|
}
|
||||||
if strings.Contains(archLower, "qwen3") {
|
if strings.Contains(archLower, "qwen3") {
|
||||||
return "qwen3"
|
return "qwen3"
|
||||||
}
|
}
|
||||||
@@ -564,6 +577,9 @@ func getParserName(modelDir string) string {
|
|||||||
if strings.Contains(typeLower, "deepseek") {
|
if strings.Contains(typeLower, "deepseek") {
|
||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
}
|
||||||
|
if strings.Contains(typeLower, "gemma4") {
|
||||||
|
return "gemma4"
|
||||||
|
}
|
||||||
if strings.Contains(typeLower, "qwen3") {
|
if strings.Contains(typeLower, "qwen3") {
|
||||||
return "qwen3"
|
return "qwen3"
|
||||||
}
|
}
|
||||||
@@ -592,6 +608,9 @@ func getRendererName(modelDir string) string {
|
|||||||
// Check architectures for known renderers
|
// Check architectures for known renderers
|
||||||
for _, arch := range cfg.Architectures {
|
for _, arch := range cfg.Architectures {
|
||||||
archLower := strings.ToLower(arch)
|
archLower := strings.ToLower(arch)
|
||||||
|
if strings.Contains(archLower, "gemma4") {
|
||||||
|
return "gemma4"
|
||||||
|
}
|
||||||
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
|
||||||
return "glm-4.7"
|
return "glm-4.7"
|
||||||
}
|
}
|
||||||
@@ -606,6 +625,9 @@ func getRendererName(modelDir string) string {
|
|||||||
// Also check model_type
|
// Also check model_type
|
||||||
if cfg.ModelType != "" {
|
if cfg.ModelType != "" {
|
||||||
typeLower := strings.ToLower(cfg.ModelType)
|
typeLower := strings.ToLower(cfg.ModelType)
|
||||||
|
if strings.Contains(typeLower, "gemma4") {
|
||||||
|
return "gemma4"
|
||||||
|
}
|
||||||
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
|
||||||
return "glm-4.7"
|
return "glm-4.7"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -311,10 +311,30 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
|||||||
name: "qwen3.5 multimodal model",
|
name: "qwen3.5 multimodal model",
|
||||||
configJSON: `{
|
configJSON: `{
|
||||||
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
"architectures": ["Qwen3_5ForConditionalGeneration"],
|
||||||
"model_type": "qwen3"
|
"model_type": "qwen3",
|
||||||
|
"vision_config": {"hidden_size": 1024}
|
||||||
}`,
|
}`,
|
||||||
want: []string{"completion", "vision", "thinking"},
|
want: []string{"completion", "vision", "thinking"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "model with audio config",
|
||||||
|
configJSON: `{
|
||||||
|
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||||
|
"model_type": "gemma4",
|
||||||
|
"vision_config": {"hidden_size": 1024},
|
||||||
|
"audio_config": {"num_mel_bins": 128}
|
||||||
|
}`,
|
||||||
|
want: []string{"completion", "vision", "audio"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with audio but no vision",
|
||||||
|
configJSON: `{
|
||||||
|
"architectures": ["SomeAudioModel"],
|
||||||
|
"model_type": "other",
|
||||||
|
"audio_config": {"num_mel_bins": 128}
|
||||||
|
}`,
|
||||||
|
want: []string{"completion", "audio"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "non-qwen conditional generation model",
|
name: "non-qwen conditional generation model",
|
||||||
configJSON: `{
|
configJSON: `{
|
||||||
@@ -339,6 +359,74 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParsePerExpertInputs(t *testing.T) {
|
||||||
|
makeInput := func(name, quantize string) create.PackedTensorInput {
|
||||||
|
return create.PackedTensorInput{Name: name, Quantize: quantize}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("uniform quant across projections", func(t *testing.T) {
|
||||||
|
inputs := []create.PackedTensorInput{
|
||||||
|
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||||
|
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||||
|
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||||
|
makeInput("layer.moe.experts.1.down_proj.weight", "int4"),
|
||||||
|
}
|
||||||
|
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||||
|
if groups == nil {
|
||||||
|
t.Fatal("expected non-nil groups")
|
||||||
|
}
|
||||||
|
if len(groups) != 2 {
|
||||||
|
t.Fatalf("expected 2 projection groups, got %d", len(groups))
|
||||||
|
}
|
||||||
|
if projQ["gate_proj.weight"] != "int4" {
|
||||||
|
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||||
|
}
|
||||||
|
if projQ["down_proj.weight"] != "int4" {
|
||||||
|
t.Errorf("down_proj quant = %q, want int4", projQ["down_proj.weight"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mixed quant across projections", func(t *testing.T) {
|
||||||
|
inputs := []create.PackedTensorInput{
|
||||||
|
makeInput("layer.moe.experts.0.gate_proj.weight", "int4"),
|
||||||
|
makeInput("layer.moe.experts.1.gate_proj.weight", "int4"),
|
||||||
|
makeInput("layer.moe.experts.0.down_proj.weight", "int8"),
|
||||||
|
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||||
|
}
|
||||||
|
groups, projQ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||||
|
if groups == nil {
|
||||||
|
t.Fatal("expected non-nil groups for mixed cross-projection quant")
|
||||||
|
}
|
||||||
|
if projQ["gate_proj.weight"] != "int4" {
|
||||||
|
t.Errorf("gate_proj quant = %q, want int4", projQ["gate_proj.weight"])
|
||||||
|
}
|
||||||
|
if projQ["down_proj.weight"] != "int8" {
|
||||||
|
t.Errorf("down_proj quant = %q, want int8", projQ["down_proj.weight"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mixed quant within same projection rejected", func(t *testing.T) {
|
||||||
|
inputs := []create.PackedTensorInput{
|
||||||
|
makeInput("layer.moe.experts.0.down_proj.weight", "int4"),
|
||||||
|
makeInput("layer.moe.experts.1.down_proj.weight", "int8"),
|
||||||
|
}
|
||||||
|
groups, _ := parsePerExpertInputs("layer.moe.experts", inputs)
|
||||||
|
if groups != nil {
|
||||||
|
t.Fatal("expected nil for mixed quant within same projection")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-experts group rejected", func(t *testing.T) {
|
||||||
|
inputs := []create.PackedTensorInput{
|
||||||
|
makeInput("layer.mlp.gate_proj.weight", "int4"),
|
||||||
|
}
|
||||||
|
groups, _ := parsePerExpertInputs("layer.mlp", inputs)
|
||||||
|
if groups != nil {
|
||||||
|
t.Fatal("expected nil for non-experts group")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestQuantizeSupported(t *testing.T) {
|
func TestQuantizeSupported(t *testing.T) {
|
||||||
// This just verifies the function exists and returns a boolean
|
// This just verifies the function exists and returns a boolean
|
||||||
// The actual value depends on build tags (mlx vs non-mlx)
|
// The actual value depends on build tags (mlx vs non-mlx)
|
||||||
|
|||||||
@@ -97,6 +97,20 @@ func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]
|
|||||||
groupSize, bits, mode := model.QuantizationParams(quantize)
|
groupSize, bits, mode := model.QuantizationParams(quantize)
|
||||||
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
qweight, scales, qbiases := mlx.Quantize(arr, groupSize, bits, mode)
|
||||||
|
|
||||||
|
// Validate quantization produced non-empty output. MLX quantize may return
|
||||||
|
// empty arrays for unsupported mode/bits combinations without raising an error.
|
||||||
|
mlx.Eval(qweight, scales)
|
||||||
|
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||||
|
st.Free()
|
||||||
|
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||||
|
name, quantize, groupSize, bits, mode)
|
||||||
|
}
|
||||||
|
if len(scales.Dims()) == 0 || scales.Dims()[0] == 0 {
|
||||||
|
st.Free()
|
||||||
|
return tmpPath, nil, nil, fmt.Errorf("mlx.Quantize produced empty scales for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||||
|
name, quantize, groupSize, bits, mode)
|
||||||
|
}
|
||||||
|
|
||||||
qweight = mlx.Contiguous(qweight, false)
|
qweight = mlx.Contiguous(qweight, false)
|
||||||
scales = mlx.Contiguous(scales, false)
|
scales = mlx.Contiguous(scales, false)
|
||||||
arrays[name] = qweight
|
arrays[name] = qweight
|
||||||
@@ -174,8 +188,8 @@ func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quanti
|
|||||||
// Returns the blob bytes.
|
// Returns the blob bytes.
|
||||||
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([]byte, error) {
|
||||||
// Check if inputs are per-expert tensors that should be stacked into 3D
|
// Check if inputs are per-expert tensors that should be stacked into 3D
|
||||||
if projGroups, quantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
if projGroups, projQuantize := parsePerExpertInputs(groupName, inputs); projGroups != nil {
|
||||||
return stackAndQuantizeExpertGroup(groupName, projGroups, quantize)
|
return stackAndQuantizeExpertGroup(groupName, projGroups, projQuantize)
|
||||||
}
|
}
|
||||||
|
|
||||||
allArrays := make(map[string]*mlx.Array)
|
allArrays := make(map[string]*mlx.Array)
|
||||||
@@ -224,6 +238,17 @@ func quantizePackedGroup(groupName string, inputs []create.PackedTensorInput) ([
|
|||||||
mlx.Pin(finalArrays...)
|
mlx.Pin(finalArrays...)
|
||||||
pinned = append(pinned, finalArrays...)
|
pinned = append(pinned, finalArrays...)
|
||||||
|
|
||||||
|
// Record per-tensor quant type so the model can resolve params at load time.
|
||||||
|
if input.Quantize != "" {
|
||||||
|
if groupSize, _, _ := model.QuantizationParams(input.Quantize); groupSize > 0 {
|
||||||
|
if metadata == nil {
|
||||||
|
metadata = make(map[string]string)
|
||||||
|
}
|
||||||
|
metadata[input.Name+".quant_type"] = input.Quantize
|
||||||
|
metadata[input.Name+".group_size"] = strconv.Itoa(groupSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if st != nil {
|
if st != nil {
|
||||||
st.Free()
|
st.Free()
|
||||||
}
|
}
|
||||||
@@ -279,57 +304,60 @@ type expertTensorInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
// parsePerExpertInputs groups per-expert 2D tensor inputs by projection type
|
||||||
// and returns the uniform quantization type shared by all inputs.
|
// and returns per-projection quantization types. Different projections may use
|
||||||
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D)
|
// different quant types (e.g., gate_up=int4, down=int8) but all experts within
|
||||||
// or if the inputs have mixed quantization types.
|
// a projection must share the same type.
|
||||||
|
// Returns nil if the inputs are not per-expert tensors (e.g., already stacked 3D).
|
||||||
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
// Only handles ".experts" groups; ".shared_experts" groups are left unpacked.
|
||||||
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, string) {
|
func parsePerExpertInputs(groupName string, inputs []create.PackedTensorInput) (map[string][]expertTensorInfo, map[string]string) {
|
||||||
if !strings.HasSuffix(groupName, ".experts") {
|
if !strings.HasSuffix(groupName, ".experts") {
|
||||||
return nil, ""
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
quantize := inputs[0].Quantize
|
|
||||||
groups := make(map[string][]expertTensorInfo)
|
groups := make(map[string][]expertTensorInfo)
|
||||||
|
projQuantize := make(map[string]string) // projection -> quant type
|
||||||
for _, input := range inputs {
|
for _, input := range inputs {
|
||||||
if input.Quantize != quantize {
|
|
||||||
return nil, "" // mixed quantization types
|
|
||||||
}
|
|
||||||
suffix := strings.TrimPrefix(input.Name, groupName)
|
suffix := strings.TrimPrefix(input.Name, groupName)
|
||||||
m := perExpertSuffix.FindStringSubmatch(suffix)
|
m := perExpertSuffix.FindStringSubmatch(suffix)
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil, "" // not a per-expert pattern
|
return nil, nil // not a per-expert pattern
|
||||||
}
|
}
|
||||||
index, err := strconv.Atoi(m[1])
|
index, err := strconv.Atoi(m[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ""
|
return nil, nil
|
||||||
}
|
}
|
||||||
groups[m[2]] = append(groups[m[2]], expertTensorInfo{
|
proj := m[2]
|
||||||
|
if existing, ok := projQuantize[proj]; ok {
|
||||||
|
if input.Quantize != existing {
|
||||||
|
return nil, nil // mixed quant within same projection
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
projQuantize[proj] = input.Quantize
|
||||||
|
}
|
||||||
|
groups[proj] = append(groups[proj], expertTensorInfo{
|
||||||
index: index,
|
index: index,
|
||||||
proj: m[2],
|
proj: proj,
|
||||||
input: input,
|
input: input,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if len(groups) == 0 {
|
if len(groups) == 0 {
|
||||||
return nil, ""
|
return nil, nil
|
||||||
}
|
}
|
||||||
return groups, quantize
|
return groups, projQuantize
|
||||||
}
|
}
|
||||||
|
|
||||||
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
// stackAndQuantizeExpertGroup decodes per-expert tensors, stacks them into 3D
|
||||||
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
// switch_mlp tensors, quantizes, and returns the combined safetensors blob.
|
||||||
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, quantize string) ([]byte, error) {
|
// projQuantize maps projection name to its quantization type (may differ per projection).
|
||||||
|
func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]expertTensorInfo, projQuantize map[string]string) ([]byte, error) {
|
||||||
groupBase := strings.TrimSuffix(groupName, ".experts")
|
groupBase := strings.TrimSuffix(groupName, ".experts")
|
||||||
|
|
||||||
allArrays := make(map[string]*mlx.Array)
|
allArrays := make(map[string]*mlx.Array)
|
||||||
var pinned []*mlx.Array
|
var pinned []*mlx.Array
|
||||||
|
|
||||||
var metadata map[string]string
|
// Build metadata: if all projections use the same quant type, set global metadata.
|
||||||
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 && quantize != "" {
|
// Otherwise record per-tensor quant info.
|
||||||
metadata = map[string]string{
|
metadata := make(map[string]string)
|
||||||
"quant_type": quantize,
|
|
||||||
"group_size": strconv.Itoa(groupSize),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort projection names for deterministic output
|
// Sort projection names for deterministic output
|
||||||
projNames := make([]string, 0, len(projGroups))
|
projNames := make([]string, 0, len(projGroups))
|
||||||
@@ -339,7 +367,11 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
|||||||
sort.Strings(projNames)
|
sort.Strings(projNames)
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
mlx.Unpin(pinned...)
|
for _, p := range pinned {
|
||||||
|
if p != nil {
|
||||||
|
mlx.Unpin(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -382,11 +414,27 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
|||||||
mlx.Pin(stacked)
|
mlx.Pin(stacked)
|
||||||
pinned = append(pinned, stacked)
|
pinned = append(pinned, stacked)
|
||||||
|
|
||||||
// Free individual decoded arrays
|
// Free individual decoded arrays (remove from pinned to avoid double-unpin in cleanup)
|
||||||
|
for i, p := range pinned {
|
||||||
|
for _, d := range decoded {
|
||||||
|
if p == d {
|
||||||
|
pinned[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
mlx.Unpin(decoded...)
|
mlx.Unpin(decoded...)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
|
|
||||||
stackedName := groupBase + ".switch_mlp." + proj
|
stackedName := groupBase + ".switch_mlp." + proj
|
||||||
|
quantize := projQuantize[proj]
|
||||||
|
|
||||||
|
// Record per-tensor quant metadata so the model can resolve params at load time.
|
||||||
|
if quantize != "" {
|
||||||
|
if groupSize, _, _ := model.QuantizationParams(quantize); groupSize > 0 {
|
||||||
|
metadata[stackedName+".quant_type"] = quantize
|
||||||
|
metadata[stackedName+".group_size"] = strconv.Itoa(groupSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Quantize the stacked tensor
|
// Quantize the stacked tensor
|
||||||
if quantize != "" {
|
if quantize != "" {
|
||||||
@@ -394,6 +442,14 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
|||||||
|
|
||||||
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
qweight, scales, qbiases := mlx.Quantize(stacked, groupSize, bits, mode)
|
||||||
|
|
||||||
|
// Validate quantization produced non-empty output.
|
||||||
|
mlx.Eval(qweight, scales)
|
||||||
|
if len(qweight.Dims()) == 0 || qweight.Dims()[0] == 0 {
|
||||||
|
cleanup()
|
||||||
|
return nil, fmt.Errorf("mlx.Quantize produced empty weight for %s (quantize=%s, groupSize=%d, bits=%d, mode=%s)",
|
||||||
|
stackedName, quantize, groupSize, bits, mode)
|
||||||
|
}
|
||||||
|
|
||||||
qweight = mlx.Contiguous(qweight, false)
|
qweight = mlx.Contiguous(qweight, false)
|
||||||
scales = mlx.Contiguous(scales, false)
|
scales = mlx.Contiguous(scales, false)
|
||||||
allArrays[stackedName] = qweight
|
allArrays[stackedName] = qweight
|
||||||
@@ -409,12 +465,19 @@ func stackAndQuantizeExpertGroup(groupName string, projGroups map[string][]exper
|
|||||||
mlx.Pin(toEval...)
|
mlx.Pin(toEval...)
|
||||||
pinned = append(pinned, toEval...)
|
pinned = append(pinned, toEval...)
|
||||||
|
|
||||||
// Free stacked source array
|
// Free stacked source array (remove from pinned to avoid double-unpin in cleanup)
|
||||||
|
for i, p := range pinned {
|
||||||
|
if p == stacked {
|
||||||
|
pinned[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
mlx.Unpin(stacked)
|
mlx.Unpin(stacked)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
} else {
|
} else {
|
||||||
stacked = mlx.Contiguous(stacked, false)
|
stacked = mlx.Contiguous(stacked, false)
|
||||||
mlx.Eval(stacked)
|
mlx.Eval(stacked)
|
||||||
|
mlx.Pin(stacked)
|
||||||
|
pinned = append(pinned, stacked)
|
||||||
allArrays[stackedName] = stacked
|
allArrays[stackedName] = stacked
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,7 +592,7 @@ func decodeSourceFP8Tensor(weight, scaleInv *mlx.Array) (*mlx.Array, error) {
|
|||||||
padBottom := blockRows*scaleShape[0] - rows
|
padBottom := blockRows*scaleShape[0] - rows
|
||||||
padSide := blockCols*scaleShape[1] - cols
|
padSide := blockCols*scaleShape[1] - cols
|
||||||
if padBottom > 0 || padSide > 0 {
|
if padBottom > 0 || padSide > 0 {
|
||||||
decoded = mlx.Pad(decoded, []int32{0, int32(padBottom), 0, int32(padSide)})
|
decoded = mlx.PadConstant(decoded, []int{0, 1}, []int{0, 0}, []int{padBottom, padSide})
|
||||||
}
|
}
|
||||||
|
|
||||||
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
decoded = mlx.Reshape(decoded, int32(scaleShape[0]), int32(blockRows), int32(scaleShape[1]), int32(blockCols))
|
||||||
|
|||||||
@@ -246,6 +246,11 @@ func ShouldQuantize(name, component string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip audio encoder tensors (highly sensitive to quantization)
|
||||||
|
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Skip embeddings
|
// Skip embeddings
|
||||||
if strings.Contains(name, "embed") {
|
if strings.Contains(name, "embed") {
|
||||||
return false
|
return false
|
||||||
@@ -291,6 +296,22 @@ func normalizeQuantType(quantize string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isAligned checks if a tensor's last dimension is divisible by the
|
||||||
|
// group size required for the given quantization type.
|
||||||
|
func isAligned(shape []int32, quantType string) bool {
|
||||||
|
if len(shape) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
groupSize := int32(32)
|
||||||
|
switch normalizeQuantType(quantType) {
|
||||||
|
case "nvfp4":
|
||||||
|
groupSize = 16
|
||||||
|
case "int4", "int8":
|
||||||
|
groupSize = 64
|
||||||
|
}
|
||||||
|
return shape[len(shape)-1]%groupSize == 0
|
||||||
|
}
|
||||||
|
|
||||||
func isStackedExpertWeight(name string) bool {
|
func isStackedExpertWeight(name string) bool {
|
||||||
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
// Combined/stacked expert tensors may be emitted either as "...proj.weight" (per-expert)
|
||||||
// or "...proj" (pre-stacked packed tensor).
|
// or "...proj" (pre-stacked packed tensor).
|
||||||
@@ -300,16 +321,16 @@ func isStackedExpertWeight(name string) bool {
|
|||||||
|
|
||||||
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
return strings.Contains(name, ".mlp.switch_mlp.") ||
|
||||||
strings.Contains(name, ".mlp.experts.") ||
|
strings.Contains(name, ".mlp.experts.") ||
|
||||||
strings.Contains(name, ".mlp.shared_experts.")
|
strings.Contains(name, ".mlp.shared_experts.") ||
|
||||||
|
strings.Contains(name, ".moe.experts.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
// GetTensorQuantization returns the appropriate quantization type for a tensor.
|
||||||
// Returns "" if the tensor should not be quantized.
|
// Returns "" if the tensor should not be quantized.
|
||||||
// This implements mixed-precision quantization:
|
// This implements mixed-precision quantization:
|
||||||
// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
|
// - v_proj, k_proj, down_proj: promoted to INT8 when base is INT4
|
||||||
// - Output projection, gate/up weights: int4 (less sensitive)
|
|
||||||
// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
|
|
||||||
// - Norms, embeddings, biases, routing gates: no quantization
|
// - Norms, embeddings, biases, routing gates: no quantization
|
||||||
|
// - All other eligible weights: use requested quantization type
|
||||||
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
||||||
stackedExpert := isStackedExpertWeight(name)
|
stackedExpert := isStackedExpertWeight(name)
|
||||||
|
|
||||||
@@ -336,60 +357,35 @@ func GetTensorQuantization(name string, shape []int32, quantize string) string {
|
|||||||
// Normalize quantization type to canonical form
|
// Normalize quantization type to canonical form
|
||||||
quantNorm := normalizeQuantType(quantize)
|
quantNorm := normalizeQuantType(quantize)
|
||||||
|
|
||||||
// MLX quantization requires last dimension to be divisible by group size
|
|
||||||
// nvfp4: 16, mxfp4/mxfp8: 32, int4/int8: 64
|
|
||||||
groupSize := int32(32)
|
|
||||||
switch quantNorm {
|
|
||||||
case "nvfp4":
|
|
||||||
groupSize = 16
|
|
||||||
case "int4", "int8":
|
|
||||||
groupSize = 64
|
|
||||||
}
|
|
||||||
if shape[len(shape)-1]%groupSize != 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip routing gate weights (should stay high precision)
|
// Skip routing gate weights (should stay high precision)
|
||||||
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
// In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
|
||||||
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MLX quantization requires last dimension to be divisible by group size.
|
||||||
|
if !isAligned(shape, quantNorm) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// For non-affine modes, use the same quantization for all eligible tensors.
|
// For non-affine modes, use the same quantization for all eligible tensors.
|
||||||
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
if quantNorm == "nvfp4" || quantNorm == "mxfp4" || quantNorm == "mxfp8" {
|
||||||
return quantNorm
|
return quantNorm
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attention MLA weights - keep unquantized (bf16)
|
// Value projection weights directly determine attention output quality.
|
||||||
// These are highly sensitive: errors accumulate in the KV cache over time
|
// Down projection weights feed directly into the residual stream where
|
||||||
// q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
|
// errors accumulate across layers. Both benefit from higher precision.
|
||||||
if strings.Contains(name, "q_a_proj") ||
|
// Promote to INT8 when base is INT4 (same affine mode, compatible with
|
||||||
strings.Contains(name, "q_b_proj") ||
|
// GatherQMM for MoE expert tensors).
|
||||||
strings.Contains(name, "kv_a_proj") ||
|
if quantNorm == "int4" {
|
||||||
strings.Contains(name, "kv_b_proj") {
|
if strings.Contains(name, ".v_proj") || strings.Contains(name, ".k_proj") || strings.Contains(name, "down_proj") {
|
||||||
return "" // No quantization - keep bf16
|
if isAligned(shape, "int8") {
|
||||||
|
return "int8"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
|
|
||||||
// mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
|
|
||||||
if strings.Contains(name, "down_proj") {
|
|
||||||
return "int8"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output projection, gate/up weights - use requested quantization (INT4)
|
|
||||||
// o_proj, gate_proj, up_proj
|
|
||||||
if strings.Contains(name, "o_proj") ||
|
|
||||||
strings.Contains(name, "gate_proj") ||
|
|
||||||
strings.Contains(name, "up_proj") {
|
|
||||||
return quantNorm
|
|
||||||
}
|
|
||||||
|
|
||||||
// LM head - use requested quantization
|
|
||||||
if strings.Contains(name, "lm_head") {
|
|
||||||
return quantNorm
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to requested quantization for other weights
|
|
||||||
return quantNorm
|
return quantNorm
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -411,6 +407,7 @@ func ExpertGroupPrefix(tensorName string) string {
|
|||||||
".mlp.experts.",
|
".mlp.experts.",
|
||||||
".mlp.shared_experts.",
|
".mlp.shared_experts.",
|
||||||
".mlp.switch_mlp.",
|
".mlp.switch_mlp.",
|
||||||
|
".moe.experts.",
|
||||||
} {
|
} {
|
||||||
idx := strings.Index(tensorName, marker)
|
idx := strings.Index(tensorName, marker)
|
||||||
if idx == -1 {
|
if idx == -1 {
|
||||||
@@ -637,6 +634,8 @@ var tensorImportTransformRegistry = map[string]tensorImportTransformFactory{
|
|||||||
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
|
"Qwen3_5MoeForConditionalGeneration": newQwen35ImportTransform,
|
||||||
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
|
"Qwen3NextMoeForCausalLM": newQwen35ImportTransform,
|
||||||
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
|
"Qwen3NextMoeForConditionalGeneration": newQwen35ImportTransform,
|
||||||
|
"Gemma4ForCausalLM": newGemma4ImportTransform,
|
||||||
|
"Gemma4ForConditionalGeneration": newGemma4ImportTransform,
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
func newTensorImportTransform(modelDir string, cfg sourceModelConfig) (tensorImportTransform, error) {
|
||||||
|
|||||||
@@ -1169,6 +1169,11 @@ func TestShouldQuantize(t *testing.T) {
|
|||||||
{"ln prefix", "ln_1.weight", "", false},
|
{"ln prefix", "ln_1.weight", "", false},
|
||||||
{"layernorm in name", "input_layernorm.weight", "", false},
|
{"layernorm in name", "input_layernorm.weight", "", false},
|
||||||
|
|
||||||
|
// Audio encoder tensors should not be quantized
|
||||||
|
{"audio tower weight", "model.audio_tower.layers.0.weight", "", false},
|
||||||
|
{"audio tower norm", "model.audio_tower.norm.weight", "", false},
|
||||||
|
{"embed audio weight", "embed_audio.weight", "", false},
|
||||||
|
|
||||||
// Biases should not be quantized
|
// Biases should not be quantized
|
||||||
{"bias tensor", "attention.bias", "", false},
|
{"bias tensor", "attention.bias", "", false},
|
||||||
{"proj bias", "o_proj.bias", "", false},
|
{"proj bias", "o_proj.bias", "", false},
|
||||||
@@ -1262,6 +1267,11 @@ func TestExpertGroupPrefix(t *testing.T) {
|
|||||||
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
{"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
|
||||||
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
{"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
|
||||||
|
|
||||||
|
// MoE expert tensors (Gemma-style .moe.experts.)
|
||||||
|
{"model.layers.0.moe.experts.0.gate_proj.weight", "model.layers.0.moe.experts"},
|
||||||
|
{"model.layers.1.moe.experts.42.down_proj.weight", "model.layers.1.moe.experts"},
|
||||||
|
{"language_model.model.layers.2.moe.experts.127.up_proj.weight", "language_model.model.layers.2.moe.experts"},
|
||||||
|
|
||||||
// Expert tensors with language_model prefix should also match
|
// Expert tensors with language_model prefix should also match
|
||||||
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
{"language_model.model.layers.0.mlp.experts.0.gate_proj.weight", "language_model.model.layers.0.mlp.experts"},
|
||||||
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
{"language_model.model.layers.1.mlp.experts.255.down_proj.weight", "language_model.model.layers.1.mlp.experts"},
|
||||||
@@ -1369,6 +1379,94 @@ func TestGetTensorQuantization_StackedExpert3D(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsAligned(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
shape []int32
|
||||||
|
quantType string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// int4/int8: group_size=64
|
||||||
|
{"int4 aligned", []int32{1024, 4096}, "int4", true},
|
||||||
|
{"int4 unaligned", []int32{1024, 48}, "int4", false},
|
||||||
|
{"int8 aligned", []int32{1024, 128}, "int8", true},
|
||||||
|
{"int8 unaligned", []int32{1024, 32}, "int8", false},
|
||||||
|
|
||||||
|
// nvfp4: group_size=16
|
||||||
|
{"nvfp4 aligned", []int32{1024, 48}, "nvfp4", true},
|
||||||
|
{"nvfp4 unaligned", []int32{1024, 24}, "nvfp4", false},
|
||||||
|
{"nvfp4 aligned 16", []int32{1024, 16}, "nvfp4", true},
|
||||||
|
|
||||||
|
// mxfp4/mxfp8: group_size=32
|
||||||
|
{"mxfp4 aligned", []int32{1024, 64}, "mxfp4", true},
|
||||||
|
{"mxfp4 unaligned", []int32{1024, 48}, "mxfp4", false},
|
||||||
|
{"mxfp8 aligned", []int32{1024, 32}, "mxfp8", true},
|
||||||
|
{"mxfp8 unaligned", []int32{1024, 24}, "mxfp8", false},
|
||||||
|
|
||||||
|
// Edge cases
|
||||||
|
{"empty shape", []int32{}, "int4", false},
|
||||||
|
{"1D tensor", []int32{4096}, "int4", true},
|
||||||
|
{"3D stacked expert", []int32{128, 4096, 2816}, "int4", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := isAligned(tt.shape, tt.quantType)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isAligned(%v, %q) = %v, want %v", tt.shape, tt.quantType, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTensorQuantization_MixedPrecisionPromotion(t *testing.T) {
|
||||||
|
aligned := []int32{4096, 4096} // divisible by 64
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tensor string
|
||||||
|
shape []int32
|
||||||
|
quantize string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// int4 → int8 promotion for sensitive tensors
|
||||||
|
{"v_proj int4 promoted", "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||||
|
{"k_proj int4 promoted", "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||||
|
{"down_proj int4 promoted", "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||||
|
|
||||||
|
// Non-sensitive int4 tensors stay int4
|
||||||
|
{"q_proj int4 stays", "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"o_proj int4 stays", "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"gate_proj int4 stays", "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"up_proj int4 stays", "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||||
|
|
||||||
|
// nvfp4/mxfp4/mxfp8: no promotion (uniform quantization)
|
||||||
|
{"v_proj nvfp4 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
|
{"down_proj mxfp4 uniform", "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
{"v_proj mxfp8 uniform", "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||||
|
|
||||||
|
// int8: already 8-bit, no promotion
|
||||||
|
{"v_proj int8 stays", "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||||
|
|
||||||
|
// Expert tensors: down_proj also promoted for int4
|
||||||
|
{"expert down_proj int4", "model.layers.0.mlp.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||||
|
{"moe expert down_proj int4", "model.layers.0.moe.experts.down_proj.weight", []int32{128, 4096, 2816}, "int4", "int8"},
|
||||||
|
|
||||||
|
// Unaligned: falls back to bf16 (empty string)
|
||||||
|
{"v_proj int4 unaligned", "model.layers.0.self_attn.v_proj.weight", []int32{1024, 48}, "int4", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetTensorQuantization(tt.tensor, tt.shape, tt.quantize)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("GetTensorQuantization(%q, %v, %q) = %q, want %q",
|
||||||
|
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
func TestCreateSafetensorsModel_Qwen35NVFP4PacksSwitchMLPExperts(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
|||||||
264
x/create/gemma4.go
Normal file
264
x/create/gemma4.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package create
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gemma4ImportTransform struct {
|
||||||
|
numLayers int
|
||||||
|
numExperts int
|
||||||
|
}
|
||||||
|
|
||||||
|
// gemma4Config is a minimal subset of the Gemma 4 config.json used for quant decisions.
|
||||||
|
type gemma4Config struct {
|
||||||
|
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||||
|
NumExperts int `json:"num_experts"`
|
||||||
|
TextConfig struct {
|
||||||
|
NumHiddenLayers int `json:"num_hidden_layers"`
|
||||||
|
NumExperts int `json:"num_experts"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGemma4ImportTransform(modelDir string, _ sourceModelConfig) (tensorImportTransform, error) {
|
||||||
|
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||||
|
}
|
||||||
|
var cfg gemma4Config
|
||||||
|
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return gemma4ImportTransform{}, nil //nolint:nilerr // fallback to no heuristic
|
||||||
|
}
|
||||||
|
|
||||||
|
numLayers := cfg.NumHiddenLayers
|
||||||
|
if numLayers == 0 {
|
||||||
|
numLayers = cfg.TextConfig.NumHiddenLayers
|
||||||
|
}
|
||||||
|
numExperts := cfg.NumExperts
|
||||||
|
if numExperts == 0 {
|
||||||
|
numExperts = cfg.TextConfig.NumExperts
|
||||||
|
}
|
||||||
|
|
||||||
|
return gemma4ImportTransform{numLayers: numLayers, numExperts: numExperts}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t gemma4ImportTransform) skipTensor(name string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// layerIndexRe extracts the layer index from tensor names like
|
||||||
|
// "model.language_model.layers.5.self_attn.v_proj.weight" or
|
||||||
|
// "model.language_model.layers.5.moe.experts.42.down_proj.weight"
|
||||||
|
var layerIndexRe = regexp.MustCompile(`\.layers\.(\d+)\.`)
|
||||||
|
|
||||||
|
// useMoreBits returns true for layers where quantization-sensitive tensors
|
||||||
|
// should use higher precision: the first and last 1/8 of layers (which handle
|
||||||
|
// input grounding and final output refinement), plus every 3rd layer in between
|
||||||
|
// to limit error accumulation through the residual stream.
|
||||||
|
func useMoreBits(layerIdx, numLayers int) bool {
|
||||||
|
return layerIdx < numLayers/8 ||
|
||||||
|
layerIdx >= 7*numLayers/8 ||
|
||||||
|
(layerIdx-numLayers/8)%3 == 2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quantize string) string {
|
||||||
|
quantNorm := normalizeQuantType(quantize)
|
||||||
|
|
||||||
|
// Embedding: quantize to 8-bit variant for bandwidth efficiency.
|
||||||
|
// The embedding serves double duty: lookup (via QuantizedEmbedding) and
|
||||||
|
// lm_head projection (via AsLinear). Using 8-bit matches GGUF Q6_K quality
|
||||||
|
// (strictly higher at 8 bpw vs 6.5 bpw) while saving ~2.8 GB on 31B vs bf16.
|
||||||
|
if isEmbedTokensWeight(name) {
|
||||||
|
switch quantNorm {
|
||||||
|
case "int4", "int8":
|
||||||
|
if isAligned(shape, "int8") {
|
||||||
|
return "int8"
|
||||||
|
}
|
||||||
|
case "mxfp4", "nvfp4", "mxfp8":
|
||||||
|
if isAligned(shape, "mxfp8") {
|
||||||
|
return "mxfp8"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isAligned(shape, quantNorm) {
|
||||||
|
return quantNorm
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mixed-precision quantization: sensitive tensors get higher precision.
|
||||||
|
//
|
||||||
|
// Value projections (v_proj) directly determine attention output quality.
|
||||||
|
// Down projections (down_proj) are the final MLP output and errors there
|
||||||
|
// propagate directly to the residual stream. Both benefit from higher
|
||||||
|
// precision at early layers, late layers, and periodically in between
|
||||||
|
// (the "useMoreBits" heuristic).
|
||||||
|
//
|
||||||
|
// For int4: promote → int8 (same affine family, GatherQMM compatible).
|
||||||
|
// For mxfp4/nvfp4: promote → mxfp8. MLX quantized_matmul handles mixed
|
||||||
|
// nvfp4+mxfp8 modes within the same model — each tensor carries its own
|
||||||
|
// quant metadata and the kernel dispatches per-tensor.
|
||||||
|
if t.numLayers > 0 {
|
||||||
|
layerIdx := -1
|
||||||
|
if m := layerIndexRe.FindStringSubmatch(name); m != nil {
|
||||||
|
if idx, err := strconv.Atoi(m[1]); err == nil {
|
||||||
|
layerIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine promotion target for sensitive tensors.
|
||||||
|
// "int8" = int4 base → int8 (affine family)
|
||||||
|
// "mxfp8" = mxfp4/nvfp4 base → mxfp8
|
||||||
|
// "" = no promotion (int8/mxfp8, already 8-bit)
|
||||||
|
promote := ""
|
||||||
|
switch quantNorm {
|
||||||
|
case "int4":
|
||||||
|
promote = "int8"
|
||||||
|
case "mxfp4", "nvfp4":
|
||||||
|
promote = "mxfp8"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only apply to language model tensors — audio/vision tower tensors
|
||||||
|
// should pass through to GetTensorQuantization which skips them.
|
||||||
|
isModelTensor := !strings.Contains(name, "audio_tower") &&
|
||||||
|
!strings.Contains(name, "vision_tower")
|
||||||
|
isSensitive := isModelTensor &&
|
||||||
|
(strings.Contains(name, ".v_proj") || strings.Contains(name, "down_proj"))
|
||||||
|
isSensitiveK := isModelTensor && strings.Contains(name, "k_proj")
|
||||||
|
|
||||||
|
if promote != "" && (isSensitive || isSensitiveK) {
|
||||||
|
shouldPromote := false
|
||||||
|
|
||||||
|
// 8-expert models: v_proj and k_proj share very few KV heads,
|
||||||
|
// so quantization errors are amplified. Always promote.
|
||||||
|
if t.numExperts == 8 && (strings.Contains(name, ".v_proj") || isSensitiveK) {
|
||||||
|
shouldPromote = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer-position heuristic for v_proj and down_proj.
|
||||||
|
if isSensitive && layerIdx >= 0 && useMoreBits(layerIdx, t.numLayers) {
|
||||||
|
shouldPromote = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldPromote && isAligned(shape, promote) {
|
||||||
|
return promote
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sensitive tensor at a non-promoted layer: use base quant type.
|
||||||
|
// Return directly to bypass GetTensorQuantization's uniform
|
||||||
|
// promotion — the layer-position heuristic is authoritative here.
|
||||||
|
if !isAligned(shape, quantNorm) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return quantNorm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return GetTensorQuantization(name, shape, quantize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEmbedTokensWeight returns true for the main token embedding weight.
|
||||||
|
func isEmbedTokensWeight(name string) bool {
|
||||||
|
return strings.HasSuffix(name, "embed_tokens.weight") &&
|
||||||
|
!strings.Contains(name, "per_layer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||||
|
if td == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split pre-stacked MoE expert tensors [N, out, in] into per-expert
|
||||||
|
// [out, in] tensors so they go through the standard expert packing and
|
||||||
|
// quantization flow (ExpertGroupPrefix matching, per-expert quantize).
|
||||||
|
if isGemma4StackedMoETensor(td.Name, td.Shape) {
|
||||||
|
return splitStackedMoETensor(td)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*safetensors.TensorData{td}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isGemma4StackedMoETensor checks if this is a pre-stacked MoE expert weight.
|
||||||
|
// Gemma 4 HF weights come in two layouts depending on the model version:
|
||||||
|
// - Older: model.language_model.layers.N.moe.{gate,up,down}_proj [experts, dim1, dim2]
|
||||||
|
// - Newer: model.language_model.layers.N.experts.{gate_up,down}_proj [experts, dim1, dim2]
|
||||||
|
//
|
||||||
|
// The newer layout has gate+up already fused. We keep it fused (no splitting)
|
||||||
|
// so the tensors flow through the standard expert packing and quantization path.
|
||||||
|
func isGemma4StackedMoETensor(name string, shape []int32) bool {
|
||||||
|
if len(shape) != 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(name, ".moe.") || strings.Contains(name, ".experts.") {
|
||||||
|
return strings.HasSuffix(name, "_proj") || strings.HasSuffix(name, "_proj.weight")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitStackedMoETensor splits a [N, out, in] stacked expert tensor into
|
||||||
|
// N individual [out, in] tensors named with the per-expert convention that
|
||||||
|
// ExpertGroupPrefix expects: prefix.moe.experts.{E}.{proj}.weight
|
||||||
|
func splitStackedMoETensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||||
|
raw, err := io.ReadAll(td.Reader())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read tensor %s: %w", td.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
numExperts := int(td.Shape[0])
|
||||||
|
rows := int(td.Shape[1]) // out_features in HF layout
|
||||||
|
cols := int(td.Shape[2]) // in_features in HF layout
|
||||||
|
|
||||||
|
elemSize, err := DTypeSize(td.Dtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get dtype size for %s: %w", td.Dtype, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
perExpertBytes := rows * cols * elemSize
|
||||||
|
if len(raw) != numExperts*perExpertBytes {
|
||||||
|
return nil, fmt.Errorf("tensor %s: raw byte length %d does not match shape %v and dtype %s",
|
||||||
|
td.Name, len(raw), td.Shape, td.Dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the per-expert name pattern.
|
||||||
|
// Two source layouts:
|
||||||
|
// Old: model.language_model.layers.N.moe.gate_proj
|
||||||
|
// -> model.language_model.layers.N.moe.experts.E.gate_proj.weight
|
||||||
|
// New: model.language_model.layers.N.experts.gate_up_proj
|
||||||
|
// -> model.language_model.layers.N.moe.experts.E.gate_up_proj.weight
|
||||||
|
baseName := td.Name
|
||||||
|
baseName = strings.TrimSuffix(baseName, ".weight")
|
||||||
|
lastDot := strings.LastIndex(baseName, ".")
|
||||||
|
if lastDot < 0 {
|
||||||
|
return nil, fmt.Errorf("tensor %s: unexpected name format", td.Name)
|
||||||
|
}
|
||||||
|
parentPrefix := baseName[:lastDot] // "...layers.N.moe" or "...layers.N.experts"
|
||||||
|
projName := baseName[lastDot+1:] // "gate_proj" or "gate_up_proj"
|
||||||
|
|
||||||
|
// Normalize: if parent already ends with ".experts", use the grandparent + ".moe"
|
||||||
|
// so we get a consistent "layers.N.moe.experts.E" pattern.
|
||||||
|
var moePrefix string
|
||||||
|
if cut, ok := strings.CutSuffix(parentPrefix, ".experts"); ok {
|
||||||
|
moePrefix = cut + ".moe"
|
||||||
|
} else {
|
||||||
|
moePrefix = parentPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
transposedShape := []int32{td.Shape[1], td.Shape[2]}
|
||||||
|
|
||||||
|
results := make([]*safetensors.TensorData, numExperts)
|
||||||
|
for e := range numExperts {
|
||||||
|
expertName := fmt.Sprintf("%s.experts.%d.%s.weight", moePrefix, e, projName)
|
||||||
|
start := e * perExpertBytes
|
||||||
|
end := start + perExpertBytes
|
||||||
|
results[e] = safetensors.NewTensorDataFromBytes(expertName, td.Dtype, transposedShape, raw[start:end])
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
191
x/create/gemma4_test.go
Normal file
191
x/create/gemma4_test.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package create
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGemma4QuantizationType(t *testing.T) {
|
||||||
|
// 26B MoE: 30 layers, 128 experts
|
||||||
|
transform26B := gemma4ImportTransform{numLayers: 30, numExperts: 128}
|
||||||
|
// 8-expert model (hypothetical)
|
||||||
|
transform8E := gemma4ImportTransform{numLayers: 30, numExperts: 8}
|
||||||
|
|
||||||
|
aligned := []int32{2816, 2816} // divisible by 64 (int4/int8 group size) and 16 (nvfp4)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
transform gemma4ImportTransform
|
||||||
|
tensor string
|
||||||
|
shape []int32
|
||||||
|
quantize string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// === embed_tokens: quantize to 8-bit variant (serves as both embed and lm_head) ===
|
||||||
|
{"embed_tokens int4", transform26B, "model.embed_tokens.weight", aligned, "int4", "int8"},
|
||||||
|
{"embed_tokens nvfp4", transform26B, "model.embed_tokens.weight", aligned, "nvfp4", "mxfp8"},
|
||||||
|
{"embed_tokens mxfp4", transform26B, "model.embed_tokens.weight", aligned, "mxfp4", "mxfp8"},
|
||||||
|
{"embed_tokens int8", transform26B, "model.embed_tokens.weight", aligned, "int8", "int8"},
|
||||||
|
{"embed_tokens mxfp8", transform26B, "model.embed_tokens.weight", aligned, "mxfp8", "mxfp8"},
|
||||||
|
|
||||||
|
// === v_proj: layer-position heuristic for int4/nvfp4 ===
|
||||||
|
// Layer 0 is in first 1/8 (30/8=3) → promoted
|
||||||
|
{"v_proj int4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||||
|
// Layer 4 is NOT in useMoreBits → base quant
|
||||||
|
{"v_proj int4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "int4", "int4"},
|
||||||
|
// Layer 29 is in last 1/8 → promoted
|
||||||
|
{"v_proj int4 last layer promoted", transform26B, "model.layers.29.self_attn.v_proj.weight", aligned, "int4", "int8"},
|
||||||
|
// nvfp4: promote to mxfp8 (cross-family, validated by MLX quantized_matmul)
|
||||||
|
{"v_proj nvfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||||
|
{"v_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
|
// mxfp4: promoted to mxfp8 at promoted layers (same mxfp family)
|
||||||
|
{"v_proj mxfp4 promoted layer", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||||
|
{"v_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.self_attn.v_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
// int8/mxfp8: no promotion (already 8-bit)
|
||||||
|
{"v_proj int8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "int8", "int8"},
|
||||||
|
{"v_proj mxfp8 base", transform26B, "model.layers.0.self_attn.v_proj.weight", aligned, "mxfp8", "mxfp8"},
|
||||||
|
|
||||||
|
// === down_proj (dense MLP): same heuristic as v_proj ===
|
||||||
|
{"dense down_proj int4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "int4", "int8"},
|
||||||
|
{"dense down_proj int4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"dense down_proj nvfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||||
|
{"dense down_proj nvfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
|
{"dense down_proj mxfp4 promoted", transform26B, "model.layers.0.mlp.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||||
|
{"dense down_proj mxfp4 non-promoted", transform26B, "model.layers.4.mlp.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
|
||||||
|
// === Expert down_proj: int4→int8, nvfp4→nvfp8 at promoted layers ===
|
||||||
|
{"expert down_proj int4 promoted", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "int4", "int8"},
|
||||||
|
{"expert down_proj int4 non-promoted", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "int4", "int4"},
|
||||||
|
// nvfp4 experts: promote to mxfp8 (all experts at a layer get same treatment,
|
||||||
|
// so GatherQMM sees uniform quant per projection per layer)
|
||||||
|
{"expert down_proj nvfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||||
|
{"expert down_proj nvfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
|
// mxfp4 experts: promote to mxfp8 (same mxfp family, GatherQMM compatible)
|
||||||
|
{"expert down_proj mxfp4 promoted layer", transform26B, "model.layers.0.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||||
|
{"expert down_proj mxfp4 non-promoted layer", transform26B, "model.layers.4.moe.experts.42.down_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
|
||||||
|
// === Expert gate_up_proj: always base quant (not a sensitive tensor) ===
|
||||||
|
{"expert gate_up int4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
|
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
|
||||||
|
// === k_proj: promoted only for 8-expert models ===
|
||||||
|
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||||
|
{"k_proj 8 experts nvfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "nvfp4", "mxfp8"},
|
||||||
|
{"k_proj 8 experts mxfp4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "mxfp4", "mxfp8"},
|
||||||
|
|
||||||
|
// === q_proj, o_proj, gate_proj, up_proj: always base quant ===
|
||||||
|
{"q_proj int4", transform26B, "model.layers.0.self_attn.q_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"o_proj int4", transform26B, "model.layers.0.self_attn.o_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"gate_proj int4", transform26B, "model.layers.0.mlp.gate_proj.weight", aligned, "int4", "int4"},
|
||||||
|
{"up_proj int4", transform26B, "model.layers.0.mlp.up_proj.weight", aligned, "int4", "int4"},
|
||||||
|
|
||||||
|
// === Non-quantizable tensors: always bf16 ===
|
||||||
|
{"embed_tokens per_layer skip", transform26B, "model.embed_tokens_per_layer.weight", aligned, "int4", ""},
|
||||||
|
{"norm", transform26B, "model.layers.0.input_layernorm.weight", []int32{2816}, "int4", ""},
|
||||||
|
{"router scale", transform26B, "model.layers.0.router.scale", []int32{2816}, "int4", ""},
|
||||||
|
|
||||||
|
// === Audio/vision tower tensors: must pass through unquantized for all quant types ===
|
||||||
|
// These contain .v_proj and down_proj but should NOT be intercepted by
|
||||||
|
// the sensitive-tensor promotion logic.
|
||||||
|
{"audio norm int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int4", ""},
|
||||||
|
{"audio norm nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "nvfp4", ""},
|
||||||
|
{"audio norm int8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "int8", ""},
|
||||||
|
{"audio norm mxfp8", transform26B, "model.audio_tower.subsample_conv_projection.layer0.norm.weight", []int32{128}, "mxfp8", ""},
|
||||||
|
{"audio conv int4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "int4", ""},
|
||||||
|
{"audio conv nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.layer0.conv.weight", []int32{128, 1, 3, 3}, "nvfp4", ""},
|
||||||
|
{"audio linear int4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "int4", ""},
|
||||||
|
{"audio linear nvfp4", transform26B, "model.audio_tower.subsample_conv_projection.input_proj_linear.weight", aligned, "nvfp4", ""},
|
||||||
|
// Audio tower v_proj — must NOT be promoted despite containing .v_proj
|
||||||
|
{"audio v_proj int4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", ""},
|
||||||
|
{"audio v_proj nvfp4", transform26B, "model.audio_tower.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", ""},
|
||||||
|
// Vision tower v_proj — vision tower IS quantized (unlike audio tower),
|
||||||
|
// but not intercepted by gemma4's layer-position heuristic.
|
||||||
|
// Falls through to GetTensorQuantization which applies uniform promotion.
|
||||||
|
{"vision v_proj int4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "int4", "int8"},
|
||||||
|
{"vision v_proj nvfp4", transform26B, "model.vision_tower.encoder.layers.0.self_attn.v_proj.linear.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
|
// Audio tower down_proj
|
||||||
|
{"audio down_proj int4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "int4", ""},
|
||||||
|
{"audio down_proj nvfp4", transform26B, "model.audio_tower.layers.0.mlp.down_proj.linear.weight", aligned, "nvfp4", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tt.transform.quantizationType(tt.tensor, tt.shape, tt.quantize)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("quantizationType(%q, %v, %q) = %q, want %q",
|
||||||
|
tt.tensor, tt.shape, tt.quantize, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseMoreBits(t *testing.T) {
|
||||||
|
// 30 layers: first 1/8 = layers 0-2, last 1/8 = layers 27-29
|
||||||
|
// In between: every 3rd from offset (i - n/8) % 3 == 2
|
||||||
|
n := 30
|
||||||
|
promoted := map[int]bool{}
|
||||||
|
for i := range n {
|
||||||
|
if useMoreBits(i, n) {
|
||||||
|
promoted[i] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First 1/8 (30/8 = 3): layers 0, 1, 2
|
||||||
|
for _, i := range []int{0, 1, 2} {
|
||||||
|
if !promoted[i] {
|
||||||
|
t.Errorf("layer %d should be promoted (first 1/8)", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last 1/8: layers 26, 27, 28, 29 (>= 7*30/8 = 26)
|
||||||
|
for _, i := range []int{26, 27, 28, 29} {
|
||||||
|
if !promoted[i] {
|
||||||
|
t.Errorf("layer %d should be promoted (last 1/8)", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some middle layers should NOT be promoted
|
||||||
|
for _, i := range []int{3, 4, 6, 7} {
|
||||||
|
if promoted[i] {
|
||||||
|
t.Errorf("layer %d should NOT be promoted", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer 5 should be promoted: (5 - 3) % 3 == 2
|
||||||
|
if !promoted[5] {
|
||||||
|
t.Errorf("layer 5 should be promoted (periodic)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsGemma4StackedMoETensor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
label string
|
||||||
|
tensorName string
|
||||||
|
shape []int32
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// New-style: .experts.gate_up_proj
|
||||||
|
{"experts gate_up_proj 3D", "model.layers.0.experts.gate_up_proj", []int32{128, 1408, 2816}, true},
|
||||||
|
{"experts down_proj 3D", "model.layers.0.experts.down_proj", []int32{128, 2816, 704}, true},
|
||||||
|
// Old-style: .moe.gate_proj
|
||||||
|
{"moe gate_proj 3D", "model.layers.0.moe.gate_proj", []int32{128, 2112, 2816}, true},
|
||||||
|
{"moe down_proj 3D", "model.layers.0.moe.down_proj.weight", []int32{128, 2816, 2112}, true},
|
||||||
|
// Not stacked: 2D
|
||||||
|
{"2D weight", "model.layers.0.experts.gate_up_proj", []int32{1408, 2816}, false},
|
||||||
|
// Not expert
|
||||||
|
{"non-expert 3D", "model.layers.0.mlp.gate_proj", []int32{3, 2816, 2816}, false},
|
||||||
|
// Not a projection
|
||||||
|
{"expert non-proj", "model.layers.0.experts.scale", []int32{128, 1, 1}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.label, func(t *testing.T) {
|
||||||
|
got := isGemma4StackedMoETensor(tt.tensorName, tt.shape)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isGemma4StackedMoETensor(%q, %v) = %v, want %v",
|
||||||
|
tt.tensorName, tt.shape, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||||
|
_ "github.com/ollama/ollama/x/models/gemma4"
|
||||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||||
_ "github.com/ollama/ollama/x/models/llama"
|
_ "github.com/ollama/ollama/x/models/llama"
|
||||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||||
|
|||||||
@@ -4,16 +4,57 @@ package mlx
|
|||||||
import "C"
|
import "C"
|
||||||
import "math"
|
import "math"
|
||||||
|
|
||||||
func GELUApprox(t *Array) *Array {
|
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||||
return t.Multiply(
|
|
||||||
FromValue[float32](0.5),
|
// GELUApprox matches mlx.nn.gelu_approx:
|
||||||
).Multiply(
|
//
|
||||||
t.Add(
|
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||||
t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
|
func GELUApprox(x *Array) *Array {
|
||||||
).Multiply(
|
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
|
||||||
FromValue(float32(math.Sqrt(2 / math.Pi))),
|
half := scalarWithDtype(0.5, x)
|
||||||
).Tanh().Add(FromValue[float32](1.0)),
|
defer C.mlx_array_free(half)
|
||||||
).AsType(t.DType())
|
coeff := scalarWithDtype(geluCoeff, x)
|
||||||
|
defer C.mlx_array_free(coeff)
|
||||||
|
c := scalarWithDtype(0.044715, x)
|
||||||
|
defer C.mlx_array_free(c)
|
||||||
|
|
||||||
|
// x^3 via x*x*x (avoids general Power which is slower)
|
||||||
|
x3 := New("GELU_X3")
|
||||||
|
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
|
||||||
|
tmp := New("GELU_X3b")
|
||||||
|
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
|
||||||
|
x3 = tmp
|
||||||
|
|
||||||
|
// 0.044715 * x^3
|
||||||
|
cx3 := New("GELU_CX3")
|
||||||
|
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
|
||||||
|
|
||||||
|
// x + 0.044715 * x^3
|
||||||
|
inner := New("GELU_INNER")
|
||||||
|
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
|
||||||
|
|
||||||
|
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||||
|
scaled := New("GELU_SCALED")
|
||||||
|
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
|
||||||
|
|
||||||
|
// tanh(...)
|
||||||
|
th := New("GELU_TANH")
|
||||||
|
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
|
||||||
|
|
||||||
|
// 1 + tanh(...)
|
||||||
|
one := scalarWithDtype(1.0, x)
|
||||||
|
defer C.mlx_array_free(one)
|
||||||
|
onePlusTanh := New("GELU_1PT")
|
||||||
|
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
|
||||||
|
|
||||||
|
// 0.5 * x
|
||||||
|
halfX := New("GELU_HALFX")
|
||||||
|
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
|
||||||
|
|
||||||
|
// 0.5 * x * (1 + tanh(...))
|
||||||
|
out := New("GELU_APPROX")
|
||||||
|
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func SILU(t *Array) *Array {
|
func SILU(t *Array) *Array {
|
||||||
|
|||||||
@@ -90,3 +90,10 @@ func AsyncEval(outputs ...*Array) {
|
|||||||
func Eval(outputs ...*Array) {
|
func Eval(outputs ...*Array) {
|
||||||
doEval(outputs, false)
|
doEval(outputs, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MetalIsAvailable returns true if a Metal GPU is available.
|
||||||
|
func MetalIsAvailable() bool {
|
||||||
|
var available C._Bool
|
||||||
|
C.mlx_metal_is_available(&available)
|
||||||
|
return bool(available)
|
||||||
|
}
|
||||||
|
|||||||
@@ -149,45 +149,132 @@ func Contiguous(a *Array, allowColMajor bool) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func Pad(a *Array, paddings []int32) *Array {
|
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
|
||||||
numAxes := len(paddings) / 2
|
// MLX uses NHWC layout.
|
||||||
axes := make([]C.int, numAxes)
|
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
|
||||||
lowPad := make([]C.int, numAxes)
|
out := New("CONV2D")
|
||||||
highPad := make([]C.int, numAxes)
|
C.mlx_conv2d(
|
||||||
for i := range numAxes {
|
&out.ctx,
|
||||||
axes[i] = C.int(i)
|
x.ctx,
|
||||||
lowPad[i] = C.int(paddings[i*2])
|
weight.ctx,
|
||||||
highPad[i] = C.int(paddings[i*2+1])
|
C.int(strideH), C.int(strideW),
|
||||||
|
C.int(padH), C.int(padW),
|
||||||
|
C.int(dilationH), C.int(dilationW),
|
||||||
|
C.int(groups),
|
||||||
|
DefaultStream().ctx,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pad pads array a along the given axes with specified low/high pad sizes.
|
||||||
|
// mode should be "constant", "edge", or "reflect".
|
||||||
|
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
|
||||||
|
cAxes := make([]C.int, len(axes))
|
||||||
|
cLow := make([]C.int, len(lowPad))
|
||||||
|
cHigh := make([]C.int, len(highPad))
|
||||||
|
for i := range axes {
|
||||||
|
cAxes[i] = C.int(axes[i])
|
||||||
|
cLow[i] = C.int(lowPad[i])
|
||||||
|
cHigh[i] = C.int(highPad[i])
|
||||||
}
|
}
|
||||||
|
cMode := C.CString(mode)
|
||||||
padValue := C.mlx_array_new_float(C.float(0))
|
|
||||||
defer C.mlx_array_free(padValue)
|
|
||||||
|
|
||||||
cMode := C.CString("constant")
|
|
||||||
defer C.free(unsafe.Pointer(cMode))
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
|
|
||||||
out := New("PAD")
|
out := New("PAD")
|
||||||
C.mlx_pad(
|
C.mlx_pad(
|
||||||
&out.ctx,
|
&out.ctx,
|
||||||
a.ctx,
|
a.ctx,
|
||||||
unsafe.SliceData(axes),
|
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
|
||||||
C.size_t(len(axes)),
|
unsafe.SliceData(cLow), C.size_t(len(cLow)),
|
||||||
unsafe.SliceData(lowPad),
|
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
|
||||||
C.size_t(len(lowPad)),
|
padValue.ctx,
|
||||||
unsafe.SliceData(highPad),
|
|
||||||
C.size_t(len(highPad)),
|
|
||||||
padValue,
|
|
||||||
cMode,
|
cMode,
|
||||||
DefaultStream().ctx,
|
DefaultStream().ctx,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PadConstant pads with zeros along the given axes.
|
||||||
|
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
|
||||||
|
zero := NewScalarArray(float32(0))
|
||||||
|
return Pad(a, axes, lowPad, highPad, zero, "constant")
|
||||||
|
}
|
||||||
|
|
||||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||||
groups := int32(x.Dim(x.NumDims() - 1))
|
groups := int32(x.Dim(x.NumDims() - 1))
|
||||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Maximum returns element-wise maximum of two arrays.
|
||||||
|
func Maximum(a, b *Array) *Array {
|
||||||
|
out := New("MAXIMUM")
|
||||||
|
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Minimum returns element-wise minimum of two arrays.
|
||||||
|
func Minimum(a, b *Array) *Array {
|
||||||
|
out := New("MINIMUM")
|
||||||
|
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
|
||||||
|
func Softplus(a *Array) *Array {
|
||||||
|
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReLU computes max(0, x).
|
||||||
|
func ReLU(a *Array) *Array {
|
||||||
|
return Maximum(a, NewScalarArray(float32(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
|
||||||
|
// returns first * sigmoid(second).
|
||||||
|
func GLU(a *Array) *Array {
|
||||||
|
lastDim := a.NumDims() - 1
|
||||||
|
halfSize := a.Dim(lastDim) / 2
|
||||||
|
first := SliceStartStop(a,
|
||||||
|
make([]int32, lastDim+1), // all zeros for start
|
||||||
|
appendDims(a, lastDim, int32(halfSize)),
|
||||||
|
)
|
||||||
|
second := SliceStartStop(a,
|
||||||
|
appendDimsStart(a, lastDim, int32(halfSize)),
|
||||||
|
appendDims(a, lastDim, int32(a.Dim(lastDim))),
|
||||||
|
)
|
||||||
|
return first.Multiply(second.Sigmoid())
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper: builds stop array for SliceStartStop where the target axis = val
|
||||||
|
func appendDims(a *Array, targetAxis int, val int32) []int32 {
|
||||||
|
n := a.NumDims()
|
||||||
|
out := make([]int32, n)
|
||||||
|
for i := range n {
|
||||||
|
if i == targetAxis {
|
||||||
|
out[i] = val
|
||||||
|
} else {
|
||||||
|
out[i] = int32(a.Dim(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper: builds start array for SliceStartStop where the target axis = val
|
||||||
|
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
|
||||||
|
n := a.NumDims()
|
||||||
|
out := make([]int32, n)
|
||||||
|
for i := range n {
|
||||||
|
if i == targetAxis {
|
||||||
|
out[i] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamp clamps array values to [min, max].
|
||||||
|
func Clamp(a *Array, minVal, maxVal float32) *Array {
|
||||||
|
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
|
||||||
|
}
|
||||||
|
|
||||||
// Convenience wrappers (function-style for the model code)
|
// Convenience wrappers (function-style for the model code)
|
||||||
|
|
||||||
func Stack(arrays []*Array, axis int) *Array {
|
func Stack(arrays []*Array, axis int) *Array {
|
||||||
@@ -323,20 +410,37 @@ func SiLU(a *Array) *Array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||||
freqs := New("")
|
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||||
|
// When freqs is non-nil, it is used instead of computing from base.
|
||||||
|
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||||
|
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||||
|
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offset int, freqs *Array) *Array {
|
||||||
|
var freqsCtx C.mlx_array
|
||||||
|
var optBase C.mlx_optional_float
|
||||||
|
if freqs != nil {
|
||||||
|
freqsCtx = freqs.ctx
|
||||||
|
optBase = C.mlx_optional_float{has_value: C.bool(false)}
|
||||||
|
} else {
|
||||||
|
empty := New("")
|
||||||
|
freqsCtx = empty.ctx
|
||||||
|
optBase = C.mlx_optional_float{
|
||||||
|
value: C.float(base),
|
||||||
|
has_value: C.bool(func() bool { return base != 0 }()),
|
||||||
|
}
|
||||||
|
}
|
||||||
out := New("FAST_ROPE")
|
out := New("FAST_ROPE")
|
||||||
C.mlx_fast_rope(
|
C.mlx_fast_rope(
|
||||||
&out.ctx,
|
&out.ctx,
|
||||||
x.ctx,
|
x.ctx,
|
||||||
C.int(dims),
|
C.int(dims),
|
||||||
C.bool(traditional),
|
C.bool(traditional),
|
||||||
C.mlx_optional_float{
|
optBase,
|
||||||
value: C.float(base),
|
|
||||||
has_value: C.bool(func() bool { return base != 0 }()),
|
|
||||||
},
|
|
||||||
C.float(scale),
|
C.float(scale),
|
||||||
C.int(offset),
|
C.int(offset),
|
||||||
freqs.ctx,
|
freqsCtx,
|
||||||
DefaultStream().ctx,
|
DefaultStream().ctx,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
@@ -358,6 +462,24 @@ func Log(a *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Sin(a *Array) *Array {
|
||||||
|
out := New("SIN")
|
||||||
|
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Cos(a *Array) *Array {
|
||||||
|
out := New("COS")
|
||||||
|
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Clip(a, aMin, aMax *Array) *Array {
|
||||||
|
out := New("CLIP")
|
||||||
|
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func Logaddexp(a, b *Array) *Array {
|
func Logaddexp(a, b *Array) *Array {
|
||||||
out := New("LOGADDEXP")
|
out := New("LOGADDEXP")
|
||||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||||
@@ -385,6 +507,20 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScaledDotProductAttentionMasked runs the fast SDPA kernel with an explicit
|
||||||
|
// additive mask. The mask is broadcast to [B, H, Q, K] and added to scores
|
||||||
|
// before softmax. Pass mode="array" so MLX actually consults mask_arr; the
|
||||||
|
// empty string is "no mask" and silently ignores the array argument.
|
||||||
|
func ScaledDotProductAttentionMasked(q, k, v *Array, scale float32, mask *Array) *Array {
|
||||||
|
sinks := New("")
|
||||||
|
cMode := C.CString("array")
|
||||||
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
|
|
||||||
|
out := New("FAST_SDPA")
|
||||||
|
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||||
out := New("FAST_LAYERNORM")
|
out := New("FAST_LAYERNORM")
|
||||||
var w, b C.mlx_array
|
var w, b C.mlx_array
|
||||||
|
|||||||
@@ -131,6 +131,12 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
|||||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||||
globalQuantType = strings.ToUpper(globalQuantType)
|
globalQuantType = strings.ToUpper(globalQuantType)
|
||||||
|
|
||||||
|
// Parse full metadata for per-tensor quant info
|
||||||
|
var metaMap map[string]string
|
||||||
|
if metaRaw, ok := header["__metadata__"]; ok {
|
||||||
|
json.Unmarshal(metaRaw, &metaMap)
|
||||||
|
}
|
||||||
|
|
||||||
mainNames := mainTensorNames(header)
|
mainNames := mainTensorNames(header)
|
||||||
infos := make(map[string]*TensorQuantInfo)
|
infos := make(map[string]*TensorQuantInfo)
|
||||||
for _, name := range mainNames {
|
for _, name := range mainNames {
|
||||||
@@ -141,6 +147,18 @@ func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string,
|
|||||||
quantType := globalQuantType
|
quantType := globalQuantType
|
||||||
groupSize := globalGroupSize
|
groupSize := globalGroupSize
|
||||||
|
|
||||||
|
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
|
||||||
|
if metaMap != nil {
|
||||||
|
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
|
||||||
|
quantType = strings.ToUpper(qt)
|
||||||
|
}
|
||||||
|
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
|
||||||
|
if v, err := strconv.Atoi(gs); err == nil {
|
||||||
|
groupSize = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||||
if quantType == "" {
|
if quantType == "" {
|
||||||
quantType = inferredType
|
quantType = inferredType
|
||||||
|
|||||||
1514
x/models/gemma4/gemma4.go
Normal file
1514
x/models/gemma4/gemma4.go
Normal file
File diff suppressed because it is too large
Load Diff
228
x/models/gemma4/gemma4_moe_test.go
Normal file
228
x/models/gemma4/gemma4_moe_test.go
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
package gemma4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// onesLike creates a tensor of the given shape filled with a small constant.
|
||||||
|
func onesLike(shape ...int) *mlx.Array {
|
||||||
|
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoEForward(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Small config matching 26b architecture pattern.
|
||||||
|
cfg := &TextConfig{
|
||||||
|
HiddenSize: 16, // tiny for testing
|
||||||
|
NumAttentionHeads: 2,
|
||||||
|
NumKeyValueHeads: 1,
|
||||||
|
NumGlobalKeyValueHeads: 1,
|
||||||
|
HeadDim: 8,
|
||||||
|
GlobalHeadDim: 8,
|
||||||
|
NumExperts: 4,
|
||||||
|
TopKExperts: 2,
|
||||||
|
ExpertIntermediateSize: 8,
|
||||||
|
EnableMoeBlock: true,
|
||||||
|
AttentionKEqV: false,
|
||||||
|
RMSNormEps: 1e-6,
|
||||||
|
SlidingScale: 1.0,
|
||||||
|
FullScale: 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
B, L := int32(1), int32(3)
|
||||||
|
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
|
||||||
|
|
||||||
|
// Test Router.Forward.
|
||||||
|
router := &Router{
|
||||||
|
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
|
||||||
|
Scale: onesLike(int(cfg.HiddenSize)),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Router", func(t *testing.T) {
|
||||||
|
scores, inds := router.Forward(x, cfg)
|
||||||
|
mlx.Eval(scores, inds)
|
||||||
|
|
||||||
|
sDims := scores.Dims()
|
||||||
|
iDims := inds.Dims()
|
||||||
|
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
|
||||||
|
|
||||||
|
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
|
||||||
|
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
|
||||||
|
}
|
||||||
|
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
|
||||||
|
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test MoEBlock.Forward.
|
||||||
|
moe := &MoEBlock{
|
||||||
|
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||||
|
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
|
||||||
|
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
|
||||||
|
PerExpertScale: onesLike(int(cfg.NumExperts)),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("MoEBlock", func(t *testing.T) {
|
||||||
|
scores, inds := router.Forward(x, cfg)
|
||||||
|
mlx.Eval(scores, inds)
|
||||||
|
|
||||||
|
out := moe.Forward(x, scores, inds, cfg)
|
||||||
|
mlx.Eval(out)
|
||||||
|
|
||||||
|
outDims := out.Dims()
|
||||||
|
t.Logf("MoE output shape: %v", outDims)
|
||||||
|
|
||||||
|
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
|
||||||
|
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
|
||||||
|
t.Run("MoEBlock_sorted", func(t *testing.T) {
|
||||||
|
bigB, bigL := int32(1), int32(128)
|
||||||
|
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
|
||||||
|
|
||||||
|
scores, inds := router.Forward(bigX, cfg)
|
||||||
|
mlx.Eval(scores, inds)
|
||||||
|
|
||||||
|
out := moe.Forward(bigX, scores, inds, cfg)
|
||||||
|
mlx.Eval(out)
|
||||||
|
|
||||||
|
outDims := out.Dims()
|
||||||
|
t.Logf("MoE sorted output shape: %v", outDims)
|
||||||
|
|
||||||
|
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
|
||||||
|
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
|
||||||
|
// which takes the top-k of the raw logits and softmaxes only the selected
|
||||||
|
// values — produces the same indices and (within tolerance) the same
|
||||||
|
// normalized scores as the legacy path that softmaxes over every expert
|
||||||
|
// first, gathers the top-k probabilities, then renormalizes.
|
||||||
|
func TestRouterForwardMatchesLegacy(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
cfg := &TextConfig{
|
||||||
|
HiddenSize: 8,
|
||||||
|
NumExperts: 4,
|
||||||
|
TopKExperts: 2,
|
||||||
|
RMSNormEps: 1e-6,
|
||||||
|
RouterScale: 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Distinct per-expert weight rows so top-k has a well-defined ordering
|
||||||
|
// (tied scores would let argpartition pick either tied expert and make
|
||||||
|
// the index comparison below flaky).
|
||||||
|
projWeight := mlx.FromValues([]float32{
|
||||||
|
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
|
||||||
|
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
|
||||||
|
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
|
||||||
|
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
|
||||||
|
}, int(cfg.NumExperts), int(cfg.HiddenSize))
|
||||||
|
|
||||||
|
scale := mlx.FromValues([]float32{
|
||||||
|
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
|
||||||
|
}, int(cfg.HiddenSize))
|
||||||
|
|
||||||
|
r := &Router{
|
||||||
|
Proj: linearFromWeight(projWeight),
|
||||||
|
Scale: scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Varied x so different positions potentially hit different top-k.
|
||||||
|
x := mlx.FromValues([]float32{
|
||||||
|
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
|
||||||
|
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
|
||||||
|
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
|
||||||
|
}, 1, 3, int(cfg.HiddenSize))
|
||||||
|
|
||||||
|
gotScores, gotInds := r.Forward(x, cfg)
|
||||||
|
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
||||||
|
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
||||||
|
|
||||||
|
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
||||||
|
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
|
||||||
|
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// legacyRouterForward implements the pre-optimization router: full softmax
|
||||||
|
// over every expert, gather the top-k probabilities, then renormalize them
|
||||||
|
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
|
||||||
|
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
|
||||||
|
dims := x.Dims()
|
||||||
|
BL := int32(dims[0]) * int32(dims[1])
|
||||||
|
|
||||||
|
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
|
||||||
|
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
|
||||||
|
normed = mlx.MulScalar(normed, cfg.RouterScale)
|
||||||
|
normed = mlx.Mul(normed, r.Scale)
|
||||||
|
|
||||||
|
expertScores := r.Proj.Forward(normed)
|
||||||
|
probs := mlx.SoftmaxAxis(expertScores, -1, true)
|
||||||
|
|
||||||
|
neg := mlx.Neg(expertScores)
|
||||||
|
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
|
||||||
|
inds = mlx.SliceStartStop(inds,
|
||||||
|
[]int32{0, 0},
|
||||||
|
[]int32{BL, cfg.TopKExperts},
|
||||||
|
)
|
||||||
|
|
||||||
|
scores := mlx.TakeAlongAxis(probs, inds, -1)
|
||||||
|
sumScores := mlx.Sum(scores, -1, true)
|
||||||
|
scores = mlx.Div(scores, sumScores)
|
||||||
|
return scores, inds
|
||||||
|
}
|
||||||
|
|
||||||
|
func intSlicesEqual(a, b []int) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func floatSlicesClose(a, b []float32, tol float32) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
d := a[i] - b[i]
|
||||||
|
if d < 0 {
|
||||||
|
d = -d
|
||||||
|
}
|
||||||
|
if d > tol {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
|
||||||
|
func linearFromWeight(w *mlx.Array) *simpleLinear {
|
||||||
|
return &simpleLinear{weight: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
type simpleLinear struct {
|
||||||
|
weight *mlx.Array
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *simpleLinear) OutputDim() int32 {
|
||||||
|
return int32(l.weight.Dims()[0])
|
||||||
|
}
|
||||||
503
x/models/gemma4/gemma4_test.go
Normal file
503
x/models/gemma4/gemma4_test.go
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
package gemma4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseTextConfigE2B(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
data := []byte(`{
|
||||||
|
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 1536,
|
||||||
|
"num_hidden_layers": 35,
|
||||||
|
"intermediate_size": 6144,
|
||||||
|
"num_attention_heads": 8,
|
||||||
|
"num_key_value_heads": 1,
|
||||||
|
"head_dim": 256,
|
||||||
|
"global_head_dim": 512,
|
||||||
|
"vocab_size": 262144,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"sliding_window": 512,
|
||||||
|
"sliding_window_pattern": 5,
|
||||||
|
"final_logit_softcapping": 30.0,
|
||||||
|
"use_double_wide_mlp": true,
|
||||||
|
"num_kv_shared_layers": 20,
|
||||||
|
"hidden_size_per_layer_input": 256,
|
||||||
|
"vocab_size_per_layer_input": 262144,
|
||||||
|
"attention_k_eq_v": false,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||||
|
],
|
||||||
|
"rope_parameters": {
|
||||||
|
"full_attention": {
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"rope_type": "proportional"
|
||||||
|
},
|
||||||
|
"sliding_attention": {
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"rope_type": "default"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseTextConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTextConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic fields.
|
||||||
|
if cfg.HiddenSize != 1536 {
|
||||||
|
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
if cfg.NumHiddenLayers != 35 {
|
||||||
|
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
if cfg.GlobalHeadDim != 512 {
|
||||||
|
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
|
||||||
|
}
|
||||||
|
if cfg.FinalLogitSoftcapping != 30.0 {
|
||||||
|
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
|
||||||
|
}
|
||||||
|
if cfg.NumKVSharedLayers != 20 {
|
||||||
|
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
|
||||||
|
}
|
||||||
|
if cfg.HiddenSizePerLayer != 256 {
|
||||||
|
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoPE settings.
|
||||||
|
if cfg.SlidingRopeDims != 256 {
|
||||||
|
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
|
||||||
|
}
|
||||||
|
if cfg.FullRopeDims != 512 {
|
||||||
|
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
|
||||||
|
}
|
||||||
|
if cfg.SlidingRopeBase != 10000 {
|
||||||
|
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
|
||||||
|
}
|
||||||
|
if cfg.FullRopeBase != 1000000 {
|
||||||
|
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attention scale.
|
||||||
|
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
|
||||||
|
t.Error("attention scales should be non-zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
// KV sharing map.
|
||||||
|
// First shared layer is 35 - 20 = 15.
|
||||||
|
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
|
||||||
|
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
|
||||||
|
}
|
||||||
|
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
|
||||||
|
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||||
|
}
|
||||||
|
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
|
||||||
|
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
|
||||||
|
}
|
||||||
|
// Layer 14 should not be shared.
|
||||||
|
if _, ok := cfg.KVShareMap[14]; ok {
|
||||||
|
t.Error("layer 14 should not be in KVShareMap (non-shared)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Donors.
|
||||||
|
if !cfg.KVDonors[13] {
|
||||||
|
t.Error("layer 13 should be a KV donor")
|
||||||
|
}
|
||||||
|
if !cfg.KVDonors[14] {
|
||||||
|
t.Error("layer 14 should be a KV donor")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextConfig26B(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
data := []byte(`{
|
||||||
|
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 2816,
|
||||||
|
"num_hidden_layers": 30,
|
||||||
|
"intermediate_size": 2112,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"num_global_key_value_heads": 2,
|
||||||
|
"head_dim": 256,
|
||||||
|
"global_head_dim": 512,
|
||||||
|
"vocab_size": 262144,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"sliding_window": 1024,
|
||||||
|
"final_logit_softcapping": 30.0,
|
||||||
|
"use_double_wide_mlp": false,
|
||||||
|
"num_kv_shared_layers": 0,
|
||||||
|
"hidden_size_per_layer_input": null,
|
||||||
|
"attention_k_eq_v": true,
|
||||||
|
"enable_moe_block": true,
|
||||||
|
"num_experts": 128,
|
||||||
|
"top_k_experts": 8,
|
||||||
|
"moe_intermediate_size": 704,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||||
|
],
|
||||||
|
"rope_parameters": {
|
||||||
|
"full_attention": {
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"rope_type": "proportional"
|
||||||
|
},
|
||||||
|
"sliding_attention": {
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"rope_type": "default"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseTextConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTextConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HiddenSize != 2816 {
|
||||||
|
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
if !cfg.AttentionKEqV {
|
||||||
|
t.Error("AttentionKEqV should be true")
|
||||||
|
}
|
||||||
|
if cfg.NumGlobalKeyValueHeads != 2 {
|
||||||
|
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
|
||||||
|
}
|
||||||
|
if !cfg.EnableMoeBlock {
|
||||||
|
t.Error("EnableMoeBlock should be true")
|
||||||
|
}
|
||||||
|
if cfg.NumExperts != 128 {
|
||||||
|
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
|
||||||
|
}
|
||||||
|
if cfg.TopKExperts != 8 {
|
||||||
|
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
|
||||||
|
}
|
||||||
|
if cfg.ExpertIntermediateSize != 704 {
|
||||||
|
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
|
||||||
|
}
|
||||||
|
if cfg.HiddenSizePerLayer != 0 {
|
||||||
|
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextConfig31B(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
data := []byte(`{
|
||||||
|
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 5376,
|
||||||
|
"num_hidden_layers": 60,
|
||||||
|
"intermediate_size": 21504,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_key_value_heads": 16,
|
||||||
|
"num_global_key_value_heads": 4,
|
||||||
|
"head_dim": 256,
|
||||||
|
"global_head_dim": 512,
|
||||||
|
"vocab_size": 262144,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"sliding_window": 1024,
|
||||||
|
"final_logit_softcapping": 30.0,
|
||||||
|
"use_double_wide_mlp": false,
|
||||||
|
"num_kv_shared_layers": 0,
|
||||||
|
"hidden_size_per_layer_input": null,
|
||||||
|
"attention_k_eq_v": true,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||||
|
],
|
||||||
|
"rope_parameters": {
|
||||||
|
"full_attention": {
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"rope_type": "proportional"
|
||||||
|
},
|
||||||
|
"sliding_attention": {
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"rope_type": "default"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseTextConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTextConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HiddenSize != 5376 {
|
||||||
|
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
if cfg.NumHiddenLayers != 60 {
|
||||||
|
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
if !cfg.AttentionKEqV {
|
||||||
|
t.Error("AttentionKEqV should be true")
|
||||||
|
}
|
||||||
|
if cfg.NumGlobalKeyValueHeads != 4 {
|
||||||
|
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
|
||||||
|
}
|
||||||
|
if cfg.NumKeyValueHeads != 16 {
|
||||||
|
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
|
||||||
|
}
|
||||||
|
if cfg.NumKVSharedLayers != 0 {
|
||||||
|
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
|
||||||
|
}
|
||||||
|
if cfg.HiddenSizePerLayer != 0 {
|
||||||
|
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
|
||||||
|
}
|
||||||
|
if cfg.SlidingWindow != 1024 {
|
||||||
|
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// KV sharing should be empty (no shared layers).
|
||||||
|
if len(cfg.KVShareMap) != 0 {
|
||||||
|
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
|
||||||
|
if !isLayerSliding(0, &cfg) {
|
||||||
|
t.Error("layer 0 should be sliding")
|
||||||
|
}
|
||||||
|
if isLayerSliding(5, &cfg) {
|
||||||
|
t.Error("layer 5 should be full attention")
|
||||||
|
}
|
||||||
|
if !isLayerSliding(6, &cfg) {
|
||||||
|
t.Error("layer 6 should be sliding")
|
||||||
|
}
|
||||||
|
if isLayerSliding(59, &cfg) {
|
||||||
|
t.Error("layer 59 should be full attention")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTextConfigE4B(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
data := []byte(`{
|
||||||
|
"architectures": ["Gemma4ForConditionalGeneration"],
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 2560,
|
||||||
|
"num_hidden_layers": 42,
|
||||||
|
"intermediate_size": 10240,
|
||||||
|
"num_attention_heads": 8,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"head_dim": 256,
|
||||||
|
"global_head_dim": 512,
|
||||||
|
"vocab_size": 262144,
|
||||||
|
"rms_norm_eps": 1e-6,
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"sliding_window": 512,
|
||||||
|
"final_logit_softcapping": 30.0,
|
||||||
|
"use_double_wide_mlp": false,
|
||||||
|
"num_kv_shared_layers": 18,
|
||||||
|
"hidden_size_per_layer_input": 256,
|
||||||
|
"vocab_size_per_layer_input": 262144,
|
||||||
|
"attention_k_eq_v": false,
|
||||||
|
"enable_moe_block": false,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
|
||||||
|
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
|
||||||
|
],
|
||||||
|
"rope_parameters": {
|
||||||
|
"full_attention": {
|
||||||
|
"partial_rotary_factor": 0.25,
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"rope_type": "proportional"
|
||||||
|
},
|
||||||
|
"sliding_attention": {
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"rope_type": "default"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
cfg, err := parseTextConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTextConfig failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HiddenSize != 2560 {
|
||||||
|
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
if cfg.NumHiddenLayers != 42 {
|
||||||
|
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
if cfg.IntermediateSize != 10240 {
|
||||||
|
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
|
||||||
|
}
|
||||||
|
if cfg.NumKeyValueHeads != 2 {
|
||||||
|
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
|
||||||
|
}
|
||||||
|
if cfg.UseDoubleWideMLP {
|
||||||
|
t.Error("UseDoubleWideMLP should be false")
|
||||||
|
}
|
||||||
|
if cfg.NumKVSharedLayers != 18 {
|
||||||
|
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
|
||||||
|
}
|
||||||
|
if cfg.HiddenSizePerLayer != 256 {
|
||||||
|
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
|
||||||
|
}
|
||||||
|
if cfg.AttentionKEqV {
|
||||||
|
t.Error("AttentionKEqV should be false")
|
||||||
|
}
|
||||||
|
if cfg.EnableMoeBlock {
|
||||||
|
t.Error("EnableMoeBlock should be false")
|
||||||
|
}
|
||||||
|
if cfg.SlidingWindow != 512 {
|
||||||
|
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
|
||||||
|
if !isLayerSliding(0, &cfg) {
|
||||||
|
t.Error("layer 0 should be sliding")
|
||||||
|
}
|
||||||
|
if isLayerSliding(5, &cfg) {
|
||||||
|
t.Error("layer 5 should be full attention")
|
||||||
|
}
|
||||||
|
if !isLayerSliding(6, &cfg) {
|
||||||
|
t.Error("layer 6 should be sliding")
|
||||||
|
}
|
||||||
|
if isLayerSliding(41, &cfg) {
|
||||||
|
t.Error("layer 41 should be full attention")
|
||||||
|
}
|
||||||
|
|
||||||
|
// KV sharing: first shared = 42 - 18 = 24.
|
||||||
|
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
|
||||||
|
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
|
||||||
|
if donor, ok := cfg.KVShareMap[24]; !ok {
|
||||||
|
t.Error("layer 24 should be in KVShareMap")
|
||||||
|
} else {
|
||||||
|
t.Logf("layer 24 donor = %d", donor)
|
||||||
|
}
|
||||||
|
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
|
||||||
|
// Non-shared full layers: 5, 11, 17, 23.
|
||||||
|
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
|
||||||
|
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
|
||||||
|
}
|
||||||
|
// Layer 23 should NOT be shared (it's the last non-shared layer).
|
||||||
|
if _, ok := cfg.KVShareMap[23]; ok {
|
||||||
|
t.Error("layer 23 should not be in KVShareMap (non-shared)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLayerTypeDetection(t *testing.T) {
|
||||||
|
cfg := &TextConfig{
|
||||||
|
LayerTypes: []string{
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isLayerSliding(0, cfg) {
|
||||||
|
t.Error("layer 0 should be sliding")
|
||||||
|
}
|
||||||
|
if !isLayerSliding(3, cfg) {
|
||||||
|
t.Error("layer 3 should be sliding")
|
||||||
|
}
|
||||||
|
if isLayerSliding(4, cfg) {
|
||||||
|
t.Error("layer 4 should be full attention")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []*DecoderLayer{
|
||||||
|
{IsSliding: true, KVShareDonor: -1},
|
||||||
|
{IsSliding: false, KVShareDonor: -1},
|
||||||
|
{IsSliding: true, KVShareDonor: 0},
|
||||||
|
{IsSliding: false, KVShareDonor: 1},
|
||||||
|
},
|
||||||
|
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||||
|
}
|
||||||
|
|
||||||
|
caches := m.NewCaches()
|
||||||
|
if got, want := len(caches), 2; got != want {
|
||||||
|
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
|
||||||
|
m := &Model{
|
||||||
|
Layers: []*DecoderLayer{
|
||||||
|
{IsSliding: true, KVShareDonor: -1},
|
||||||
|
{IsSliding: false, KVShareDonor: -1},
|
||||||
|
{IsSliding: true, KVShareDonor: -1},
|
||||||
|
},
|
||||||
|
TextConfig: &TextConfig{SlidingWindow: 512},
|
||||||
|
}
|
||||||
|
|
||||||
|
caches := m.NewCaches()
|
||||||
|
if got, want := len(caches), len(m.Layers); got != want {
|
||||||
|
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveWeightPrefix(t *testing.T) {
|
||||||
|
if err := mlx.CheckInit(); err != nil {
|
||||||
|
t.Skipf("MLX not available: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
wantPfx string
|
||||||
|
}{
|
||||||
|
{"bare", "embed_tokens.weight", ""},
|
||||||
|
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
|
||||||
|
{"with_model", "model.embed_tokens.weight", "model."},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
dummy := mlx.FromValue(float32(1.0))
|
||||||
|
mlx.Eval(dummy)
|
||||||
|
tensors := map[string]*mlx.Array{tt.key: dummy}
|
||||||
|
got := resolveWeightPrefix(tensors)
|
||||||
|
if got != tt.wantPfx {
|
||||||
|
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipIfNoMLX(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
if err := mlx.CheckInit(); err != nil {
|
||||||
|
t.Skipf("MLX not available: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user