mirror of
https://github.com/ollama/ollama.git
synced 2026-04-21 16:25:42 +02:00
Compare commits
13 Commits
v0.16.2-rc
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ac1300805 | ||
|
|
43d9907dd6 | ||
|
|
91dc088e8b | ||
|
|
9aefd2dfee | ||
|
|
d07e4a1dd3 | ||
|
|
8a257ec00a | ||
|
|
2f4de1acf7 | ||
|
|
ec95c45f70 | ||
|
|
3a88f7eb20 | ||
|
|
0d5da826d4 | ||
|
|
9b795698b8 | ||
|
|
041fb77639 | ||
|
|
8224cce583 |
@@ -16,7 +16,7 @@ Start building with open models.
|
|||||||
curl -fsSL https://ollama.com/install.sh | sh
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
or [download manually](http://localhost:8080/download/Ollama.dmg)
|
or [download manually](https://ollama.com/download/Ollama.dmg)
|
||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
|
|||||||
13
api/types.go
13
api/types.go
@@ -922,6 +922,19 @@ type UserResponse struct {
|
|||||||
Plan string `json:"plan,omitempty"`
|
Plan string `json:"plan,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UsageResponse struct {
|
||||||
|
// Start is the time the server started tracking usage (UTC, RFC 3339).
|
||||||
|
Start time.Time `json:"start"`
|
||||||
|
Usage []ModelUsageData `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelUsageData struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Requests int64 `json:"requests"`
|
||||||
|
PromptTokens int64 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int64 `json:"completion_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
// Tensor describes the metadata for a given tensor.
|
// Tensor describes the metadata for a given tensor.
|
||||||
type Tensor struct {
|
type Tensor struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
12
cmd/cmd.go
12
cmd/cmd.go
@@ -57,9 +57,9 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
// Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
|
||||||
config.DefaultSingleSelector = func(title string, items []config.ModelItem) (string, error) {
|
config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectSingle(title, tuiItems)
|
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", config.ErrCancelled
|
||||||
}
|
}
|
||||||
@@ -182,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
mfConfig.System = cmd.Args
|
mfConfig.System = cmd.Args
|
||||||
case "license":
|
case "license":
|
||||||
mfConfig.License = cmd.Args
|
mfConfig.License = cmd.Args
|
||||||
|
case "parser":
|
||||||
|
mfConfig.Parser = cmd.Args
|
||||||
|
case "renderer":
|
||||||
|
mfConfig.Renderer = cmd.Args
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1897,9 +1901,9 @@ func runInteractiveTUI(cmd *cobra.Command) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Selector adapters for tui
|
// Selector adapters for tui
|
||||||
singleSelector := func(title string, items []config.ModelItem) (string, error) {
|
singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
|
||||||
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
tuiItems := tui.ReorderItems(tui.ConvertItems(items))
|
||||||
result, err := tui.SelectSingle(title, tuiItems)
|
result, err := tui.SelectSingle(title, tuiItems, current)
|
||||||
if errors.Is(err, tui.ErrCancelled) {
|
if errors.Is(err, tui.ErrCancelled) {
|
||||||
return "", config.ErrCancelled
|
return "", config.ErrCancelled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAli
|
|||||||
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
|
||||||
|
|
||||||
if aliases["primary"] == "" || force {
|
if aliases["primary"] == "" || force {
|
||||||
primary, err := DefaultSingleSelector("Select model:", items)
|
primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|||||||
123
cmd/config/cline.go
Normal file
123
cmd/config/cline.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cline implements Runner and Editor for the Cline CLI integration
|
||||||
|
type Cline struct{}
|
||||||
|
|
||||||
|
func (c *Cline) String() string { return "Cline" }
|
||||||
|
|
||||||
|
func (c *Cline) Run(model string, args []string) error {
|
||||||
|
if _, err := exec.LookPath("cline"); err != nil {
|
||||||
|
return fmt.Errorf("cline is not installed, install with: npm install -g cline")
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []string{model}
|
||||||
|
if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
|
||||||
|
models = config.Models
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
models, err = resolveEditorModels("cline", models, func() ([]string, error) {
|
||||||
|
return selectModels(context.Background(), "cline", "")
|
||||||
|
})
|
||||||
|
if errors.Is(err, errCancelled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.Edit(models); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("cline", args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cline) Paths() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return []string{p}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cline) Edit(models []string) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(home, ".cline", "data", "globalState.json")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config := make(map[string]any)
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
if err := json.Unmarshal(data, &config); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Ollama as the provider for both act and plan modes
|
||||||
|
baseURL := envconfig.Host().String()
|
||||||
|
config["ollamaBaseUrl"] = baseURL
|
||||||
|
config["actModeApiProvider"] = "ollama"
|
||||||
|
config["actModeOllamaModelId"] = models[0]
|
||||||
|
config["actModeOllamaBaseUrl"] = baseURL
|
||||||
|
config["planModeApiProvider"] = "ollama"
|
||||||
|
config["planModeOllamaModelId"] = models[0]
|
||||||
|
config["planModeOllamaBaseUrl"] = baseURL
|
||||||
|
|
||||||
|
config["welcomeViewCompleted"] = true
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(config, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeWithBackup(configPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cline) Models() []string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config["actModeApiProvider"] != "ollama" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelID, _ := config["actModeOllamaModelId"].(string)
|
||||||
|
if modelID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{modelID}
|
||||||
|
}
|
||||||
204
cmd/config/cline_test.go
Normal file
204
cmd/config/cline_test.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClineIntegration(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Cline" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Cline")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Editor", func(t *testing.T) {
|
||||||
|
var _ Editor = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClineEdit(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||||
|
configPath := filepath.Join(configDir, "globalState.json")
|
||||||
|
|
||||||
|
readConfig := func() map[string]any {
|
||||||
|
data, _ := os.ReadFile(configPath)
|
||||||
|
var config map[string]any
|
||||||
|
json.Unmarshal(data, &config)
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("creates config from scratch", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["actModeApiProvider"] != "ollama" {
|
||||||
|
t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
|
||||||
|
}
|
||||||
|
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
if config["planModeApiProvider"] != "ollama" {
|
||||||
|
t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
|
||||||
|
}
|
||||||
|
if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
if config["welcomeViewCompleted"] != true {
|
||||||
|
t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing fields", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
|
||||||
|
existing := map[string]any{
|
||||||
|
"remoteRulesToggles": map[string]any{},
|
||||||
|
"remoteWorkflowToggles": map[string]any{},
|
||||||
|
"customSetting": "keep-me",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(existing)
|
||||||
|
os.WriteFile(configPath, data, 0o644)
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["customSetting"] != "keep-me" {
|
||||||
|
t.Errorf("customSetting was not preserved")
|
||||||
|
}
|
||||||
|
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updates model on re-edit", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["actModeOllamaModelId"] != "glm-5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
if config["planModeOllamaModelId"] != "glm-5:cloud" {
|
||||||
|
t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty models is no-op", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit(nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(configPath); !os.IsNotExist(err) {
|
||||||
|
t.Error("expected no config file to be created for empty models")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses first model as primary", func(t *testing.T) {
|
||||||
|
os.RemoveAll(filepath.Join(tmpDir, ".cline"))
|
||||||
|
|
||||||
|
if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := readConfig()
|
||||||
|
if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClineModels(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||||
|
configPath := filepath.Join(configDir, "globalState.json")
|
||||||
|
|
||||||
|
t.Run("returns nil when no config", func(t *testing.T) {
|
||||||
|
if models := c.Models(); models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when provider is not ollama", func(t *testing.T) {
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
config := map[string]any{
|
||||||
|
"actModeApiProvider": "anthropic",
|
||||||
|
"actModeOllamaModelId": "some-model",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(config)
|
||||||
|
os.WriteFile(configPath, data, 0o644)
|
||||||
|
|
||||||
|
if models := c.Models(); models != nil {
|
||||||
|
t.Errorf("Models() = %v, want nil", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns model when ollama is configured", func(t *testing.T) {
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
config := map[string]any{
|
||||||
|
"actModeApiProvider": "ollama",
|
||||||
|
"actModeOllamaModelId": "kimi-k2.5:cloud",
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(config)
|
||||||
|
os.WriteFile(configPath, data, 0o644)
|
||||||
|
|
||||||
|
models := c.Models()
|
||||||
|
if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
|
||||||
|
t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClinePaths(t *testing.T) {
|
||||||
|
c := &Cline{}
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
t.Run("returns nil when no config exists", func(t *testing.T) {
|
||||||
|
if paths := c.Paths(); paths != nil {
|
||||||
|
t.Errorf("Paths() = %v, want nil", paths)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns path when config exists", func(t *testing.T) {
|
||||||
|
configDir := filepath.Join(tmpDir, ".cline", "data")
|
||||||
|
os.MkdirAll(configDir, 0o755)
|
||||||
|
configPath := filepath.Join(configDir, "globalState.json")
|
||||||
|
os.WriteFile(configPath, []byte("{}"), 0o644)
|
||||||
|
|
||||||
|
paths := c.Paths()
|
||||||
|
if len(paths) != 1 || paths[0] != configPath {
|
||||||
|
t.Errorf("Paths() = %v, want [%s]", paths, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -54,6 +53,7 @@ type AliasConfigurer interface {
|
|||||||
var integrations = map[string]Runner{
|
var integrations = map[string]Runner{
|
||||||
"claude": &Claude{},
|
"claude": &Claude{},
|
||||||
"clawdbot": &Openclaw{},
|
"clawdbot": &Openclaw{},
|
||||||
|
"cline": &Cline{},
|
||||||
"codex": &Codex{},
|
"codex": &Codex{},
|
||||||
"moltbot": &Openclaw{},
|
"moltbot": &Openclaw{},
|
||||||
"droid": &Droid{},
|
"droid": &Droid{},
|
||||||
@@ -102,16 +102,17 @@ var recommendedVRAM = map[string]string{
|
|||||||
var integrationAliases = map[string]bool{
|
var integrationAliases = map[string]bool{
|
||||||
"clawdbot": true,
|
"clawdbot": true,
|
||||||
"moltbot": true,
|
"moltbot": true,
|
||||||
"pi": true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// integrationInstallHints maps integration names to install URLs.
|
// integrationInstallHints maps integration names to install URLs.
|
||||||
var integrationInstallHints = map[string]string{
|
var integrationInstallHints = map[string]string{
|
||||||
"claude": "https://code.claude.com/docs/en/quickstart",
|
"claude": "https://code.claude.com/docs/en/quickstart",
|
||||||
|
"cline": "https://cline.bot/cli",
|
||||||
"openclaw": "https://docs.openclaw.ai",
|
"openclaw": "https://docs.openclaw.ai",
|
||||||
"codex": "https://developers.openai.com/codex/cli/",
|
"codex": "https://developers.openai.com/codex/cli/",
|
||||||
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
"droid": "https://docs.factory.ai/cli/getting-started/quickstart",
|
||||||
"opencode": "https://opencode.ai",
|
"opencode": "https://opencode.ai",
|
||||||
|
"pi": "https://github.com/badlogic/pi-mono",
|
||||||
}
|
}
|
||||||
|
|
||||||
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
|
||||||
@@ -129,13 +130,21 @@ type IntegrationInfo struct {
|
|||||||
// integrationDescriptions maps integration names to short descriptions.
|
// integrationDescriptions maps integration names to short descriptions.
|
||||||
var integrationDescriptions = map[string]string{
|
var integrationDescriptions = map[string]string{
|
||||||
"claude": "Anthropic's coding tool with subagents",
|
"claude": "Anthropic's coding tool with subagents",
|
||||||
|
"cline": "Autonomous coding agent with parallel execution",
|
||||||
"codex": "OpenAI's open-source coding agent",
|
"codex": "OpenAI's open-source coding agent",
|
||||||
"openclaw": "Personal AI with 100+ skills",
|
"openclaw": "Personal AI with 100+ skills",
|
||||||
"droid": "Factory's coding agent across terminal and IDEs",
|
"droid": "Factory's coding agent across terminal and IDEs",
|
||||||
"opencode": "Anomaly's open-source coding agent",
|
"opencode": "Anomaly's open-source coding agent",
|
||||||
|
"pi": "Minimal AI agent toolkit with plugin support",
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name.
|
// integrationOrder defines a custom display order for integrations.
|
||||||
|
// Integrations listed here are placed at the end in the given order;
|
||||||
|
// all others appear first, sorted alphabetically.
|
||||||
|
var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
|
||||||
|
|
||||||
|
// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
|
||||||
|
// with integrationOrder entries placed at the end.
|
||||||
func ListIntegrationInfos() []IntegrationInfo {
|
func ListIntegrationInfos() []IntegrationInfo {
|
||||||
var result []IntegrationInfo
|
var result []IntegrationInfo
|
||||||
for name, r := range integrations {
|
for name, r := range integrations {
|
||||||
@@ -148,7 +157,26 @@ func ListIntegrationInfos() []IntegrationInfo {
|
|||||||
Description: integrationDescriptions[name],
|
Description: integrationDescriptions[name],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
orderRank := make(map[string]int, len(integrationOrder))
|
||||||
|
for i, name := range integrationOrder {
|
||||||
|
orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
|
||||||
|
}
|
||||||
|
|
||||||
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
slices.SortFunc(result, func(a, b IntegrationInfo) int {
|
||||||
|
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||||
|
// Both have custom order: sort by their rank
|
||||||
|
if aRank > 0 && bRank > 0 {
|
||||||
|
return aRank - bRank
|
||||||
|
}
|
||||||
|
// Only one has custom order: it goes last
|
||||||
|
if aRank > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if bRank > 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
// Neither has custom order: alphabetical
|
||||||
return strings.Compare(a.Name, b.Name)
|
return strings.Compare(a.Name, b.Name)
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
@@ -186,9 +214,15 @@ func IsIntegrationInstalled(name string) bool {
|
|||||||
case "droid":
|
case "droid":
|
||||||
_, err := exec.LookPath("droid")
|
_, err := exec.LookPath("droid")
|
||||||
return err == nil
|
return err == nil
|
||||||
|
case "cline":
|
||||||
|
_, err := exec.LookPath("cline")
|
||||||
|
return err == nil
|
||||||
case "opencode":
|
case "opencode":
|
||||||
_, err := exec.LookPath("opencode")
|
_, err := exec.LookPath("opencode")
|
||||||
return err == nil
|
return err == nil
|
||||||
|
case "pi":
|
||||||
|
_, err := exec.LookPath("pi")
|
||||||
|
return err == nil
|
||||||
default:
|
default:
|
||||||
return true // Assume installed for unknown integrations
|
return true // Assume installed for unknown integrations
|
||||||
}
|
}
|
||||||
@@ -214,7 +248,8 @@ type ModelItem struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SingleSelector is a function type for single item selection.
|
// SingleSelector is a function type for single item selection.
|
||||||
type SingleSelector func(title string, items []ModelItem) (string, error)
|
// current is the name of the previously selected item to highlight; empty means no pre-selection.
|
||||||
|
type SingleSelector func(title string, items []ModelItem, current string) (string, error)
|
||||||
|
|
||||||
// MultiSelector is a function type for multi item selection.
|
// MultiSelector is a function type for multi item selection.
|
||||||
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
|
||||||
@@ -257,7 +292,7 @@ func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (stri
|
|||||||
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
return "", fmt.Errorf("no models available, run 'ollama pull <model>' first")
|
||||||
}
|
}
|
||||||
|
|
||||||
selected, err := selector("Select model to run:", items)
|
selected, err := selector("Select model to run:", items, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -367,13 +402,11 @@ func selectIntegration() (string, error) {
|
|||||||
return "", fmt.Errorf("no integrations available")
|
return "", fmt.Errorf("no integrations available")
|
||||||
}
|
}
|
||||||
|
|
||||||
names := slices.Sorted(maps.Keys(integrations))
|
|
||||||
var items []ModelItem
|
var items []ModelItem
|
||||||
for _, name := range names {
|
for name, r := range integrations {
|
||||||
if integrationAliases[name] {
|
if integrationAliases[name] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
r := integrations[name]
|
|
||||||
description := r.String()
|
description := r.String()
|
||||||
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
|
||||||
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
|
||||||
@@ -381,7 +414,25 @@ func selectIntegration() (string, error) {
|
|||||||
items = append(items, ModelItem{Name: name, Description: description})
|
items = append(items, ModelItem{Name: name, Description: description})
|
||||||
}
|
}
|
||||||
|
|
||||||
return DefaultSingleSelector("Select integration:", items)
|
orderRank := make(map[string]int, len(integrationOrder))
|
||||||
|
for i, name := range integrationOrder {
|
||||||
|
orderRank[name] = i + 1
|
||||||
|
}
|
||||||
|
slices.SortFunc(items, func(a, b ModelItem) int {
|
||||||
|
aRank, bRank := orderRank[a.Name], orderRank[b.Name]
|
||||||
|
if aRank > 0 && bRank > 0 {
|
||||||
|
return aRank - bRank
|
||||||
|
}
|
||||||
|
if aRank > 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if bRank > 0 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return strings.Compare(a.Name, b.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
return DefaultSingleSelector("Select integration:", items, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
|
||||||
@@ -439,7 +490,7 @@ func selectModelsWithSelectors(ctx context.Context, name, current string, single
|
|||||||
if _, ok := r.(AliasConfigurer); ok {
|
if _, ok := r.(AliasConfigurer); ok {
|
||||||
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
prompt = fmt.Sprintf("Select Primary model for %s:", r)
|
||||||
}
|
}
|
||||||
model, err := single(prompt, items)
|
model, err := single(prompt, items, current)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -812,10 +863,12 @@ Without arguments, this is equivalent to running 'ollama' directly.
|
|||||||
|
|
||||||
Supported integrations:
|
Supported integrations:
|
||||||
claude Claude Code
|
claude Claude Code
|
||||||
|
cline Cline
|
||||||
codex Codex
|
codex Codex
|
||||||
droid Droid
|
droid Droid
|
||||||
opencode OpenCode
|
opencode OpenCode
|
||||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
|
pi Pi
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
ollama launch
|
ollama launch
|
||||||
@@ -915,11 +968,9 @@ Examples:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validate saved model still exists
|
// Validate saved model still exists
|
||||||
cloudCleared := false
|
|
||||||
if model != "" && modelFlag == "" {
|
if model != "" && modelFlag == "" {
|
||||||
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
|
if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
|
||||||
model = ""
|
model = ""
|
||||||
cloudCleared = true
|
|
||||||
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
} else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
|
||||||
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
if err := ShowOrPull(cmd.Context(), client, model); err != nil {
|
||||||
@@ -928,9 +979,8 @@ Examples:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no valid model or --config flag, show picker
|
// Show picker so user can change model (skip when --model flag provided)
|
||||||
if model == "" || configFlag {
|
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
|
||||||
aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, configFlag || cloudCleared)
|
|
||||||
if errors.Is(err, errCancelled) {
|
if errors.Is(err, errCancelled) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -939,7 +989,6 @@ Examples:
|
|||||||
}
|
}
|
||||||
model = aliases["primary"]
|
model = aliases["primary"]
|
||||||
existingAliases = aliases
|
existingAliases = aliases
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure cloud models are authenticated
|
// Ensure cloud models are authenticated
|
||||||
if isCloudModel(cmd.Context(), client, model) {
|
if isCloudModel(cmd.Context(), client, model) {
|
||||||
@@ -1001,27 +1050,13 @@ Examples:
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 && !configFlag {
|
|
||||||
savedModels := filterDisabledCloudModels(saved.Models)
|
|
||||||
if len(savedModels) != len(saved.Models) {
|
|
||||||
_ = SaveIntegration(name, savedModels)
|
|
||||||
}
|
|
||||||
if len(savedModels) == 0 {
|
|
||||||
// All saved models were cloud — fall through to picker
|
|
||||||
models, err = selectModels(cmd.Context(), name, "")
|
|
||||||
if errors.Is(err, errCancelled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
models = savedModels
|
current := ""
|
||||||
return runIntegration(name, models[0], passArgs)
|
if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
|
||||||
|
current = saved.Models[0]
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
var err error
|
var err error
|
||||||
models, err = selectModels(cmd.Context(), name, "")
|
models, err = selectModels(cmd.Context(), name, current)
|
||||||
if errors.Is(err, errCancelled) {
|
if errors.Is(err, errCancelled) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1248,10 +1248,26 @@ func TestListIntegrationInfos(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sorted by name", func(t *testing.T) {
|
t.Run("sorted with custom order at end", func(t *testing.T) {
|
||||||
|
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
||||||
|
// All other entries should be sorted alphabetically before them.
|
||||||
|
orderRank := make(map[string]int)
|
||||||
|
for i, name := range integrationOrder {
|
||||||
|
orderRank[name] = i + 1
|
||||||
|
}
|
||||||
for i := 1; i < len(infos); i++ {
|
for i := 1; i < len(infos); i++ {
|
||||||
|
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
||||||
|
switch {
|
||||||
|
case aRank == 0 && bRank == 0:
|
||||||
if infos[i-1].Name >= infos[i].Name {
|
if infos[i-1].Name >= infos[i].Name {
|
||||||
t.Errorf("not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
||||||
|
}
|
||||||
|
case aRank > 0 && bRank == 0:
|
||||||
|
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
||||||
|
case aRank > 0 && bRank > 0:
|
||||||
|
if aRank >= bRank {
|
||||||
|
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -365,7 +365,19 @@ func (m selectorModel) View() string {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func SelectSingle(title string, items []SelectItem) (string, error) {
|
// cursorForCurrent returns the item index matching current, or 0 if not found.
|
||||||
|
func cursorForCurrent(items []SelectItem, current string) int {
|
||||||
|
if current != "" {
|
||||||
|
for i, item := range items {
|
||||||
|
if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func SelectSingle(title string, items []SelectItem, current string) (string, error) {
|
||||||
if len(items) == 0 {
|
if len(items) == 0 {
|
||||||
return "", fmt.Errorf("no items to select from")
|
return "", fmt.Errorf("no items to select from")
|
||||||
}
|
}
|
||||||
@@ -373,6 +385,7 @@ func SelectSingle(title string, items []SelectItem) (string, error) {
|
|||||||
m := selectorModel{
|
m := selectorModel{
|
||||||
title: title,
|
title: title,
|
||||||
items: items,
|
items: items,
|
||||||
|
cursor: cursorForCurrent(items, current),
|
||||||
}
|
}
|
||||||
|
|
||||||
p := tea.NewProgram(m)
|
p := tea.NewProgram(m)
|
||||||
|
|||||||
@@ -382,6 +382,42 @@ func TestUpdateNavigation_Backspace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- cursorForCurrent ---
|
||||||
|
|
||||||
|
func TestCursorForCurrent(t *testing.T) {
|
||||||
|
testItems := []SelectItem{
|
||||||
|
{Name: "llama3.2", Recommended: true},
|
||||||
|
{Name: "qwen3:8b", Recommended: true},
|
||||||
|
{Name: "gemma3:latest"},
|
||||||
|
{Name: "deepseek-r1"},
|
||||||
|
{Name: "glm-5:cloud"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
current string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"empty current", "", 0},
|
||||||
|
{"exact match", "qwen3:8b", 1},
|
||||||
|
{"no match returns 0", "nonexistent", 0},
|
||||||
|
{"bare name matches with :latest suffix", "gemma3", 2},
|
||||||
|
{"full tag matches bare item", "llama3.2:latest", 0},
|
||||||
|
{"cloud model exact match", "glm-5:cloud", 4},
|
||||||
|
{"cloud model bare name", "glm-5", 4},
|
||||||
|
{"recommended item exact match", "llama3.2", 0},
|
||||||
|
{"recommended item with tag", "qwen3", 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := cursorForCurrent(testItems, tt.current); got != tt.want {
|
||||||
|
t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- ReorderItems ---
|
// --- ReorderItems ---
|
||||||
|
|
||||||
func TestReorderItems(t *testing.T) {
|
func TestReorderItems(t *testing.T) {
|
||||||
|
|||||||
48
docs/api.md
48
docs/api.md
@@ -15,6 +15,7 @@
|
|||||||
- [Push a Model](#push-a-model)
|
- [Push a Model](#push-a-model)
|
||||||
- [Generate Embeddings](#generate-embeddings)
|
- [Generate Embeddings](#generate-embeddings)
|
||||||
- [List Running Models](#list-running-models)
|
- [List Running Models](#list-running-models)
|
||||||
|
- [Usage](#usage)
|
||||||
- [Version](#version)
|
- [Version](#version)
|
||||||
- [Experimental: Image Generation](#image-generation-experimental)
|
- [Experimental: Image Generation](#image-generation-experimental)
|
||||||
|
|
||||||
@@ -1854,6 +1855,53 @@ curl http://localhost:11434/api/embeddings -d '{
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/usage
|
||||||
|
```
|
||||||
|
|
||||||
|
Show aggregate usage statistics per model since the server started. All timestamps are UTC in RFC 3339 format.
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/usage
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"start": "2025-01-27T20:00:00Z",
|
||||||
|
"usage": [
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"requests": 5,
|
||||||
|
"prompt_tokens": 130,
|
||||||
|
"completion_tokens": 890
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "deepseek-r1",
|
||||||
|
"requests": 2,
|
||||||
|
"prompt_tokens": 48,
|
||||||
|
"completion_tokens": 312
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response fields
|
||||||
|
|
||||||
|
- `start`: when the server started tracking usage (UTC, RFC 3339)
|
||||||
|
- `usage`: list of per-model usage statistics
|
||||||
|
- `model`: model name
|
||||||
|
- `requests`: total number of completed requests
|
||||||
|
- `prompt_tokens`: total prompt tokens evaluated
|
||||||
|
- `completion_tokens`: total completion tokens generated
|
||||||
|
|
||||||
## Version
|
## Version
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -106,20 +106,23 @@
|
|||||||
"group": "Integrations",
|
"group": "Integrations",
|
||||||
"pages": [
|
"pages": [
|
||||||
"/integrations/index",
|
"/integrations/index",
|
||||||
|
{
|
||||||
|
"group": "Assistants",
|
||||||
|
"expanded": true,
|
||||||
|
"pages": [
|
||||||
|
"/integrations/openclaw"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"group": "Coding",
|
"group": "Coding",
|
||||||
|
"expanded": true,
|
||||||
"pages": [
|
"pages": [
|
||||||
"/integrations/claude-code",
|
"/integrations/claude-code",
|
||||||
"/integrations/codex",
|
"/integrations/codex",
|
||||||
"/integrations/opencode",
|
"/integrations/opencode",
|
||||||
"/integrations/droid",
|
"/integrations/droid",
|
||||||
"/integrations/goose"
|
"/integrations/goose",
|
||||||
]
|
"/integrations/pi"
|
||||||
},
|
|
||||||
{
|
|
||||||
"group": "Assistants",
|
|
||||||
"pages": [
|
|
||||||
"/integrations/openclaw"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
|||||||
- [OpenCode](/integrations/opencode)
|
- [OpenCode](/integrations/opencode)
|
||||||
- [Droid](/integrations/droid)
|
- [Droid](/integrations/droid)
|
||||||
- [Goose](/integrations/goose)
|
- [Goose](/integrations/goose)
|
||||||
|
- [Pi](/integrations/pi)
|
||||||
|
|
||||||
## Assistants
|
## Assistants
|
||||||
|
|
||||||
|
|||||||
57
docs/integrations/pi.mdx
Normal file
57
docs/integrations/pi.mdx
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
---
|
||||||
|
title: Pi
|
||||||
|
---
|
||||||
|
|
||||||
|
Pi is a minimal AI agent toolkit with plugin support.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Install [Pi](https://github.com/badlogic/pi-mono):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install -g @mariozechner/pi-coding-agent
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage with Ollama
|
||||||
|
|
||||||
|
### Quick setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama launch pi
|
||||||
|
```
|
||||||
|
|
||||||
|
To configure without launching:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch pi --config
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual setup
|
||||||
|
|
||||||
|
Add a configuration block to `~/.pi/agent/models.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"ollama": {
|
||||||
|
"baseUrl": "http://localhost:11434/v1",
|
||||||
|
"api": "openai-completions",
|
||||||
|
"apiKey": "ollama",
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"id": "qwen3-coder"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Update `~/.pi/agent/settings.json` to set the default provider:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"defaultProvider": "ollama",
|
||||||
|
"defaultModel": "qwen3-coder"
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -27,9 +27,17 @@ The menu provides quick access to:
|
|||||||
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
- **Launch tools** - Claude Code, Codex, OpenClaw, and more
|
||||||
- **Additional integrations** - Available under "More..."
|
- **Additional integrations** - Available under "More..."
|
||||||
|
|
||||||
|
## Assistants
|
||||||
|
|
||||||
|
Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
ollama launch openclaw
|
||||||
|
```
|
||||||
|
|
||||||
## Coding
|
## Coding
|
||||||
|
|
||||||
Launch coding tools with Ollama models:
|
Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
ollama launch claude
|
ollama launch claude
|
||||||
|
|||||||
@@ -45,6 +45,10 @@ func ParserForName(name string) Parser {
|
|||||||
var p Parser
|
var p Parser
|
||||||
|
|
||||||
switch name {
|
switch name {
|
||||||
|
case "qwen3":
|
||||||
|
p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
case "qwen3-thinking":
|
||||||
|
p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
case "qwen3-coder":
|
case "qwen3-coder":
|
||||||
p = &Qwen3CoderParser{}
|
p = &Qwen3CoderParser{}
|
||||||
case "qwen3-vl-instruct":
|
case "qwen3-vl-instruct":
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
}{
|
}{
|
||||||
{"passthrough"},
|
{"passthrough"},
|
||||||
|
{"qwen3"},
|
||||||
|
{"qwen3-thinking"},
|
||||||
{"qwen3-coder"},
|
{"qwen3-coder"},
|
||||||
{"harmony"},
|
{"harmony"},
|
||||||
}
|
}
|
||||||
|
|||||||
335
model/parsers/qwen3.go
Normal file
335
model/parsers/qwen3.go
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qwen3ParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
|
||||||
|
qwen3ParserStateThinkingStartedEatingWhitespace
|
||||||
|
qwen3ParserStateCollectingThinking
|
||||||
|
qwen3ParserStateThinkingDoneEatingWhitespace
|
||||||
|
qwen3ParserStateCollectingContent
|
||||||
|
qwen3ParserStateToolStartedEatingWhitespace
|
||||||
|
qwen3ParserStateCollectingToolContent
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
qwen3ThinkingOpenTag = "<think>"
|
||||||
|
qwen3ThinkingCloseTag = "</think>"
|
||||||
|
qwen3ToolOpenTag = "<tool_call>"
|
||||||
|
qwen3ToolCloseTag = "</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
|
||||||
|
// Qwen3 prompts end with <think> when thinking is enabled, so output begins
|
||||||
|
// with thinking content directly (without an opening tag).
|
||||||
|
type Qwen3Parser struct {
|
||||||
|
state qwen3ParserState
|
||||||
|
buffer strings.Builder
|
||||||
|
tools []api.Tool
|
||||||
|
hasThinkingSupport bool
|
||||||
|
defaultThinking bool
|
||||||
|
maybeThinkingOpenAtBOL bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) HasThinkingSupport() bool {
|
||||||
|
return p.hasThinkingSupport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
|
p.tools = tools
|
||||||
|
p.buffer.Reset()
|
||||||
|
|
||||||
|
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
|
||||||
|
if thinkValue == nil {
|
||||||
|
thinkingEnabled = p.defaultThinking
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.hasThinkingSupport && thinkingEnabled {
|
||||||
|
p.state = qwen3ParserStateCollectingThinking
|
||||||
|
p.maybeThinkingOpenAtBOL = true
|
||||||
|
} else {
|
||||||
|
p.state = qwen3ParserStateCollectingContent
|
||||||
|
p.maybeThinkingOpenAtBOL = false
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3Event interface {
|
||||||
|
isQwen3Event()
|
||||||
|
}
|
||||||
|
|
||||||
|
type qwen3EventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen3EventContent) isQwen3Event() {}
|
||||||
|
|
||||||
|
type qwen3EventRawToolCall struct {
|
||||||
|
raw string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen3EventRawToolCall) isQwen3Event() {}
|
||||||
|
|
||||||
|
type qwen3EventThinkingContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qwen3EventThinkingContent) isQwen3Event() {}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var contentSb strings.Builder
|
||||||
|
var thinkingSb strings.Builder
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case qwen3EventRawToolCall:
|
||||||
|
toolCall, err := parseQwen3ToolCall(event, p.tools)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("qwen3 tool call parsing failed", "error", err)
|
||||||
|
return "", "", nil, err
|
||||||
|
}
|
||||||
|
calls = append(calls, toolCall)
|
||||||
|
case qwen3EventThinkingContent:
|
||||||
|
thinkingSb.WriteString(event.content)
|
||||||
|
case qwen3EventContent:
|
||||||
|
contentSb.WriteString(event.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) parseEvents() []qwen3Event {
|
||||||
|
var all []qwen3Event
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []qwen3Event
|
||||||
|
events, keepLooping = p.eat()
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(all) > 0 {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
|
||||||
|
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
if trimmed == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
p.state = nextState
|
||||||
|
p.buffer.WriteString(trimmed)
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
|
||||||
|
return splitAtTag(&p.buffer, tag, trimAfter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
|
||||||
|
var events []qwen3Event
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case qwen3ParserStateLookingForThinkingOpen:
|
||||||
|
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||||
|
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||||
|
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||||
|
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
if after == "" {
|
||||||
|
p.state = qwen3ParserStateThinkingStartedEatingWhitespace
|
||||||
|
} else {
|
||||||
|
p.state = qwen3ParserStateCollectingThinking
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||||
|
return events, false
|
||||||
|
} else if trimmed == "" {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
p.state = qwen3ParserStateCollectingContent
|
||||||
|
return events, true
|
||||||
|
|
||||||
|
case qwen3ParserStateThinkingStartedEatingWhitespace:
|
||||||
|
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
|
||||||
|
|
||||||
|
case qwen3ParserStateCollectingThinking:
|
||||||
|
acc := p.buffer.String()
|
||||||
|
|
||||||
|
// Some qwen3 checkpoints emit an explicit opening <think> tag even
|
||||||
|
// though the prompt already ended with <think>. Strip exactly one
|
||||||
|
// leading opening tag if present.
|
||||||
|
if p.maybeThinkingOpenAtBOL {
|
||||||
|
trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
|
||||||
|
if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
|
||||||
|
after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
|
||||||
|
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
if after == "" {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
p.maybeThinkingOpenAtBOL = false
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
p.maybeThinkingOpenAtBOL = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(acc, qwen3ThinkingCloseTag) {
|
||||||
|
thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
|
||||||
|
if len(thinking) > 0 {
|
||||||
|
events = append(events, qwen3EventThinkingContent{content: thinking})
|
||||||
|
}
|
||||||
|
if remaining == "" {
|
||||||
|
p.state = qwen3ParserStateThinkingDoneEatingWhitespace
|
||||||
|
} else {
|
||||||
|
p.state = qwen3ParserStateCollectingContent
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
|
||||||
|
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||||
|
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||||
|
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
whitespaceLen := trailingWhitespaceLen(acc)
|
||||||
|
ambiguousStart := len(acc) - whitespaceLen
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen3EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
case qwen3ParserStateThinkingDoneEatingWhitespace:
|
||||||
|
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
|
||||||
|
|
||||||
|
case qwen3ParserStateCollectingContent:
|
||||||
|
acc := p.buffer.String()
|
||||||
|
if strings.Contains(acc, qwen3ToolOpenTag) {
|
||||||
|
before, after := p.splitAtTag(qwen3ToolOpenTag, true)
|
||||||
|
if len(before) > 0 {
|
||||||
|
events = append(events, qwen3EventContent{content: before})
|
||||||
|
}
|
||||||
|
if after == "" {
|
||||||
|
p.state = qwen3ParserStateToolStartedEatingWhitespace
|
||||||
|
} else {
|
||||||
|
p.state = qwen3ParserStateCollectingToolContent
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
|
||||||
|
beforePartialTag := acc[:len(acc)-overlapLen]
|
||||||
|
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
|
ambiguousStart := len(beforePartialTag) - trailingWsLen
|
||||||
|
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen3EventContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
whitespaceLen := trailingWhitespaceLen(acc)
|
||||||
|
ambiguousStart := len(acc) - whitespaceLen
|
||||||
|
unambiguous := acc[:ambiguousStart]
|
||||||
|
ambiguous := acc[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, qwen3EventContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
case qwen3ParserStateToolStartedEatingWhitespace:
|
||||||
|
return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
|
||||||
|
|
||||||
|
case qwen3ParserStateCollectingToolContent:
|
||||||
|
acc := p.buffer.String()
|
||||||
|
if strings.Contains(acc, qwen3ToolCloseTag) {
|
||||||
|
toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
|
||||||
|
if len(toolContent) == 0 {
|
||||||
|
slog.Warn("qwen3 tool call closing tag found but no content before it")
|
||||||
|
}
|
||||||
|
events = append(events, qwen3EventRawToolCall{raw: toolContent})
|
||||||
|
p.state = qwen3ParserStateCollectingContent
|
||||||
|
return events, true
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||||
|
var parsed struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
|
||||||
|
return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.Name == "" {
|
||||||
|
return api.ToolCall{}, fmt.Errorf("empty function name")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
|
||||||
|
|
||||||
|
toolCall := api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: parsed.Name,
|
||||||
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range parsed.Arguments {
|
||||||
|
toolCall.Function.Arguments.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCall, nil
|
||||||
|
}
|
||||||
147
model/parsers/qwen3_test.go
Normal file
147
model/parsers/qwen3_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQwen3ParserThinkingEnabled(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Let me think...</think>Answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "Let me think..." {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||||
|
}
|
||||||
|
if content != "Answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("<think>\nLet me think...</think>Answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "Let me think..." {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||||
|
}
|
||||||
|
if content != "Answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("<thi", false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on first chunk: %v", err)
|
||||||
|
}
|
||||||
|
if content != "" || thinking != "" || len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no output for first chunk, got content=%q thinking=%q calls=%d", content, thinking, len(calls))
|
||||||
|
}
|
||||||
|
|
||||||
|
content, thinking, calls, err = parser.Add("nk>Let me think...</think>Answer.", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed on second chunk: %v", err)
|
||||||
|
}
|
||||||
|
if thinking != "Let me think..." {
|
||||||
|
t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
|
||||||
|
}
|
||||||
|
if content != "Answer." {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Answer.", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserThinkingDisabled(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "Direct answer" {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
content, thinking, calls, err := parser.Add("Direct answer", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected no thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if content != "Direct answer" {
|
||||||
|
t.Fatalf("expected content %q, got %q", "Direct answer", content)
|
||||||
|
}
|
||||||
|
if len(calls) != 0 {
|
||||||
|
t.Fatalf("expected no tool calls, got %d", len(calls))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQwen3ParserToolCall(t *testing.T) {
|
||||||
|
parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
|
||||||
|
parser.Init(nil, nil, &api.ThinkValue{Value: false})
|
||||||
|
|
||||||
|
input := "<tool_call>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}</tool_call>"
|
||||||
|
content, thinking, calls, err := parser.Add(input, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if content != "" {
|
||||||
|
t.Fatalf("expected empty content, got %q", content)
|
||||||
|
}
|
||||||
|
if thinking != "" {
|
||||||
|
t.Fatalf("expected empty thinking, got %q", thinking)
|
||||||
|
}
|
||||||
|
if len(calls) != 1 {
|
||||||
|
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||||
|
}
|
||||||
|
if calls[0].Function.Name != "get_weather" {
|
||||||
|
t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
location, ok := calls[0].Function.Arguments.Get("location")
|
||||||
|
if !ok || location != "San Francisco" {
|
||||||
|
t.Fatalf("expected location %q, got %v", "San Francisco", location)
|
||||||
|
}
|
||||||
|
unit, ok := calls[0].Function.Arguments.Get("unit")
|
||||||
|
if !ok || unit != "celsius" {
|
||||||
|
t.Fatalf("expected unit %q, got %v", "celsius", unit)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -91,6 +91,8 @@ type Server struct {
|
|||||||
aliasesOnce sync.Once
|
aliasesOnce sync.Once
|
||||||
aliases *store
|
aliases *store
|
||||||
aliasesErr error
|
aliasesErr error
|
||||||
|
lowVRAM bool
|
||||||
|
usage *UsageTracker
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -289,6 +291,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
c.Header("Content-Type", contentType)
|
c.Header("Content-Type", contentType)
|
||||||
|
|
||||||
fn := func(resp api.GenerateResponse) error {
|
fn := func(resp api.GenerateResponse) error {
|
||||||
|
if resp.Done {
|
||||||
|
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
resp.Model = origModel
|
resp.Model = origModel
|
||||||
resp.RemoteModel = m.Config.RemoteModel
|
resp.RemoteModel = m.Config.RemoteModel
|
||||||
resp.RemoteHost = m.Config.RemoteHost
|
resp.RemoteHost = m.Config.RemoteHost
|
||||||
@@ -595,6 +601,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
res.Context = tokens
|
res.Context = tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.usage.Record(req.Model, cr.PromptEvalCount, cr.EvalCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
if builtinParser != nil {
|
if builtinParser != nil {
|
||||||
@@ -1622,6 +1630,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
r.POST("/api/experimental/aliases", s.CreateAliasHandler)
|
||||||
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
|
||||||
|
|
||||||
|
r.GET("/api/usage", s.UsageHandler)
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
r.GET("/api/ps", s.PsHandler)
|
r.GET("/api/ps", s.PsHandler)
|
||||||
r.POST("/api/generate", s.GenerateHandler)
|
r.POST("/api/generate", s.GenerateHandler)
|
||||||
@@ -1692,7 +1702,7 @@ func Serve(ln net.Listener) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr(), usage: NewUsageTracker()}
|
||||||
|
|
||||||
var rc *ollama.Registry
|
var rc *ollama.Registry
|
||||||
if useClient2 {
|
if useClient2 {
|
||||||
@@ -1927,6 +1937,10 @@ func (s *Server) SignoutHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, nil)
|
c.JSON(http.StatusOK, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) UsageHandler(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, s.usage.Stats())
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) PsHandler(c *gin.Context) {
|
func (s *Server) PsHandler(c *gin.Context) {
|
||||||
models := []api.ProcessModelResponse{}
|
models := []api.ProcessModelResponse{}
|
||||||
|
|
||||||
@@ -2097,6 +2111,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
c.Header("Content-Type", contentType)
|
c.Header("Content-Type", contentType)
|
||||||
|
|
||||||
fn := func(resp api.ChatResponse) error {
|
fn := func(resp api.ChatResponse) error {
|
||||||
|
if resp.Done {
|
||||||
|
s.usage.Record(origModel, resp.PromptEvalCount, resp.EvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
resp.Model = origModel
|
resp.Model = origModel
|
||||||
resp.RemoteModel = m.Config.RemoteModel
|
resp.RemoteModel = m.Config.RemoteModel
|
||||||
resp.RemoteHost = m.Config.RemoteHost
|
resp.RemoteHost = m.Config.RemoteHost
|
||||||
@@ -2317,6 +2335,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
res.DoneReason = r.DoneReason.String()
|
res.DoneReason = r.DoneReason.String()
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
|
s.usage.Record(req.Model, r.PromptEvalCount, r.EvalCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
if builtinParser != nil {
|
if builtinParser != nil {
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -224,6 +225,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -220,6 +221,7 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
|||||||
@@ -88,20 +88,40 @@ func TestGenerateChatRemote(t *testing.T) {
|
|||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
t.Errorf("Expected POST request, got %s", r.Method)
|
t.Errorf("Expected POST request, got %s", r.Method)
|
||||||
}
|
}
|
||||||
if r.URL.Path != "/api/chat" {
|
|
||||||
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/chat":
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: "test",
|
Model: "test",
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: "load",
|
DoneReason: "load",
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 10,
|
||||||
|
EvalCount: 20,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
case "/api/generate":
|
||||||
|
resp := api.GenerateResponse{
|
||||||
|
Model: "test",
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
Metrics: api.Metrics{
|
||||||
|
PromptEvalCount: 5,
|
||||||
|
EvalCount: 15,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected path %s", r.URL.Path)
|
||||||
|
}
|
||||||
}))
|
}))
|
||||||
defer rs.Close()
|
defer rs.Close()
|
||||||
|
|
||||||
@@ -111,7 +131,7 @@ func TestGenerateChatRemote(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||||
s := Server{}
|
s := Server{usage: NewUsageTracker()}
|
||||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
Model: "test-cloud",
|
Model: "test-cloud",
|
||||||
RemoteHost: rs.URL,
|
RemoteHost: rs.URL,
|
||||||
@@ -159,6 +179,61 @@ func TestGenerateChatRemote(t *testing.T) {
|
|||||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("remote chat usage tracking", func(t *testing.T) {
|
||||||
|
stats := s.usage.Stats()
|
||||||
|
found := false
|
||||||
|
for _, m := range stats.Usage {
|
||||||
|
if m.Model == "test-cloud" {
|
||||||
|
found = true
|
||||||
|
if m.Requests != 1 {
|
||||||
|
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||||
|
}
|
||||||
|
if m.PromptTokens != 10 {
|
||||||
|
t.Errorf("expected 10 prompt tokens, got %d", m.PromptTokens)
|
||||||
|
}
|
||||||
|
if m.CompletionTokens != 20 {
|
||||||
|
t.Errorf("expected 20 completion tokens, got %d", m.CompletionTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected usage entry for test-cloud")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remote generate usage tracking", func(t *testing.T) {
|
||||||
|
// Reset the tracker for a clean test
|
||||||
|
s.usage = NewUsageTracker()
|
||||||
|
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-cloud",
|
||||||
|
Prompt: "hello",
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := s.usage.Stats()
|
||||||
|
found := false
|
||||||
|
for _, m := range stats.Usage {
|
||||||
|
if m.Model == "test-cloud" {
|
||||||
|
found = true
|
||||||
|
if m.Requests != 1 {
|
||||||
|
t.Errorf("expected 1 request, got %d", m.Requests)
|
||||||
|
}
|
||||||
|
if m.PromptTokens != 5 {
|
||||||
|
t.Errorf("expected 5 prompt tokens, got %d", m.PromptTokens)
|
||||||
|
}
|
||||||
|
if m.CompletionTokens != 15 {
|
||||||
|
t.Errorf("expected 15 completion tokens, got %d", m.CompletionTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected usage entry for test-cloud")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateChat(t *testing.T) {
|
func TestGenerateChat(t *testing.T) {
|
||||||
@@ -177,6 +252,7 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -894,6 +970,7 @@ func TestGenerate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -1378,6 +1455,7 @@ func TestGenerateLogprobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -1558,6 +1636,7 @@ func TestChatLogprobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -1668,6 +1747,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -2114,6 +2194,7 @@ func TestGenerateUnload(t *testing.T) {
|
|||||||
var loadFnCalled bool
|
var loadFnCalled bool
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -2215,6 +2296,7 @@ func TestGenerateWithImages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -2371,30 +2453,6 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
s := Server{
|
|
||||||
sched: &Scheduler{
|
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
|
||||||
expiredCh: make(chan *runnerRef, 1),
|
|
||||||
unloadedCh: make(chan any, 1),
|
|
||||||
loaded: map[string]*runnerRef{
|
|
||||||
"": {
|
|
||||||
llama: &mock,
|
|
||||||
Options: &opts,
|
|
||||||
model: &Model{Config: model.ConfigV2{Capabilities: []string{"image"}}},
|
|
||||||
isImagegen: true,
|
|
||||||
numParallel: 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
newServerFn: newMockServer(&mock),
|
|
||||||
getGpuFn: getGpuFn,
|
|
||||||
getSystemInfoFn: getSystemInfoFn,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.sched.Run(t.Context())
|
|
||||||
|
|
||||||
// Create model manifest with image capability
|
// Create model manifest with image capability
|
||||||
n := model.ParseName("test-image")
|
n := model.ParseName("test-image")
|
||||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||||
@@ -2410,6 +2468,36 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loadedModel, err := GetModel("test-image")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: map[string]*runnerRef{
|
||||||
|
schedulerModelKey(loadedModel): {
|
||||||
|
llama: &mock,
|
||||||
|
Options: &opts,
|
||||||
|
model: loadedModel,
|
||||||
|
isImagegen: true,
|
||||||
|
numParallel: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: getGpuFn,
|
||||||
|
getSystemInfoFn: getSystemInfoFn,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
streamFalse := false
|
streamFalse := false
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
Model: "test-image",
|
Model: "test-image",
|
||||||
|
|||||||
@@ -255,6 +255,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -406,6 +407,7 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
@@ -588,6 +590,7 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := Server{
|
s := Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
sched: &Scheduler{
|
sched: &Scheduler{
|
||||||
pendingReqCh: make(chan *LlmRequest, 1),
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
finishedReqCh: make(chan *LlmRequest, 1),
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
|||||||
@@ -83,6 +83,28 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
|||||||
return sched
|
return sched
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// schedulerModelKey returns the scheduler map key for a model.
|
||||||
|
// GGUF-backed models use ModelPath; safetensors/image models without a
|
||||||
|
// ModelPath use manifest digest so distinct models don't collide.
|
||||||
|
func schedulerModelKey(m *Model) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if m.ModelPath != "" {
|
||||||
|
return m.ModelPath
|
||||||
|
}
|
||||||
|
if m.Digest != "" {
|
||||||
|
return "digest:" + m.Digest
|
||||||
|
}
|
||||||
|
if m.Name != "" {
|
||||||
|
return "name:" + m.Name
|
||||||
|
}
|
||||||
|
if m.ShortName != "" {
|
||||||
|
return "short:" + m.ShortName
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
@@ -104,8 +126,9 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
|
|||||||
useImagegen: useImagegen,
|
useImagegen: useImagegen,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
key := schedulerModelKey(req.model)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
runner := s.loaded[req.model.ModelPath]
|
runner := s.loaded[key]
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
if runner != nil && !runner.needsReload(c, req) {
|
if runner != nil && !runner.needsReload(c, req) {
|
||||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||||
@@ -151,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
var runnerToExpire *runnerRef
|
var runnerToExpire *runnerRef
|
||||||
|
pendingKey := schedulerModelKey(pending.model)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
runner := s.loaded[pending.model.ModelPath]
|
runner := s.loaded[pendingKey]
|
||||||
loadedCount := len(s.loaded)
|
loadedCount := len(s.loaded)
|
||||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||||
for _, r := range s.loaded {
|
for _, r := range s.loaded {
|
||||||
@@ -166,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
runnerToExpire = runner
|
runnerToExpire = runner
|
||||||
} else {
|
} else {
|
||||||
// Runner is usable, return it
|
// Runner is usable, return it
|
||||||
logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
|
logutil.Trace("using existing loaded runner", "model", pendingKey)
|
||||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -292,11 +316,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
|||||||
slog.Debug("shutting down scheduler completed loop")
|
slog.Debug("shutting down scheduler completed loop")
|
||||||
return
|
return
|
||||||
case finished := <-s.finishedReqCh:
|
case finished := <-s.finishedReqCh:
|
||||||
|
finishedKey := schedulerModelKey(finished.model)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
runner := s.loaded[finished.model.ModelPath]
|
runner := s.loaded[finishedKey]
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
if runner == nil {
|
if runner == nil {
|
||||||
slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
|
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
@@ -347,7 +372,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
|||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
slog.Debug("got lock to unload expired event", "runner", runner)
|
slog.Debug("got lock to unload expired event", "runner", runner)
|
||||||
runnerToUnload := s.loaded[runner.modelPath]
|
runnerToUnload := s.loaded[runner.modelKey]
|
||||||
if runnerToUnload == nil {
|
if runnerToUnload == nil {
|
||||||
// If runnerToUnload is nil, we already processed an event and
|
// If runnerToUnload is nil, we already processed an event and
|
||||||
// unloaded it. This double unload can happen if the initial
|
// unloaded it. This double unload can happen if the initial
|
||||||
@@ -376,7 +401,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
|
||||||
runner.unload()
|
runner.unload()
|
||||||
delete(s.loaded, runner.modelPath)
|
delete(s.loaded, runner.modelKey)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
|
||||||
<-finished
|
<-finished
|
||||||
@@ -514,6 +539,7 @@ iGPUScan:
|
|||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
|
modelKey: schedulerModelKey(req.model),
|
||||||
llama: llama,
|
llama: llama,
|
||||||
Options: &req.opts,
|
Options: &req.opts,
|
||||||
sessionDuration: sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
@@ -528,7 +554,7 @@ iGPUScan:
|
|||||||
runner.refMu.Lock() // hold lock until running or aborted
|
runner.refMu.Lock() // hold lock until running or aborted
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
|
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
|
||||||
// Shouldn't happen, but safeguard against leaking a runner
|
// Shouldn't happen, but safeguard against leaking a runner
|
||||||
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
|
||||||
oldRunner.refMu.Lock()
|
oldRunner.refMu.Lock()
|
||||||
@@ -536,7 +562,7 @@ iGPUScan:
|
|||||||
oldRunner.refMu.Unlock()
|
oldRunner.refMu.Unlock()
|
||||||
}
|
}
|
||||||
s.activeLoading = nil
|
s.activeLoading = nil
|
||||||
s.loaded[req.model.ModelPath] = runner
|
s.loaded[runner.modelKey] = runner
|
||||||
slog.Info("loaded runners", "count", len(s.loaded))
|
slog.Info("loaded runners", "count", len(s.loaded))
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
@@ -596,6 +622,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
runner := &runnerRef{
|
runner := &runnerRef{
|
||||||
model: req.model,
|
model: req.model,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
|
modelKey: schedulerModelKey(req.model),
|
||||||
llama: server,
|
llama: server,
|
||||||
Options: &req.opts,
|
Options: &req.opts,
|
||||||
loading: false,
|
loading: false,
|
||||||
@@ -606,7 +633,7 @@ func (s *Scheduler) loadMLX(req *LlmRequest) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
s.loaded[req.model.ModelPath] = runner
|
s.loaded[runner.modelKey] = runner
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// Set up expiration timer
|
// Set up expiration timer
|
||||||
@@ -684,6 +711,7 @@ type runnerRef struct {
|
|||||||
|
|
||||||
model *Model
|
model *Model
|
||||||
modelPath string
|
modelPath string
|
||||||
|
modelKey string
|
||||||
numParallel int
|
numParallel int
|
||||||
*api.Options
|
*api.Options
|
||||||
}
|
}
|
||||||
@@ -703,7 +731,7 @@ func (runner *runnerRef) unload() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||||
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
|
||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
defer runner.refMu.Unlock()
|
defer runner.refMu.Unlock()
|
||||||
|
|
||||||
@@ -814,6 +842,10 @@ func (runner *runnerRef) LogValue() slog.Value {
|
|||||||
if runner == nil {
|
if runner == nil {
|
||||||
return slog.StringValue("nil")
|
return slog.StringValue("nil")
|
||||||
}
|
}
|
||||||
|
modelID := runner.modelPath
|
||||||
|
if modelID == "" {
|
||||||
|
modelID = runner.modelKey
|
||||||
|
}
|
||||||
attrs := []slog.Attr{}
|
attrs := []slog.Attr{}
|
||||||
if runner.model != nil {
|
if runner.model != nil {
|
||||||
attrs = append(attrs, slog.String("name", runner.model.Name))
|
attrs = append(attrs, slog.String("name", runner.model.Name))
|
||||||
@@ -828,7 +860,7 @@ func (runner *runnerRef) LogValue() slog.Value {
|
|||||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||||
slog.Int("parallel", runner.numParallel),
|
slog.Int("parallel", runner.numParallel),
|
||||||
slog.Int("pid", runner.pid),
|
slog.Int("pid", runner.pid),
|
||||||
slog.String("model", runner.modelPath),
|
slog.String("model", modelID),
|
||||||
)
|
)
|
||||||
if runner.Options != nil {
|
if runner.Options != nil {
|
||||||
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
|
||||||
@@ -873,8 +905,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
|
|||||||
if d1 != d2 {
|
if d1 != d2 {
|
||||||
return d1 < d2
|
return d1 < d2
|
||||||
}
|
}
|
||||||
// Secondary sort by model path lex order
|
// Secondary sort by model key/path lex order
|
||||||
return a[i].modelPath < a[j].modelPath
|
n1 := a[i].modelPath
|
||||||
|
if n1 == "" {
|
||||||
|
n1 = a[i].modelKey
|
||||||
|
}
|
||||||
|
n2 := a[j].modelPath
|
||||||
|
if n2 == "" {
|
||||||
|
n2 = a[j].modelKey
|
||||||
|
}
|
||||||
|
return n1 < n2
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - future consideration to pick runners based on size
|
// TODO - future consideration to pick runners based on size
|
||||||
@@ -934,8 +974,9 @@ func (s *Scheduler) unloadAllRunners() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Scheduler) expireRunner(model *Model) {
|
func (s *Scheduler) expireRunner(model *Model) {
|
||||||
|
modelKey := schedulerModelKey(model)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
runner, ok := s.loaded[model.ModelPath]
|
runner, ok := s.loaded[modelKey]
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
if ok {
|
if ok {
|
||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
|
|||||||
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
|
|||||||
b.ctxDone()
|
b.ctxDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
opts.NumCtx = 4
|
||||||
|
|
||||||
|
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||||
|
loadedRunner := &runnerRef{
|
||||||
|
model: loadedModel,
|
||||||
|
modelKey: schedulerModelKey(loadedModel),
|
||||||
|
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||||
|
Options: &opts,
|
||||||
|
numParallel: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
|
reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
|
||||||
|
successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
|
||||||
|
|
||||||
|
require.Empty(t, successCh)
|
||||||
|
require.Empty(t, errCh)
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
opts.NumCtx = 4
|
||||||
|
|
||||||
|
loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
|
||||||
|
loadedRunner := &runnerRef{
|
||||||
|
model: loadedModel,
|
||||||
|
modelKey: schedulerModelKey(loadedModel),
|
||||||
|
llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
|
||||||
|
Options: &opts,
|
||||||
|
numParallel: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
s.loaded[loadedRunner.modelKey] = loadedRunner
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
|
reqCtx, cancelReq := context.WithCancel(ctx)
|
||||||
|
successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
|
||||||
|
cancelReq()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case runner := <-successCh:
|
||||||
|
require.Equal(t, loadedRunner, runner)
|
||||||
|
default:
|
||||||
|
t.Fatal("expected existing runner to be reused")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Empty(t, errCh)
|
||||||
|
require.Empty(t, s.pendingReqCh)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSchedExpireRunner(t *testing.T) {
|
func TestSchedExpireRunner(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|||||||
62
server/usage.go
Normal file
62
server/usage.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelUsage struct {
|
||||||
|
Requests int64
|
||||||
|
PromptTokens int64
|
||||||
|
CompletionTokens int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageTracker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
start time.Time
|
||||||
|
models map[string]*ModelUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsageTracker() *UsageTracker {
|
||||||
|
return &UsageTracker{
|
||||||
|
start: time.Now().UTC(),
|
||||||
|
models: make(map[string]*ModelUsage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UsageTracker) Record(model string, promptTokens, completionTokens int) {
|
||||||
|
u.mu.Lock()
|
||||||
|
defer u.mu.Unlock()
|
||||||
|
|
||||||
|
m, ok := u.models[model]
|
||||||
|
if !ok {
|
||||||
|
m = &ModelUsage{}
|
||||||
|
u.models[model] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Requests++
|
||||||
|
m.PromptTokens += int64(promptTokens)
|
||||||
|
m.CompletionTokens += int64(completionTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UsageTracker) Stats() api.UsageResponse {
|
||||||
|
u.mu.Lock()
|
||||||
|
defer u.mu.Unlock()
|
||||||
|
|
||||||
|
byModel := make([]api.ModelUsageData, 0, len(u.models))
|
||||||
|
for model, usage := range u.models {
|
||||||
|
byModel = append(byModel, api.ModelUsageData{
|
||||||
|
Model: model,
|
||||||
|
Requests: usage.Requests,
|
||||||
|
PromptTokens: usage.PromptTokens,
|
||||||
|
CompletionTokens: usage.CompletionTokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.UsageResponse{
|
||||||
|
Start: u.start,
|
||||||
|
Usage: byModel,
|
||||||
|
}
|
||||||
|
}
|
||||||
136
server/usage_test.go
Normal file
136
server/usage_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUsageTrackerRecord(t *testing.T) {
|
||||||
|
tracker := NewUsageTracker()
|
||||||
|
|
||||||
|
tracker.Record("model-a", 10, 20)
|
||||||
|
tracker.Record("model-a", 5, 15)
|
||||||
|
tracker.Record("model-b", 100, 200)
|
||||||
|
|
||||||
|
stats := tracker.Stats()
|
||||||
|
|
||||||
|
if len(stats.Usage) != 2 {
|
||||||
|
t.Fatalf("expected 2 models, got %d", len(stats.Usage))
|
||||||
|
}
|
||||||
|
|
||||||
|
lookup := make(map[string]api.ModelUsageData)
|
||||||
|
for _, m := range stats.Usage {
|
||||||
|
lookup[m.Model] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
a := lookup["model-a"]
|
||||||
|
if a.Requests != 2 {
|
||||||
|
t.Errorf("model-a requests: expected 2, got %d", a.Requests)
|
||||||
|
}
|
||||||
|
if a.PromptTokens != 15 {
|
||||||
|
t.Errorf("model-a prompt tokens: expected 15, got %d", a.PromptTokens)
|
||||||
|
}
|
||||||
|
if a.CompletionTokens != 35 {
|
||||||
|
t.Errorf("model-a completion tokens: expected 35, got %d", a.CompletionTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
b := lookup["model-b"]
|
||||||
|
if b.Requests != 1 {
|
||||||
|
t.Errorf("model-b requests: expected 1, got %d", b.Requests)
|
||||||
|
}
|
||||||
|
if b.PromptTokens != 100 {
|
||||||
|
t.Errorf("model-b prompt tokens: expected 100, got %d", b.PromptTokens)
|
||||||
|
}
|
||||||
|
if b.CompletionTokens != 200 {
|
||||||
|
t.Errorf("model-b completion tokens: expected 200, got %d", b.CompletionTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageTrackerConcurrent(t *testing.T) {
|
||||||
|
tracker := NewUsageTracker()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for range 100 {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
tracker.Record("model-a", 1, 2)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
stats := tracker.Stats()
|
||||||
|
if len(stats.Usage) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(stats.Usage))
|
||||||
|
}
|
||||||
|
|
||||||
|
m := stats.Usage[0]
|
||||||
|
if m.Requests != 100 {
|
||||||
|
t.Errorf("requests: expected 100, got %d", m.Requests)
|
||||||
|
}
|
||||||
|
if m.PromptTokens != 100 {
|
||||||
|
t.Errorf("prompt tokens: expected 100, got %d", m.PromptTokens)
|
||||||
|
}
|
||||||
|
if m.CompletionTokens != 200 {
|
||||||
|
t.Errorf("completion tokens: expected 200, got %d", m.CompletionTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageTrackerStart(t *testing.T) {
|
||||||
|
tracker := NewUsageTracker()
|
||||||
|
|
||||||
|
stats := tracker.Stats()
|
||||||
|
if stats.Start.IsZero() {
|
||||||
|
t.Error("expected non-zero start time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageHandler(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
usage: NewUsageTracker(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.usage.Record("llama3", 50, 100)
|
||||||
|
s.usage.Record("llama3", 25, 50)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/usage", nil)
|
||||||
|
|
||||||
|
s.UsageHandler(c)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.UsageResponse
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Usage) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(resp.Usage))
|
||||||
|
}
|
||||||
|
|
||||||
|
m := resp.Usage[0]
|
||||||
|
if m.Model != "llama3" {
|
||||||
|
t.Errorf("expected model llama3, got %s", m.Model)
|
||||||
|
}
|
||||||
|
if m.Requests != 2 {
|
||||||
|
t.Errorf("expected 2 requests, got %d", m.Requests)
|
||||||
|
}
|
||||||
|
if m.PromptTokens != 75 {
|
||||||
|
t.Errorf("expected 75 prompt tokens, got %d", m.PromptTokens)
|
||||||
|
}
|
||||||
|
if m.CompletionTokens != 150 {
|
||||||
|
t.Errorf("expected 150 completion tokens, got %d", m.CompletionTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -30,6 +30,8 @@ type ModelfileConfig struct {
|
|||||||
Template string
|
Template string
|
||||||
System string
|
System string
|
||||||
License string
|
License string
|
||||||
|
Parser string
|
||||||
|
Renderer string
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateOptions holds all options for model creation.
|
// CreateOptions holds all options for model creation.
|
||||||
@@ -37,7 +39,7 @@ type CreateOptions struct {
|
|||||||
ModelName string
|
ModelName string
|
||||||
ModelDir string
|
ModelDir string
|
||||||
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
|
||||||
Modelfile *ModelfileConfig // template/system/license from Modelfile
|
Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateModel imports a model from a local directory.
|
// CreateModel imports a model from a local directory.
|
||||||
@@ -267,8 +269,8 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
|||||||
ModelFormat: "safetensors",
|
ModelFormat: "safetensors",
|
||||||
Capabilities: caps,
|
Capabilities: caps,
|
||||||
Requires: MinOllamaVersion,
|
Requires: MinOllamaVersion,
|
||||||
Parser: parserName,
|
Parser: resolveParserName(opts.Modelfile, parserName),
|
||||||
Renderer: rendererName,
|
Renderer: resolveRendererName(opts.Modelfile, rendererName),
|
||||||
}
|
}
|
||||||
configJSON, err := json.Marshal(configData)
|
configJSON, err := json.Marshal(configData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -305,6 +307,22 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveParserName(mf *ModelfileConfig, inferred string) string {
|
||||||
|
if mf != nil && mf.Parser != "" {
|
||||||
|
return mf.Parser
|
||||||
|
}
|
||||||
|
|
||||||
|
return inferred
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveRendererName(mf *ModelfileConfig, inferred string) string {
|
||||||
|
if mf != nil && mf.Renderer != "" {
|
||||||
|
return mf.Renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
return inferred
|
||||||
|
}
|
||||||
|
|
||||||
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
|
||||||
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
|
||||||
var layers []manifest.Layer
|
var layers []manifest.Layer
|
||||||
@@ -410,7 +428,7 @@ func getParserName(modelDir string) string {
|
|||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
}
|
||||||
if strings.Contains(archLower, "qwen3") {
|
if strings.Contains(archLower, "qwen3") {
|
||||||
return "qwen3-coder"
|
return "qwen3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,7 +442,7 @@ func getParserName(modelDir string) string {
|
|||||||
return "deepseek3"
|
return "deepseek3"
|
||||||
}
|
}
|
||||||
if strings.Contains(typeLower, "qwen3") {
|
if strings.Contains(typeLower, "qwen3") {
|
||||||
return "qwen3-coder"
|
return "qwen3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
|
|||||||
Template: "{{ .Prompt }}",
|
Template: "{{ .Prompt }}",
|
||||||
System: "You are a helpful assistant.",
|
System: "You are a helpful assistant.",
|
||||||
License: "MIT",
|
License: "MIT",
|
||||||
|
Parser: "qwen3",
|
||||||
|
Renderer: "qwen3",
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Template != "{{ .Prompt }}" {
|
if config.Template != "{{ .Prompt }}" {
|
||||||
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
|
|||||||
if config.License != "MIT" {
|
if config.License != "MIT" {
|
||||||
t.Errorf("License = %q, want %q", config.License, "MIT")
|
t.Errorf("License = %q, want %q", config.License, "MIT")
|
||||||
}
|
}
|
||||||
|
if config.Parser != "qwen3" {
|
||||||
|
t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
|
||||||
|
}
|
||||||
|
if config.Renderer != "qwen3" {
|
||||||
|
t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelfileConfig_Empty(t *testing.T) {
|
func TestModelfileConfig_Empty(t *testing.T) {
|
||||||
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
|
|||||||
if config.License != "" {
|
if config.License != "" {
|
||||||
t.Errorf("License should be empty, got %q", config.License)
|
t.Errorf("License should be empty, got %q", config.License)
|
||||||
}
|
}
|
||||||
|
if config.Parser != "" {
|
||||||
|
t.Errorf("Parser should be empty, got %q", config.Parser)
|
||||||
|
}
|
||||||
|
if config.Renderer != "" {
|
||||||
|
t.Errorf("Renderer should be empty, got %q", config.Renderer)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelfileConfig_PartialFields(t *testing.T) {
|
func TestModelfileConfig_PartialFields(t *testing.T) {
|
||||||
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
|
|||||||
if config.License != "" {
|
if config.License != "" {
|
||||||
t.Error("License should be empty")
|
t.Error("License should be empty")
|
||||||
}
|
}
|
||||||
|
if config.Parser != "" {
|
||||||
|
t.Error("Parser should be empty")
|
||||||
|
}
|
||||||
|
if config.Renderer != "" {
|
||||||
|
t.Error("Renderer should be empty")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinOllamaVersion(t *testing.T) {
|
func TestMinOllamaVersion(t *testing.T) {
|
||||||
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
|
|||||||
Template: "test",
|
Template: "test",
|
||||||
System: "system",
|
System: "system",
|
||||||
License: "MIT",
|
License: "MIT",
|
||||||
|
Parser: "qwen3-thinking",
|
||||||
|
Renderer: "qwen3",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
|
|||||||
if opts.Modelfile.Template != "test" {
|
if opts.Modelfile.Template != "test" {
|
||||||
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
|
||||||
}
|
}
|
||||||
|
if opts.Modelfile.Parser != "qwen3-thinking" {
|
||||||
|
t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
|
||||||
|
}
|
||||||
|
if opts.Modelfile.Renderer != "qwen3" {
|
||||||
|
t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveParserName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mf *ModelfileConfig
|
||||||
|
inferred string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil modelfile uses inferred",
|
||||||
|
mf: nil,
|
||||||
|
inferred: "qwen3",
|
||||||
|
want: "qwen3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty parser uses inferred",
|
||||||
|
mf: &ModelfileConfig{
|
||||||
|
Parser: "",
|
||||||
|
},
|
||||||
|
inferred: "qwen3",
|
||||||
|
want: "qwen3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit parser overrides inferred",
|
||||||
|
mf: &ModelfileConfig{
|
||||||
|
Parser: "qwen3-thinking",
|
||||||
|
},
|
||||||
|
inferred: "qwen3",
|
||||||
|
want: "qwen3-thinking",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
|
||||||
|
t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveRendererName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mf *ModelfileConfig
|
||||||
|
inferred string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil modelfile uses inferred",
|
||||||
|
mf: nil,
|
||||||
|
inferred: "qwen3-coder",
|
||||||
|
want: "qwen3-coder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty renderer uses inferred",
|
||||||
|
mf: &ModelfileConfig{
|
||||||
|
Renderer: "",
|
||||||
|
},
|
||||||
|
inferred: "qwen3-coder",
|
||||||
|
want: "qwen3-coder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit renderer overrides inferred",
|
||||||
|
mf: &ModelfileConfig{
|
||||||
|
Renderer: "qwen3",
|
||||||
|
},
|
||||||
|
inferred: "qwen3-coder",
|
||||||
|
want: "qwen3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
|
||||||
|
t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateOptions_Defaults(t *testing.T) {
|
func TestCreateOptions_Defaults(t *testing.T) {
|
||||||
|
|||||||
@@ -3,5 +3,8 @@
|
|||||||
package mlxrunner
|
package mlxrunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||||
|
_ "github.com/ollama/ollama/x/models/llama"
|
||||||
|
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||||
)
|
)
|
||||||
|
|||||||
92
x/mlxrunner/model/linear.go
Normal file
92
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/models/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||||
|
type LinearFactory struct {
|
||||||
|
tensors map[string]*mlx.Array
|
||||||
|
defaultGroupSize int
|
||||||
|
defaultBits int
|
||||||
|
defaultMode string
|
||||||
|
tensorQuant map[string]*TensorQuantInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||||
|
func NewLinearFactory(
|
||||||
|
tensors map[string]*mlx.Array,
|
||||||
|
defaultGroupSize, defaultBits int,
|
||||||
|
defaultMode string,
|
||||||
|
tensorQuant map[string]*TensorQuantInfo,
|
||||||
|
) LinearFactory {
|
||||||
|
return LinearFactory{
|
||||||
|
tensors: tensors,
|
||||||
|
defaultGroupSize: defaultGroupSize,
|
||||||
|
defaultBits: defaultBits,
|
||||||
|
defaultMode: defaultMode,
|
||||||
|
tensorQuant: tensorQuant,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make constructs a linear layer at path.
|
||||||
|
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||||
|
return MakeLinearLayer(
|
||||||
|
f.tensors,
|
||||||
|
path,
|
||||||
|
f.defaultGroupSize,
|
||||||
|
f.defaultBits,
|
||||||
|
f.defaultMode,
|
||||||
|
f.tensorQuant,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||||
|
//
|
||||||
|
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||||
|
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||||
|
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||||
|
func MakeLinearLayer(
|
||||||
|
tensors map[string]*mlx.Array,
|
||||||
|
path string,
|
||||||
|
defaultGroupSize, defaultBits int,
|
||||||
|
defaultMode string,
|
||||||
|
tensorQuant map[string]*TensorQuantInfo,
|
||||||
|
) nn.LinearLayer {
|
||||||
|
w := tensors[path+".weight"]
|
||||||
|
if w == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
scales := tensors[path+".weight_scale"]
|
||||||
|
if scales != nil {
|
||||||
|
qbiases := tensors[path+".weight_qbias"]
|
||||||
|
bias := tensors[path+".bias"]
|
||||||
|
|
||||||
|
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||||
|
defaultGroupSize,
|
||||||
|
defaultBits,
|
||||||
|
defaultMode,
|
||||||
|
tensorQuant,
|
||||||
|
path+".weight",
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &nn.QuantizedLinear{
|
||||||
|
Weight: w,
|
||||||
|
Scales: scales,
|
||||||
|
QBiases: qbiases,
|
||||||
|
Bias: bias,
|
||||||
|
GroupSize: groupSize,
|
||||||
|
Bits: bits,
|
||||||
|
Mode: mode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bias := tensors[path+".bias"]
|
||||||
|
return nn.NewLinear(w, bias)
|
||||||
|
}
|
||||||
130
x/mlxrunner/model/quant.go
Normal file
130
x/mlxrunner/model/quant.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
|
||||||
|
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||||
|
switch strings.ToUpper(quantization) {
|
||||||
|
case "NVFP4":
|
||||||
|
return 16, 4, "nvfp4"
|
||||||
|
case "FP4", "Q4", "INT4":
|
||||||
|
return 32, 4, "affine"
|
||||||
|
case "MXFP8":
|
||||||
|
return 32, 8, "mxfp8"
|
||||||
|
case "FP8", "Q8", "INT8", "":
|
||||||
|
return 64, 8, "affine"
|
||||||
|
default:
|
||||||
|
return 32, 8, "affine"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
|
||||||
|
// when available, otherwise falling back to the provided model defaults.
|
||||||
|
func TensorQuantParams(
|
||||||
|
defaultGroupSize, defaultBits int,
|
||||||
|
defaultMode string,
|
||||||
|
tensorQuant map[string]*TensorQuantInfo,
|
||||||
|
tensorName string,
|
||||||
|
) (groupSize, bits int, mode string, fromTensor bool) {
|
||||||
|
if tensorQuant != nil {
|
||||||
|
if tq := tensorQuant[tensorName]; tq != nil {
|
||||||
|
groupSize, bits, mode = QuantizationParams(tq.QuantType)
|
||||||
|
if tq.GroupSize > 0 {
|
||||||
|
groupSize = tq.GroupSize
|
||||||
|
}
|
||||||
|
return groupSize, bits, mode, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultGroupSize, defaultBits, defaultMode, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveLinearQuantParams resolves quantization params for a quantized linear
|
||||||
|
// tensor, preferring per-tensor metadata and falling back to shape-based
|
||||||
|
// inference for affine packed tensors.
|
||||||
|
func ResolveLinearQuantParams(
|
||||||
|
defaultGroupSize, defaultBits int,
|
||||||
|
defaultMode string,
|
||||||
|
tensorQuant map[string]*TensorQuantInfo,
|
||||||
|
tensorName string,
|
||||||
|
weight, scales *mlx.Array,
|
||||||
|
) (groupSize, bits int, mode string) {
|
||||||
|
groupSize, bits, mode, fromTensor := TensorQuantParams(
|
||||||
|
defaultGroupSize,
|
||||||
|
defaultBits,
|
||||||
|
defaultMode,
|
||||||
|
tensorQuant,
|
||||||
|
tensorName,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == "affine" {
|
||||||
|
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
|
||||||
|
if !fromTensor || groupSize == 0 || bits == 0 {
|
||||||
|
groupSize = inferredGroupSize
|
||||||
|
bits = inferredBits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupSize, bits, mode
|
||||||
|
}
|
||||||
|
|
||||||
|
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
|
||||||
|
// tensors from packed weight and scale shapes.
|
||||||
|
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
|
||||||
|
if weight == nil || scales == nil {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
weightShape := weight.Dims()
|
||||||
|
scaleShape := scales.Dims()
|
||||||
|
if len(weightShape) == 0 || len(scaleShape) == 0 {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
weightCols := weightShape[len(weightShape)-1]
|
||||||
|
scalesCols := scaleShape[len(scaleShape)-1]
|
||||||
|
if weightCols <= 0 || scalesCols <= 0 {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
groupSize4 := weightCols * 8 / scalesCols
|
||||||
|
groupSize8 := weightCols * 4 / scalesCols
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case groupSize4 == 32:
|
||||||
|
return 32, 4, true
|
||||||
|
case groupSize8 == 64:
|
||||||
|
return 64, 8, true
|
||||||
|
case groupSize4 == 64 && groupSize8 == 32:
|
||||||
|
if hintBits == 8 {
|
||||||
|
return 32, 8, true
|
||||||
|
}
|
||||||
|
if hintBits == 4 {
|
||||||
|
return 64, 4, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||||
|
return groupSize4, 4, true
|
||||||
|
}
|
||||||
|
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||||
|
return groupSize8, 8, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCommonGroupSize(v int) bool {
|
||||||
|
switch v {
|
||||||
|
case 16, 32, 64, 128:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,42 +8,63 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TensorQuantInfo describes per-tensor quantization metadata.
|
||||||
|
type TensorQuantInfo struct {
|
||||||
|
QuantType string
|
||||||
|
GroupSize int
|
||||||
|
}
|
||||||
|
|
||||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||||
type Root struct {
|
type Root struct {
|
||||||
Manifest *manifest.ModelManifest
|
Manifest *manifest.ModelManifest
|
||||||
|
|
||||||
|
// Backwards-compatible model-level quant metadata (first tensor blob).
|
||||||
quantType string
|
quantType string
|
||||||
groupSize int
|
groupSize int
|
||||||
|
|
||||||
|
// Per-tensor quantization metadata.
|
||||||
|
tensorQuant map[string]*TensorQuantInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open loads a manifest for the given model name and pre-scans the first
|
// Open loads a manifest for the given model name and scans tensor blobs for
|
||||||
// tensor blob for quantization metadata (quant_type, group_size).
|
// quantization metadata.
|
||||||
func Open(modelName string) (*Root, error) {
|
func Open(modelName string) (*Root, error) {
|
||||||
m, err := manifest.LoadManifest(modelName)
|
m, err := manifest.LoadManifest(modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
root := &Root{Manifest: m}
|
root := &Root{
|
||||||
|
Manifest: m,
|
||||||
|
tensorQuant: make(map[string]*TensorQuantInfo),
|
||||||
|
}
|
||||||
|
|
||||||
// Pre-scan first tensor blob for quantization metadata
|
|
||||||
for _, layer := range m.GetTensorLayers("") {
|
for _, layer := range m.GetTensorLayers("") {
|
||||||
blobPath := m.BlobPath(layer.Digest)
|
blobPath := m.BlobPath(layer.Digest)
|
||||||
meta, err := readBlobMetadata(blobPath)
|
|
||||||
if err != nil || meta == nil {
|
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
|
||||||
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if qt := meta["quant_type"]; qt != "" {
|
|
||||||
root.quantType = strings.ToUpper(qt)
|
for name, info := range infos {
|
||||||
|
root.tensorQuant[name] = info
|
||||||
|
}
|
||||||
|
|
||||||
|
if root.quantType == "" && blobQuantType != "" {
|
||||||
|
root.quantType = strings.ToUpper(blobQuantType)
|
||||||
|
root.groupSize = blobGroupSize
|
||||||
|
if root.groupSize == 0 {
|
||||||
|
root.groupSize = defaultGroupSize(root.quantType)
|
||||||
}
|
}
|
||||||
if gs := meta["group_size"]; gs != "" {
|
|
||||||
fmt.Sscanf(gs, "%d", &root.groupSize)
|
|
||||||
}
|
}
|
||||||
break // only check the first tensor blob
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return root, nil
|
return root, nil
|
||||||
@@ -52,46 +73,180 @@ func Open(modelName string) (*Root, error) {
|
|||||||
// Close is a no-op for now (future: release resources).
|
// Close is a no-op for now (future: release resources).
|
||||||
func (r *Root) Close() {}
|
func (r *Root) Close() {}
|
||||||
|
|
||||||
// QuantType returns the quantization type detected from tensor metadata.
|
// QuantType returns the quantization type detected from the first tensor blob metadata.
|
||||||
func (r *Root) QuantType() string { return r.quantType }
|
func (r *Root) QuantType() string { return r.quantType }
|
||||||
|
|
||||||
// GroupSize returns the quantization group size detected from tensor metadata.
|
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
|
||||||
func (r *Root) GroupSize() int { return r.groupSize }
|
func (r *Root) GroupSize() int { return r.groupSize }
|
||||||
|
|
||||||
// readBlobMetadata reads the __metadata__ from a safetensors blob header.
|
// TensorQuant returns per-tensor quantization metadata if available.
|
||||||
func readBlobMetadata(path string) (map[string]string, error) {
|
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.tensorQuant[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
|
||||||
|
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
|
||||||
|
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
|
||||||
|
for k, v := range r.tensorQuant {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copy := *v
|
||||||
|
out[k] = ©
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultGroupSize(quantType string) int {
|
||||||
|
groupSize, _, _ := QuantizationParams(quantType)
|
||||||
|
return groupSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", 0, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var headerSize uint64
|
var headerSize uint64
|
||||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||||
return nil, err
|
return nil, "", 0, err
|
||||||
}
|
}
|
||||||
if headerSize > 1024*1024 {
|
if headerSize > 100*1024*1024 {
|
||||||
return nil, fmt.Errorf("header too large: %d", headerSize)
|
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := make([]byte, headerSize)
|
data := make([]byte, headerSize)
|
||||||
if _, err := io.ReadFull(f, data); err != nil {
|
if _, err := io.ReadFull(f, data); err != nil {
|
||||||
return nil, err
|
return nil, "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var header map[string]json.RawMessage
|
var header map[string]json.RawMessage
|
||||||
if err := json.Unmarshal(data, &header); err != nil {
|
if err := json.Unmarshal(data, &header); err != nil {
|
||||||
return nil, err
|
return nil, "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||||
|
globalQuantType = strings.ToUpper(globalQuantType)
|
||||||
|
|
||||||
|
mainNames := mainTensorNames(header)
|
||||||
|
infos := make(map[string]*TensorQuantInfo)
|
||||||
|
for _, name := range mainNames {
|
||||||
|
if _, ok := header[name+".scale"]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
quantType := globalQuantType
|
||||||
|
groupSize := globalGroupSize
|
||||||
|
|
||||||
|
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||||
|
if quantType == "" {
|
||||||
|
quantType = inferredType
|
||||||
|
}
|
||||||
|
if groupSize == 0 {
|
||||||
|
groupSize = inferredGroup
|
||||||
|
}
|
||||||
|
if quantType == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if groupSize == 0 {
|
||||||
|
groupSize = defaultGroupSize(quantType)
|
||||||
|
}
|
||||||
|
|
||||||
|
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
|
||||||
|
}
|
||||||
|
|
||||||
|
return infos, globalQuantType, globalGroupSize, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
|
||||||
metaRaw, ok := header["__metadata__"]
|
metaRaw, ok := header["__metadata__"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return "", 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var meta map[string]string
|
var meta map[string]string
|
||||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||||
return nil, err
|
return "", 0
|
||||||
}
|
}
|
||||||
return meta, nil
|
|
||||||
|
quantType = meta["quant_type"]
|
||||||
|
if gs := meta["group_size"]; gs != "" {
|
||||||
|
groupSize, _ = strconv.Atoi(gs)
|
||||||
|
}
|
||||||
|
return quantType, groupSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func mainTensorNames(header map[string]json.RawMessage) []string {
|
||||||
|
names := make([]string, 0, len(header))
|
||||||
|
for name := range header {
|
||||||
|
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
|
||||||
|
type tensorShape struct {
|
||||||
|
Shape []int64 `json:"shape"`
|
||||||
|
}
|
||||||
|
|
||||||
|
mainRaw, ok := header[tensorName]
|
||||||
|
if !ok {
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
scaleRaw, ok := header[tensorName+".scale"]
|
||||||
|
if !ok {
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var mainInfo tensorShape
|
||||||
|
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var scaleInfo tensorShape
|
||||||
|
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
|
||||||
|
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
|
||||||
|
if weightCols <= 0 || scalesCols <= 0 {
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
groupSize4 := weightCols * 8 / scalesCols
|
||||||
|
groupSize8 := weightCols * 4 / scalesCols
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case groupSize4 == 32:
|
||||||
|
return "INT4", 32
|
||||||
|
case groupSize8 == 64:
|
||||||
|
return "INT8", 64
|
||||||
|
case groupSize4 == 64 && groupSize8 == 32:
|
||||||
|
h := strings.ToUpper(hintQuantType)
|
||||||
|
if strings.Contains(h, "8") {
|
||||||
|
return "INT8", 32
|
||||||
|
}
|
||||||
|
if strings.Contains(h, "4") {
|
||||||
|
return "INT4", 64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||||
|
return "INT4", groupSize4
|
||||||
|
}
|
||||||
|
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||||
|
return "INT8", groupSize8
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,17 +18,29 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enableCompile := true
|
||||||
|
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
||||||
|
enableCompile = modelCompile.EnableCompile()
|
||||||
|
}
|
||||||
|
if enableCompile {
|
||||||
mlx.EnableCompile()
|
mlx.EnableCompile()
|
||||||
|
} else {
|
||||||
|
mlx.DisableCompile()
|
||||||
|
}
|
||||||
|
|
||||||
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
inputs := r.Tokenizer.Encode(request.Prompt, true)
|
||||||
|
|
||||||
caches, tokens := r.FindNearestCache(inputs)
|
caches, tokens := r.FindNearestCache(inputs)
|
||||||
if len(caches) == 0 {
|
if len(caches) == 0 {
|
||||||
|
if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
|
||||||
|
caches = cacheFactory.NewCaches()
|
||||||
|
} else {
|
||||||
caches = make([]cache.Cache, r.Model.NumLayers())
|
caches = make([]cache.Cache, r.Model.NumLayers())
|
||||||
for i := range caches {
|
for i := range caches {
|
||||||
caches[i] = cache.NewKVCache()
|
caches[i] = cache.NewKVCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
total, processed := len(tokens), 0
|
total, processed := len(tokens), 0
|
||||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||||
|
|||||||
521
x/models/gemma3/gemma3.go
Normal file
521
x/models/gemma3/gemma3.go
Normal file
@@ -0,0 +1,521 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
|
||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/models/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
base.Register("Gemma3ForCausalLM", newModel)
|
||||||
|
base.Register("Gemma3ForConditionalGeneration", newModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextConfig holds configuration for the Gemma 3 text model.
|
||||||
|
type TextConfig struct {
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||||
|
HeadDim int32 `json:"head_dim"`
|
||||||
|
VocabSize int32 `json:"vocab_size"`
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||||
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||||
|
SlidingWindow int32 `json:"sliding_window"`
|
||||||
|
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||||
|
|
||||||
|
// Quantization parameters (set during load based on model quantization).
|
||||||
|
QuantGroupSize int `json:"-"`
|
||||||
|
QuantBits int `json:"-"`
|
||||||
|
QuantMode string `json:"-"`
|
||||||
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||||
|
|
||||||
|
// Computed fields.
|
||||||
|
Scale float32 `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attention implements Gemma 3 attention with Q/K normalization.
|
||||||
|
type Attention struct {
|
||||||
|
QProj nn.LinearLayer
|
||||||
|
KProj nn.LinearLayer
|
||||||
|
VProj nn.LinearLayer
|
||||||
|
OProj nn.LinearLayer
|
||||||
|
|
||||||
|
QNorm *nn.RMSNorm
|
||||||
|
KNorm *nn.RMSNorm
|
||||||
|
|
||||||
|
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||||
|
QNormScaled *mlx.Array
|
||||||
|
KNormScaled *mlx.Array
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP is the feed-forward network with GELU activation.
|
||||||
|
type MLP struct {
|
||||||
|
GateProj nn.LinearLayer
|
||||||
|
UpProj nn.LinearLayer
|
||||||
|
DownProj nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecoderLayer is a single transformer block.
|
||||||
|
type DecoderLayer struct {
|
||||||
|
InputNorm *nn.RMSNorm
|
||||||
|
Attention *Attention
|
||||||
|
PostAttnNorm *nn.RMSNorm
|
||||||
|
PreFFNorm *nn.RMSNorm
|
||||||
|
MLP *MLP
|
||||||
|
PostFFNorm *nn.RMSNorm
|
||||||
|
|
||||||
|
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||||
|
InputNormScaled *mlx.Array
|
||||||
|
PostAttnNormScaled *mlx.Array
|
||||||
|
PreFFNormScaled *mlx.Array
|
||||||
|
PostFFNormScaled *mlx.Array
|
||||||
|
|
||||||
|
// Layer metadata.
|
||||||
|
IsSliding bool
|
||||||
|
LayerIdx int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model is the Gemma 3 text-only model.
|
||||||
|
type Model struct {
|
||||||
|
EmbedTokens *nn.Embedding
|
||||||
|
Layers []*DecoderLayer
|
||||||
|
Norm *nn.RMSNorm
|
||||||
|
LMHead nn.LinearLayer
|
||||||
|
|
||||||
|
// Precomputed (1 + weight) for Gemma-style RMSNorm.
|
||||||
|
NormScaled *mlx.Array
|
||||||
|
|
||||||
|
tok *tokenizer.Tokenizer
|
||||||
|
*TextConfig
|
||||||
|
|
||||||
|
weightPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
|
||||||
|
switch numLayers {
|
||||||
|
case 34:
|
||||||
|
return 8, 4
|
||||||
|
case 48:
|
||||||
|
return 16, 8
|
||||||
|
case 62:
|
||||||
|
return 32, 16
|
||||||
|
default:
|
||||||
|
return 8, 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
|
||||||
|
var cfg TextConfig
|
||||||
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||||
|
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wrapped struct {
|
||||||
|
TextConfig *TextConfig `json:"text_config"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(configData, &wrapped); err != nil {
|
||||||
|
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fromConditional := wrapped.TextConfig != nil
|
||||||
|
if fromConditional {
|
||||||
|
cfg = *wrapped.TextConfig
|
||||||
|
|
||||||
|
if cfg.HeadDim == 0 {
|
||||||
|
cfg.HeadDim = 256
|
||||||
|
}
|
||||||
|
if cfg.NumAttentionHeads == 0 {
|
||||||
|
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
if cfg.NumKeyValueHeads == 0 {
|
||||||
|
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
if cfg.VocabSize == 0 {
|
||||||
|
cfg.VocabSize = 262208
|
||||||
|
}
|
||||||
|
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
|
||||||
|
cfg.SlidingWindowPattern = 6
|
||||||
|
}
|
||||||
|
if cfg.MaxPositionEmbeddings == 0 {
|
||||||
|
cfg.MaxPositionEmbeddings = 131072
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HeadDim == 0 {
|
||||||
|
cfg.HeadDim = 256
|
||||||
|
}
|
||||||
|
if cfg.NumAttentionHeads == 0 {
|
||||||
|
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
|
||||||
|
}
|
||||||
|
if cfg.NumKeyValueHeads == 0 {
|
||||||
|
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
|
||||||
|
}
|
||||||
|
if cfg.RopeTheta == 0 {
|
||||||
|
cfg.RopeTheta = 1000000
|
||||||
|
}
|
||||||
|
if cfg.RopeLocalBaseFreq == 0 {
|
||||||
|
cfg.RopeLocalBaseFreq = 10000
|
||||||
|
}
|
||||||
|
if cfg.RMSNormEps == 0 {
|
||||||
|
cfg.RMSNormEps = 1e-6
|
||||||
|
}
|
||||||
|
if cfg.VocabSize == 0 {
|
||||||
|
cfg.VocabSize = 262208
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
|
||||||
|
return cfg, fromConditional, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||||
|
for _, prefix := range []string{"", "language_model."} {
|
||||||
|
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
|
||||||
|
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
|
||||||
|
return cfg.LayerTypes[layerIdx] == "sliding_attention"
|
||||||
|
}
|
||||||
|
if cfg.SlidingWindowPattern <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func precomputeGemmaScaledWeights(m *Model) {
|
||||||
|
if m.Norm != nil {
|
||||||
|
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
var scaled []*mlx.Array
|
||||||
|
if m.NormScaled != nil {
|
||||||
|
scaled = append(scaled, m.NormScaled)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
if layer == nil || layer.Attention == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if layer.InputNorm != nil {
|
||||||
|
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||||
|
scaled = append(scaled, layer.InputNormScaled)
|
||||||
|
}
|
||||||
|
if layer.PostAttnNorm != nil {
|
||||||
|
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||||
|
scaled = append(scaled, layer.PostAttnNormScaled)
|
||||||
|
}
|
||||||
|
if layer.PreFFNorm != nil {
|
||||||
|
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||||
|
scaled = append(scaled, layer.PreFFNormScaled)
|
||||||
|
}
|
||||||
|
if layer.PostFFNorm != nil {
|
||||||
|
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||||
|
scaled = append(scaled, layer.PostFFNormScaled)
|
||||||
|
}
|
||||||
|
|
||||||
|
if layer.Attention.QNorm != nil {
|
||||||
|
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||||
|
scaled = append(scaled, layer.Attention.QNormScaled)
|
||||||
|
}
|
||||||
|
if layer.Attention.KNorm != nil {
|
||||||
|
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||||
|
scaled = append(scaled, layer.Attention.KNormScaled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(scaled) > 0 {
|
||||||
|
mlx.Eval(scaled...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(root *model.Root) (base.Model, error) {
|
||||||
|
configData, err := root.Manifest.ReadConfig("config.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, _, err := parseTextConfig(configData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if qt := root.QuantType(); qt != "" {
|
||||||
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||||
|
if gs := root.GroupSize(); gs > 0 {
|
||||||
|
cfg.QuantGroupSize = gs
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||||
|
}
|
||||||
|
cfg.TensorQuant = root.AllTensorQuant()
|
||||||
|
|
||||||
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
|
||||||
|
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||||
|
tokConfig.GenerationConfigJSON = genConfigData
|
||||||
|
}
|
||||||
|
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||||
|
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||||
|
TextConfig: &cfg,
|
||||||
|
tok: tok,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range m.Layers {
|
||||||
|
m.Layers[i] = &DecoderLayer{
|
||||||
|
LayerIdx: int32(i),
|
||||||
|
IsSliding: isLayerSliding(int32(i), m.TextConfig),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||||
|
// to model fields.
|
||||||
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||||
|
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||||
|
prefix := m.weightPrefix
|
||||||
|
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||||
|
|
||||||
|
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||||
|
if embedWeight == nil {
|
||||||
|
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||||
|
}
|
||||||
|
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||||
|
|
||||||
|
normWeight := tensors[prefix+"model.norm.weight"]
|
||||||
|
if normWeight == nil {
|
||||||
|
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||||
|
}
|
||||||
|
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||||
|
|
||||||
|
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||||
|
m.LMHead = lmHead
|
||||||
|
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||||
|
m.LMHead = lmHead
|
||||||
|
} else {
|
||||||
|
// Gemma usually ties output projection to embeddings.
|
||||||
|
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||||
|
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||||
|
|
||||||
|
layer := &DecoderLayer{
|
||||||
|
LayerIdx: i,
|
||||||
|
IsSliding: isLayerSliding(i, m.TextConfig),
|
||||||
|
Attention: &Attention{},
|
||||||
|
MLP: &MLP{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||||
|
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||||
|
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
|
||||||
|
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
|
||||||
|
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||||
|
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||||
|
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||||
|
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||||
|
|
||||||
|
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||||
|
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||||
|
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||||
|
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||||
|
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||||
|
|
||||||
|
if layer.InputNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.PostAttnNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.PreFFNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.PostFFNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||||
|
}
|
||||||
|
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||||
|
}
|
||||||
|
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Layers[i] = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
precomputeGemmaScaledWeights(m)
|
||||||
|
if m.NormScaled == nil {
|
||||||
|
return fmt.Errorf("missing precomputed final norm weight")
|
||||||
|
}
|
||||||
|
collected := mlx.Collect(m)
|
||||||
|
mlx.Eval(collected...)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||||
|
dims := tokens.Dims()
|
||||||
|
B, L := int32(dims[0]), int32(dims[1])
|
||||||
|
|
||||||
|
h := m.EmbedTokens.Forward(tokens)
|
||||||
|
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
var c cache.Cache
|
||||||
|
if caches != nil && i < len(caches) {
|
||||||
|
c = caches[i]
|
||||||
|
}
|
||||||
|
h = layer.Forward(h, c, B, L, m.TextConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||||
|
return m.LMHead.Forward(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) NumLayers() int {
|
||||||
|
return len(m.Layers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
|
return m.tok
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCaches creates cache objects for all layers.
|
||||||
|
func (m *Model) NewCaches() []cache.Cache {
|
||||||
|
caches := make([]cache.Cache, len(m.Layers))
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
if m.SlidingWindow > 0 && layer.IsSliding {
|
||||||
|
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
|
||||||
|
} else {
|
||||||
|
caches[i] = cache.NewKVCache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return caches
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatPrompt applies the Gemma 3 chat template.
|
||||||
|
func (m *Model) FormatPrompt(prompt string) string {
|
||||||
|
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||||
|
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||||
|
|
||||||
|
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
|
||||||
|
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||||
|
h := mlx.Add(x, attnOut)
|
||||||
|
|
||||||
|
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||||
|
|
||||||
|
mlpOut := l.MLP.Forward(normed)
|
||||||
|
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||||
|
|
||||||
|
return mlx.Add(h, mlpOut)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||||
|
q := a.QProj.Forward(x)
|
||||||
|
k := a.KProj.Forward(x)
|
||||||
|
v := a.VProj.Forward(x)
|
||||||
|
|
||||||
|
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
|
||||||
|
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
|
||||||
|
|
||||||
|
ropeTheta := cfg.RopeTheta
|
||||||
|
if isSliding {
|
||||||
|
ropeTheta = cfg.RopeLocalBaseFreq
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
if c != nil {
|
||||||
|
offset = c.Offset()
|
||||||
|
}
|
||||||
|
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||||
|
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
k, v = c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||||
|
if repeatFactor > 1 {
|
||||||
|
k = nn.RepeatKV(k, repeatFactor)
|
||||||
|
v = nn.RepeatKV(v, repeatFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||||
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||||
|
return a.OProj.Forward(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
||||||
|
up := m.UpProj.Forward(x)
|
||||||
|
return m.DownProj.Forward(mlx.Mul(gate, up))
|
||||||
|
}
|
||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
@@ -67,6 +66,7 @@ type Config struct {
|
|||||||
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
|
||||||
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
QuantBits int `json:"-"` // Bits per weight (4 or 8)
|
||||||
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
|
||||||
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||||
|
|
||||||
// Computed fields
|
// Computed fields
|
||||||
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
|
||||||
@@ -372,22 +372,6 @@ func supportsGatherQMM(mode string, bits int) bool {
|
|||||||
return mode == "affine" && (bits == 4 || bits == 8)
|
return mode == "affine" && (bits == 4 || bits == 8)
|
||||||
}
|
}
|
||||||
|
|
||||||
// quantizationParams returns groupSize, bits, mode for a quantization type string.
|
|
||||||
func quantizationParams(quantization string) (groupSize, bits int, mode string) {
|
|
||||||
switch strings.ToUpper(quantization) {
|
|
||||||
case "NVFP4":
|
|
||||||
return 16, 4, "nvfp4"
|
|
||||||
case "FP4", "Q4", "INT4":
|
|
||||||
return 32, 4, "affine"
|
|
||||||
case "MXFP8":
|
|
||||||
return 32, 8, "mxfp8"
|
|
||||||
case "FP8", "Q8", "INT8", "":
|
|
||||||
return 64, 8, "affine"
|
|
||||||
default:
|
|
||||||
return 32, 8, "affine"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExpertWeight holds a single expert's weight with optional quantization components.
|
// ExpertWeight holds a single expert's weight with optional quantization components.
|
||||||
type ExpertWeight struct {
|
type ExpertWeight struct {
|
||||||
Weight *mlx.Array
|
Weight *mlx.Array
|
||||||
@@ -408,7 +392,15 @@ func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized b
|
|||||||
if scales != nil {
|
if scales != nil {
|
||||||
qbiases := tensors[path+".weight_qbias"]
|
qbiases := tensors[path+".weight_qbias"]
|
||||||
|
|
||||||
groupSize, bits, mode := cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
||||||
|
cfg.QuantGroupSize,
|
||||||
|
cfg.QuantBits,
|
||||||
|
cfg.QuantMode,
|
||||||
|
cfg.TensorQuant,
|
||||||
|
path+".weight",
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
)
|
||||||
|
|
||||||
if useQuantized && supportsGatherQMM(mode, bits) {
|
if useQuantized && supportsGatherQMM(mode, bits) {
|
||||||
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
|
||||||
@@ -492,7 +484,16 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
|
|||||||
// Check if quantized and dequantize
|
// Check if quantized and dequantize
|
||||||
if scales := tensors[path+".weight_scale"]; scales != nil {
|
if scales := tensors[path+".weight_scale"]; scales != nil {
|
||||||
qbiases := tensors[path+".weight_qbias"]
|
qbiases := tensors[path+".weight_qbias"]
|
||||||
w = mlx.Dequantize(w, scales, qbiases, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode)
|
groupSize, bits, mode := model.ResolveLinearQuantParams(
|
||||||
|
cfg.QuantGroupSize,
|
||||||
|
cfg.QuantBits,
|
||||||
|
cfg.QuantMode,
|
||||||
|
cfg.TensorQuant,
|
||||||
|
path+".weight",
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
)
|
||||||
|
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
|
||||||
@@ -507,32 +508,6 @@ func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Confi
|
|||||||
return embedQ, unembedOut
|
return embedQ, unembedOut
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeLinear creates a Linear or QuantizedLinear layer from the tensor map.
|
|
||||||
func makeLinear(tensors map[string]*mlx.Array, path string, cfg *Config) nn.LinearLayer {
|
|
||||||
w := tensors[path+".weight"]
|
|
||||||
if w == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
scales := tensors[path+".weight_scale"]
|
|
||||||
if scales != nil {
|
|
||||||
qbiases := tensors[path+".weight_qbias"]
|
|
||||||
bias := tensors[path+".bias"]
|
|
||||||
return &nn.QuantizedLinear{
|
|
||||||
Weight: w,
|
|
||||||
Scales: scales,
|
|
||||||
QBiases: qbiases,
|
|
||||||
Bias: bias,
|
|
||||||
GroupSize: cfg.QuantGroupSize,
|
|
||||||
Bits: cfg.QuantBits,
|
|
||||||
Mode: cfg.QuantMode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bias := tensors[path+".bias"]
|
|
||||||
return nn.NewLinear(w, bias)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
|
||||||
// no weights loaded yet). Called by the registry via base.New().
|
// no weights loaded yet). Called by the registry via base.New().
|
||||||
func newModel(root *model.Root) (base.Model, error) {
|
func newModel(root *model.Root) (base.Model, error) {
|
||||||
@@ -551,13 +526,14 @@ func newModel(root *model.Root) (base.Model, error) {
|
|||||||
|
|
||||||
// Set up quantization parameters from pre-scanned metadata
|
// Set up quantization parameters from pre-scanned metadata
|
||||||
if qt := root.QuantType(); qt != "" {
|
if qt := root.QuantType(); qt != "" {
|
||||||
_, cfg.QuantBits, cfg.QuantMode = quantizationParams(qt)
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||||
if gs := root.GroupSize(); gs > 0 {
|
if gs := root.GroupSize(); gs > 0 {
|
||||||
cfg.QuantGroupSize = gs
|
cfg.QuantGroupSize = gs
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
cfg.QuantGroupSize, _, _ = quantizationParams(qt)
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
cfg.TensorQuant = root.AllTensorQuant()
|
||||||
|
|
||||||
// Load tokenizer
|
// Load tokenizer
|
||||||
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||||
@@ -596,7 +572,20 @@ func newModel(root *model.Root) (base.Model, error) {
|
|||||||
// layer creation.
|
// layer creation.
|
||||||
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||||
cfg := m.Config
|
cfg := m.Config
|
||||||
|
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
|
||||||
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
|
||||||
|
if !useQuantized && cfg.TensorQuant != nil {
|
||||||
|
for _, tq := range cfg.TensorQuant {
|
||||||
|
if tq == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, bits, mode := model.QuantizationParams(tq.QuantType)
|
||||||
|
if supportsGatherQMM(mode, bits) {
|
||||||
|
useQuantized = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Load embedding
|
// Load embedding
|
||||||
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
if w := tensors["model.embed_tokens.weight"]; w != nil {
|
||||||
@@ -609,7 +598,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load LM head
|
// Load LM head
|
||||||
m.LMHead = makeLinear(tensors, "lm_head", cfg)
|
m.LMHead = linears.Make("lm_head")
|
||||||
|
|
||||||
// Load layers
|
// Load layers
|
||||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||||
@@ -617,16 +606,16 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
|
|
||||||
// Load attention (same for both block types)
|
// Load attention (same for both block types)
|
||||||
attn := &MLAAttention{}
|
attn := &MLAAttention{}
|
||||||
attn.QAProj = makeLinear(tensors, prefix+".self_attn.q_a_proj", cfg)
|
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
|
||||||
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
|
||||||
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||||
}
|
}
|
||||||
attn.QBProj = makeLinear(tensors, prefix+".self_attn.q_b_proj", cfg)
|
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
|
||||||
attn.KVAProjWithMQA = makeLinear(tensors, prefix+".self_attn.kv_a_proj_with_mqa", cfg)
|
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
|
||||||
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
|
||||||
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
|
||||||
}
|
}
|
||||||
attn.OProj = makeLinear(tensors, prefix+".self_attn.o_proj", cfg)
|
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
|
||||||
|
|
||||||
// Sanitize MLA weights for absorbed attention
|
// Sanitize MLA weights for absorbed attention
|
||||||
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
|
||||||
@@ -647,9 +636,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
block.MLP = &DenseMLP{
|
block.MLP = &DenseMLP{
|
||||||
GateProj: makeLinear(tensors, prefix+".mlp.gate_proj", cfg),
|
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
|
||||||
UpProj: makeLinear(tensors, prefix+".mlp.up_proj", cfg),
|
UpProj: linears.Make(prefix + ".mlp.up_proj"),
|
||||||
DownProj: makeLinear(tensors, prefix+".mlp.down_proj", cfg),
|
DownProj: linears.Make(prefix + ".mlp.down_proj"),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Layers[i] = block
|
m.Layers[i] = block
|
||||||
@@ -690,7 +679,7 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
moeGate := &MoEGate{}
|
moeGate := &MoEGate{}
|
||||||
moeGate.Gate = makeLinear(tensors, prefix+".mlp.gate", cfg)
|
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
|
||||||
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
|
||||||
moeGate.EScoreCorrectionBias = bias
|
moeGate.EScoreCorrectionBias = bias
|
||||||
}
|
}
|
||||||
@@ -703,9 +692,9 @@ func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
|||||||
// Load shared experts if present
|
// Load shared experts if present
|
||||||
if cfg.NSharedExperts > 0 {
|
if cfg.NSharedExperts > 0 {
|
||||||
block.MoE.SharedExperts = &SharedExperts{
|
block.MoE.SharedExperts = &SharedExperts{
|
||||||
GateProj: makeLinear(tensors, prefix+".mlp.shared_experts.gate_proj", cfg),
|
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
|
||||||
UpProj: makeLinear(tensors, prefix+".mlp.shared_experts.up_proj", cfg),
|
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
|
||||||
DownProj: makeLinear(tensors, prefix+".mlp.shared_experts.down_proj", cfg),
|
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
323
x/models/llama/llama.go
Normal file
323
x/models/llama/llama.go
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package llama provides a Llama-style decoder-only transformer for MLX.
|
||||||
|
package llama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/models/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
base.Register("LlamaForCausalLM", newModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds Llama model configuration.
|
||||||
|
type Config struct {
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||||
|
VocabSize int32 `json:"vocab_size"`
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||||
|
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||||
|
|
||||||
|
// Quantization parameters (set during load based on model quantization).
|
||||||
|
QuantGroupSize int `json:"-"`
|
||||||
|
QuantBits int `json:"-"`
|
||||||
|
QuantMode string `json:"-"`
|
||||||
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||||
|
|
||||||
|
// Computed fields.
|
||||||
|
HeadDim int32 `json:"-"`
|
||||||
|
Scale float32 `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model is a Llama text model.
|
||||||
|
type Model struct {
|
||||||
|
EmbedTokens *nn.Embedding
|
||||||
|
Layers []*Layer
|
||||||
|
Norm *nn.RMSNorm
|
||||||
|
LMHead nn.LinearLayer
|
||||||
|
|
||||||
|
tok *tokenizer.Tokenizer
|
||||||
|
*Config
|
||||||
|
|
||||||
|
weightPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
Attention *Attention
|
||||||
|
MLP *MLP
|
||||||
|
AttentionNorm *nn.RMSNorm
|
||||||
|
MLPNorm *nn.RMSNorm
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
QProj nn.LinearLayer
|
||||||
|
KProj nn.LinearLayer
|
||||||
|
VProj nn.LinearLayer
|
||||||
|
OProj nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
GateProj nn.LinearLayer
|
||||||
|
UpProj nn.LinearLayer
|
||||||
|
DownProj nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||||
|
for _, prefix := range []string{"", "language_model."} {
|
||||||
|
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(root *model.Root) (base.Model, error) {
|
||||||
|
configData, err := root.Manifest.ReadConfig("config.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HiddenSize <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
if cfg.NumAttentionHeads <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||||
|
}
|
||||||
|
if cfg.NumKeyValueHeads <= 0 {
|
||||||
|
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||||
|
}
|
||||||
|
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||||
|
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||||
|
}
|
||||||
|
if cfg.HeadDim == 0 {
|
||||||
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||||
|
}
|
||||||
|
if cfg.HeadDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||||
|
}
|
||||||
|
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||||
|
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||||
|
}
|
||||||
|
if cfg.RopeTheta == 0 {
|
||||||
|
cfg.RopeTheta = 10000
|
||||||
|
}
|
||||||
|
if cfg.RMSNormEps == 0 {
|
||||||
|
cfg.RMSNormEps = 1e-5
|
||||||
|
}
|
||||||
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
|
||||||
|
if qt := root.QuantType(); qt != "" {
|
||||||
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||||
|
if gs := root.GroupSize(); gs > 0 {
|
||||||
|
cfg.QuantGroupSize = gs
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||||
|
}
|
||||||
|
cfg.TensorQuant = root.AllTensorQuant()
|
||||||
|
|
||||||
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokConfig := &tokenizer.TokenizerConfig{
|
||||||
|
ConfigJSON: configData,
|
||||||
|
}
|
||||||
|
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||||
|
tokConfig.GenerationConfigJSON = genConfigData
|
||||||
|
}
|
||||||
|
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||||
|
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||||
|
Config: &cfg,
|
||||||
|
tok: tok,
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||||
|
// to model fields.
|
||||||
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||||
|
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||||
|
prefix := m.weightPrefix
|
||||||
|
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||||
|
|
||||||
|
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||||
|
if embedWeight == nil {
|
||||||
|
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||||
|
}
|
||||||
|
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||||
|
|
||||||
|
normWeight := tensors[prefix+"model.norm.weight"]
|
||||||
|
if normWeight == nil {
|
||||||
|
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||||
|
}
|
||||||
|
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||||
|
|
||||||
|
if m.TieWordEmbeddings {
|
||||||
|
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||||
|
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||||
|
m.LMHead = lmHead
|
||||||
|
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||||
|
m.LMHead = lmHead
|
||||||
|
} else {
|
||||||
|
// Fallback used by many Llama checkpoints where output is tied.
|
||||||
|
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||||
|
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||||
|
|
||||||
|
layer := &Layer{
|
||||||
|
Attention: &Attention{},
|
||||||
|
MLP: &MLP{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||||
|
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||||
|
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||||
|
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||||
|
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||||
|
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||||
|
|
||||||
|
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||||
|
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||||
|
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||||
|
|
||||||
|
if layer.AttentionNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.MLPNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||||
|
}
|
||||||
|
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Layers[i] = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
collected := mlx.Collect(m)
|
||||||
|
mlx.Eval(collected...)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||||
|
dims := tokens.Dims()
|
||||||
|
B, L := int32(dims[0]), int32(dims[1])
|
||||||
|
|
||||||
|
h := m.EmbedTokens.Forward(tokens)
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
var c cache.Cache
|
||||||
|
if caches != nil && i < len(caches) {
|
||||||
|
c = caches[i]
|
||||||
|
}
|
||||||
|
h = layer.Forward(h, c, B, L, m.Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.Norm.Forward(h, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||||
|
return m.LMHead.Forward(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) NumLayers() int {
|
||||||
|
return len(m.Layers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
|
return m.tok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) NewCaches() []cache.Cache {
|
||||||
|
caches := make([]cache.Cache, len(m.Layers))
|
||||||
|
for i := range caches {
|
||||||
|
caches[i] = cache.NewKVCache()
|
||||||
|
}
|
||||||
|
return caches
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||||
|
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||||
|
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||||
|
q := a.QProj.Forward(x)
|
||||||
|
k := a.KProj.Forward(x)
|
||||||
|
v := a.VProj.Forward(x)
|
||||||
|
|
||||||
|
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
if c != nil {
|
||||||
|
offset = c.Offset()
|
||||||
|
}
|
||||||
|
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
|
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
k, v = c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||||
|
if repeatFactor > 1 {
|
||||||
|
k = nn.RepeatKV(k, repeatFactor)
|
||||||
|
v = nn.RepeatKV(v, repeatFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||||
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||||
|
return a.OProj.Forward(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||||
|
}
|
||||||
338
x/models/qwen3/qwen3.go
Normal file
338
x/models/qwen3/qwen3.go
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
// Package qwen3 provides the Qwen3 text model implementation for MLX.
|
||||||
|
package qwen3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||||
|
"github.com/ollama/ollama/x/models/nn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
base.Register("Qwen3ForCausalLM", newModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds Qwen3 model configuration.
|
||||||
|
type Config struct {
|
||||||
|
HiddenSize int32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize int32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||||
|
VocabSize int32 `json:"vocab_size"`
|
||||||
|
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
HeadDim int32 `json:"head_dim"`
|
||||||
|
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||||
|
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||||
|
|
||||||
|
// Quantization parameters (set during load based on model quantization).
|
||||||
|
QuantGroupSize int `json:"-"`
|
||||||
|
QuantBits int `json:"-"`
|
||||||
|
QuantMode string `json:"-"`
|
||||||
|
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
|
||||||
|
|
||||||
|
// Computed fields.
|
||||||
|
Scale float32 `json:"-"`
|
||||||
|
QKNormEps float32 `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model is the Qwen3 text-only model.
|
||||||
|
type Model struct {
|
||||||
|
EmbedTokens *nn.Embedding
|
||||||
|
Layers []*Layer
|
||||||
|
Norm *nn.RMSNorm
|
||||||
|
LMHead nn.LinearLayer
|
||||||
|
|
||||||
|
tok *tokenizer.Tokenizer
|
||||||
|
*Config
|
||||||
|
|
||||||
|
weightPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Layer is a single Qwen3 decoder block.
|
||||||
|
type Layer struct {
|
||||||
|
Attention *Attention
|
||||||
|
MLP *MLP
|
||||||
|
AttentionNorm *nn.RMSNorm
|
||||||
|
MLPNorm *nn.RMSNorm
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attention implements Qwen3 attention with Q/K norms.
|
||||||
|
type Attention struct {
|
||||||
|
QProj nn.LinearLayer
|
||||||
|
KProj nn.LinearLayer
|
||||||
|
VProj nn.LinearLayer
|
||||||
|
OProj nn.LinearLayer
|
||||||
|
QNorm *nn.RMSNorm
|
||||||
|
KNorm *nn.RMSNorm
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP is the feed-forward network with SwiGLU activation.
|
||||||
|
type MLP struct {
|
||||||
|
GateProj nn.LinearLayer
|
||||||
|
UpProj nn.LinearLayer
|
||||||
|
DownProj nn.LinearLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
|
||||||
|
for _, prefix := range []string{"", "language_model."} {
|
||||||
|
if tensors[prefix+"model.embed_tokens.weight"] != nil {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(root *model.Root) (base.Model, error) {
|
||||||
|
configData, err := root.Manifest.ReadConfig("config.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := json.Unmarshal(configData, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.HiddenSize <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
|
||||||
|
}
|
||||||
|
if cfg.NumAttentionHeads <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
|
||||||
|
}
|
||||||
|
if cfg.NumKeyValueHeads <= 0 {
|
||||||
|
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
|
||||||
|
}
|
||||||
|
if cfg.HeadDim == 0 {
|
||||||
|
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
|
||||||
|
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
|
||||||
|
}
|
||||||
|
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||||
|
}
|
||||||
|
if cfg.HeadDim <= 0 {
|
||||||
|
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
|
||||||
|
}
|
||||||
|
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
|
||||||
|
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
|
||||||
|
}
|
||||||
|
if cfg.RMSNormEps == 0 {
|
||||||
|
cfg.RMSNormEps = 1e-6
|
||||||
|
}
|
||||||
|
if cfg.RopeTheta == 0 {
|
||||||
|
cfg.RopeTheta = 1000000
|
||||||
|
}
|
||||||
|
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||||
|
cfg.QKNormEps = 1e-6
|
||||||
|
|
||||||
|
if qt := root.QuantType(); qt != "" {
|
||||||
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
|
||||||
|
if gs := root.GroupSize(); gs > 0 {
|
||||||
|
cfg.QuantGroupSize = gs
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
|
||||||
|
}
|
||||||
|
cfg.TensorQuant = root.AllTensorQuant()
|
||||||
|
|
||||||
|
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load tokenizer config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokConfig := &tokenizer.TokenizerConfig{
|
||||||
|
ConfigJSON: configData,
|
||||||
|
}
|
||||||
|
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
|
||||||
|
tokConfig.GenerationConfigJSON = genConfigData
|
||||||
|
}
|
||||||
|
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
|
||||||
|
tokConfig.TokenizerConfigJSON = tokConfigData
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tokenizer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
Layers: make([]*Layer, cfg.NumHiddenLayers),
|
||||||
|
Config: &cfg,
|
||||||
|
tok: tok,
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWeights receives all tensors loaded from the manifest and assigns them
|
||||||
|
// to model fields.
|
||||||
|
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
|
||||||
|
m.weightPrefix = resolveWeightPrefix(tensors)
|
||||||
|
prefix := m.weightPrefix
|
||||||
|
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
|
||||||
|
|
||||||
|
embedWeight := tensors[prefix+"model.embed_tokens.weight"]
|
||||||
|
if embedWeight == nil {
|
||||||
|
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
|
||||||
|
}
|
||||||
|
m.EmbedTokens = nn.NewEmbedding(embedWeight)
|
||||||
|
|
||||||
|
normWeight := tensors[prefix+"model.norm.weight"]
|
||||||
|
if normWeight == nil {
|
||||||
|
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
|
||||||
|
}
|
||||||
|
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
|
||||||
|
|
||||||
|
if m.TieWordEmbeddings {
|
||||||
|
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||||
|
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
|
||||||
|
m.LMHead = lmHead
|
||||||
|
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
|
||||||
|
m.LMHead = lmHead
|
||||||
|
} else {
|
||||||
|
// Qwen3 checkpoints commonly tie output projection to embeddings.
|
||||||
|
m.LMHead = nn.NewLinear(embedWeight, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int32(0); i < m.NumHiddenLayers; i++ {
|
||||||
|
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
|
||||||
|
|
||||||
|
layer := &Layer{
|
||||||
|
Attention: &Attention{},
|
||||||
|
MLP: &MLP{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
|
||||||
|
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
|
||||||
|
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
|
||||||
|
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
|
||||||
|
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
|
||||||
|
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
|
||||||
|
|
||||||
|
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
|
||||||
|
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||||
|
}
|
||||||
|
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
|
||||||
|
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
|
||||||
|
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
|
||||||
|
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
|
||||||
|
|
||||||
|
if layer.AttentionNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing input_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.MLPNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
|
||||||
|
}
|
||||||
|
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing attention projections", i)
|
||||||
|
}
|
||||||
|
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing attention q/k norms", i)
|
||||||
|
}
|
||||||
|
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
|
||||||
|
return fmt.Errorf("layer %d: missing mlp projections", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Layers[i] = layer
|
||||||
|
}
|
||||||
|
|
||||||
|
collected := mlx.Collect(m)
|
||||||
|
mlx.Eval(collected...)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||||
|
dims := tokens.Dims()
|
||||||
|
B, L := int32(dims[0]), int32(dims[1])
|
||||||
|
|
||||||
|
h := m.EmbedTokens.Forward(tokens)
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
var c cache.Cache
|
||||||
|
if caches != nil && i < len(caches) {
|
||||||
|
c = caches[i]
|
||||||
|
}
|
||||||
|
h = layer.Forward(h, c, B, L, m.Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.Norm.Forward(h, m.RMSNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
||||||
|
return m.LMHead.Forward(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) NumLayers() int {
|
||||||
|
return len(m.Layers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
|
||||||
|
return m.tok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) NewCaches() []cache.Cache {
|
||||||
|
caches := make([]cache.Cache, len(m.Layers))
|
||||||
|
for i := range caches {
|
||||||
|
caches[i] = cache.NewKVCache()
|
||||||
|
}
|
||||||
|
return caches
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||||
|
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
|
||||||
|
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
|
||||||
|
q := a.QProj.Forward(x)
|
||||||
|
k := a.KProj.Forward(x)
|
||||||
|
v := a.VProj.Forward(x)
|
||||||
|
|
||||||
|
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
|
||||||
|
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||||
|
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
|
||||||
|
|
||||||
|
q = a.QNorm.Forward(q, cfg.QKNormEps)
|
||||||
|
k = a.KNorm.Forward(k, cfg.QKNormEps)
|
||||||
|
|
||||||
|
q = mlx.Transpose(q, 0, 2, 1, 3)
|
||||||
|
k = mlx.Transpose(k, 0, 2, 1, 3)
|
||||||
|
v = mlx.Transpose(v, 0, 2, 1, 3)
|
||||||
|
|
||||||
|
offset := 0
|
||||||
|
if c != nil {
|
||||||
|
offset = c.Offset()
|
||||||
|
}
|
||||||
|
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
|
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
|
||||||
|
|
||||||
|
if c != nil {
|
||||||
|
k, v = c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||||
|
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
|
||||||
|
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||||
|
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||||
|
return a.OProj.Forward(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
|
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -58,7 +59,15 @@ func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildModelInfo(config, totalBytes, tensorCount), nil
|
info := buildModelInfo(config, totalBytes, tensorCount)
|
||||||
|
|
||||||
|
// For quantized models, byte-based estimation can significantly undercount
|
||||||
|
// parameters. Prefer exact counting from tensor shapes in safetensors headers.
|
||||||
|
if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
|
||||||
|
info["general.parameter_count"] = paramCount
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildModelInfo constructs the model info map from config and tensor stats.
|
// buildModelInfo constructs the model info map from config and tensor stats.
|
||||||
@@ -151,6 +160,51 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getParameterCountFromManifest counts model parameters from tensor shapes.
|
||||||
|
// This accounts for quantized tensors by using unpacked shapes from
|
||||||
|
// getTensorInfoFromManifest.
|
||||||
|
func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
|
||||||
|
tensors, err := getTensorInfoFromManifest(mf)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
for _, tensor := range tensors {
|
||||||
|
if len(tensor.Shape) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
elements := int64(1)
|
||||||
|
for _, dim := range tensor.Shape {
|
||||||
|
if dim == 0 {
|
||||||
|
elements = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if dim > uint64(math.MaxInt64) {
|
||||||
|
return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
d := int64(dim)
|
||||||
|
if elements > math.MaxInt64/d {
|
||||||
|
return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
|
||||||
|
}
|
||||||
|
elements *= d
|
||||||
|
}
|
||||||
|
|
||||||
|
if elements == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if total > math.MaxInt64-elements {
|
||||||
|
return 0, fmt.Errorf("total parameter count overflow")
|
||||||
|
}
|
||||||
|
total += elements
|
||||||
|
}
|
||||||
|
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
|
||||||
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
|
||||||
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
|
||||||
|
|||||||
@@ -714,6 +714,187 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetParameterCountFromManifest(t *testing.T) {
|
||||||
|
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||||
|
|
||||||
|
blobDir := filepath.Join(tempDir, "blobs")
|
||||||
|
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create blobs dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unquantized tensor: [4,5] = 20 params
|
||||||
|
header1 := map[string]any{
|
||||||
|
"model.embed_tokens.weight": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{4, 5},
|
||||||
|
"data_offsets": []int64{0, 40},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
header1JSON, _ := json.Marshal(header1)
|
||||||
|
var buf1 bytes.Buffer
|
||||||
|
binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
|
||||||
|
buf1.Write(header1JSON)
|
||||||
|
|
||||||
|
digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
|
||||||
|
blobPath1, err := manifest.BlobsPath(digest1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get blob path: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write blob1: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
|
||||||
|
header2 := map[string]any{
|
||||||
|
"__metadata__": map[string]string{
|
||||||
|
"quant_type": "int4",
|
||||||
|
"group_size": "32",
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.up_proj.weight": map[string]any{
|
||||||
|
"dtype": "U32",
|
||||||
|
"shape": []int64{10, 2},
|
||||||
|
"data_offsets": []int64{0, 80},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.up_proj.weight.scale": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{10, 1},
|
||||||
|
"data_offsets": []int64{80, 100},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.up_proj.weight.bias": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{10, 1},
|
||||||
|
"data_offsets": []int64{100, 120},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
header2JSON, _ := json.Marshal(header2)
|
||||||
|
var buf2 bytes.Buffer
|
||||||
|
binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
|
||||||
|
buf2.Write(header2JSON)
|
||||||
|
|
||||||
|
digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
|
||||||
|
blobPath2, err := manifest.BlobsPath(digest2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get blob path: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write blob2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mf := &manifest.Manifest{
|
||||||
|
SchemaVersion: 2,
|
||||||
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
|
Layers: []manifest.Layer{
|
||||||
|
{
|
||||||
|
MediaType: manifest.MediaTypeImageTensor,
|
||||||
|
Digest: digest1,
|
||||||
|
Size: int64(buf1.Len() + 40),
|
||||||
|
Name: "model.embed_tokens.weight",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MediaType: manifest.MediaTypeImageTensor,
|
||||||
|
Digest: digest2,
|
||||||
|
Size: int64(buf2.Len() + 120),
|
||||||
|
Name: "model.layers.0.mlp.up_proj.weight",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
paramCount, err := getParameterCountFromManifest(mf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const want int64 = 180 // 20 + 160
|
||||||
|
if paramCount != want {
|
||||||
|
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
|
||||||
|
// Create a temp directory for blobs and set OLLAMA_MODELS
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", tempDir)
|
||||||
|
|
||||||
|
blobDir := filepath.Join(tempDir, "blobs")
|
||||||
|
if err := os.MkdirAll(blobDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create blobs dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Packed mixed-precision blob (no global metadata):
|
||||||
|
// - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
|
||||||
|
// - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
|
||||||
|
header := map[string]any{
|
||||||
|
"model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
|
||||||
|
"dtype": "U32",
|
||||||
|
"shape": []int64{5, 8},
|
||||||
|
"data_offsets": []int64{0, 160},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{5, 2},
|
||||||
|
"data_offsets": []int64{160, 180},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{5, 2},
|
||||||
|
"data_offsets": []int64{180, 200},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
|
||||||
|
"dtype": "U32",
|
||||||
|
"shape": []int64{5, 16},
|
||||||
|
"data_offsets": []int64{200, 520},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{5, 1},
|
||||||
|
"data_offsets": []int64{520, 530},
|
||||||
|
},
|
||||||
|
"model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
|
||||||
|
"dtype": "BF16",
|
||||||
|
"shape": []int64{5, 1},
|
||||||
|
"data_offsets": []int64{530, 540},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
headerJSON, _ := json.Marshal(header)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
|
||||||
|
buf.Write(headerJSON)
|
||||||
|
|
||||||
|
digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
|
||||||
|
blobPath, err := manifest.BlobsPath(digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get blob path: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
|
||||||
|
t.Fatalf("failed to write blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mf := &manifest.Manifest{
|
||||||
|
SchemaVersion: 2,
|
||||||
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
|
Layers: []manifest.Layer{
|
||||||
|
{
|
||||||
|
MediaType: manifest.MediaTypeImageTensor,
|
||||||
|
Digest: digest,
|
||||||
|
Size: int64(buf.Len() + 540),
|
||||||
|
Name: "model.layers.0.mlp.experts",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
paramCount, err := getParameterCountFromManifest(mf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getParameterCountFromManifest() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const want int64 = 640 // 320 + 320
|
||||||
|
if paramCount != want {
|
||||||
|
t.Errorf("parameter_count = %d, want %d", paramCount, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseSafetensorsAllHeaders(t *testing.T) {
|
func TestParseSafetensorsAllHeaders(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
Reference in New Issue
Block a user