Compare commits

...

13 Commits

Author SHA1 Message Date
Bruce MacDonald
9ac1300805 fix lint 2026-02-17 16:04:10 -07:00
Bruce MacDonald
43d9907dd6 fix tests 2026-02-17 16:04:10 -07:00
Bruce MacDonald
91dc088e8b server: usage api
Add a new /api/usage endpoint that shows aggregate usage statistics per model since the server started.
2026-02-17 15:59:52 -07:00
Patrick Devine
9aefd2dfee model: add qwen3 support to mlxrunner (#14293) 2026-02-17 13:58:49 -08:00
Patrick Devine
d07e4a1dd3 bugfix: better mlx model scheduling (#14290)
This fixes a bug with current MLX based models which don't get loaded/unloaded correctly. The first model currently gets loaded and then subsequent model starts get shunted to the first runner which results in the wrong model being run.
2026-02-17 13:57:05 -08:00
Parth Sareen
8a257ec00a docs: make integrations more discoverable (#14301)
* docs: add Pi integration page

* docs: flatten integration sidebar with expanded subheadings

* docs: add OpenClaw and Claude Code to quickstart
2026-02-17 13:27:25 -08:00
Parth Sareen
2f4de1acf7 cmd: ollama launch always show model picker (#14299) 2026-02-17 12:02:14 -08:00
Parth Sareen
ec95c45f70 cmd/config: ollama launch cline CLI (#14294) 2026-02-17 11:37:53 -08:00
Patrick Devine
3a88f7eb20 bugfix: add missing linear layer factory (#14289) 2026-02-16 17:22:20 -08:00
Patrick Devine
0d5da826d4 bugfix: display the parameter count correctly in mlx for ollama show (#14285) 2026-02-16 13:03:34 -08:00
Patrick Devine
9b795698b8 model: add llama3 architecture to mlxrunner (#14277) 2026-02-15 23:06:28 -08:00
Patrick Devine
041fb77639 model: add gemma3 to the mlxrunner (#14276)
This change adds the gemma3 model to the mlxrunner and simplifies some of the quantization
code for loading weights.
2026-02-15 22:47:59 -08:00
Saumil Shah
8224cce583 readme: update download link for macOS (#1) (#14271) 2026-02-15 15:25:15 -08:00
41 changed files with 3607 additions and 215 deletions

View File

@@ -16,7 +16,7 @@ Start building with open models.
curl -fsSL https://ollama.com/install.sh | sh curl -fsSL https://ollama.com/install.sh | sh
``` ```
or [download manually](http://localhost:8080/download/Ollama.dmg) or [download manually](https://ollama.com/download/Ollama.dmg)
### Windows ### Windows

View File

@@ -922,6 +922,19 @@ type UserResponse struct {
Plan string `json:"plan,omitempty"` Plan string `json:"plan,omitempty"`
} }
type UsageResponse struct {
// Start is the time the server started tracking usage (UTC, RFC 3339).
Start time.Time `json:"start"`
Usage []ModelUsageData `json:"usage"`
}
type ModelUsageData struct {
Model string `json:"model"`
Requests int64 `json:"requests"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
}
// Tensor describes the metadata for a given tensor. // Tensor describes the metadata for a given tensor.
type Tensor struct { type Tensor struct {
Name string `json:"name"` Name string `json:"name"`

View File

@@ -57,9 +57,9 @@ import (
func init() { func init() {
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O. // Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) { config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems) result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled return "", config.ErrCancelled
} }
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
mfConfig.System = cmd.Args mfConfig.System = cmd.Args
case "license": case "license":
mfConfig.License = cmd.Args mfConfig.License = cmd.Args
case "parser":
mfConfig.Parser = cmd.Args
case "renderer":
mfConfig.Renderer = cmd.Args
} }
} }
@@ -1897,9 +1901,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
} }
// Selector adapters for tui // Selector adapters for tui
singleSelector := func(title string, items []config.ModelItem) (string, error) { singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
tuiItems := tui.ReorderItems(tui.ConvertItems(items)) tuiItems := tui.ReorderItems(tui.ConvertItems(items))
result, err := tui.SelectSingle(title, tuiItems) result, err := tui.SelectSingle(title, tuiItems, current)
if errors.Is(err, tui.ErrCancelled) { if errors.Is(err, tui.ErrCancelled) {
return "", config.ErrCancelled return "", config.ErrCancelled
} }

View File

@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset) fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
if aliases["primary"] == "" || force { if aliases["primary"] == "" || force {
primary, err := DefaultSingleSelector("Select model:", items) primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }

123
cmd/config/cline.go Normal file
View File

@@ -0,0 +1,123 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/ollama/ollama/envconfig"
)
// Cline implements Runner and Editor for the Cline CLI integration
type Cline struct{}
func (c *Cline) String() string { return "Cline" }
func (c *Cline) Run(model string, args []string) error {
if _, err := exec.LookPath("cline"); err != nil {
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
}
models := []string{model}
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
models = config.Models
}
var err error
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
return selectModels(context.Background(), "cline", "")
})
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
if err := c.Edit(models); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
cmd := exec.Command("cline", args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func (c *Cline) Paths() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
p := filepath.Join(home, ".cline", "data", "globalState.json")
if _, err := os.Stat(p); err == nil {
return []string{p}
}
return nil
}
func (c *Cline) Edit(models []string) error {
if len(models) == 0 {
return nil
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
config := make(map[string]any)
if data, err := os.ReadFile(configPath); err == nil {
if err := json.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
}
}
// Set Ollama as the provider for both act and plan modes
baseURL := envconfig.Host().String()
config["ollamaBaseUrl"] = baseURL
config["actModeApiProvider"] = "ollama"
config["actModeOllamaModelId"] = models[0]
config["actModeOllamaBaseUrl"] = baseURL
config["planModeApiProvider"] = "ollama"
config["planModeOllamaModelId"] = models[0]
config["planModeOllamaBaseUrl"] = baseURL
config["welcomeViewCompleted"] = true
data, err := json.MarshalIndent(config, "", " ")
if err != nil {
return err
}
return writeWithBackup(configPath, data)
}
func (c *Cline) Models() []string {
home, err := os.UserHomeDir()
if err != nil {
return nil
}
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
if err != nil {
return nil
}
if config["actModeApiProvider"] != "ollama" {
return nil
}
modelID, _ := config["actModeOllamaModelId"].(string)
if modelID == "" {
return nil
}
return []string{modelID}
}

204
cmd/config/cline_test.go Normal file
View File

@@ -0,0 +1,204 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
func TestClineIntegration(t *testing.T) {
c := &Cline{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Cline" {
t.Errorf("String() = %q, want %q", got, "Cline")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
t.Run("implements Editor", func(t *testing.T) {
var _ Editor = c
})
}
func TestClineEdit(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
readConfig := func() map[string]any {
data, _ := os.ReadFile(configPath)
var config map[string]any
json.Unmarshal(data, &config)
return config
}
t.Run("creates config from scratch", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeApiProvider"] != "ollama" {
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
}
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
}
if config["planModeApiProvider"] != "ollama" {
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
}
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
}
if config["welcomeViewCompleted"] != true {
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
}
})
t.Run("preserves existing fields", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
os.MkdirAll(configDir, 0o755)
existing := map[string]any{
"remoteRulesToggles": map[string]any{},
"remoteWorkflowToggles": map[string]any{},
"customSetting": "keep-me",
}
data, _ := json.Marshal(existing)
os.WriteFile(configPath, data, 0o644)
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["customSetting"] != "keep-me" {
t.Errorf("customSetting was not preserved")
}
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
})
t.Run("updates model on re-edit", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
t.Fatal(err)
}
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
}
if config["planModeOllamaModelId"] != "glm-5:cloud" {
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
}
})
t.Run("empty models is no-op", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit(nil); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
t.Error("expected no config file to be created for empty models")
}
})
t.Run("uses first model as primary", func(t *testing.T) {
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
t.Fatal(err)
}
config := readConfig()
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
}
})
}
func TestClineModels(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
configDir := filepath.Join(tmpDir, ".cline", "data")
configPath := filepath.Join(configDir, "globalState.json")
t.Run("returns nil when no config", func(t *testing.T) {
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "anthropic",
"actModeOllamaModelId": "some-model",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
if models := c.Models(); models != nil {
t.Errorf("Models() = %v, want nil", models)
}
})
t.Run("returns model when ollama is configured", func(t *testing.T) {
os.MkdirAll(configDir, 0o755)
config := map[string]any{
"actModeApiProvider": "ollama",
"actModeOllamaModelId": "kimi-k2.5:cloud",
}
data, _ := json.Marshal(config)
os.WriteFile(configPath, data, 0o644)
models := c.Models()
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
}
})
}
func TestClinePaths(t *testing.T) {
c := &Cline{}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Run("returns nil when no config exists", func(t *testing.T) {
if paths := c.Paths(); paths != nil {
t.Errorf("Paths() = %v, want nil", paths)
}
})
t.Run("returns path when config exists", func(t *testing.T) {
configDir := filepath.Join(tmpDir, ".cline", "data")
os.MkdirAll(configDir, 0o755)
configPath := filepath.Join(configDir, "globalState.json")
os.WriteFile(configPath, []byte("{}"), 0o644)
paths := c.Paths()
if len(paths) != 1 || paths[0] != configPath {
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
}
})
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"maps"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@@ -54,6 +53,7 @@ type AliasConfigurer interface {
var integrations = map[string]Runner{ var integrations = map[string]Runner{
"claude": &Claude{}, "claude": &Claude{},
"clawdbot": &Openclaw{}, "clawdbot": &Openclaw{},
"cline": &Cline{},
"codex": &Codex{}, "codex": &Codex{},
"moltbot": &Openclaw{}, "moltbot": &Openclaw{},
"droid": &Droid{}, "droid": &Droid{},
@@ -102,16 +102,17 @@ var recommendedVRAM = map[string]string{
var integrationAliases = map[string]bool{ var integrationAliases = map[string]bool{
"clawdbot": true, "clawdbot": true,
"moltbot": true, "moltbot": true,
"pi": true,
} }
// integrationInstallHints maps integration names to install URLs. // integrationInstallHints maps integration names to install URLs.
var integrationInstallHints = map[string]string{ var integrationInstallHints = map[string]string{
"claude": "https://code.claude.com/docs/en/quickstart", "claude": "https://code.claude.com/docs/en/quickstart",
"cline": "https://cline.bot/cli",
"openclaw": "https://docs.openclaw.ai", "openclaw": "https://docs.openclaw.ai",
"codex": "https://developers.openai.com/codex/cli/", "codex": "https://developers.openai.com/codex/cli/",
"droid": "https://docs.factory.ai/cli/getting-started/quickstart", "droid": "https://docs.factory.ai/cli/getting-started/quickstart",
"opencode": "https://opencode.ai", "opencode": "https://opencode.ai",
"pi": "https://github.com/badlogic/pi-mono",
} }
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable. // hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
@@ -129,13 +130,21 @@ type IntegrationInfo struct {
// integrationDescriptions maps integration names to short descriptions. // integrationDescriptions maps integration names to short descriptions.
var integrationDescriptions = map[string]string{ var integrationDescriptions = map[string]string{
"claude": "Anthropic's coding tool with subagents", "claude": "Anthropic's coding tool with subagents",
"cline": "Autonomous coding agent with parallel execution",
"codex": "OpenAI's open-source coding agent", "codex": "OpenAI's open-source coding agent",
"openclaw": "Personal AI with 100+ skills", "openclaw": "Personal AI with 100+ skills",
"droid": "Factory's coding agent across terminal and IDEs", "droid": "Factory's coding agent across terminal and IDEs",
"opencode": "Anomaly's open-source coding agent", "opencode": "Anomaly's open-source coding agent",
"pi": "Minimal AI agent toolkit with plugin support",
} }
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name. // integrationOrder defines a custom display order for integrations.
// Integrations listed here are placed at the end in the given order;
// all others appear first, sorted alphabetically.
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
// with integrationOrder entries placed at the end.
func ListIntegrationInfos() []IntegrationInfo { func ListIntegrationInfos() []IntegrationInfo {
var result []IntegrationInfo var result []IntegrationInfo
for name, r := range integrations { for name, r := range integrations {
@@ -148,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
Description: integrationDescriptions[name], Description: integrationDescriptions[name],
}) })
} }
orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
}
slices.SortFunc(result, func(a, b IntegrationInfo) int { slices.SortFunc(result, func(a, b IntegrationInfo) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
// Both have custom order: sort by their rank
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
// Only one has custom order: it goes last
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
// Neither has custom order: alphabetical
return strings.Compare(a.Name, b.Name) return strings.Compare(a.Name, b.Name)
}) })
return result return result
@@ -186,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
case "droid": case "droid":
_, err := exec.LookPath("droid") _, err := exec.LookPath("droid")
return err == nil return err == nil
case "cline":
_, err := exec.LookPath("cline")
return err == nil
case "opencode": case "opencode":
_, err := exec.LookPath("opencode") _, err := exec.LookPath("opencode")
return err == nil return err == nil
case "pi":
_, err := exec.LookPath("pi")
return err == nil
default: default:
return true // Assume installed for unknown integrations return true // Assume installed for unknown integrations
} }
@@ -214,7 +248,8 @@ type ModelItem struct {
} }
// SingleSelector is a function type for single item selection. // SingleSelector is a function type for single item selection.
type SingleSelector func(title string, items []ModelItem) (string, error) // current is the name of the previously selected item to highlight; empty means no pre-selection.
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
// MultiSelector is a function type for multi item selection. // MultiSelector is a function type for multi item selection.
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error) type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
@@ -257,7 +292,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first") return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
} }
selected, err := selector("Select model to run:", items) selected, err := selector("Select model to run:", items, "")
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -367,13 +402,11 @@ func selectIntegration() (string, error) {
return "", fmt.Errorf("no integrations available") return "", fmt.Errorf("no integrations available")
} }
names := slices.Sorted(maps.Keys(integrations))
var items []ModelItem var items []ModelItem
for _, name := range names { for name, r := range integrations {
if integrationAliases[name] { if integrationAliases[name] {
continue continue
} }
r := integrations[name]
description := r.String() description := r.String()
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 { if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0]) description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
@@ -381,7 +414,25 @@ func selectIntegration() (string, error) {
items = append(items, ModelItem{Name: name, Description: description}) items = append(items, ModelItem{Name: name, Description: description})
} }
return DefaultSingleSelector("Select integration:", items) orderRank := make(map[string]int, len(integrationOrder))
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
slices.SortFunc(items, func(a, b ModelItem) int {
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
if aRank > 0 && bRank > 0 {
return aRank - bRank
}
if aRank > 0 {
return 1
}
if bRank > 0 {
return -1
}
return strings.Compare(a.Name, b.Name)
})
return DefaultSingleSelector("Select integration:", items, "")
} }
// selectModelsWithSelectors lets the user select models for an integration using provided selectors. // selectModelsWithSelectors lets the user select models for an integration using provided selectors.
@@ -439,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
if _, ok := r.(AliasConfigurer); ok { if _, ok := r.(AliasConfigurer); ok {
prompt = fmt.Sprintf("Select Primary model for %s:", r) prompt = fmt.Sprintf("Select Primary model for %s:", r)
} }
model, err := single(prompt, items) model, err := single(prompt, items, current)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -812,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
Supported integrations: Supported integrations:
claude Claude Code claude Claude Code
cline Cline
codex Codex codex Codex
droid Droid droid Droid
opencode OpenCode opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot) openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
Examples: Examples:
ollama launch ollama launch
@@ -915,11 +968,9 @@ Examples:
} }
// Validate saved model still exists // Validate saved model still exists
cloudCleared := false
if model != "" && modelFlag == "" { if model != "" && modelFlag == "" {
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) { if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
model = "" model = ""
cloudCleared = true
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil { } else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset) fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
if err := ShowOrPull(cmd.Context(), client, model); err != nil { if err := ShowOrPull(cmd.Context(), client, model); err != nil {
@@ -928,18 +979,16 @@ Examples:
} }
} }
// If no valid model or --config flag, show picker // Show picker so user can change model (skip when --model flag provided)
if model == "" || configFlag { aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag || cloudCleared) if errors.Is(err, errCancelled) {
if errors.Is(err, errCancelled) { return nil
return nil
}
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
} }
if err != nil {
return err
}
model = aliases["primary"]
existingAliases = aliases
// Ensure cloud models are authenticated // Ensure cloud models are authenticated
if isCloudModel(cmd.Context(), client, model) { if isCloudModel(cmd.Context(), client, model) {
@@ -1001,27 +1050,13 @@ Examples:
return err return err
} }
} }
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
savedModels := filterDisabledCloudModels(saved.Models)
if len(savedModels) != len(saved.Models) {
_ = SaveIntegration(name, savedModels)
}
if len(savedModels) == 0 {
// All saved models were cloud — fall through to picker
models, err = selectModels(cmd.Context(), name, "")
if errors.Is(err, errCancelled) {
return nil
}
if err != nil {
return err
}
} else {
models = savedModels
return runIntegration(name, models[0], passArgs)
}
} else { } else {
current := ""
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
current = saved.Models[0]
}
var err error var err error
models, err = selectModels(cmd.Context(), name, "") models, err = selectModels(cmd.Context(), name, current)
if errors.Is(err, errCancelled) { if errors.Is(err, errCancelled) {
return nil return nil
} }

View File

@@ -1248,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
} }
}) })
t.Run("sorted by name", func(t *testing.T) { t.Run("sorted with custom order at end", func(t *testing.T) {
// integrationOrder entries (cline, opencode) should appear last, in that order.
// All other entries should be sorted alphabetically before them.
orderRank := make(map[string]int)
for i, name := range integrationOrder {
orderRank[name] = i + 1
}
for i := 1; i < len(infos); i++ { for i := 1; i < len(infos); i++ {
if infos[i-1].Name >= infos[i].Name { aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name) switch {
case aRank == 0 && bRank == 0:
if infos[i-1].Name >= infos[i].Name {
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
case aRank > 0 && bRank == 0:
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
case aRank > 0 && bRank > 0:
if aRank >= bRank {
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
}
} }
} }
}) })

View File

@@ -365,14 +365,27 @@ func (m selectorModel) View() string {
return s return s
} }
func SelectSingle(title string, items []SelectItem) (string, error) { // cursorForCurrent returns the item index matching current, or 0 if not found.
func cursorForCurrent(items []SelectItem, current string) int {
if current != "" {
for i, item := range items {
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
return i
}
}
}
return 0
}
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
if len(items) == 0 { if len(items) == 0 {
return "", fmt.Errorf("no items to select from") return "", fmt.Errorf("no items to select from")
} }
m := selectorModel{ m := selectorModel{
title: title, title: title,
items: items, items: items,
cursor: cursorForCurrent(items, current),
} }
p := tea.NewProgram(m) p := tea.NewProgram(m)

View File

@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
} }
} }
// --- cursorForCurrent ---
func TestCursorForCurrent(t *testing.T) {
testItems := []SelectItem{
{Name: "llama3.2", Recommended: true},
{Name: "qwen3:8b", Recommended: true},
{Name: "gemma3:latest"},
{Name: "deepseek-r1"},
{Name: "glm-5:cloud"},
}
tests := []struct {
name string
current string
want int
}{
{"empty current", "", 0},
{"exact match", "qwen3:8b", 1},
{"no match returns 0", "nonexistent", 0},
{"bare name matches with :latest suffix", "gemma3", 2},
{"full tag matches bare item", "llama3.2:latest", 0},
{"cloud model exact match", "glm-5:cloud", 4},
{"cloud model bare name", "glm-5", 4},
{"recommended item exact match", "llama3.2", 0},
{"recommended item with tag", "qwen3", 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
}
})
}
}
// --- ReorderItems --- // --- ReorderItems ---
func TestReorderItems(t *testing.T) { func TestReorderItems(t *testing.T) {

View File

@@ -15,6 +15,7 @@
- [Push a Model](#push-a-model) - [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings) - [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models) - [List Running Models](#list-running-models)
- [Usage](#usage)
- [Version](#version) - [Version](#version)
- [Experimental: Image Generation](#image-generation-experimental) - [Experimental: Image Generation](#image-generation-experimental)
@@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{
} }
``` ```
## Usage
```
GET /api/usage
```
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
### Examples
#### Request
```shell
curl http://localhost:11434/api/usage
```
#### Response
```json
{
"start": "2025-01-27T20:00:00Z",
"usage": [
{
"model": "llama3.2",
"requests": 5,
"prompt_tokens": 130,
"completion_tokens": 890
},
{
"model": "deepseek-r1",
"requests": 2,
"prompt_tokens": 48,
"completion_tokens": 312
}
]
}
```
#### Response fields
- `start`: when the server started tracking usage (UTC, RFC 3339)
- `usage`: list of per-model usage statistics
- `model`: model name
- `requests`: total number of completed requests
- `prompt_tokens`: total prompt tokens evaluated
- `completion_tokens`: total completion tokens generated
## Version ## Version
``` ```

View File

@@ -106,20 +106,23 @@
"group": "Integrations", "group": "Integrations",
"pages": [ "pages": [
"/integrations/index", "/integrations/index",
{
"group": "Assistants",
"expanded": true,
"pages": [
"/integrations/openclaw"
]
},
{ {
"group": "Coding", "group": "Coding",
"expanded": true,
"pages": [ "pages": [
"/integrations/claude-code", "/integrations/claude-code",
"/integrations/codex", "/integrations/codex",
"/integrations/opencode", "/integrations/opencode",
"/integrations/droid", "/integrations/droid",
"/integrations/goose" "/integrations/goose",
] "/integrations/pi"
},
{
"group": "Assistants",
"pages": [
"/integrations/openclaw"
] ]
}, },
{ {

View File

@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [OpenCode](/integrations/opencode) - [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid) - [Droid](/integrations/droid)
- [Goose](/integrations/goose) - [Goose](/integrations/goose)
- [Pi](/integrations/pi)
## Assistants ## Assistants

57
docs/integrations/pi.mdx Normal file
View File

@@ -0,0 +1,57 @@
---
title: Pi
---
Pi is a minimal AI agent toolkit with plugin support.
## Install
Install [Pi](https://github.com/badlogic/pi-mono):
```bash
npm install -g @mariozechner/pi-coding-agent
```
## Usage with Ollama
### Quick setup
```bash
ollama launch pi
```
To configure without launching:
```shell
ollama launch pi --config
```
### Manual setup
Add a configuration block to `~/.pi/agent/models.json`:
```json
{
"providers": {
"ollama": {
"baseUrl": "http://localhost:11434/v1",
"api": "openai-completions",
"apiKey": "ollama",
"models": [
{
"id": "qwen3-coder"
}
]
}
}
}
```
Update `~/.pi/agent/settings.json` to set the default provider:
```json
{
"defaultProvider": "ollama",
"defaultModel": "qwen3-coder"
}
```

View File

@@ -27,9 +27,17 @@ The menu provides quick access to:
- **Launch tools** - Claude Code, Codex, OpenClaw, and more - **Launch tools** - Claude Code, Codex, OpenClaw, and more
- **Additional integrations** - Available under "More..." - **Additional integrations** - Available under "More..."
## Assistants
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
```sh
ollama launch openclaw
```
## Coding ## Coding
Launch coding tools with Ollama models: Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
```sh ```sh
ollama launch claude ollama launch claude

View File

@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
var p Parser var p Parser
switch name { switch name {
case "qwen3":
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
case "qwen3-thinking":
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder": case "qwen3-coder":
p = &Qwen3CoderParser{} p = &Qwen3CoderParser{}
case "qwen3-vl-instruct": case "qwen3-vl-instruct":

View File

@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
name string name string
}{ }{
{"passthrough"}, {"passthrough"},
{"qwen3"},
{"qwen3-thinking"},
{"qwen3-coder"}, {"qwen3-coder"},
{"harmony"}, {"harmony"},
} }

335
model/parsers/qwen3.go Normal file
View File

@@ -0,0 +1,335 @@
package parsers
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwen3ParserState int
const (
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
qwen3ParserStateThinkingStartedEatingWhitespace
qwen3ParserStateCollectingThinking
qwen3ParserStateThinkingDoneEatingWhitespace
qwen3ParserStateCollectingContent
qwen3ParserStateToolStartedEatingWhitespace
qwen3ParserStateCollectingToolContent
)
const (
qwen3ThinkingOpenTag = "<think>"
qwen3ThinkingCloseTag = "</think>"
qwen3ToolOpenTag = "<tool_call>"
qwen3ToolCloseTag = "</tool_call>"
)
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
// with thinking content directly (without an opening tag).
type Qwen3Parser struct {
state qwen3ParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
defaultThinking bool
maybeThinkingOpenAtBOL bool
}
func (p *Qwen3Parser) HasToolSupport() bool {
return true
}
func (p *Qwen3Parser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.buffer.Reset()
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
if thinkValue == nil {
thinkingEnabled = p.defaultThinking
}
if p.hasThinkingSupport && thinkingEnabled {
p.state = qwen3ParserStateCollectingThinking
p.maybeThinkingOpenAtBOL = true
} else {
p.state = qwen3ParserStateCollectingContent
p.maybeThinkingOpenAtBOL = false
}
return tools
}
type qwen3Event interface {
isQwen3Event()
}
type qwen3EventContent struct {
content string
}
func (qwen3EventContent) isQwen3Event() {}
type qwen3EventRawToolCall struct {
raw string
}
func (qwen3EventRawToolCall) isQwen3Event() {}
type qwen3EventThinkingContent struct {
content string
}
func (qwen3EventThinkingContent) isQwen3Event() {}
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwen3EventRawToolCall:
toolCall, err := parseQwen3ToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen3 tool call parsing failed", "error", err)
return "", "", nil, err
}
calls = append(calls, toolCall)
case qwen3EventThinkingContent:
thinkingSb.WriteString(event.content)
case qwen3EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen3Parser) parseEvents() []qwen3Event {
var all []qwen3Event
keepLooping := true
for keepLooping {
var events []qwen3Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true
}
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
return splitAtTag(&p.buffer, tag, trimAfter)
}
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
var events []qwen3Event
switch p.state {
case qwen3ParserStateLookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingThinking
}
return events, true
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
} else if trimmed == "" {
return events, false
}
p.state = qwen3ParserStateCollectingContent
return events, true
case qwen3ParserStateThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
case qwen3ParserStateCollectingThinking:
acc := p.buffer.String()
// Some qwen3 checkpoints emit an explicit opening <think> tag even
// though the prompt already ended with <think>. Strip exactly one
// leading opening tag if present.
if p.maybeThinkingOpenAtBOL {
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
return events, false
}
p.maybeThinkingOpenAtBOL = false
return events, true
}
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
return events, false
}
p.maybeThinkingOpenAtBOL = false
}
if strings.Contains(acc, qwen3ThinkingCloseTag) {
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwen3EventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventThinkingContent{content: unambiguous})
}
return events, false
case qwen3ParserStateThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
case qwen3ParserStateCollectingContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolOpenTag) {
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
if len(before) > 0 {
events = append(events, qwen3EventContent{content: before})
}
if after == "" {
p.state = qwen3ParserStateToolStartedEatingWhitespace
} else {
p.state = qwen3ParserStateCollectingToolContent
}
return events, true
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
}
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwen3EventContent{content: unambiguous})
}
return events, false
case qwen3ParserStateToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
case qwen3ParserStateCollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, qwen3ToolCloseTag) {
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("qwen3 tool call closing tag found but no content before it")
}
events = append(events, qwen3EventRawToolCall{raw: toolContent})
p.state = qwen3ParserStateCollectingContent
return events, true
}
return events, false
default:
panic("unreachable")
}
}
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var parsed struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
}
if parsed.Name == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: parsed.Name,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for key, value := range parsed.Arguments {
toolCall.Function.Arguments.Set(key, value)
}
return toolCall, nil
}

147
model/parsers/qwen3_test.go Normal file
View File

@@ -0,0 +1,147 @@
package parsers
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestQwen3ParserThinkingEnabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
content, thinking, calls, err := parser.Add("<thi", false)
if err != nil {
t.Fatalf("parse failed on first chunk: %v", err)
}
if content != "" || thinking != "" || len(calls) != 0 {
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
}
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
if err != nil {
t.Fatalf("parse failed on second chunk: %v", err)
}
if thinking != "Let me think..." {
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
}
if content != "Answer." {
t.Fatalf("expected content %q, got %q", "Answer.", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserThinkingDisabled(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, nil)
content, thinking, calls, err := parser.Add("Direct answer", true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if thinking != "" {
t.Fatalf("expected no thinking, got %q", thinking)
}
if content != "Direct answer" {
t.Fatalf("expected content %q, got %q", "Direct answer", content)
}
if len(calls) != 0 {
t.Fatalf("expected no tool calls, got %d", len(calls))
}
}
func TestQwen3ParserToolCall(t *testing.T) {
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
parser.Init(nil, nil, &api.ThinkValue{Value: false})
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
content, thinking, calls, err := parser.Add(input, true)
if err != nil {
t.Fatalf("parse failed: %v", err)
}
if content != "" {
t.Fatalf("expected empty content, got %q", content)
}
if thinking != "" {
t.Fatalf("expected empty thinking, got %q", thinking)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "get_weather" {
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
}
location, ok := calls[0].Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Fatalf("expected location %q, got %v", "San Francisco", location)
}
unit, ok := calls[0].Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Fatalf("expected unit %q, got %v", "celsius", unit)
}
}

View File

@@ -91,6 +91,8 @@ type Server struct {
aliasesOnce sync.Once aliasesOnce sync.Once
aliases *store aliases *store
aliasesErr error aliasesErr error
lowVRAM bool
usage *UsageTracker
} }
func init() { func init() {
@@ -289,6 +291,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.Header("Content-Type", contentType) c.Header("Content-Type", contentType)
fn := func(resp api.GenerateResponse) error { fn := func(resp api.GenerateResponse) error {
if resp.Done {
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
}
resp.Model = origModel resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost resp.RemoteHost = m.Config.RemoteHost
@@ -595,6 +601,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
res.Context = tokens res.Context = tokens
} }
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
} }
if builtinParser != nil { if builtinParser != nil {
@@ -1622,6 +1630,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/experimental/aliases", s.CreateAliasHandler) r.POST("/api/experimental/aliases", s.CreateAliasHandler)
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler) r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
r.GET("/api/usage", s.UsageHandler)
// Inference // Inference
r.GET("/api/ps", s.PsHandler) r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
@@ -1692,7 +1702,7 @@ func Serve(ln net.Listener) error {
} }
} }
s := &Server{addr: ln.Addr()} s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
var rc *ollama.Registry var rc *ollama.Registry
if useClient2 { if useClient2 {
@@ -1927,6 +1937,10 @@ func (s *Server) SignoutHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil) c.JSON(http.StatusOK, nil)
} }
func (s *Server) UsageHandler(c *gin.Context) {
c.JSON(http.StatusOK, s.usage.Stats())
}
func (s *Server) PsHandler(c *gin.Context) { func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{} models := []api.ProcessModelResponse{}
@@ -2097,6 +2111,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.Header("Content-Type", contentType) c.Header("Content-Type", contentType)
fn := func(resp api.ChatResponse) error { fn := func(resp api.ChatResponse) error {
if resp.Done {
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
}
resp.Model = origModel resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost resp.RemoteHost = m.Config.RemoteHost
@@ -2317,6 +2335,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.DoneReason = r.DoneReason.String() res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
} }
if builtinParser != nil { if builtinParser != nil {

View File

@@ -30,6 +30,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -224,6 +225,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -35,6 +35,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -220,6 +221,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -88,19 +88,39 @@ func TestGenerateChatRemote(t *testing.T) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
t.Errorf("Expected POST request, got %s", r.Method) t.Errorf("Expected POST request, got %s", r.Method)
} }
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
resp := api.ChatResponse{
Model: "test", switch r.URL.Path {
Done: true, case "/api/chat":
DoneReason: "load", resp := api.ChatResponse{
} Model: "test",
if err := json.NewEncoder(w).Encode(&resp); err != nil { Done: true,
t.Fatal(err) DoneReason: "load",
Metrics: api.Metrics{
PromptEvalCount: 10,
EvalCount: 20,
},
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
case "/api/generate":
resp := api.GenerateResponse{
Model: "test",
Done: true,
DoneReason: "stop",
Metrics: api.Metrics{
PromptEvalCount: 5,
EvalCount: 15,
},
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
t.Fatal(err)
}
default:
t.Errorf("unexpected path %s", r.URL.Path)
} }
})) }))
defer rs.Close() defer rs.Close()
@@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) {
} }
t.Setenv("OLLAMA_REMOTES", p.Hostname()) t.Setenv("OLLAMA_REMOTES", p.Hostname())
s := Server{} s := Server{usage: NewUsageTracker()}
w := createRequest(t, s.CreateHandler, api.CreateRequest{ w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-cloud", Model: "test-cloud",
RemoteHost: rs.URL, RemoteHost: rs.URL,
@@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) {
t.Errorf("expected done reason load, got %s", actual.DoneReason) t.Errorf("expected done reason load, got %s", actual.DoneReason)
} }
}) })
t.Run("remote chat usage tracking", func(t *testing.T) {
stats := s.usage.Stats()
found := false
for _, m := range stats.Usage {
if m.Model == "test-cloud" {
found = true
if m.Requests != 1 {
t.Errorf("expected 1 request, got %d", m.Requests)
}
if m.PromptTokens != 10 {
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 20 {
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
}
}
}
if !found {
t.Error("expected usage entry for test-cloud")
}
})
t.Run("remote generate usage tracking", func(t *testing.T) {
// Reset the tracker for a clean test
s.usage = NewUsageTracker()
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-cloud",
Prompt: "hello",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
stats := s.usage.Stats()
found := false
for _, m := range stats.Usage {
if m.Model == "test-cloud" {
found = true
if m.Requests != 1 {
t.Errorf("expected 1 request, got %d", m.Requests)
}
if m.PromptTokens != 5 {
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 15 {
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
}
}
}
if !found {
t.Error("expected usage entry for test-cloud")
}
})
} }
func TestGenerateChat(t *testing.T) { func TestGenerateChat(t *testing.T) {
@@ -177,6 +252,7 @@ func TestGenerateChat(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -894,6 +970,7 @@ func TestGenerate(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -1378,6 +1455,7 @@ func TestGenerateLogprobs(t *testing.T) {
} }
s := &Server{ s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -1558,6 +1636,7 @@ func TestChatLogprobs(t *testing.T) {
} }
s := &Server{ s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -1668,6 +1747,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
} }
s := &Server{ s := &Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -2114,6 +2194,7 @@ func TestGenerateUnload(t *testing.T) {
var loadFnCalled bool var loadFnCalled bool
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -2215,6 +2296,7 @@ func TestGenerateWithImages(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -2371,30 +2453,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
return nil return nil
} }
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
"": {
llama: &mock,
Options: &opts,
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
// Create model manifest with image capability // Create model manifest with image capability
n := model.ParseName("test-image") n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}} cfg := model.ConfigV2{Capabilities: []string{"image"}}
@@ -2410,6 +2468,36 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: &mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
streamFalse := false streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image", Model: "test-image",

View File

@@ -255,6 +255,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -406,6 +407,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),
@@ -588,6 +590,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
} }
s := Server{ s := Server{
usage: NewUsageTracker(),
sched: &Scheduler{ sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1), pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1), finishedReqCh: make(chan *LlmRequest, 1),

View File

@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched return sched
} }
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner // context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) { func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 { if opts.NumCtx < 4 {
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
useImagegen: useImagegen, useImagegen: useImagegen,
} }
key := schedulerModelKey(req.model)
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[req.model.ModelPath] runner := s.loaded[key]
s.loadedMu.Unlock() s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) { if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh) req.useLoadedRunner(runner, s.finishedReqCh)
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for { for {
var runnerToExpire *runnerRef var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath] runner := s.loaded[pendingKey]
loadedCount := len(s.loaded) loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded)) runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded { for _, r := range s.loaded {
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner runnerToExpire = runner
} else { } else {
// Runner is usable, return it // Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath) logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh) pending.useLoadedRunner(runner, s.finishedReqCh)
break break
} }
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop") slog.Debug("shutting down scheduler completed loop")
return return
case finished := <-s.finishedReqCh: case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath] runner := s.loaded[finishedKey]
s.loadedMu.Unlock() s.loadedMu.Unlock()
if runner == nil { if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath) slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue continue
} }
runner.refMu.Lock() runner.refMu.Lock()
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
s.loadedMu.Lock() s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner) slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelPath] runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil { if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and // If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial // unloaded it. This double unload can happen if the initial
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
} }
finished := s.waitForVRAMRecovery(runner, runnersSnapshot) finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload() runner.unload()
delete(s.loaded, runner.modelPath) delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock() s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner) slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished <-finished
@@ -514,6 +539,7 @@ iGPUScan:
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama, llama: llama,
Options: &req.opts, Options: &req.opts,
sessionDuration: sessionDuration, sessionDuration: sessionDuration,
@@ -528,7 +554,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock() s.loadedMu.Lock()
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok { if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner // Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner) slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock() oldRunner.refMu.Lock()
@@ -536,7 +562,7 @@ iGPUScan:
oldRunner.refMu.Unlock() oldRunner.refMu.Unlock()
} }
s.activeLoading = nil s.activeLoading = nil
s.loaded[req.model.ModelPath] = runner s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded)) slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock() s.loadedMu.Unlock()
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
runner := &runnerRef{ runner := &runnerRef{
model: req.model, model: req.model,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: server, llama: server,
Options: &req.opts, Options: &req.opts,
loading: false, loading: false,
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
} }
s.loadedMu.Lock() s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock() s.loadedMu.Unlock()
// Set up expiration timer // Set up expiration timer
@@ -684,6 +711,7 @@ type runnerRef struct {
model *Model model *Model
modelPath string modelPath string
modelKey string
numParallel int numParallel int
*api.Options *api.Options
} }
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
} }
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool { func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath) slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock() runner.refMu.Lock()
defer runner.refMu.Unlock() defer runner.refMu.Unlock()
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil { if runner == nil {
return slog.StringValue("nil") return slog.StringValue("nil")
} }
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{} attrs := []slog.Attr{}
if runner.model != nil { if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name)) attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)), slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel), slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid), slog.Int("pid", runner.pid),
slog.String("model", runner.modelPath), slog.String("model", modelID),
) )
if runner.Options != nil { if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx)) attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 { if d1 != d2 {
return d1 < d2 return d1 < d2
} }
// Secondary sort by model path lex order // Secondary sort by model key/path lex order
return a[i].modelPath < a[j].modelPath n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
} }
// TODO - future consideration to pick runners based on size // TODO - future consideration to pick runners based on size
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
} }
func (s *Scheduler) expireRunner(model *Model) { func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock() s.loadedMu.Lock()
runner, ok := s.loaded[model.ModelPath] runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock() s.loadedMu.Unlock()
if ok { if ok {
runner.refMu.Lock() runner.refMu.Lock()

View File

@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone() b.ctxDone()
} }
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
require.Empty(t, successCh)
require.Empty(t, errCh)
require.Len(t, s.pendingReqCh, 1)
}
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
defer done()
s := InitScheduler(ctx)
opts := api.DefaultOptions()
opts.NumCtx = 4
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
loadedRunner := &runnerRef{
model: loadedModel,
modelKey: schedulerModelKey(loadedModel),
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
Options: &opts,
numParallel: 1,
}
s.loadedMu.Lock()
s.loaded[loadedRunner.modelKey] = loadedRunner
s.loadedMu.Unlock()
reqCtx, cancelReq := context.WithCancel(ctx)
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
cancelReq()
select {
case runner := <-successCh:
require.Equal(t, loadedRunner, runner)
default:
t.Fatal("expected existing runner to be reused")
}
require.Empty(t, errCh)
require.Empty(t, s.pendingReqCh)
}
func TestSchedExpireRunner(t *testing.T) { func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond) ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done() defer done()

62
server/usage.go Normal file
View File

@@ -0,0 +1,62 @@
package server
import (
"sync"
"time"
"github.com/ollama/ollama/api"
)
type ModelUsage struct {
Requests int64
PromptTokens int64
CompletionTokens int64
}
type UsageTracker struct {
mu sync.Mutex
start time.Time
models map[string]*ModelUsage
}
func NewUsageTracker() *UsageTracker {
return &UsageTracker{
start: time.Now().UTC(),
models: make(map[string]*ModelUsage),
}
}
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
u.mu.Lock()
defer u.mu.Unlock()
m, ok := u.models[model]
if !ok {
m = &ModelUsage{}
u.models[model] = m
}
m.Requests++
m.PromptTokens += int64(promptTokens)
m.CompletionTokens += int64(completionTokens)
}
func (u *UsageTracker) Stats() api.UsageResponse {
u.mu.Lock()
defer u.mu.Unlock()
byModel := make([]api.ModelUsageData, 0, len(u.models))
for model, usage := range u.models {
byModel = append(byModel, api.ModelUsageData{
Model: model,
Requests: usage.Requests,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
})
}
return api.UsageResponse{
Start: u.start,
Usage: byModel,
}
}

136
server/usage_test.go Normal file
View File

@@ -0,0 +1,136 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestUsageTrackerRecord(t *testing.T) {
tracker := NewUsageTracker()
tracker.Record("model-a", 10, 20)
tracker.Record("model-a", 5, 15)
tracker.Record("model-b", 100, 200)
stats := tracker.Stats()
if len(stats.Usage) != 2 {
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
}
lookup := make(map[string]api.ModelUsageData)
for _, m := range stats.Usage {
lookup[m.Model] = m
}
a := lookup["model-a"]
if a.Requests != 2 {
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
}
if a.PromptTokens != 15 {
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
}
if a.CompletionTokens != 35 {
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
}
b := lookup["model-b"]
if b.Requests != 1 {
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
}
if b.PromptTokens != 100 {
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
}
if b.CompletionTokens != 200 {
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
}
}
func TestUsageTrackerConcurrent(t *testing.T) {
tracker := NewUsageTracker()
var wg sync.WaitGroup
for range 100 {
wg.Add(1)
go func() {
defer wg.Done()
tracker.Record("model-a", 1, 2)
}()
}
wg.Wait()
stats := tracker.Stats()
if len(stats.Usage) != 1 {
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
}
m := stats.Usage[0]
if m.Requests != 100 {
t.Errorf("requests: expected 100, got %d", m.Requests)
}
if m.PromptTokens != 100 {
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
}
if m.CompletionTokens != 200 {
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
}
}
func TestUsageTrackerStart(t *testing.T) {
tracker := NewUsageTracker()
stats := tracker.Stats()
if stats.Start.IsZero() {
t.Error("expected non-zero start time")
}
}
func TestUsageHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
s := &Server{
usage: NewUsageTracker(),
}
s.usage.Record("llama3", 50, 100)
s.usage.Record("llama3", 25, 50)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
s.UsageHandler(c)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.UsageResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if len(resp.Usage) != 1 {
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
}
m := resp.Usage[0]
if m.Model != "llama3" {
t.Errorf("expected model llama3, got %s", m.Model)
}
if m.Requests != 2 {
t.Errorf("expected 2 requests, got %d", m.Requests)
}
if m.PromptTokens != 75 {
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
}
if m.CompletionTokens != 150 {
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
}
}

View File

@@ -30,6 +30,8 @@ type ModelfileConfig struct {
Template string Template string
System string System string
License string License string
Parser string
Renderer string
} }
// CreateOptions holds all options for model creation. // CreateOptions holds all options for model creation.
@@ -37,7 +39,7 @@ type CreateOptions struct {
ModelName string ModelName string
ModelDir string ModelDir string
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
Modelfile *ModelfileConfig // template/system/license from Modelfile Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
} }
// CreateModel imports a model from a local directory. // CreateModel imports a model from a local directory.
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
ModelFormat: "safetensors", ModelFormat: "safetensors",
Capabilities: caps, Capabilities: caps,
Requires: MinOllamaVersion, Requires: MinOllamaVersion,
Parser: parserName, Parser: resolveParserName(opts.Modelfile, parserName),
Renderer: rendererName, Renderer: resolveRendererName(opts.Modelfile, rendererName),
} }
configJSON, err := json.Marshal(configData) configJSON, err := json.Marshal(configData)
if err != nil { if err != nil {
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
} }
} }
func resolveParserName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Parser != "" {
return mf.Parser
}
return inferred
}
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
if mf != nil && mf.Renderer != "" {
return mf.Renderer
}
return inferred
}
// createModelfileLayers creates layers for template, system, and license from Modelfile config. // createModelfileLayers creates layers for template, system, and license from Modelfile config.
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) { func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
var layers []manifest.Layer var layers []manifest.Layer
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
return "deepseek3" return "deepseek3"
} }
if strings.Contains(archLower, "qwen3") { if strings.Contains(archLower, "qwen3") {
return "qwen3-coder" return "qwen3"
} }
} }
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
return "deepseek3" return "deepseek3"
} }
if strings.Contains(typeLower, "qwen3") { if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder" return "qwen3"
} }
} }

View File

@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
Template: "{{ .Prompt }}", Template: "{{ .Prompt }}",
System: "You are a helpful assistant.", System: "You are a helpful assistant.",
License: "MIT", License: "MIT",
Parser: "qwen3",
Renderer: "qwen3",
} }
if config.Template != "{{ .Prompt }}" { if config.Template != "{{ .Prompt }}" {
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
if config.License != "MIT" { if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT") t.Errorf("License = %q, want %q", config.License, "MIT")
} }
if config.Parser != "qwen3" {
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
}
if config.Renderer != "qwen3" {
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
}
} }
func TestModelfileConfig_Empty(t *testing.T) { func TestModelfileConfig_Empty(t *testing.T) {
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
if config.License != "" { if config.License != "" {
t.Errorf("License should be empty, got %q", config.License) t.Errorf("License should be empty, got %q", config.License)
} }
if config.Parser != "" {
t.Errorf("Parser should be empty, got %q", config.Parser)
}
if config.Renderer != "" {
t.Errorf("Renderer should be empty, got %q", config.Renderer)
}
} }
func TestModelfileConfig_PartialFields(t *testing.T) { func TestModelfileConfig_PartialFields(t *testing.T) {
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
if config.License != "" { if config.License != "" {
t.Error("License should be empty") t.Error("License should be empty")
} }
if config.Parser != "" {
t.Error("Parser should be empty")
}
if config.Renderer != "" {
t.Error("Renderer should be empty")
}
} }
func TestMinOllamaVersion(t *testing.T) { func TestMinOllamaVersion(t *testing.T) {
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
Template: "test", Template: "test",
System: "system", System: "system",
License: "MIT", License: "MIT",
Parser: "qwen3-thinking",
Renderer: "qwen3",
}, },
} }
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Template != "test" { if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test") t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
} }
if opts.Modelfile.Parser != "qwen3-thinking" {
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
}
if opts.Modelfile.Renderer != "qwen3" {
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
}
}
func TestResolveParserName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3",
want: "qwen3",
},
{
name: "empty parser uses inferred",
mf: &ModelfileConfig{
Parser: "",
},
inferred: "qwen3",
want: "qwen3",
},
{
name: "explicit parser overrides inferred",
mf: &ModelfileConfig{
Parser: "qwen3-thinking",
},
inferred: "qwen3",
want: "qwen3-thinking",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
}
})
}
}
func TestResolveRendererName(t *testing.T) {
tests := []struct {
name string
mf *ModelfileConfig
inferred string
want string
}{
{
name: "nil modelfile uses inferred",
mf: nil,
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "empty renderer uses inferred",
mf: &ModelfileConfig{
Renderer: "",
},
inferred: "qwen3-coder",
want: "qwen3-coder",
},
{
name: "explicit renderer overrides inferred",
mf: &ModelfileConfig{
Renderer: "qwen3",
},
inferred: "qwen3-coder",
want: "qwen3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
}
})
}
} }
func TestCreateOptions_Defaults(t *testing.T) { func TestCreateOptions_Defaults(t *testing.T) {

View File

@@ -3,5 +3,8 @@
package mlxrunner package mlxrunner
import ( import (
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite" _ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
) )

View File

@@ -0,0 +1,92 @@
//go:build mlx
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
type LinearFactory struct {
tensors map[string]*mlx.Array
defaultGroupSize int
defaultBits int
defaultMode string
tensorQuant map[string]*TensorQuantInfo
}
// NewLinearFactory creates a reusable constructor for model linear layers.
func NewLinearFactory(
tensors map[string]*mlx.Array,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) LinearFactory {
return LinearFactory{
tensors: tensors,
defaultGroupSize: defaultGroupSize,
defaultBits: defaultBits,
defaultMode: defaultMode,
tensorQuant: tensorQuant,
}
}
// Make constructs a linear layer at path.
func (f LinearFactory) Make(path string) nn.LinearLayer {
return MakeLinearLayer(
f.tensors,
path,
f.defaultGroupSize,
f.defaultBits,
f.defaultMode,
f.tensorQuant,
)
}
// MakeLinearLayer constructs a linear layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
// quant params via TensorQuant metadata (with shape-based affine fallback).
// For non-quantized tensors, it returns a standard nn.Linear.
func MakeLinearLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}

130
x/mlxrunner/model/quant.go Normal file
View File

@@ -0,0 +1,130 @@
//go:build mlx
package model
import (
"strings"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
// when available, otherwise falling back to the provided model defaults.
func TensorQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
) (groupSize, bits int, mode string, fromTensor bool) {
if tensorQuant != nil {
if tq := tensorQuant[tensorName]; tq != nil {
groupSize, bits, mode = QuantizationParams(tq.QuantType)
if tq.GroupSize > 0 {
groupSize = tq.GroupSize
}
return groupSize, bits, mode, true
}
}
return defaultGroupSize, defaultBits, defaultMode, false
}
// ResolveLinearQuantParams resolves quantization params for a quantized linear
// tensor, preferring per-tensor metadata and falling back to shape-based
// inference for affine packed tensors.
func ResolveLinearQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
weight, scales *mlx.Array,
) (groupSize, bits int, mode string) {
groupSize, bits, mode, fromTensor := TensorQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
tensorName,
)
if mode == "affine" {
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
if !fromTensor || groupSize == 0 || bits == 0 {
groupSize = inferredGroupSize
bits = inferredBits
}
}
}
return groupSize, bits, mode
}
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
// tensors from packed weight and scale shapes.
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
if weight == nil || scales == nil {
return 0, 0, false
}
weightShape := weight.Dims()
scaleShape := scales.Dims()
if len(weightShape) == 0 || len(scaleShape) == 0 {
return 0, 0, false
}
weightCols := weightShape[len(weightShape)-1]
scalesCols := scaleShape[len(scaleShape)-1]
if weightCols <= 0 || scalesCols <= 0 {
return 0, 0, false
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return 32, 4, true
case groupSize8 == 64:
return 64, 8, true
case groupSize4 == 64 && groupSize8 == 32:
if hintBits == 8 {
return 32, 8, true
}
if hintBits == 4 {
return 64, 4, true
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return groupSize4, 4, true
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return groupSize8, 8, true
}
return 0, 0, false
}
func isCommonGroupSize(v int) bool {
switch v {
case 16, 32, 64, 128:
return true
default:
return false
}
}

View File

@@ -8,42 +8,63 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"sort"
"strconv"
"strings" "strings"
"github.com/ollama/ollama/x/imagegen/manifest" "github.com/ollama/ollama/x/imagegen/manifest"
) )
// Root wraps a ModelManifest with pre-scanned quantization metadata. // TensorQuantInfo describes per-tensor quantization metadata.
type Root struct { type TensorQuantInfo struct {
Manifest *manifest.ModelManifest QuantType string
quantType string GroupSize int
groupSize int
} }
// Open loads a manifest for the given model name and pre-scans the first // Root wraps a ModelManifest with pre-scanned quantization metadata.
// tensor blob for quantization metadata (quant_type, group_size). type Root struct {
Manifest *manifest.ModelManifest
// Backwards-compatible model-level quant metadata (first tensor blob).
quantType string
groupSize int
// Per-tensor quantization metadata.
tensorQuant map[string]*TensorQuantInfo
}
// Open loads a manifest for the given model name and scans tensor blobs for
// quantization metadata.
func Open(modelName string) (*Root, error) { func Open(modelName string) (*Root, error) {
m, err := manifest.LoadManifest(modelName) m, err := manifest.LoadManifest(modelName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
root := &Root{Manifest: m} root := &Root{
Manifest: m,
tensorQuant: make(map[string]*TensorQuantInfo),
}
// Pre-scan first tensor blob for quantization metadata
for _, layer := range m.GetTensorLayers("") { for _, layer := range m.GetTensorLayers("") {
blobPath := m.BlobPath(layer.Digest) blobPath := m.BlobPath(layer.Digest)
meta, err := readBlobMetadata(blobPath)
if err != nil || meta == nil { infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
if err != nil {
continue continue
} }
if qt := meta["quant_type"]; qt != "" {
root.quantType = strings.ToUpper(qt) for name, info := range infos {
root.tensorQuant[name] = info
} }
if gs := meta["group_size"]; gs != "" {
fmt.Sscanf(gs, "%d", &root.groupSize) if root.quantType == "" && blobQuantType != "" {
root.quantType = strings.ToUpper(blobQuantType)
root.groupSize = blobGroupSize
if root.groupSize == 0 {
root.groupSize = defaultGroupSize(root.quantType)
}
} }
break // only check the first tensor blob
} }
return root, nil return root, nil
@@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) {
// Close is a no-op for now (future: release resources). // Close is a no-op for now (future: release resources).
func (r *Root) Close() {} func (r *Root) Close() {}
// QuantType returns the quantization type detected from tensor metadata. // QuantType returns the quantization type detected from the first tensor blob metadata.
func (r *Root) QuantType() string { return r.quantType } func (r *Root) QuantType() string { return r.quantType }
// GroupSize returns the quantization group size detected from tensor metadata. // GroupSize returns the quantization group size detected from the first tensor blob metadata.
func (r *Root) GroupSize() int { return r.groupSize } func (r *Root) GroupSize() int { return r.groupSize }
// readBlobMetadata reads the __metadata__ from a safetensors blob header. // TensorQuant returns per-tensor quantization metadata if available.
func readBlobMetadata(path string) (map[string]string, error) { func (r *Root) TensorQuant(name string) *TensorQuantInfo {
if r == nil {
return nil
}
return r.tensorQuant[name]
}
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
for k, v := range r.tensorQuant {
if v == nil {
continue
}
copy := *v
out[k] = &copy
}
return out
}
func defaultGroupSize(quantType string) int {
groupSize, _, _ := QuantizationParams(quantType)
return groupSize
}
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
return nil, err return nil, "", 0, err
} }
defer f.Close() defer f.Close()
var headerSize uint64 var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil { if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, err return nil, "", 0, err
} }
if headerSize > 1024*1024 { if headerSize > 100*1024*1024 {
return nil, fmt.Errorf("header too large: %d", headerSize) return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
} }
data := make([]byte, headerSize) data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil { if _, err := io.ReadFull(f, data); err != nil {
return nil, err return nil, "", 0, err
} }
var header map[string]json.RawMessage var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil { if err := json.Unmarshal(data, &header); err != nil {
return nil, err return nil, "", 0, err
} }
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
globalQuantType = strings.ToUpper(globalQuantType)
mainNames := mainTensorNames(header)
infos := make(map[string]*TensorQuantInfo)
for _, name := range mainNames {
if _, ok := header[name+".scale"]; !ok {
continue
}
quantType := globalQuantType
groupSize := globalGroupSize
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
if quantType == "" {
quantType = inferredType
}
if groupSize == 0 {
groupSize = inferredGroup
}
if quantType == "" {
continue
}
if groupSize == 0 {
groupSize = defaultGroupSize(quantType)
}
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
}
return infos, globalQuantType, globalGroupSize, nil
}
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
metaRaw, ok := header["__metadata__"] metaRaw, ok := header["__metadata__"]
if !ok { if !ok {
return nil, nil return "", 0
} }
var meta map[string]string var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil { if err := json.Unmarshal(metaRaw, &meta); err != nil {
return nil, err return "", 0
} }
return meta, nil
quantType = meta["quant_type"]
if gs := meta["group_size"]; gs != "" {
groupSize, _ = strconv.Atoi(gs)
}
return quantType, groupSize
}
func mainTensorNames(header map[string]json.RawMessage) []string {
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
type tensorShape struct {
Shape []int64 `json:"shape"`
}
mainRaw, ok := header[tensorName]
if !ok {
return "", 0
}
scaleRaw, ok := header[tensorName+".scale"]
if !ok {
return "", 0
}
var mainInfo tensorShape
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
return "", 0
}
var scaleInfo tensorShape
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
return "", 0
}
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
if weightCols <= 0 || scalesCols <= 0 {
return "", 0
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return "INT4", 32
case groupSize8 == 64:
return "INT8", 64
case groupSize4 == 64 && groupSize8 == 32:
h := strings.ToUpper(hintQuantType)
if strings.Contains(h, "8") {
return "INT8", 32
}
if strings.Contains(h, "4") {
return "INT4", 64
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return "INT4", groupSize4
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return "INT8", groupSize8
}
return "", 0
} }

View File

@@ -18,15 +18,27 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded") return errors.New("model not loaded")
} }
mlx.EnableCompile() enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
inputs := r.Tokenizer.Encode(request.Prompt, true) inputs := r.Tokenizer.Encode(request.Prompt, true)
caches, tokens := r.FindNearestCache(inputs) caches, tokens := r.FindNearestCache(inputs)
if len(caches) == 0 { if len(caches) == 0 {
caches = make([]cache.Cache, r.Model.NumLayers()) if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
for i := range caches { caches = cacheFactory.NewCaches()
caches[i] = cache.NewKVCache() } else {
caches = make([]cache.Cache, r.Model.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
} }
} }

521
x/models/gemma3/gemma3.go Normal file
View File

@@ -0,0 +1,521 @@
//go:build mlx
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
package gemma3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
func init() {
base.Register("Gemma3ForCausalLM", newModel)
base.Register("Gemma3ForConditionalGeneration", newModel)
}
// TextConfig holds configuration for the Gemma 3 text model.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
}
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
QNormScaled *mlx.Array
KNormScaled *mlx.Array
}
// MLP is the feed-forward network with GELU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *nn.RMSNorm
Attention *Attention
PostAttnNorm *nn.RMSNorm
PreFFNorm *nn.RMSNorm
MLP *MLP
PostFFNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
// Layer metadata.
IsSliding bool
LayerIdx int32
}
// Model is the Gemma 3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*DecoderLayer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
// Precomputed (1 + weight) for Gemma-style RMSNorm.
NormScaled *mlx.Array
tok *tokenizer.Tokenizer
*TextConfig
weightPrefix string
}
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
switch numLayers {
case 34:
return 8, 4
case 48:
return 16, 8
case 62:
return 32, 16
default:
return 8, 4
}
}
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
var cfg TextConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
}
var wrapped struct {
TextConfig *TextConfig `json:"text_config"`
}
if err := json.Unmarshal(configData, &wrapped); err != nil {
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
}
fromConditional := wrapped.TextConfig != nil
if fromConditional {
cfg = *wrapped.TextConfig
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
cfg.SlidingWindowPattern = 6
}
if cfg.MaxPositionEmbeddings == 0 {
cfg.MaxPositionEmbeddings = 131072
}
}
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
return cfg, fromConditional, nil
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
return cfg.LayerTypes[layerIdx] == "sliding_attention"
}
if cfg.SlidingWindowPattern <= 0 {
return false
}
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
}
func precomputeGemmaScaledWeights(m *Model) {
if m.Norm != nil {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
}
var scaled []*mlx.Array
if m.NormScaled != nil {
scaled = append(scaled, m.NormScaled)
}
for _, layer := range m.Layers {
if layer == nil || layer.Attention == nil {
continue
}
if layer.InputNorm != nil {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
scaled = append(scaled, layer.InputNormScaled)
}
if layer.PostAttnNorm != nil {
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
scaled = append(scaled, layer.PostAttnNormScaled)
}
if layer.PreFFNorm != nil {
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PreFFNormScaled)
}
if layer.PostFFNorm != nil {
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PostFFNormScaled)
}
if layer.Attention.QNorm != nil {
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.QNormScaled)
}
if layer.Attention.KNorm != nil {
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.KNormScaled)
}
}
if len(scaled) > 0 {
mlx.Eval(scaled...)
}
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
cfg, _, err := parseTextConfig(configData)
if err != nil {
return nil, err
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), m.TextConfig),
}
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Gemma usually ties output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &DecoderLayer{
LayerIdx: i,
IsSliding: isLayerSliding(i, m.TextConfig),
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.InputNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.PostAttnNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.PreFFNorm == nil {
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
}
if layer.PostFFNorm == nil {
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
precomputeGemmaScaledWeights(m)
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.TextConfig)
}
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// NewCaches creates cache objects for all layers.
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if m.SlidingWindow > 0 && layer.IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
// FormatPrompt applies the Gemma 3 chat template.
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.Forward(normed)
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}

View File

@@ -8,7 +8,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
"strings"
"github.com/ollama/ollama/x/imagegen/tokenizer" "github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/cache"
@@ -64,9 +63,10 @@ type Config struct {
RopeScaling *RopeScaling `json:"rope_scaling"` RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization) // Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64) QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8) QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.) QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields // Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
@@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8) return mode == "affine" && (bits == 4 || bits == 8)
} }
// quantizationParams returns groupSize, bits, mode for a quantization type string.
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "FP4", "Q4", "INT4":
return 32, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8", "":
return 64, 8, "affine"
default:
return 32, 8, "affine"
}
}
// ExpertWeight holds a single expert's weight with optional quantization components. // ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct { type ExpertWeight struct {
Weight *mlx.Array Weight *mlx.Array
@@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b
if scales != nil { if scales != nil {
qbiases := tensors[path+".weight_qbias"] qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
if useQuantized && supportsGatherQMM(mode, bits) { if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize} return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
@@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
// Check if quantized and dequantize // Check if quantized and dequantize
if scales := tensors[path+".weight_scale"]; scales != nil { if scales := tensors[path+".weight_scale"]; scales != nil {
qbiases := tensors[path+".weight_qbias"] qbiases := tensors[path+".weight_qbias"]
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode) groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
} }
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
@@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
return embedQ, unembedOut return embedQ, unembedOut
} }
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: cfg.QuantGroupSize,
Bits: cfg.QuantBits,
Mode: cfg.QuantMode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer, // newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
// no weights loaded yet). Called by the registry via base.New(). // no weights loaded yet). Called by the registry via base.New().
func newModel(root *model.Root) (base.Model, error) { func newModel(root *model.Root) (base.Model, error) {
@@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) {
// Set up quantization parameters from pre-scanned metadata // Set up quantization parameters from pre-scanned metadata
if qt := root.QuantType(); qt != "" { if qt := root.QuantType(); qt != "" {
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt) cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 { if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs cfg.QuantGroupSize = gs
} else {
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
} }
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
} }
cfg.TensorQuant = root.AllTensorQuant()
// Load tokenizer // Load tokenizer
tokData, err := root.Manifest.ReadConfig("tokenizer.json") tokData, err := root.Manifest.ReadConfig("tokenizer.json")
@@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) {
// layer creation. // layer creation.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
cfg := m.Config cfg := m.Config
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits) useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
if !useQuantized && cfg.TensorQuant != nil {
for _, tq := range cfg.TensorQuant {
if tq == nil {
continue
}
_, bits, mode := model.QuantizationParams(tq.QuantType)
if supportsGatherQMM(mode, bits) {
useQuantized = true
break
}
}
}
// Load embedding // Load embedding
if w := tensors["model.embed_tokens.weight"]; w != nil { if w := tensors["model.embed_tokens.weight"]; w != nil {
@@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
} }
// Load LM head // Load LM head
m.LMHead = makeLinear(tensors, "lm_head", cfg) m.LMHead = linears.Make("lm_head")
// Load layers // Load layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ { for i := int32(0); i < cfg.NumHiddenLayers; i++ {
@@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
// Load attention (same for both block types) // Load attention (same for both block types)
attn := &MLAAttention{} attn := &MLAAttention{}
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg) attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil { if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
} }
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg) attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg) attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil { if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps) attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
} }
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg) attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
// Sanitize MLA weights for absorbed attention // Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg) embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
@@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
} }
block.MLP = &DenseMLP{ block.MLP = &DenseMLP{
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg), GateProj: linears.Make(prefix + ".mlp.gate_proj"),
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg), UpProj: linears.Make(prefix + ".mlp.up_proj"),
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg), DownProj: linears.Make(prefix + ".mlp.down_proj"),
} }
m.Layers[i] = block m.Layers[i] = block
@@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
} }
moeGate := &MoEGate{} moeGate := &MoEGate{}
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg) moeGate.Gate = linears.Make(prefix + ".mlp.gate")
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil { if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.EScoreCorrectionBias = bias moeGate.EScoreCorrectionBias = bias
} }
@@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
// Load shared experts if present // Load shared experts if present
if cfg.NSharedExperts > 0 { if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{ block.MoE.SharedExperts = &SharedExperts{
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg), GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg), UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg), DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
} }
} }

323
x/models/llama/llama.go Normal file
View File

@@ -0,0 +1,323 @@
//go:build mlx
// Package llama provides a Llama-style decoder-only transformer for MLX.
package llama
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
func init() {
base.Register("LlamaForCausalLM", newModel)
}
// Config holds Llama model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
// Model is a Llama text model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
}
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-5
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Fallback used by many Llama checkpoints where output is tied.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}

338
x/models/qwen3/qwen3.go Normal file
View File

@@ -0,0 +1,338 @@
//go:build mlx
// Package qwen3 provides the Qwen3 text model implementation for MLX.
package qwen3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
func init() {
base.Register("Qwen3ForCausalLM", newModel)
}
// Config holds Qwen3 model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
QKNormEps float32 `json:"-"`
}
// Model is the Qwen3 text-only model.
type Model struct {
EmbedTokens *nn.Embedding
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
// Layer is a single Qwen3 decoder block.
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
// Attention implements Qwen3 attention with Q/K norms.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
}
// MLP is the feed-forward network with SwiGLU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HeadDim == 0 {
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
cfg.QKNormEps = 1e-6
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
if embedWeight == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = nn.NewEmbedding(embedWeight)
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = nn.NewLinear(embedWeight, nil)
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Qwen3 checkpoints commonly tie output projection to embeddings.
m.LMHead = nn.NewLinear(embedWeight, nil)
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
collected := mlx.Collect(m)
mlx.Eval(collected...)
return nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := tokens.Dims()
B, L := int32(dims[0]), int32(dims[1])
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, c, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
q = a.QNorm.Forward(q, cfg.QKNormEps)
k = a.KNorm.Forward(k, cfg.QKNormEps)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
offset := 0
if c != nil {
offset = c.Offset()
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
if c != nil {
k, v = c.Update(k, v)
}
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"os" "os"
"sort" "sort"
"strings" "strings"
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
} }
} }
return buildModelInfo(config, totalBytes, tensorCount), nil info := buildModelInfo(config, totalBytes, tensorCount)
// For quantized models, byte-based estimation can significantly undercount
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
info["general.parameter_count"] = paramCount
}
return info, nil
} }
// buildModelInfo constructs the model info map from config and tensor stats. // buildModelInfo constructs the model info map from config and tensor stats.
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
return info return info
} }
// getParameterCountFromManifest counts model parameters from tensor shapes.
// This accounts for quantized tensors by using unpacked shapes from
// getTensorInfoFromManifest.
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
tensors, err := getTensorInfoFromManifest(mf)
if err != nil {
return 0, err
}
var total int64
for _, tensor := range tensors {
if len(tensor.Shape) == 0 {
continue
}
elements := int64(1)
for _, dim := range tensor.Shape {
if dim == 0 {
elements = 0
break
}
if dim > uint64(math.MaxInt64) {
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
}
d := int64(dim)
if elements > math.MaxInt64/d {
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
}
elements *= d
}
if elements == 0 {
continue
}
if total > math.MaxInt64-elements {
return 0, fmt.Errorf("total parameter count overflow")
}
total += elements
}
return total, nil
}
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers. // GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata. // Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) { func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {

View File

@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
} }
} }
func TestGetParameterCountFromManifest(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Unquantized tensor: [4,5] = 20 params
header1 := map[string]any{
"model.embed_tokens.weight": map[string]any{
"dtype": "BF16",
"shape": []int64{4, 5},
"data_offsets": []int64{0, 40},
},
}
header1JSON, _ := json.Marshal(header1)
var buf1 bytes.Buffer
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
buf1.Write(header1JSON)
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
blobPath1, err := manifest.BlobsPath(digest1)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob1: %v", err)
}
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
header2 := map[string]any{
"__metadata__": map[string]string{
"quant_type": "int4",
"group_size": "32",
},
"model.layers.0.mlp.up_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{10, 2},
"data_offsets": []int64{0, 80},
},
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{10, 1},
"data_offsets": []int64{80, 100},
},
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{10, 1},
"data_offsets": []int64{100, 120},
},
}
header2JSON, _ := json.Marshal(header2)
var buf2 bytes.Buffer
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
buf2.Write(header2JSON)
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
blobPath2, err := manifest.BlobsPath(digest2)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob2: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest1,
Size: int64(buf1.Len() + 40),
Name: "model.embed_tokens.weight",
},
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest2,
Size: int64(buf2.Len() + 120),
Name: "model.layers.0.mlp.up_proj.weight",
},
},
}
paramCount, err := getParameterCountFromManifest(mf)
if err != nil {
t.Fatalf("getParameterCountFromManifest() error = %v", err)
}
const want int64 = 180 // 20 + 160
if paramCount != want {
t.Errorf("parameter_count = %d, want %d", paramCount, want)
}
}
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
// Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", tempDir)
blobDir := filepath.Join(tempDir, "blobs")
if err := os.MkdirAll(blobDir, 0o755); err != nil {
t.Fatalf("failed to create blobs dir: %v", err)
}
// Packed mixed-precision blob (no global metadata):
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
header := map[string]any{
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{5, 8},
"data_offsets": []int64{0, 160},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 2},
"data_offsets": []int64{160, 180},
},
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 2},
"data_offsets": []int64{180, 200},
},
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
"dtype": "U32",
"shape": []int64{5, 16},
"data_offsets": []int64{200, 520},
},
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 1},
"data_offsets": []int64{520, 530},
},
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
"dtype": "BF16",
"shape": []int64{5, 1},
"data_offsets": []int64{530, 540},
},
}
headerJSON, _ := json.Marshal(header)
var buf bytes.Buffer
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatalf("failed to get blob path: %v", err)
}
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
mf := &manifest.Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Layers: []manifest.Layer{
{
MediaType: manifest.MediaTypeImageTensor,
Digest: digest,
Size: int64(buf.Len() + 540),
Name: "model.layers.0.mlp.experts",
},
},
}
paramCount, err := getParameterCountFromManifest(mf)
if err != nil {
t.Fatalf("getParameterCountFromManifest() error = %v", err)
}
const want int64 = 640 // 320 + 320
if paramCount != want {
t.Errorf("parameter_count = %d, want %d", paramCount, want)
}
}
func TestParseSafetensorsAllHeaders(t *testing.T) { func TestParseSafetensorsAllHeaders(t *testing.T) {
tests := []struct { tests := []struct {
name string name string