mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 00:36:11 +02:00
Compare commits
27 Commits
hoyyeva/op
...
pdevine/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bcdb250b9 | ||
|
|
7bbcd2e6be | ||
|
|
22d6c817f8 | ||
|
|
ca01373b28 | ||
|
|
24e038d56a | ||
|
|
5d1021603a | ||
|
|
8e05d734b9 | ||
|
|
05e0f21bec | ||
|
|
ff23dd343f | ||
|
|
123b300af6 | ||
|
|
57653b8e42 | ||
|
|
a50ce61c54 | ||
|
|
2bb7ea00d2 | ||
|
|
55fa80d07a | ||
|
|
b9cb535407 | ||
|
|
031baef094 | ||
|
|
7d271e6dc9 | ||
|
|
c88dae2d6b | ||
|
|
9e3618d663 | ||
|
|
5d920cc6bc | ||
|
|
e585ecd11f | ||
|
|
cdddea0592 | ||
|
|
43f90def04 | ||
|
|
06ae6367bd | ||
|
|
48ad7085c4 | ||
|
|
e1e3cec8d0 | ||
|
|
d3e67e305c |
@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
ollama
|
ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
|
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
|
||||||
|
|
||||||
### Coding
|
### Coding
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ To launch a specific integration:
|
|||||||
ollama launch claude
|
ollama launch claude
|
||||||
```
|
```
|
||||||
|
|
||||||
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
|
||||||
|
|
||||||
### AI assistant
|
### AI assistant
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,12 @@ func TestLaunchCmd(t *testing.T) {
|
|||||||
if cmd.Long == "" {
|
if cmd.Long == "" {
|
||||||
t.Error("Long description should not be empty")
|
t.Error("Long description should not be empty")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(cmd.Long, "hermes") {
|
||||||
|
t.Error("Long description should mention hermes")
|
||||||
|
}
|
||||||
|
if !strings.Contains(cmd.Long, "kimi") {
|
||||||
|
t.Error("Long description should mention kimi")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("flags exist", func(t *testing.T) {
|
t.Run("flags exist", func(t *testing.T) {
|
||||||
|
|||||||
76
cmd/launch/copilot.go
Normal file
76
cmd/launch/copilot.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Copilot implements Runner for GitHub Copilot CLI integration.
|
||||||
|
type Copilot struct{}
|
||||||
|
|
||||||
|
func (c *Copilot) String() string { return "Copilot CLI" }
|
||||||
|
|
||||||
|
func (c *Copilot) args(model string, extra []string) []string {
|
||||||
|
var args []string
|
||||||
|
if model != "" {
|
||||||
|
args = append(args, "--model", model)
|
||||||
|
}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Copilot) findPath() (string, error) {
|
||||||
|
if p, err := exec.LookPath("copilot"); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".local", "bin", name)
|
||||||
|
if _, err := os.Stat(fallback); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Copilot) Run(model string, args []string) error {
|
||||||
|
copilotPath, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(copilotPath, c.args(model, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
cmd.Env = append(os.Environ(), c.envVars(model)...)
|
||||||
|
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// envVars returns the environment variables that configure Copilot CLI
|
||||||
|
// to use Ollama as its model provider.
|
||||||
|
func (c *Copilot) envVars(model string) []string {
|
||||||
|
env := []string{
|
||||||
|
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
|
||||||
|
"COPILOT_PROVIDER_API_KEY=",
|
||||||
|
"COPILOT_PROVIDER_WIRE_API=responses",
|
||||||
|
}
|
||||||
|
|
||||||
|
if model != "" {
|
||||||
|
env = append(env, "COPILOT_MODEL="+model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
161
cmd/launch/copilot_test.go
Normal file
161
cmd/launch/copilot_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCopilotIntegration(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := c.String(); got != "Copilot CLI" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = c
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotFindPath(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
t.Run("finds copilot in PATH", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fakeBin := filepath.Join(tmpDir, name)
|
||||||
|
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fakeBin {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fakeBin)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when not in PATH", func(t *testing.T) {
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
name := "copilot"
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
name = "copilot.exe"
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(tmpDir, ".local", "bin", name)
|
||||||
|
os.MkdirAll(filepath.Dir(fallback), 0o755)
|
||||||
|
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
|
||||||
|
|
||||||
|
got, err := c.findPath()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != fallback {
|
||||||
|
t.Errorf("findPath() = %q, want %q", got, fallback)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
|
||||||
|
|
||||||
|
_, err := c.findPath()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotArgs(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
args []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
|
||||||
|
{"empty model", "", nil, nil},
|
||||||
|
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
|
||||||
|
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := c.args(tt.model, tt.args)
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopilotEnvVars(t *testing.T) {
|
||||||
|
c := &Copilot{}
|
||||||
|
|
||||||
|
envMap := func(envs []string) map[string]string {
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, e := range envs {
|
||||||
|
k, v, _ := strings.Cut(e, "=")
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("sets required provider env vars with model", func(t *testing.T) {
|
||||||
|
got := envMap(c.envVars("llama3.2"))
|
||||||
|
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
|
||||||
|
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
|
||||||
|
}
|
||||||
|
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
|
||||||
|
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
|
||||||
|
}
|
||||||
|
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
|
||||||
|
}
|
||||||
|
if got["COPILOT_MODEL"] != "llama3.2" {
|
||||||
|
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
|
||||||
|
got := envMap(c.envVars(""))
|
||||||
|
if _, ok := got["COPILOT_MODEL"]; ok {
|
||||||
|
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
|
||||||
|
got := envMap(c.envVars("test"))
|
||||||
|
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
|
||||||
|
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
679
cmd/launch/hermes.go
Normal file
679
cmd/launch/hermes.go
Normal file
@@ -0,0 +1,679 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/cmd/config"
|
||||||
|
"github.com/ollama/ollama/cmd/internal/fileutil"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
|
||||||
|
hermesProviderName = "Ollama"
|
||||||
|
hermesProviderKey = "ollama-launch"
|
||||||
|
hermesLegacyKey = "ollama"
|
||||||
|
hermesPlaceholderKey = "ollama"
|
||||||
|
hermesGatewaySetupHint = "hermes gateway setup"
|
||||||
|
hermesGatewaySetupTitle = "Connect a messaging app now?"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
hermesGOOS = runtime.GOOS
|
||||||
|
hermesLookPath = exec.LookPath
|
||||||
|
hermesCommand = exec.Command
|
||||||
|
hermesUserHome = os.UserHomeDir
|
||||||
|
hermesOllamaURL = envconfig.ConnectableHost
|
||||||
|
)
|
||||||
|
|
||||||
|
var hermesMessagingEnvGroups = [][]string{
|
||||||
|
{"TELEGRAM_BOT_TOKEN"},
|
||||||
|
{"DISCORD_BOT_TOKEN"},
|
||||||
|
{"SLACK_BOT_TOKEN"},
|
||||||
|
{"SIGNAL_ACCOUNT"},
|
||||||
|
{"EMAIL_ADDRESS"},
|
||||||
|
{"TWILIO_ACCOUNT_SID"},
|
||||||
|
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
|
||||||
|
{"MATTERMOST_TOKEN"},
|
||||||
|
{"WHATSAPP_PHONE_NUMBER_ID"},
|
||||||
|
{"DINGTALK_CLIENT_ID"},
|
||||||
|
{"FEISHU_APP_ID"},
|
||||||
|
{"WECOM_BOT_ID"},
|
||||||
|
{"WEIXIN_ACCOUNT_ID"},
|
||||||
|
{"BLUEBUBBLES_SERVER_URL"},
|
||||||
|
{"WEBHOOK_ENABLED"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hermes is intentionally not an Editor integration: launch owns one primary
|
||||||
|
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
|
||||||
|
// switching UX after startup.
|
||||||
|
type Hermes struct{}
|
||||||
|
|
||||||
|
func (h *Hermes) String() string { return "Hermes Agent" }
|
||||||
|
|
||||||
|
func (h *Hermes) Run(_ string, args []string) error {
|
||||||
|
// Hermes reads its primary model from config.yaml. launch configures that
|
||||||
|
// default model ahead of time so we can keep runtime invocation simple and
|
||||||
|
// still let Hermes discover additional models later via its own UX.
|
||||||
|
bin, err := h.binary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||||
|
return hermesAttachedCommand(bin, "gateway", "setup").Run()
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return hermesAttachedCommand(bin, args...).Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) Paths() []string {
|
||||||
|
configPath, err := hermesConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{configPath}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) Configure(model string) error {
|
||||||
|
configPath, err := hermesConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := map[string]any{}
|
||||||
|
if data, err := os.ReadFile(configPath); err == nil {
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return fmt.Errorf("parse hermes config: %w", err)
|
||||||
|
}
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelSection, _ := cfg["model"].(map[string]any)
|
||||||
|
if modelSection == nil {
|
||||||
|
modelSection = make(map[string]any)
|
||||||
|
}
|
||||||
|
models := h.listModels(model)
|
||||||
|
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
|
||||||
|
|
||||||
|
// launch writes the minimum provider/default-model settings needed to
|
||||||
|
// bootstrap Hermes against Ollama. The active provider stays on a
|
||||||
|
// launch-owned key so /model stays aligned with the launcher-managed entry,
|
||||||
|
// and the Ollama endpoint lives in providers: so the picker shows one row.
|
||||||
|
modelSection["provider"] = hermesProviderKey
|
||||||
|
modelSection["default"] = model
|
||||||
|
modelSection["base_url"] = hermesBaseURL()
|
||||||
|
modelSection["api_key"] = hermesPlaceholderKey
|
||||||
|
cfg["model"] = modelSection
|
||||||
|
|
||||||
|
// use Hermes' built-in web toolset for now.
|
||||||
|
// TODO(parthsareen): move this to using Ollama web search
|
||||||
|
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
|
||||||
|
|
||||||
|
data, err := yaml.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fileutil.WriteWithBackup(configPath, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) CurrentModel() string {
|
||||||
|
configPath, err := hermesConfigPath()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := map[string]any{}
|
||||||
|
if yaml.Unmarshal(data, &cfg) != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hermesManagedCurrentModel(cfg, hermesBaseURL())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) Onboard() error {
|
||||||
|
return config.MarkIntegrationOnboarded("hermes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) RequiresInteractiveOnboarding() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
|
||||||
|
running, err := h.gatewayRunning()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check Hermes gateway status: %w", err)
|
||||||
|
}
|
||||||
|
if !running {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
|
||||||
|
if err := h.restartGateway(); err != nil {
|
||||||
|
return fmt.Errorf("restart Hermes gateway: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) installed() bool {
|
||||||
|
_, err := h.binary()
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) ensureInstalled() error {
|
||||||
|
if h.installed() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
return hermesWindowsHint()
|
||||||
|
}
|
||||||
|
|
||||||
|
var missing []string
|
||||||
|
for _, dep := range []string{"bash", "curl", "git"} {
|
||||||
|
if _, err := hermesLookPath(dep); err != nil {
|
||||||
|
missing = append(missing, dep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("hermes installation cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
|
||||||
|
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to install hermes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.installed() {
|
||||||
|
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) listModels(defaultModel string) []string {
|
||||||
|
client := hermesOllamaClient()
|
||||||
|
resp, err := client.List(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return []string{defaultModel}
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, 0, len(resp.Models)+1)
|
||||||
|
seen := make(map[string]struct{}, len(resp.Models)+1)
|
||||||
|
add := func(name string) {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := seen[name]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[name] = struct{}{}
|
||||||
|
models = append(models, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
add(defaultModel)
|
||||||
|
for _, entry := range resp.Models {
|
||||||
|
add(entry.Name)
|
||||||
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
return []string{defaultModel}
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) binary() (string, error) {
|
||||||
|
if path, err := hermesLookPath("hermes"); err == nil {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hermesGOOS == "windows" {
|
||||||
|
return "", hermesWindowsHint()
|
||||||
|
}
|
||||||
|
|
||||||
|
home, err := hermesUserHome()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fallback := filepath.Join(home, ".local", "bin", "hermes")
|
||||||
|
if _, err := os.Stat(fallback); err == nil {
|
||||||
|
return fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("hermes is not installed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesConfigPath() (string, error) {
|
||||||
|
home, err := hermesUserHome()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".hermes", "config.yaml"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesBaseURL() string {
|
||||||
|
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesEnvPath() (string, error) {
|
||||||
|
home, err := hermesUserHome()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".hermes", ".env"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
|
||||||
|
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if h.messagingConfigured() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
|
||||||
|
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
|
||||||
|
YesLabel: "Yes",
|
||||||
|
NoLabel: "Set up later",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := runSetup(); err != nil {
|
||||||
|
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) messagingConfigured() bool {
|
||||||
|
envVars, err := h.gatewayEnvVars()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, group := range hermesMessagingEnvGroups {
|
||||||
|
for _, key := range group {
|
||||||
|
if strings.TrimSpace(envVars[key]) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||||
|
envVars := make(map[string]string)
|
||||||
|
|
||||||
|
envFilePath, err := hermesEnvPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch data, err := os.ReadFile(envFilePath); {
|
||||||
|
case err == nil:
|
||||||
|
for key, value := range hermesParseEnvFile(data) {
|
||||||
|
envVars[key] = value
|
||||||
|
}
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
// nothing persisted yet
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range hermesMessagingEnvGroups {
|
||||||
|
for _, key := range group {
|
||||||
|
if value, ok := os.LookupEnv(key); ok {
|
||||||
|
envVars[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return envVars, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) gatewayRunning() (bool, error) {
|
||||||
|
status, err := h.gatewayStatusOutput()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return hermesGatewayStatusRunning(status), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) gatewayStatusOutput() (string, error) {
|
||||||
|
bin, err := h.binary()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
|
||||||
|
return string(out), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hermes) restartGateway() error {
|
||||||
|
bin, err := h.binary()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return hermesAttachedCommand(bin, "gateway", "restart").Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesGatewayStatusRunning(output string) bool {
|
||||||
|
status := strings.ToLower(output)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(status, "gateway is not running"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(status, "gateway service is stopped"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(status, "gateway service is not loaded"):
|
||||||
|
return false
|
||||||
|
case strings.Contains(status, "gateway is running"):
|
||||||
|
return true
|
||||||
|
case strings.Contains(status, "gateway service is running"):
|
||||||
|
return true
|
||||||
|
case strings.Contains(status, "gateway service is loaded"):
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesParseEnvFile(data []byte) map[string]string {
|
||||||
|
out := make(map[string]string)
|
||||||
|
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
|
||||||
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "export ") {
|
||||||
|
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
|
||||||
|
}
|
||||||
|
|
||||||
|
key, value, ok := strings.Cut(line, "=")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if len(value) >= 2 {
|
||||||
|
switch {
|
||||||
|
case value[0] == '"' && value[len(value)-1] == '"':
|
||||||
|
if unquoted, err := strconv.Unquote(value); err == nil {
|
||||||
|
value = unquoted
|
||||||
|
}
|
||||||
|
case value[0] == '\'' && value[len(value)-1] == '\'':
|
||||||
|
value = value[1 : len(value)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesOllamaClient() *api.Client {
|
||||||
|
// Hermes queries the same launch-resolved Ollama host that launch writes
|
||||||
|
// into config, so model discovery follows the configured endpoint.
|
||||||
|
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
|
||||||
|
providers := hermesUserProviders(cfg["providers"])
|
||||||
|
entry := hermesManagedProviderEntry(providers)
|
||||||
|
if entry == nil {
|
||||||
|
entry = make(map[string]any)
|
||||||
|
}
|
||||||
|
entry["name"] = hermesProviderName
|
||||||
|
entry["api"] = baseURL
|
||||||
|
entry["default_model"] = model
|
||||||
|
entry["models"] = hermesStringListAny(models)
|
||||||
|
providers[hermesProviderKey] = entry
|
||||||
|
delete(providers, hermesLegacyKey)
|
||||||
|
cfg["providers"] = providers
|
||||||
|
|
||||||
|
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
|
||||||
|
if len(customProviders) == 0 {
|
||||||
|
delete(cfg, "custom_providers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg["custom_providers"] = customProviders
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
|
||||||
|
modelCfg, _ := cfg["model"].(map[string]any)
|
||||||
|
if modelCfg == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, _ := modelCfg["provider"].(string)
|
||||||
|
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
configBaseURL, _ := modelCfg["base_url"].(string)
|
||||||
|
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
current, _ := modelCfg["default"].(string)
|
||||||
|
current = strings.TrimSpace(current)
|
||||||
|
if current == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := hermesUserProviders(cfg["providers"])
|
||||||
|
entry, _ := providers[hermesProviderKey].(map[string]any)
|
||||||
|
if entry == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
apiURL, _ := entry["api"].(string)
|
||||||
|
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultModel, _ := entry["default_model"].(string)
|
||||||
|
if strings.TrimSpace(defaultModel) != current {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return current
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesUserProviders(current any) map[string]any {
|
||||||
|
switch existing := current.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
out := make(map[string]any, len(existing))
|
||||||
|
for key, value := range existing {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case map[any]any:
|
||||||
|
out := make(map[string]any, len(existing))
|
||||||
|
for key, value := range existing {
|
||||||
|
if s, ok := key.(string); ok {
|
||||||
|
out[s] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return make(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesCustomProviders(current any) []any {
|
||||||
|
switch existing := current.(type) {
|
||||||
|
case []any:
|
||||||
|
return append([]any(nil), existing...)
|
||||||
|
case []map[string]any:
|
||||||
|
out := make([]any, 0, len(existing))
|
||||||
|
for _, entry := range existing {
|
||||||
|
out = append(out, entry)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
|
||||||
|
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
|
||||||
|
if entry, _ := providers[key].(map[string]any); entry != nil {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesWithoutManagedCustomProviders(current any) []any {
|
||||||
|
customProviders := hermesCustomProviders(current)
|
||||||
|
preserved := make([]any, 0, len(customProviders))
|
||||||
|
|
||||||
|
for _, item := range customProviders {
|
||||||
|
entry, _ := item.(map[string]any)
|
||||||
|
if entry == nil {
|
||||||
|
preserved = append(preserved, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hermesManagedCustomProvider(entry) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
preserved = append(preserved, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
return preserved
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesHasManagedCustomProvider(current any) bool {
|
||||||
|
for _, item := range hermesCustomProviders(current) {
|
||||||
|
entry, _ := item.(map[string]any)
|
||||||
|
if entry != nil && hermesManagedCustomProvider(entry) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesManagedCustomProvider(entry map[string]any) bool {
|
||||||
|
name, _ := entry["name"].(string)
|
||||||
|
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesNormalizeURL(raw string) string {
|
||||||
|
return strings.TrimRight(strings.TrimSpace(raw), "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesStringListAny(models []string) []any {
|
||||||
|
out := make([]any, 0, len(models))
|
||||||
|
for _, model := range dedupeModelList(models) {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if model == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, model)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeHermesToolsets(current any) any {
|
||||||
|
added := false
|
||||||
|
switch existing := current.(type) {
|
||||||
|
case []any:
|
||||||
|
out := make([]any, 0, len(existing)+1)
|
||||||
|
for _, item := range existing {
|
||||||
|
out = append(out, item)
|
||||||
|
if s, _ := item.(string); s == "web" {
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !added {
|
||||||
|
out = append(out, "web")
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []string:
|
||||||
|
out := append([]string(nil), existing...)
|
||||||
|
if !slices.Contains(out, "web") {
|
||||||
|
out = append(out, "web")
|
||||||
|
}
|
||||||
|
asAny := make([]any, 0, len(out))
|
||||||
|
for _, item := range out {
|
||||||
|
asAny = append(asAny, item)
|
||||||
|
}
|
||||||
|
return asAny
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(existing) == "" {
|
||||||
|
return []any{"hermes-cli", "web"}
|
||||||
|
}
|
||||||
|
parts := strings.Split(existing, ",")
|
||||||
|
out := make([]any, 0, len(parts)+1)
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if part == "web" {
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
out = append(out, part)
|
||||||
|
}
|
||||||
|
if !added {
|
||||||
|
out = append(out, "web")
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
default:
|
||||||
|
return []any{"hermes-cli", "web"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
|
||||||
|
cmd := hermesCommand(name, args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func hermesWindowsHint() error {
|
||||||
|
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
|
||||||
|
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
|
||||||
|
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
|
||||||
|
}
|
||||||
1110
cmd/launch/hermes_test.go
Normal file
1110
cmd/launch/hermes_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -54,6 +54,7 @@ func TestIntegrationLookup(t *testing.T) {
|
|||||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||||
{"codex", "codex", true, "Codex"},
|
{"codex", "codex", true, "Codex"},
|
||||||
|
{"kimi", "kimi", true, "Kimi Code CLI"},
|
||||||
{"droid", "droid", true, "Droid"},
|
{"droid", "droid", true, "Droid"},
|
||||||
{"opencode", "opencode", true, "OpenCode"},
|
{"opencode", "opencode", true, "OpenCode"},
|
||||||
{"unknown integration", "unknown", false, ""},
|
{"unknown integration", "unknown", false, ""},
|
||||||
@@ -74,7 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegrationRegistry(t *testing.T) {
|
func TestIntegrationRegistry(t *testing.T) {
|
||||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
|
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
|
||||||
|
|
||||||
for _, name := range expectedIntegrations {
|
for _, name := range expectedIntegrations {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
@@ -89,6 +90,15 @@ func TestIntegrationRegistry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHiddenIntegrationsExcludedFromVisibleLists(t *testing.T) {
|
||||||
|
for _, info := range ListIntegrationInfos() {
|
||||||
|
switch info.Name {
|
||||||
|
case "cline", "vscode", "kimi":
|
||||||
|
t.Fatalf("hidden integration %q should not appear in ListIntegrationInfos", info.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHasLocalModel(t *testing.T) {
|
func TestHasLocalModel(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -329,7 +339,7 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
func TestBuildModelList_OnlyLocalModels_CloudRecsStillFirst(t *testing.T) {
|
||||||
existing := []modelInfo{
|
existing := []modelInfo{
|
||||||
{Name: "llama3.2:latest", Remote: false},
|
{Name: "llama3.2:latest", Remote: false},
|
||||||
{Name: "qwen2.5:latest", Remote: false},
|
{Name: "qwen2.5:latest", Remote: false},
|
||||||
@@ -338,10 +348,11 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
|
|||||||
items, _, _, _ := buildModelList(existing, nil, "")
|
items, _, _, _ := buildModelList(existing, nil, "")
|
||||||
got := names(items)
|
got := names(items)
|
||||||
|
|
||||||
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
|
// Cloud recs always come first among recommended, regardless of installed inventory.
|
||||||
want := []string{"gemma4", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"}
|
// Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems.
|
||||||
|
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2", "qwen2.5"}
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
|
t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -588,7 +599,7 @@ func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
func TestBuildModelList_OnlyLocal_CloudRecsStillFirst(t *testing.T) {
|
||||||
existing := []modelInfo{
|
existing := []modelInfo{
|
||||||
{Name: "llama3.2:latest", Remote: false},
|
{Name: "llama3.2:latest", Remote: false},
|
||||||
}
|
}
|
||||||
@@ -596,11 +607,11 @@ func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
|
|||||||
items, _, _, _ := buildModelList(existing, nil, "")
|
items, _, _, _ := buildModelList(existing, nil, "")
|
||||||
got := names(items)
|
got := names(items)
|
||||||
|
|
||||||
// Local recs should sort before cloud recs in only-local case
|
// Cloud recs sort before local recs regardless of installed inventory.
|
||||||
localIdx := slices.Index(got, "gemma4")
|
localIdx := slices.Index(got, "gemma4")
|
||||||
cloudIdx := slices.Index(got, "glm-5.1:cloud")
|
cloudIdx := slices.Index(got, "glm-5.1:cloud")
|
||||||
if localIdx > cloudIdx {
|
if cloudIdx > localIdx {
|
||||||
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
|
t.Errorf("cloud recs should be before local recs even when only local models installed, got %v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1509,27 +1520,13 @@ func TestListIntegrationInfos(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("sorted with custom order at end", func(t *testing.T) {
|
t.Run("follows launcher order", func(t *testing.T) {
|
||||||
// integrationOrder entries (cline, opencode) should appear last, in that order.
|
got := make([]string, 0, len(infos))
|
||||||
// All other entries should be sorted alphabetically before them.
|
for _, info := range infos {
|
||||||
orderRank := make(map[string]int)
|
got = append(got, info.Name)
|
||||||
for i, name := range integrationOrder {
|
|
||||||
orderRank[name] = i + 1
|
|
||||||
}
|
}
|
||||||
for i := 1; i < len(infos); i++ {
|
if diff := compareStrings(got, integrationOrder); diff != "" {
|
||||||
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
|
t.Fatalf("launcher integration order mismatch: %s", diff)
|
||||||
switch {
|
|
||||||
case aRank == 0 && bRank == 0:
|
|
||||||
if infos[i-1].Name >= infos[i].Name {
|
|
||||||
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
|
|
||||||
}
|
|
||||||
case aRank > 0 && bRank == 0:
|
|
||||||
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
|
|
||||||
case aRank > 0 && bRank > 0:
|
|
||||||
if aRank >= bRank {
|
|
||||||
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1557,6 +1554,28 @@ func TestListIntegrationInfos(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("includes hermes", func(t *testing.T) {
|
||||||
|
for _, info := range infos {
|
||||||
|
if info.Name == "hermes" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatal("expected hermes to be included in ListIntegrationInfos")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("hermes still resolves explicitly", func(t *testing.T) {
|
||||||
|
name, runner, err := LookupIntegration("hermes")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected explicit hermes integration lookup to work, got %v", err)
|
||||||
|
}
|
||||||
|
if name != "hermes" {
|
||||||
|
t.Fatalf("expected canonical name hermes, got %q", name)
|
||||||
|
}
|
||||||
|
if runner.String() == "" {
|
||||||
|
t.Fatal("expected hermes integration runner to be present")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBuildModelList_Descriptions(t *testing.T) {
|
func TestBuildModelList_Descriptions(t *testing.T) {
|
||||||
@@ -1645,6 +1664,7 @@ func TestIntegration_AutoInstallable(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{"openclaw", true},
|
{"openclaw", true},
|
||||||
{"pi", true},
|
{"pi", true},
|
||||||
|
{"hermes", true},
|
||||||
{"claude", false},
|
{"claude", false},
|
||||||
{"codex", false},
|
{"codex", false},
|
||||||
{"opencode", false},
|
{"opencode", false},
|
||||||
|
|||||||
315
cmd/launch/kimi.go
Normal file
315
cmd/launch/kimi.go
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Kimi implements Runner for Kimi Code CLI integration.
|
||||||
|
type Kimi struct{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
kimiDefaultModelAlias = "ollama"
|
||||||
|
kimiDefaultMaxContextSize = 32768
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kimiGOOS = runtime.GOOS
|
||||||
|
kimiModelShowTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func (k *Kimi) String() string { return "Kimi Code CLI" }
|
||||||
|
|
||||||
|
func (k *Kimi) args(config string, extra []string) []string {
|
||||||
|
args := []string{"--config", config}
|
||||||
|
args = append(args, extra...)
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *Kimi) Run(model string, args []string) error {
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
return fmt.Errorf("model is required")
|
||||||
|
}
|
||||||
|
if err := validateKimiPassthroughArgs(args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to build kimi config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := ensureKimiInstalled()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command(bin, k.args(config, args)...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
return cmd.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func findKimiBinary() (string, error) {
|
||||||
|
if path, err := exec.LookPath("kimi"); err == nil {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
|
||||||
|
var candidates []string
|
||||||
|
switch kimiGOOS {
|
||||||
|
case "windows":
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
|
||||||
|
|
||||||
|
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
|
||||||
|
}
|
||||||
|
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
candidates = append(candidates,
|
||||||
|
filepath.Join(home, ".local", "bin", "kimi"),
|
||||||
|
filepath.Join(home, "bin", "kimi"),
|
||||||
|
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
|
||||||
|
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
|
||||||
|
candidates = append(candidates,
|
||||||
|
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
|
||||||
|
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSL users can inherit Windows env vars while launching from Linux shells.
|
||||||
|
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
|
||||||
|
}
|
||||||
|
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
|
||||||
|
}
|
||||||
|
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
|
||||||
|
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
|
||||||
|
return candidate, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("kimi binary not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
|
||||||
|
if strings.TrimSpace(dir) == "" {
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(candidates,
|
||||||
|
filepath.Join(dir, "kimi.exe"),
|
||||||
|
filepath.Join(dir, "kimi.cmd"),
|
||||||
|
filepath.Join(dir, "kimi.bat"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func windowsPathToWSL(path string) string {
|
||||||
|
trimmed := strings.TrimSpace(path)
|
||||||
|
if len(trimmed) < 3 || trimmed[1] != ':' {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
drive := strings.ToLower(string(trimmed[0]))
|
||||||
|
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
|
||||||
|
rest = strings.TrimPrefix(rest, "/")
|
||||||
|
if rest == "" {
|
||||||
|
return filepath.Join("/mnt", drive)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join("/mnt", drive, rest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateKimiPassthroughArgs(args []string) error {
|
||||||
|
for _, arg := range args {
|
||||||
|
switch {
|
||||||
|
case arg == "--config", strings.HasPrefix(arg, "--config="):
|
||||||
|
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
|
||||||
|
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
|
||||||
|
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
|
||||||
|
case arg == "--model", strings.HasPrefix(arg, "--model="):
|
||||||
|
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
|
||||||
|
case arg == "-m", strings.HasPrefix(arg, "-m="):
|
||||||
|
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
|
||||||
|
cfg := map[string]any{
|
||||||
|
"default_model": kimiDefaultModelAlias,
|
||||||
|
"providers": map[string]any{
|
||||||
|
kimiDefaultModelAlias: map[string]any{
|
||||||
|
"type": "openai_legacy",
|
||||||
|
"base_url": envconfig.ConnectableHost().String() + "/v1",
|
||||||
|
"api_key": "ollama",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"models": map[string]any{
|
||||||
|
kimiDefaultModelAlias: map[string]any{
|
||||||
|
"provider": kimiDefaultModelAlias,
|
||||||
|
"model": model,
|
||||||
|
"max_context_size": maxContextSize,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveKimiMaxContextSize(model string) int {
|
||||||
|
if l, ok := lookupCloudModelLimit(model); ok {
|
||||||
|
return l.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return kimiDefaultMaxContextSize
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||||
|
if err != nil {
|
||||||
|
return kimiDefaultMaxContextSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
return kimiDefaultMaxContextSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
|
||||||
|
for key, val := range modelInfo {
|
||||||
|
if !strings.HasSuffix(key, ".context_length") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch v := val.(type) {
|
||||||
|
case float64:
|
||||||
|
if v > 0 {
|
||||||
|
return int(v), true
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
if v > 0 {
|
||||||
|
return v, true
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if v > 0 {
|
||||||
|
return int(v), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureKimiInstalled() (string, error) {
|
||||||
|
if path, err := findKimiBinary(); err == nil {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := checkKimiInstallerDependencies(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("kimi installation cancelled")
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, args, err := kimiInstallerCommand(kimiGOOS)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
|
||||||
|
cmd := exec.Command(bin, args...)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to install kimi: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := findKimiBinary()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkKimiInstallerDependencies() error {
|
||||||
|
switch kimiGOOS {
|
||||||
|
case "windows":
|
||||||
|
if _, err := exec.LookPath("powershell"); err != nil {
|
||||||
|
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var missing []string
|
||||||
|
if _, err := exec.LookPath("curl"); err != nil {
|
||||||
|
missing = append(missing, "curl: https://curl.se/")
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("bash"); err != nil {
|
||||||
|
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func kimiInstallerCommand(goos string) (string, []string, error) {
|
||||||
|
switch goos {
|
||||||
|
case "windows":
|
||||||
|
return "powershell", []string{
|
||||||
|
"-NoProfile",
|
||||||
|
"-ExecutionPolicy",
|
||||||
|
"Bypass",
|
||||||
|
"-Command",
|
||||||
|
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
|
||||||
|
}, nil
|
||||||
|
case "darwin", "linux":
|
||||||
|
return "bash", []string{
|
||||||
|
"-c",
|
||||||
|
"curl -LsSf https://code.kimi.com/install.sh | bash",
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
|
||||||
|
}
|
||||||
|
}
|
||||||
636
cmd/launch/kimi_test.go
Normal file
636
cmd/launch/kimi_test.go
Normal file
@@ -0,0 +1,636 @@
|
|||||||
|
package launch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertKimiBinPath(t *testing.T, bin string) {
|
||||||
|
t.Helper()
|
||||||
|
base := strings.ToLower(filepath.Base(bin))
|
||||||
|
if !strings.HasPrefix(base, "kimi") {
|
||||||
|
t.Fatalf("bin = %q, want path to kimi executable", bin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKimiIntegration(t *testing.T) {
|
||||||
|
k := &Kimi{}
|
||||||
|
|
||||||
|
t.Run("String", func(t *testing.T) {
|
||||||
|
if got := k.String(); got != "Kimi Code CLI" {
|
||||||
|
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("implements Runner", func(t *testing.T) {
|
||||||
|
var _ Runner = k
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKimiArgs(t *testing.T) {
|
||||||
|
k := &Kimi{}
|
||||||
|
|
||||||
|
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
|
||||||
|
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("args() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWindowsPathToWSL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user profile path",
|
||||||
|
in: `C:\Users\parth`,
|
||||||
|
want: filepath.Join("/mnt", "c", "Users", "parth"),
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path with trailing slash",
|
||||||
|
in: `D:\tools\bin\`,
|
||||||
|
want: filepath.Join("/mnt", "d", "tools", "bin"),
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non windows path",
|
||||||
|
in: "/home/parth",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
in: "",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := windowsPathToWSL(tt.in)
|
||||||
|
if !tt.valid {
|
||||||
|
if got != "" {
|
||||||
|
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindKimiBinaryFallbacks(t *testing.T) {
|
||||||
|
oldGOOS := kimiGOOS
|
||||||
|
t.Cleanup(func() { kimiGOOS = oldGOOS })
|
||||||
|
|
||||||
|
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
|
||||||
|
homeDir := t.TempDir()
|
||||||
|
setTestHome(t, homeDir)
|
||||||
|
t.Setenv("PATH", t.TempDir())
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
|
||||||
|
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create candidate dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write kimi candidate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := findKimiBinary()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findKimiBinary() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != target {
|
||||||
|
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("windows appdata uv bin", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
t.Setenv("PATH", t.TempDir())
|
||||||
|
kimiGOOS = "windows"
|
||||||
|
|
||||||
|
appDataDir := t.TempDir()
|
||||||
|
t.Setenv("APPDATA", appDataDir)
|
||||||
|
t.Setenv("LOCALAPPDATA", "")
|
||||||
|
|
||||||
|
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to create candidate dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write kimi candidate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := findKimiBinary()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findKimiBinary() error = %v", err)
|
||||||
|
}
|
||||||
|
if got != target {
|
||||||
|
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args []string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
|
||||||
|
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
|
||||||
|
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
|
||||||
|
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
|
||||||
|
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
|
||||||
|
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
|
||||||
|
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
|
||||||
|
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateKimiPassthroughArgs(tt.args)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for args %v", tt.args)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tt.want) {
|
||||||
|
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKimiInlineConfig(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
|
||||||
|
|
||||||
|
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildKimiInlineConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||||
|
t.Fatalf("config is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed["default_model"] != "ollama" {
|
||||||
|
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, ok := parsed["providers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
|
||||||
|
}
|
||||||
|
ollamaProvider, ok := providers["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
|
||||||
|
}
|
||||||
|
if ollamaProvider["type"] != "openai_legacy" {
|
||||||
|
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
|
||||||
|
}
|
||||||
|
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
|
||||||
|
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
|
||||||
|
}
|
||||||
|
if ollamaProvider["api_key"] != "ollama" {
|
||||||
|
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
models, ok := parsed["models"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("models missing or wrong type: %T", parsed["models"])
|
||||||
|
}
|
||||||
|
ollamaModel, ok := models["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
|
||||||
|
}
|
||||||
|
if ollamaModel["provider"] != "ollama" {
|
||||||
|
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
|
||||||
|
}
|
||||||
|
if ollamaModel["model"] != "llama3.2" {
|
||||||
|
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
|
||||||
|
}
|
||||||
|
if ollamaModel["max_context_size"] != float64(65536) {
|
||||||
|
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
|
||||||
|
|
||||||
|
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildKimiInlineConfig() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||||
|
t.Fatalf("config is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, ok := parsed["providers"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
|
||||||
|
}
|
||||||
|
|
||||||
|
ollamaProvider, ok := providers["ollama"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
|
||||||
|
}
|
||||||
|
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
|
||||||
|
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveKimiMaxContextSize(t *testing.T) {
|
||||||
|
t.Run("uses cloud limit when known", func(t *testing.T) {
|
||||||
|
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
|
||||||
|
if got != 262_144 {
|
||||||
|
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses model show context length for local models", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/show" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
got := resolveKimiMaxContextSize("llama3.2")
|
||||||
|
if got != 131_072 {
|
||||||
|
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to default when show fails", func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
oldTimeout := kimiModelShowTimeout
|
||||||
|
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
|
||||||
|
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
|
||||||
|
|
||||||
|
got := resolveKimiMaxContextSize("llama3.2")
|
||||||
|
if got != kimiDefaultMaxContextSize {
|
||||||
|
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
|
||||||
|
k := &Kimi{}
|
||||||
|
|
||||||
|
oldConfirm := DefaultConfirmPrompt
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
t.Fatalf("did not expect install prompt, got %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||||
|
|
||||||
|
err := k.Run("llama3.2", []string{"--model", "other"})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "--model") {
|
||||||
|
t.Fatalf("expected conflict error mentioning --model, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("uses POSIX shell fake binary")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setTestHome(t, tmpDir)
|
||||||
|
logPath := filepath.Join(tmpDir, "kimi-args.log")
|
||||||
|
script := fmt.Sprintf(`#!/bin/sh
|
||||||
|
for arg in "$@"; do
|
||||||
|
printf "%%s\n" "$arg" >> %q
|
||||||
|
done
|
||||||
|
exit 0
|
||||||
|
`, logPath)
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write fake kimi: %v", err)
|
||||||
|
}
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
k := &Kimi{}
|
||||||
|
if err := k.Run("llama3.2", []string{"--quiet", "--print"}); err != nil {
|
||||||
|
t.Fatalf("Run() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(logPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read args log: %v", err)
|
||||||
|
}
|
||||||
|
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||||
|
if len(lines) < 4 {
|
||||||
|
t.Fatalf("expected at least 4 args, got %v", lines)
|
||||||
|
}
|
||||||
|
if lines[0] != "--config" {
|
||||||
|
t.Fatalf("first arg = %q, want --config", lines[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
|
||||||
|
t.Fatalf("config arg is not valid JSON: %v", err)
|
||||||
|
}
|
||||||
|
providers := cfg["providers"].(map[string]any)
|
||||||
|
ollamaProvider := providers["ollama"].(map[string]any)
|
||||||
|
if ollamaProvider["type"] != "openai_legacy" {
|
||||||
|
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if lines[2] != "--quiet" || lines[3] != "--print" {
|
||||||
|
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureKimiInstalled(t *testing.T) {
|
||||||
|
oldGOOS := kimiGOOS
|
||||||
|
t.Cleanup(func() { kimiGOOS = oldGOOS })
|
||||||
|
|
||||||
|
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
|
||||||
|
t.Helper()
|
||||||
|
oldConfirm := DefaultConfirmPrompt
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return fn(prompt)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("already installed", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
writeFakeBinary(t, tmpDir, "kimi")
|
||||||
|
kimiGOOS = runtime.GOOS
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("did not expect prompt, got %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
bin, err := ensureKimiInstalled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||||
|
}
|
||||||
|
assertKimiBinPath(t, bin)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing dependencies", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
t.Fatalf("did not expect prompt, got %q", prompt)
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := ensureKimiInstalled()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
|
||||||
|
t.Fatalf("expected missing dependency error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing and user declines install", func(t *testing.T) {
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
writeFakeBinary(t, tmpDir, "curl")
|
||||||
|
writeFakeBinary(t, tmpDir, "bash")
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
if !strings.Contains(prompt, "Kimi is not installed.") {
|
||||||
|
t.Fatalf("unexpected prompt: %q", prompt)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := ensureKimiInstalled()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
|
||||||
|
t.Fatalf("expected cancellation error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("uses POSIX shell fake binaries")
|
||||||
|
}
|
||||||
|
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
|
||||||
|
writeFakeBinary(t, tmpDir, "curl")
|
||||||
|
|
||||||
|
installLog := filepath.Join(tmpDir, "bash.log")
|
||||||
|
kimiPath := filepath.Join(tmpDir, "kimi")
|
||||||
|
bashScript := fmt.Sprintf(`#!/bin/sh
|
||||||
|
echo "$@" >> %q
|
||||||
|
if [ "$1" = "-c" ]; then
|
||||||
|
/bin/cat > %q <<'EOS'
|
||||||
|
#!/bin/sh
|
||||||
|
exit 0
|
||||||
|
EOS
|
||||||
|
/bin/chmod +x %q
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
`, installLog, kimiPath, kimiPath)
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write fake bash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
bin, err := ensureKimiInstalled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||||
|
}
|
||||||
|
assertKimiBinPath(t, bin)
|
||||||
|
|
||||||
|
logData, err := os.ReadFile(installLog)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read install log: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
|
||||||
|
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("uses POSIX shell fake binaries")
|
||||||
|
}
|
||||||
|
|
||||||
|
homeDir := t.TempDir()
|
||||||
|
setTestHome(t, homeDir)
|
||||||
|
|
||||||
|
tmpBin := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpBin)
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
writeFakeBinary(t, tmpBin, "curl")
|
||||||
|
|
||||||
|
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
|
||||||
|
bashScript := fmt.Sprintf(`#!/bin/sh
|
||||||
|
if [ "$1" = "-c" ]; then
|
||||||
|
/bin/mkdir -p %q
|
||||||
|
/bin/cat > %q <<'EOS'
|
||||||
|
#!/bin/sh
|
||||||
|
exit 0
|
||||||
|
EOS
|
||||||
|
/bin/chmod +x %q
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write fake bash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
bin, err := ensureKimiInstalled()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||||
|
}
|
||||||
|
if bin != installedKimi {
|
||||||
|
t.Fatalf("bin = %q, want %q", bin, installedKimi)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("install command fails", func(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("uses POSIX shell fake binaries")
|
||||||
|
}
|
||||||
|
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
writeFakeBinary(t, tmpDir, "curl")
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write fake bash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := ensureKimiInstalled()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
|
||||||
|
t.Fatalf("expected install failure error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("uses POSIX shell fake binaries")
|
||||||
|
}
|
||||||
|
|
||||||
|
setTestHome(t, t.TempDir())
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
t.Setenv("PATH", tmpDir)
|
||||||
|
kimiGOOS = "linux"
|
||||||
|
writeFakeBinary(t, tmpDir, "curl")
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||||
|
t.Fatalf("failed to write fake bash: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
withConfirm(t, func(prompt string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := ensureKimiInstalled()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
|
||||||
|
t.Fatalf("expected PATH guidance error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKimiInstallerCommand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
goos string
|
||||||
|
wantBin string
|
||||||
|
wantParts []string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "linux",
|
||||||
|
goos: "linux",
|
||||||
|
wantBin: "bash",
|
||||||
|
wantParts: []string{"-c", "install.sh"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "darwin",
|
||||||
|
goos: "darwin",
|
||||||
|
wantBin: "bash",
|
||||||
|
wantParts: []string{"-c", "install.sh"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "windows",
|
||||||
|
goos: "windows",
|
||||||
|
wantBin: "powershell",
|
||||||
|
wantParts: []string{"-Command", "install.ps1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported",
|
||||||
|
goos: "freebsd",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
bin, args, err := kimiInstallerCommand(tt.goos)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("kimiInstallerCommand() error = %v", err)
|
||||||
|
}
|
||||||
|
if bin != tt.wantBin {
|
||||||
|
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
|
||||||
|
}
|
||||||
|
joined := strings.Join(args, " ")
|
||||||
|
for _, part := range tt.wantParts {
|
||||||
|
if !strings.Contains(joined, part) {
|
||||||
|
t.Fatalf("args %q missing %q", joined, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -141,6 +141,36 @@ type Editor interface {
|
|||||||
Models() []string
|
Models() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ManagedSingleModel is the narrow launch-owned config path for integrations
|
||||||
|
// like Hermes that have one primary model selected by launcher, need launcher
|
||||||
|
// to persist minimal config, and still keep their own model discovery and
|
||||||
|
// onboarding UX. This stays separate from Runner-only integrations and the
|
||||||
|
// multi-model Editor flow so Hermes-specific behavior stays scoped to one path.
|
||||||
|
type ManagedSingleModel interface {
|
||||||
|
Paths() []string
|
||||||
|
Configure(model string) error
|
||||||
|
CurrentModel() string
|
||||||
|
Onboard() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedRuntimeRefresher lets managed integrations refresh any long-lived
|
||||||
|
// background runtime after launch rewrites their config.
|
||||||
|
type ManagedRuntimeRefresher interface {
|
||||||
|
RefreshRuntimeAfterConfigure() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedOnboardingValidator lets managed integrations re-check saved
|
||||||
|
// onboarding state when launcher needs a stronger live readiness signal.
|
||||||
|
type ManagedOnboardingValidator interface {
|
||||||
|
OnboardingComplete() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManagedInteractiveOnboarding lets a managed integration declare whether its
|
||||||
|
// onboarding step really requires an interactive terminal. Hermes does not.
|
||||||
|
type ManagedInteractiveOnboarding interface {
|
||||||
|
RequiresInteractiveOnboarding() bool
|
||||||
|
}
|
||||||
|
|
||||||
type modelInfo struct {
|
type modelInfo struct {
|
||||||
Name string
|
Name string
|
||||||
Remote bool
|
Remote bool
|
||||||
@@ -176,7 +206,10 @@ Supported integrations:
|
|||||||
claude Claude Code
|
claude Claude Code
|
||||||
cline Cline
|
cline Cline
|
||||||
codex Codex
|
codex Codex
|
||||||
|
copilot Copilot CLI (aliases: copilot-cli)
|
||||||
droid Droid
|
droid Droid
|
||||||
|
hermes Hermes Agent
|
||||||
|
kimi Kimi Code CLI
|
||||||
opencode OpenCode
|
opencode OpenCode
|
||||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||||
pi Pi
|
pi Pi
|
||||||
@@ -186,6 +219,7 @@ Examples:
|
|||||||
ollama launch
|
ollama launch
|
||||||
ollama launch claude
|
ollama launch claude
|
||||||
ollama launch claude --model <model>
|
ollama launch claude --model <model>
|
||||||
|
ollama launch hermes
|
||||||
ollama launch droid --config (does not auto-launch)
|
ollama launch droid --config (does not auto-launch)
|
||||||
ollama launch codex -- -p myprofile (pass extra args to integration)
|
ollama launch codex -- -p myprofile (pass extra args to integration)
|
||||||
ollama launch codex -- --sandbox workspace-write`,
|
ollama launch codex -- --sandbox workspace-write`,
|
||||||
@@ -308,36 +342,54 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy := launchIntegrationPolicy(req)
|
||||||
|
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
||||||
|
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
launchClient, saved, err := prepareIntegrationLaunch(name, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if managed, ok := runner.(ManagedSingleModel); ok {
|
||||||
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return launchClient.launchManagedSingleIntegration(ctx, name, runner, managed, saved, req)
|
||||||
|
}
|
||||||
|
|
||||||
if !req.ConfigureOnly {
|
if !req.ConfigureOnly {
|
||||||
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
if err := EnsureIntegrationInstalled(name, runner); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var policy LaunchPolicy
|
|
||||||
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
|
|
||||||
if req.Policy == nil {
|
|
||||||
policy = defaultLaunchPolicy(isInteractiveSession(), false)
|
|
||||||
} else {
|
|
||||||
policy = *req.Policy
|
|
||||||
}
|
|
||||||
|
|
||||||
launchClient, err := newLauncherClient(policy)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
saved, _ := loadStoredIntegrationConfig(name)
|
|
||||||
// In headless --yes mode we cannot prompt, so require an explicit --model.
|
|
||||||
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
|
|
||||||
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if editor, ok := runner.(Editor); ok {
|
if editor, ok := runner.(Editor); ok {
|
||||||
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
|
||||||
}
|
}
|
||||||
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func launchIntegrationPolicy(req IntegrationLaunchRequest) LaunchPolicy {
|
||||||
|
// TUI does not set a policy, whereas ollama launch <app> does as it can
|
||||||
|
// have flags which change the behavior.
|
||||||
|
if req.Policy != nil {
|
||||||
|
return *req.Policy
|
||||||
|
}
|
||||||
|
return defaultLaunchPolicy(isInteractiveSession(), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareIntegrationLaunch(name string, policy LaunchPolicy) (*launcherClient, *config.IntegrationConfig, error) {
|
||||||
|
launchClient, err := newLauncherClient(policy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
saved, _ := loadStoredIntegrationConfig(name)
|
||||||
|
return launchClient, saved, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
|
||||||
_ = c.loadModelInventoryOnce(ctx)
|
_ = c.loadModelInventoryOnce(ctx)
|
||||||
|
|
||||||
@@ -368,9 +420,18 @@ func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return LauncherIntegrationState{}, err
|
return LauncherIntegrationState{}, err
|
||||||
}
|
}
|
||||||
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor)
|
var currentModel string
|
||||||
if err != nil {
|
var usable bool
|
||||||
return LauncherIntegrationState{}, err
|
if managed, ok := integration.spec.Runner.(ManagedSingleModel); ok {
|
||||||
|
currentModel, usable, err = c.launcherManagedModelState(ctx, info.Name, managed)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currentModel, usable, err = c.launcherModelState(ctx, info.Name, integration.editor)
|
||||||
|
if err != nil {
|
||||||
|
return LauncherIntegrationState{}, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return LauncherIntegrationState{
|
return LauncherIntegrationState{
|
||||||
@@ -408,6 +469,28 @@ func (c *launcherClient) launcherModelState(ctx context.Context, name string, is
|
|||||||
return model, usableErr == nil && usable, nil
|
return model, usableErr == nil && usable, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launcherManagedModelState(ctx context.Context, name string, managed ManagedSingleModel) (string, bool, error) {
|
||||||
|
current := managed.CurrentModel()
|
||||||
|
if current == "" {
|
||||||
|
cfg, loadErr := loadStoredIntegrationConfig(name)
|
||||||
|
if loadErr == nil {
|
||||||
|
current = primaryModelFromConfig(cfg)
|
||||||
|
}
|
||||||
|
if current != "" {
|
||||||
|
return current, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if current == "" {
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
usable, err := c.savedModelUsable(ctx, current)
|
||||||
|
if err != nil {
|
||||||
|
return current, false, err
|
||||||
|
}
|
||||||
|
return current, usable, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
|
||||||
current := config.LastModel()
|
current := config.LastModel()
|
||||||
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
|
||||||
@@ -444,35 +527,15 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
current := primaryModelFromConfig(saved)
|
target, _, err := c.resolveSingleIntegrationTarget(ctx, runner, primaryModelFromConfig(saved), req)
|
||||||
target := req.ModelOverride
|
if err != nil {
|
||||||
needsConfigure := req.ForceConfigure
|
|
||||||
|
|
||||||
if target == "" {
|
|
||||||
target = current
|
|
||||||
usable, err := c.savedModelUsable(ctx, target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !usable {
|
|
||||||
needsConfigure = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needsConfigure {
|
|
||||||
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
target = selected
|
|
||||||
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if target == "" {
|
if target == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
current := primaryModelFromConfig(saved)
|
||||||
if target != current {
|
if target != current {
|
||||||
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
if err := config.SaveIntegration(name, []string{target}); err != nil {
|
||||||
return fmt.Errorf("failed to save: %w", err)
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
@@ -510,6 +573,102 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
|
|||||||
return launchAfterConfiguration(name, runner, models[0], req)
|
return launchAfterConfiguration(name, runner, models[0], req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, name string, runner Runner, managed ManagedSingleModel, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
|
||||||
|
current := managed.CurrentModel()
|
||||||
|
selectionCurrent := current
|
||||||
|
if selectionCurrent == "" {
|
||||||
|
selectionCurrent = primaryModelFromConfig(saved)
|
||||||
|
}
|
||||||
|
|
||||||
|
target, needsConfigure, err := c.resolveSingleIntegrationTarget(ctx, runner, selectionCurrent, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if target == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
|
||||||
|
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if refresher, ok := managed.(ManagedRuntimeRefresher); ok {
|
||||||
|
if err := refresher.RefreshRuntimeAfterConfigure(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !managedIntegrationOnboarded(saved, managed) {
|
||||||
|
if !isInteractiveSession() && managedRequiresInteractiveOnboarding(managed) {
|
||||||
|
return fmt.Errorf("%s still needs interactive gateway setup; run 'ollama launch %s' in a terminal to finish onboarding", runner, name)
|
||||||
|
}
|
||||||
|
if err := managed.Onboard(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ConfigureOnly {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return runIntegration(runner, target, req.ExtraArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *launcherClient) resolveSingleIntegrationTarget(ctx context.Context, runner Runner, current string, req IntegrationLaunchRequest) (string, bool, error) {
|
||||||
|
target := req.ModelOverride
|
||||||
|
needsConfigure := req.ForceConfigure
|
||||||
|
|
||||||
|
if target == "" {
|
||||||
|
target = current
|
||||||
|
usable, err := c.savedModelUsable(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
if !usable {
|
||||||
|
needsConfigure = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsConfigure {
|
||||||
|
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
target = selected
|
||||||
|
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return target, needsConfigure, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func savedIntegrationOnboarded(saved *config.IntegrationConfig) bool {
|
||||||
|
return saved != nil && saved.Onboarded
|
||||||
|
}
|
||||||
|
|
||||||
|
func managedIntegrationOnboarded(saved *config.IntegrationConfig, managed ManagedSingleModel) bool {
|
||||||
|
if !savedIntegrationOnboarded(saved) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
validator, ok := managed.(ManagedOnboardingValidator)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return validator.OnboardingComplete()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most managed integrations treat onboarding as an interactive terminal step.
|
||||||
|
// Hermes opts out because its launch-owned onboarding is just bookkeeping, so
|
||||||
|
// headless launches should not be blocked once config is already prepared.
|
||||||
|
func managedRequiresInteractiveOnboarding(managed ManagedSingleModel) bool {
|
||||||
|
onboarding, ok := managed.(ManagedInteractiveOnboarding)
|
||||||
|
if !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return onboarding.RequiresInteractiveOnboarding()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
|
||||||
if selector == nil {
|
if selector == nil {
|
||||||
return "", fmt.Errorf("no selector configured")
|
return "", fmt.Errorf("no selector configured")
|
||||||
|
|||||||
@@ -49,6 +49,55 @@ func (r *launcherSingleRunner) Run(model string, args []string) error {
|
|||||||
|
|
||||||
func (r *launcherSingleRunner) String() string { return "StubSingle" }
|
func (r *launcherSingleRunner) String() string { return "StubSingle" }
|
||||||
|
|
||||||
|
type launcherManagedRunner struct {
|
||||||
|
paths []string
|
||||||
|
currentModel string
|
||||||
|
configured []string
|
||||||
|
ranModel string
|
||||||
|
onboarded bool
|
||||||
|
onboardCalls int
|
||||||
|
onboardingComplete bool
|
||||||
|
refreshCalls int
|
||||||
|
refreshErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Run(model string, args []string) error {
|
||||||
|
r.ranModel = model
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) String() string { return "StubManaged" }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Paths() []string { return r.paths }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Configure(model string) error {
|
||||||
|
r.configured = append(r.configured, model)
|
||||||
|
r.currentModel = model
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) CurrentModel() string { return r.currentModel }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) Onboard() error {
|
||||||
|
r.onboardCalls++
|
||||||
|
r.onboarded = true
|
||||||
|
r.onboardingComplete = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) OnboardingComplete() bool { return r.onboardingComplete }
|
||||||
|
|
||||||
|
func (r *launcherManagedRunner) RefreshRuntimeAfterConfigure() error {
|
||||||
|
r.refreshCalls++
|
||||||
|
return r.refreshErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type launcherHeadlessManagedRunner struct {
|
||||||
|
launcherManagedRunner
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *launcherHeadlessManagedRunner) RequiresInteractiveOnboarding() bool { return false }
|
||||||
|
|
||||||
func setLaunchTestHome(t *testing.T, dir string) {
|
func setLaunchTestHome(t *testing.T, dir string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
t.Setenv("HOME", dir)
|
t.Setenv("HOME", dir)
|
||||||
@@ -141,6 +190,451 @@ func TestDefaultLaunchPolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{currentModel: "gemma4"}
|
||||||
|
withIntegrationOverride(t, "pi", runner)
|
||||||
|
|
||||||
|
state, err := BuildLauncherState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildLauncherState returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Integrations["pi"].CurrentModel != "gemma4" {
|
||||||
|
t.Fatalf("expected managed current model from integration config, got %q", state.Integrations["pi"].CurrentModel)
|
||||||
|
}
|
||||||
|
if !state.Integrations["pi"].ModelUsable {
|
||||||
|
t.Fatal("expected managed current model to be usable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfigMissing(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("pi", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "pi", runner)
|
||||||
|
|
||||||
|
state, err := BuildLauncherState(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildLauncherState returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.Integrations["pi"].CurrentModel != "gemma4" {
|
||||||
|
t.Fatalf("expected saved model to remain visible, got %q", state.Integrations["pi"].CurrentModel)
|
||||||
|
}
|
||||||
|
if state.Integrations["pi"].ModelUsable {
|
||||||
|
t.Fatal("expected missing live config to mark managed model unusable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
paths: nil,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
return "gemma4", nil
|
||||||
|
}
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("configured models mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
saved, err := config.LoadIntegration("stubmanaged")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reload managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
if diff := compareStrings(saved.Models, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("saved models mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStale(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
currentModel: "gemma4",
|
||||||
|
onboardingComplete: false,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
if err := config.MarkIntegrationOnboarded("stubmanaged"); err != nil {
|
||||||
|
t.Fatalf("failed to mark managed integration onboarded: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected stale onboarded flag to trigger onboarding, got %d calls", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 0 {
|
||||||
|
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run saved model after onboarding repair, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
paths: nil,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
ConfigureOnly: true,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected configure-only flow to refresh runtime once, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected configure-only flow to onboard once, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called when saved model matches target")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
t.Fatal("confirm prompt should not run when saved model matches target")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(runner.configured) != 0 {
|
||||||
|
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 0 {
|
||||||
|
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/tags":
|
||||||
|
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
|
||||||
|
t.Fatalf("failed to save managed integration config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||||
|
t.Fatal("selector should not be called when model override is provided")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationStopsWhenRuntimeRefreshFails(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, true)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
refreshErr: fmt.Errorf("boom"),
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
})
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "boom") {
|
||||||
|
t.Fatalf("expected runtime refresh error, got %v", err)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected final launch to stop on runtime refresh failure, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
if runner.refreshCalls != 1 {
|
||||||
|
t.Fatalf("expected one runtime refresh attempt, got %d", runner.refreshCalls)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 0 {
|
||||||
|
t.Fatalf("expected onboarding to stop after refresh failure, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessNeedsInteractiveOnboarding(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherManagedRunner{
|
||||||
|
paths: nil,
|
||||||
|
}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected headless onboarding requirement to fail")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "interactive gateway setup") {
|
||||||
|
t.Fatalf("expected interactive onboarding guidance, got %v", err)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "" {
|
||||||
|
t.Fatalf("expected no final launch when onboarding is still required, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 0 {
|
||||||
|
t.Fatalf("expected no onboarding attempts in headless mode, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessAllowsNonInteractiveOnboarding(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
withInteractiveSession(t, false)
|
||||||
|
withLauncherHooks(t)
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/show":
|
||||||
|
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
|
||||||
|
runner := &launcherHeadlessManagedRunner{}
|
||||||
|
withIntegrationOverride(t, "stubmanaged", runner)
|
||||||
|
|
||||||
|
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||||
|
Name: "stubmanaged",
|
||||||
|
ModelOverride: "gemma4",
|
||||||
|
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected non-interactive onboarding to succeed headlessly, got %v", err)
|
||||||
|
}
|
||||||
|
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||||
|
t.Fatalf("configured models mismatch: %s", diff)
|
||||||
|
}
|
||||||
|
if runner.onboardCalls != 1 {
|
||||||
|
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
|
||||||
|
}
|
||||||
|
if runner.ranModel != "gemma4" {
|
||||||
|
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
|
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
setLaunchTestHome(t, tmpDir)
|
setLaunchTestHome(t, tmpDir)
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ func pullMissingModel(ctx context.Context, client *api.Client, model string) err
|
|||||||
|
|
||||||
// prepareEditorIntegration persists models and applies editor-managed config files.
|
// prepareEditorIntegration persists models and applies editor-managed config files.
|
||||||
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
|
||||||
if ok, err := confirmEditorEdit(runner, editor); err != nil {
|
if ok, err := confirmConfigEdit(runner, editor.Paths()); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
return errCancelled
|
return errCancelled
|
||||||
@@ -244,8 +244,22 @@ func prepareEditorIntegration(name string, runner Runner, editor Editor, models
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) {
|
func prepareManagedSingleIntegration(name string, runner Runner, managed ManagedSingleModel, model string) error {
|
||||||
paths := editor.Paths()
|
if ok, err := confirmConfigEdit(runner, managed.Paths()); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !ok {
|
||||||
|
return errCancelled
|
||||||
|
}
|
||||||
|
if err := managed.Configure(model); err != nil {
|
||||||
|
return fmt.Errorf("setup failed: %w", err)
|
||||||
|
}
|
||||||
|
if err := config.SaveIntegration(name, []string{model}); err != nil {
|
||||||
|
return fmt.Errorf("failed to save: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
|
||||||
if len(paths) == 0 {
|
if len(paths) == 0 {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
@@ -345,8 +359,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||||||
recRank[rec.Name] = i + 1
|
recRank[rec.Name] = i + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
onlyLocal := hasLocalModel && !hasCloudModel
|
|
||||||
|
|
||||||
if hasLocalModel || hasCloudModel {
|
if hasLocalModel || hasCloudModel {
|
||||||
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
slices.SortStableFunc(items, func(a, b ModelItem) int {
|
||||||
ac, bc := checked[a.Name], checked[b.Name]
|
ac, bc := checked[a.Name], checked[b.Name]
|
||||||
@@ -368,12 +380,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
|
|||||||
}
|
}
|
||||||
if aRec && bRec {
|
if aRec && bRec {
|
||||||
if aCloud != bCloud {
|
if aCloud != bCloud {
|
||||||
if onlyLocal {
|
|
||||||
if aCloud {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if aCloud {
|
if aCloud {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type IntegrationInfo struct {
|
|||||||
Description string
|
Description string
|
||||||
}
|
}
|
||||||
|
|
||||||
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"}
|
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
|
||||||
|
|
||||||
var integrationSpecs = []*IntegrationSpec{
|
var integrationSpecs = []*IntegrationSpec{
|
||||||
{
|
{
|
||||||
@@ -74,6 +74,36 @@ var integrationSpecs = []*IntegrationSpec{
|
|||||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "kimi",
|
||||||
|
Runner: &Kimi{},
|
||||||
|
Description: "Moonshot's coding agent for terminal and IDEs",
|
||||||
|
Hidden: true,
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := exec.LookPath("kimi")
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
_, err := ensureKimiInstalled()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "copilot",
|
||||||
|
Runner: &Copilot{},
|
||||||
|
Aliases: []string{"copilot-cli"},
|
||||||
|
Description: "GitHub's AI coding agent for the terminal",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
_, err := (&Copilot{}).findPath()
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
URL: "https://github.com/features/copilot/cli/",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "droid",
|
Name: "droid",
|
||||||
Runner: &Droid{},
|
Runner: &Droid{},
|
||||||
@@ -136,6 +166,20 @@ var integrationSpecs = []*IntegrationSpec{
|
|||||||
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "hermes",
|
||||||
|
Runner: &Hermes{},
|
||||||
|
Description: "Self-improving AI agent built by Nous Research",
|
||||||
|
Install: IntegrationInstallSpec{
|
||||||
|
CheckInstalled: func() bool {
|
||||||
|
return (&Hermes{}).installed()
|
||||||
|
},
|
||||||
|
EnsureInstalled: func() error {
|
||||||
|
return (&Hermes{}).ensureInstalled()
|
||||||
|
},
|
||||||
|
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "vscode",
|
Name: "vscode",
|
||||||
Runner: &VSCode{},
|
Runner: &VSCode{},
|
||||||
@@ -255,10 +299,10 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
|
|||||||
return aRank - bRank
|
return aRank - bRank
|
||||||
}
|
}
|
||||||
if aRank > 0 {
|
if aRank > 0 {
|
||||||
return 1
|
return -1
|
||||||
}
|
}
|
||||||
if bRank > 0 {
|
if bRank > 0 {
|
||||||
return -1
|
return 1
|
||||||
}
|
}
|
||||||
return strings.Compare(a.Name, b.Name)
|
return strings.Compare(a.Name, b.Name)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
|||||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "kimi",
|
||||||
|
binary: "kimi",
|
||||||
|
runner: &Kimi{},
|
||||||
|
checkPath: func(home string) string {
|
||||||
|
return filepath.Join(home, ".kimi", "config.toml")
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
|||||||
if tt.name == "pi" {
|
if tt.name == "pi" {
|
||||||
writeFakeBinary(t, binDir, "npm")
|
writeFakeBinary(t, binDir, "npm")
|
||||||
}
|
}
|
||||||
|
if tt.name == "kimi" {
|
||||||
|
writeFakeBinary(t, binDir, "curl")
|
||||||
|
writeFakeBinary(t, binDir, "bash")
|
||||||
|
}
|
||||||
t.Setenv("PATH", binDir)
|
t.Setenv("PATH", binDir)
|
||||||
|
|
||||||
configPath := tt.checkPath(home)
|
configPath := tt.checkPath(home)
|
||||||
|
|||||||
@@ -45,21 +45,12 @@ type menuItem struct {
|
|||||||
isOthers bool
|
isOthers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var mainMenuItems = []menuItem{
|
const pinnedIntegrationCount = 3
|
||||||
{
|
|
||||||
title: "Chat with a model",
|
var runModelMenuItem = menuItem{
|
||||||
description: "Start an interactive chat with a model",
|
title: "Chat with a model",
|
||||||
isRunModel: true,
|
description: "Start an interactive chat with a model",
|
||||||
},
|
isRunModel: true,
|
||||||
{
|
|
||||||
integration: "openclaw",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
integration: "claude",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
integration: "opencode",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var othersMenuItem = menuItem{
|
var othersMenuItem = menuItem{
|
||||||
@@ -102,20 +93,14 @@ func shouldExpandOthers(state *launch.LauncherState) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
|
||||||
items := make([]menuItem, 0, len(mainMenuItems)+1)
|
items := []menuItem{runModelMenuItem}
|
||||||
for _, item := range mainMenuItems {
|
items = append(items, pinnedIntegrationItems(state)...)
|
||||||
if item.integration == "" {
|
|
||||||
items = append(items, item)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if integrationState, ok := state.Integrations[item.integration]; ok {
|
|
||||||
items = append(items, integrationMenuItem(integrationState))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if showOthers {
|
otherItems := otherIntegrationItems(state)
|
||||||
items = append(items, otherIntegrationItems(state)...)
|
switch {
|
||||||
} else {
|
case showOthers:
|
||||||
|
items = append(items, otherItems...)
|
||||||
|
case len(otherItems) > 0:
|
||||||
items = append(items, othersMenuItem)
|
items = append(items, othersMenuItem)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,17 +120,28 @@ func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
pinned := map[string]bool{
|
ordered := orderedIntegrationItems(state)
|
||||||
"openclaw": true,
|
if len(ordered) <= pinnedIntegrationCount {
|
||||||
"claude": true,
|
return nil
|
||||||
"opencode": true,
|
}
|
||||||
|
return ordered[pinnedIntegrationCount:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
|
ordered := orderedIntegrationItems(state)
|
||||||
|
if len(ordered) <= pinnedIntegrationCount {
|
||||||
|
return ordered
|
||||||
|
}
|
||||||
|
return ordered[:pinnedIntegrationCount]
|
||||||
|
}
|
||||||
|
|
||||||
|
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var items []menuItem
|
items := make([]menuItem, 0, len(state.Integrations))
|
||||||
for _, info := range launch.ListIntegrationInfos() {
|
for _, info := range launch.ListIntegrationInfos() {
|
||||||
if pinned[info.Name] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
integrationState, ok := state.Integrations[info.Name]
|
integrationState, ok := state.Integrations[info.Name]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
@@ -155,6 +151,10 @@ func otherIntegrationItems(state *launch.LauncherState) []menuItem {
|
|||||||
return items
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func primaryMenuItemCount(state *launch.LauncherState) int {
|
||||||
|
return 1 + len(pinnedIntegrationItems(state))
|
||||||
|
}
|
||||||
|
|
||||||
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
func initialCursor(state *launch.LauncherState, items []menuItem) int {
|
||||||
if state == nil || state.LastSelection == "" {
|
if state == nil || state.LastSelection == "" {
|
||||||
return 0
|
return 0
|
||||||
@@ -190,7 +190,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
if m.cursor > 0 {
|
if m.cursor > 0 {
|
||||||
m.cursor--
|
m.cursor--
|
||||||
}
|
}
|
||||||
if m.showOthers && m.cursor < len(mainMenuItems) {
|
if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
|
||||||
m.showOthers = false
|
m.showOthers = false
|
||||||
m.items = buildMenuItems(m.state, false)
|
m.items = buildMenuItems(m.state, false)
|
||||||
m.cursor = min(m.cursor, len(m.items)-1)
|
m.cursor = min(m.cursor, len(m.items)-1)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/cmd/launch"
|
"github.com/ollama/ollama/cmd/launch"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +44,13 @@ func launcherTestState() *launch.LauncherState {
|
|||||||
Selectable: true,
|
Selectable: true,
|
||||||
Changeable: true,
|
Changeable: true,
|
||||||
},
|
},
|
||||||
|
"hermes": {
|
||||||
|
Name: "hermes",
|
||||||
|
DisplayName: "Hermes Agent",
|
||||||
|
Description: "Self-improving AI agent built by Nous Research",
|
||||||
|
Selectable: true,
|
||||||
|
Changeable: true,
|
||||||
|
},
|
||||||
"droid": {
|
"droid": {
|
||||||
Name: "droid",
|
Name: "droid",
|
||||||
DisplayName: "Droid",
|
DisplayName: "Droid",
|
||||||
@@ -70,8 +78,28 @@ func findMenuCursorByIntegration(items []menuItem, name string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func integrationSequence(items []menuItem) []string {
|
||||||
|
sequence := make([]string, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
switch {
|
||||||
|
case item.isRunModel:
|
||||||
|
sequence = append(sequence, "run")
|
||||||
|
case item.isOthers:
|
||||||
|
sequence = append(sequence, "more")
|
||||||
|
case item.integration != "":
|
||||||
|
sequence = append(sequence, item.integration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sequence
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareStrings(got, want []string) string {
|
||||||
|
return cmp.Diff(want, got)
|
||||||
|
}
|
||||||
|
|
||||||
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
||||||
view := newModel(launcherTestState()).View()
|
menu := newModel(launcherTestState())
|
||||||
|
view := menu.View()
|
||||||
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
|
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
|
||||||
if !strings.Contains(view, want) {
|
if !strings.Contains(view, want) {
|
||||||
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
t.Fatalf("expected menu view to contain %q\n%s", want, view)
|
||||||
@@ -80,23 +108,31 @@ func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
|
|||||||
if strings.Contains(view, "Launch Codex") {
|
if strings.Contains(view, "Launch Codex") {
|
||||||
t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
|
t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
|
||||||
}
|
}
|
||||||
|
wantOrder := []string{"run", "openclaw", "claude", "opencode", "more"}
|
||||||
|
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||||
|
t.Fatalf("unexpected pinned order: %s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
|
||||||
state := launcherTestState()
|
state := launcherTestState()
|
||||||
state.LastSelection = "pi"
|
state.LastSelection = "codex"
|
||||||
|
|
||||||
menu := newModel(state)
|
menu := newModel(state)
|
||||||
if !menu.showOthers {
|
if !menu.showOthers {
|
||||||
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
t.Fatal("expected others section to expand when last selection is in the overflow list")
|
||||||
}
|
}
|
||||||
view := menu.View()
|
view := menu.View()
|
||||||
if !strings.Contains(view, "Launch Pi") {
|
if !strings.Contains(view, "Launch Codex") {
|
||||||
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
|
||||||
}
|
}
|
||||||
if strings.Contains(view, "More...") {
|
if strings.Contains(view, "More...") {
|
||||||
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
t.Fatalf("expected expanded view to replace More... item\n%s", view)
|
||||||
}
|
}
|
||||||
|
wantOrder := []string{"run", "openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
|
||||||
|
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
|
||||||
|
t.Fatalf("unexpected expanded order: %s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
func TestMenuEnterOnRunSelectsRun(t *testing.T) {
|
||||||
|
|||||||
@@ -120,6 +120,7 @@
|
|||||||
"pages": [
|
"pages": [
|
||||||
"/integrations/claude-code",
|
"/integrations/claude-code",
|
||||||
"/integrations/codex",
|
"/integrations/codex",
|
||||||
|
"/integrations/copilot-cli",
|
||||||
"/integrations/opencode",
|
"/integrations/opencode",
|
||||||
"/integrations/droid",
|
"/integrations/droid",
|
||||||
"/integrations/goose",
|
"/integrations/goose",
|
||||||
|
|||||||
BIN
docs/images/hermes.png
Normal file
BIN
docs/images/hermes.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
93
docs/integrations/copilot-cli.mdx
Normal file
93
docs/integrations/copilot-cli.mdx
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
---
|
||||||
|
title: Copilot CLI
|
||||||
|
---
|
||||||
|
|
||||||
|
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
|
||||||
|
|
||||||
|
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Install [Copilot CLI](https://github.com/features/copilot/cli/):
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
|
||||||
|
```shell macOS / Linux (Homebrew)
|
||||||
|
brew install copilot-cli
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell npm (all platforms)
|
||||||
|
npm install -g @github/copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell macOS / Linux (script)
|
||||||
|
curl -fsSL https://gh.io/copilot-install | bash
|
||||||
|
```
|
||||||
|
|
||||||
|
```powershell Windows (WinGet)
|
||||||
|
winget install GitHub.Copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
|
## Usage with Ollama
|
||||||
|
|
||||||
|
### Quick setup
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run directly with a model
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot --model kimi-k2.5:cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recommended Models
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud`
|
||||||
|
- `glm-5:cloud`
|
||||||
|
- `minimax-m2.7:cloud`
|
||||||
|
- `qwen3.5:cloud`
|
||||||
|
- `glm-4.7-flash`
|
||||||
|
- `qwen3.5`
|
||||||
|
|
||||||
|
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Non-interactive (headless) mode
|
||||||
|
|
||||||
|
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
|
||||||
|
```
|
||||||
|
|
||||||
|
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
|
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
|
||||||
|
|
||||||
|
1. Set the environment variables:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
|
||||||
|
export COPILOT_PROVIDER_API_KEY=
|
||||||
|
export COPILOT_PROVIDER_WIRE_API=responses
|
||||||
|
export COPILOT_MODEL=qwen3.5
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Run Copilot CLI:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
Or run with environment variables inline:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
|
||||||
@@ -2,29 +2,66 @@
|
|||||||
title: Hermes Agent
|
title: Hermes Agent
|
||||||
---
|
---
|
||||||
|
|
||||||
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway.
|
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|
||||||
### Pull a model
|
|
||||||
|
|
||||||
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ollama pull kimi-k2.5:cloud
|
ollama launch hermes
|
||||||
```
|
```
|
||||||
|
|
||||||
See [Recommended models](#recommended-models) for more options.
|
Ollama handles everything automatically:
|
||||||
|
|
||||||
### Install
|
1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
|
||||||
|
2. **Model** — Pick a model from the selector (local or cloud)
|
||||||
|
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
|
||||||
|
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
|
||||||
|
|
||||||
|
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
|
||||||
|
|
||||||
|
## Recommended models
|
||||||
|
|
||||||
|
**Cloud models**:
|
||||||
|
|
||||||
|
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||||
|
- `glm-5.1:cloud` — Reasoning and code generation
|
||||||
|
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
|
||||||
|
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||||
|
|
||||||
|
**Local models:**
|
||||||
|
|
||||||
|
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
|
||||||
|
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
|
||||||
|
|
||||||
|
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||||
|
|
||||||
|
## Connect messaging apps
|
||||||
|
|
||||||
|
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
hermes gateway setup
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reconfigure
|
||||||
|
|
||||||
|
Re-run the full setup wizard at any time:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
hermes setup
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manual setup
|
||||||
|
|
||||||
|
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||||
```
|
```
|
||||||
|
|
||||||
### Set up
|
Hermes launches the setup wizard automatically. Choose **Quick setup**:
|
||||||
|
|
||||||
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
How would you like to set up Hermes?
|
How would you like to set up Hermes?
|
||||||
@@ -80,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
|
|||||||
Launch hermes chat now? [Y/n]: Y
|
Launch hermes chat now? [Y/n]: Y
|
||||||
```
|
```
|
||||||
|
|
||||||
## Recommended models
|
|
||||||
|
|
||||||
**Cloud models**:
|
|
||||||
|
|
||||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
|
||||||
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
|
|
||||||
- `glm-5.1:cloud` — Reasoning and code generation
|
|
||||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
|
||||||
|
|
||||||
**Local models:**
|
|
||||||
|
|
||||||
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
|
|
||||||
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
|
|
||||||
|
|
||||||
More models at [ollama.com/search](https://ollama.com/models).
|
|
||||||
|
|
||||||
## Configure later
|
|
||||||
|
|
||||||
Re-run the setup wizard at any time:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
hermes setup
|
|
||||||
```
|
|
||||||
|
|
||||||
To configure just messaging:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
hermes setup gateway
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
|
|||||||
|
|
||||||
- [Claude Code](/integrations/claude-code)
|
- [Claude Code](/integrations/claude-code)
|
||||||
- [Codex](/integrations/codex)
|
- [Codex](/integrations/codex)
|
||||||
|
- [Copilot CLI](/integrations/copilot-cli)
|
||||||
- [OpenCode](/integrations/opencode)
|
- [OpenCode](/integrations/opencode)
|
||||||
- [Droid](/integrations/droid)
|
- [Droid](/integrations/droid)
|
||||||
- [Goose](/integrations/goose)
|
- [Goose](/integrations/goose)
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -106,5 +106,5 @@ require (
|
|||||||
golang.org/x/term v0.36.0
|
golang.org/x/term v0.36.0
|
||||||
golang.org/x/text v0.30.0
|
golang.org/x/text v0.30.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIGenerateLogprobs(t *testing.T) {
|
func TestAPIGenerateLogprobs(t *testing.T) {
|
||||||
if testModel != "" {
|
|
||||||
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
|
|
||||||
t.Skip("logprobs not supported by all runners")
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAPIChatLogprobs(t *testing.T) {
|
func TestAPIChatLogprobs(t *testing.T) {
|
||||||
if testModel != "" {
|
|
||||||
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
|
|
||||||
t.Skip("logprobs not supported by all runners")
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Layer struct {
|
type Layer struct {
|
||||||
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
|||||||
return Layer{}, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := touchLayer(blob); err != nil {
|
||||||
|
return Layer{}, err
|
||||||
|
}
|
||||||
|
|
||||||
return Layer{
|
return Layer{
|
||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return Layer{}, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
if err := touchLayer(blob); err != nil {
|
||||||
|
return Layer{}, err
|
||||||
|
}
|
||||||
|
|
||||||
return Layer{
|
return Layer{
|
||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func touchLayer(path string) error {
|
||||||
|
now := time.Now()
|
||||||
|
return os.Chtimes(path, now, now)
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||||
if l.Digest == "" {
|
if l.Digest == "" {
|
||||||
return nil, errors.New("opening layer with empty digest")
|
return nil, errors.New("opening layer with empty digest")
|
||||||
|
|||||||
@@ -1,18 +1,23 @@
|
|||||||
package manifest
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var blobFilenamePattern = regexp.MustCompile(`^sha256-[0-9a-fA-F]{64}$`)
|
||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
SchemaVersion int `json:"schemaVersion"`
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
MediaType string `json:"mediaType"`
|
MediaType string `json:"mediaType"`
|
||||||
@@ -22,6 +27,7 @@ type Manifest struct {
|
|||||||
filepath string
|
filepath string
|
||||||
fi os.FileInfo
|
fi os.FileInfo
|
||||||
digest string
|
digest string
|
||||||
|
name model.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) Size() (size int64) {
|
func (m *Manifest) Size() (size int64) {
|
||||||
@@ -36,6 +42,14 @@ func (m *Manifest) Digest() string {
|
|||||||
return m.digest
|
return m.digest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manifest) BlobDigest() string {
|
||||||
|
if m.digest == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return "sha256:" + m.digest
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manifest) FileInfo() os.FileInfo {
|
func (m *Manifest) FileInfo() os.FileInfo {
|
||||||
return m.fi
|
return m.fi
|
||||||
}
|
}
|
||||||
@@ -59,16 +73,7 @@ func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) Remove() error {
|
func (m *Manifest) Remove() error {
|
||||||
if err := os.Remove(m.filepath); err != nil {
|
return removeNamedManifestPaths(m.name)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
manifests, err := Path()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return PruneDirectory(manifests)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) RemoveLayers() error {
|
func (m *Manifest) RemoveLayers() error {
|
||||||
@@ -80,6 +85,9 @@ func (m *Manifest) RemoveLayers() error {
|
|||||||
// Build set of digests still in use by other manifests
|
// Build set of digests still in use by other manifests
|
||||||
inUse := make(map[string]struct{})
|
inUse := make(map[string]struct{})
|
||||||
for _, other := range ms {
|
for _, other := range ms {
|
||||||
|
if other.BlobDigest() != "" {
|
||||||
|
inUse[other.BlobDigest()] = struct{}{}
|
||||||
|
}
|
||||||
for _, layer := range append(other.Layers, other.Config) {
|
for _, layer := range append(other.Layers, other.Config) {
|
||||||
if layer.Digest != "" {
|
if layer.Digest != "" {
|
||||||
inUse[layer.Digest] = struct{}{}
|
inUse[layer.Digest] = struct{}{}
|
||||||
@@ -87,20 +95,27 @@ func (m *Manifest) RemoveLayers() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove layers not used by any other manifest
|
digests := make([]string, 0, len(m.Layers)+2)
|
||||||
for _, layer := range append(m.Layers, m.Config) {
|
digests = append(digests, m.BlobDigest())
|
||||||
if layer.Digest == "" {
|
for _, layer := range m.Layers {
|
||||||
|
digests = append(digests, layer.Digest)
|
||||||
|
}
|
||||||
|
digests = append(digests, m.Config.Digest)
|
||||||
|
|
||||||
|
// Remove manifest and layer blobs not used by any other manifest
|
||||||
|
for _, digest := range digests {
|
||||||
|
if digest == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, used := inUse[layer.Digest]; used {
|
if _, used := inUse[digest]; used {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
blob, err := BlobsPath(layer.Digest)
|
blob, err := BlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
slog.Debug("blob does not exist", "digest", digest)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -114,15 +129,36 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
return nil, model.Unqualified(n)
|
return nil, model.Unqualified(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := Path()
|
p, root, err := resolveManifestPath(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p := filepath.Join(manifests, n.Filepath())
|
return parseManifestFile(normalizeLogicalName(n), p, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadManifestData(n model.Name) ([]byte, error) {
|
||||||
|
if !n.IsFullyQualified() {
|
||||||
|
return nil, model.Unqualified(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, root, err := resolveManifestPath(n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, _, err := OpenVerifiedManifest(p, root)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return io.ReadAll(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseManifestFile(name model.Name, path, root string) (*Manifest, error) {
|
||||||
var m Manifest
|
var m Manifest
|
||||||
f, err := os.Open(p)
|
f, digest, err := OpenVerifiedManifest(path, root)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -133,35 +169,19 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.filepath = p
|
m.filepath = path
|
||||||
m.fi = fi
|
m.fi = fi
|
||||||
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
|
m.digest = digest
|
||||||
|
m.name = name
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||||
manifests, err := Path()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p := filepath.Join(manifests, name.Filepath())
|
|
||||||
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.Create(p)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
m := Manifest{
|
m := Manifest{
|
||||||
SchemaVersion: 2,
|
SchemaVersion: 2,
|
||||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
@@ -169,33 +189,371 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
|||||||
Layers: layers,
|
Layers: layers,
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.NewEncoder(f).Encode(m)
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return WriteManifestData(name, b.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
// WriteManifestData stores raw manifest bytes as a content-addressed blob and
|
||||||
|
// updates the v2 named manifest path to reference that blob. Any legacy named
|
||||||
|
// manifest for the same model is removed after the v2 write succeeds.
|
||||||
|
func WriteManifestData(name model.Name, data []byte) error {
|
||||||
|
if !name.IsFullyQualified() {
|
||||||
|
return model.Unqualified(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
digest, err := writeManifestBlob(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := LinkManifest(name, digest); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return removeLegacyManifestPaths(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkManifest updates the v2 named manifest path to reference an existing
|
||||||
|
// manifest blob. It prefers symlinks, then hardlinks, then a byte-for-byte copy
|
||||||
|
// for filesystems that do not support links.
|
||||||
|
func LinkManifest(name model.Name, digest string) error {
|
||||||
|
if !name.IsFullyQualified() {
|
||||||
|
return model.Unqualified(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
manifestPath, err := V2PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
blobPath, err := BlobsPath(digest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(blobPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkBlobDigest(blobPath, digest); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if rel, err := filepath.Rel(filepath.Dir(manifestPath), blobPath); err == nil {
|
||||||
|
if err := os.Symlink(rel, manifestPath); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Link(blobPath, manifestPath); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return copyManifestFile(blobPath, manifestPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeManifestBlob(data []byte) (string, error) {
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
digest := fmt.Sprintf("sha256:%x", sum)
|
||||||
|
|
||||||
|
blobPath, err := BlobsPath(digest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if existing, err := os.ReadFile(blobPath); err == nil && bytes.Equal(existing, data) {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
blobs, err := BlobsPath("")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
temp, err := os.CreateTemp(blobs, "sha256-")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
tempName := temp.Name()
|
||||||
|
defer os.Remove(tempName)
|
||||||
|
|
||||||
|
if _, err := temp.Write(data); err != nil {
|
||||||
|
temp.Close()
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := temp.Close(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := os.Chmod(tempName, 0o644); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := os.Rename(tempName, blobPath); err != nil {
|
||||||
|
if err := os.Remove(blobPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := os.Rename(tempName, blobPath); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyManifestFile(src, dst string) error {
|
||||||
|
in, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer in.Close()
|
||||||
|
|
||||||
|
temp, err := os.CreateTemp(filepath.Dir(dst), ".manifest-*")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tempName := temp.Name()
|
||||||
|
defer os.Remove(tempName)
|
||||||
|
|
||||||
|
if _, err := io.Copy(temp, in); err != nil {
|
||||||
|
temp.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := temp.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Chmod(tempName, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Rename(tempName, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenVerifiedManifest opens a named manifest path rooted under root. Symlinks must resolve to a
|
||||||
|
// blob whose basename is sha256-<hex> and whose bytes hash to that digest.
|
||||||
|
// Regular-file manifests are treated as legacy/copy fallback manifests and are
|
||||||
|
// opened without mutating the local store.
|
||||||
|
func OpenVerifiedManifest(path, root string) (*os.File, string, error) {
|
||||||
|
resolvedRoot, err := filepath.EvalSymlinks(root)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Lstat(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
target, err := evalAbs(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Mode()&os.ModeSymlink != 0 {
|
||||||
|
base := filepath.Base(target)
|
||||||
|
if !blobFilenamePattern.MatchString(base) {
|
||||||
|
return nil, "", fmt.Errorf("manifest symlink target %q is not a sha256 blob", target)
|
||||||
|
}
|
||||||
|
|
||||||
|
digest := strings.ToLower(strings.TrimPrefix(base, "sha256-"))
|
||||||
|
blobPath, err := BlobsPath("sha256:" + digest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if !sameFile(target, blobPath) {
|
||||||
|
return nil, "", fmt.Errorf("manifest symlink target %q does not match blob %q", target, blobPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if err := checkBlobDigestReader(f, "sha256:"+digest); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pathWithin(target, resolvedRoot) {
|
||||||
|
return nil, "", fmt.Errorf("manifest path %q resolves outside manifest directory", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := io.Copy(h, f); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
digest := fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
|
||||||
|
return f, digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MigrateManifestLinks moves legacy named manifests into manifests-v2. This is currently unwired but
|
||||||
|
// will be added in the future.
|
||||||
|
func MigrateManifestLinks() (int, error) {
|
||||||
manifests, err := Path()
|
manifests, err := Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(mxyng): use something less brittle
|
// TODO(mxyng): use something less brittle
|
||||||
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ms := make(map[model.Name]*Manifest)
|
var migrated int
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
fi, err := os.Stat(match)
|
fi, err := os.Stat(match)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return migrated, err
|
||||||
|
}
|
||||||
|
if fi.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rel, err := filepath.Rel(manifests, match)
|
||||||
|
if err != nil {
|
||||||
|
return migrated, fmt.Errorf("%s %w", match, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := model.ParseNameFromFilepath(rel)
|
||||||
|
if !n.IsFullyQualified() {
|
||||||
|
slog.Warn("bad manifest name", "path", rel)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := readManifestPath(match, manifests)
|
||||||
|
if err != nil {
|
||||||
|
return migrated, err
|
||||||
|
}
|
||||||
|
if err := WriteManifestData(normalizeLogicalName(n), data); err != nil {
|
||||||
|
return migrated, err
|
||||||
|
}
|
||||||
|
migrated++
|
||||||
|
}
|
||||||
|
|
||||||
|
return migrated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readManifestPath(path, root string) ([]byte, error) {
|
||||||
|
f, _, err := OpenVerifiedManifest(path, root)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return io.ReadAll(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathWithin(path, root string) bool {
|
||||||
|
rel, err := filepath.Rel(root, path)
|
||||||
|
return err == nil && rel != "." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
|
||||||
|
}
|
||||||
|
|
||||||
|
func evalAbs(path string) (string, error) {
|
||||||
|
abs, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.EvalSymlinks(abs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameFile(a, b string) bool {
|
||||||
|
ai, err := os.Stat(a)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
bi, err := os.Stat(b)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return os.SameFile(ai, bi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkBlobDigest(path, digest string) error {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return checkBlobDigestReader(f, digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkBlobDigestReader(r io.Reader, digest string) error {
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := io.Copy(h, r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
got := fmt.Sprintf("sha256:%x", h.Sum(nil))
|
||||||
|
if got != strings.ToLower(strings.Replace(digest, "-", ":", 1)) {
|
||||||
|
return errors.New("digest mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||||
|
ms := make(map[model.Name]*Manifest)
|
||||||
|
|
||||||
|
manifestsV2, err := V2Path()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := collectManifests(ms, manifestsV2, continueOnError); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
manifests, err := Path()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := collectManifests(ms, manifests, continueOnError); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectManifests(ms map[model.Name]*Manifest, root string, continueOnError bool) error {
|
||||||
|
// TODO(mxyng): use something less brittle
|
||||||
|
matches, err := filepath.Glob(filepath.Join(root, "*", "*", "*", "*"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, match := range matches {
|
||||||
|
fi, err := os.Lstat(match)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !fi.IsDir() {
|
if !fi.IsDir() {
|
||||||
rel, err := filepath.Rel(manifests, match)
|
rel, err := filepath.Rel(root, match)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !continueOnError {
|
if !continueOnError {
|
||||||
return nil, fmt.Errorf("%s %w", match, err)
|
return fmt.Errorf("%s %w", match, err)
|
||||||
}
|
}
|
||||||
slog.Warn("bad filepath", "path", match, "error", err)
|
slog.Warn("bad filepath", "path", match, "error", err)
|
||||||
continue
|
continue
|
||||||
@@ -204,16 +562,21 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
|||||||
n := model.ParseNameFromFilepath(rel)
|
n := model.ParseNameFromFilepath(rel)
|
||||||
if !n.IsValid() {
|
if !n.IsValid() {
|
||||||
if !continueOnError {
|
if !continueOnError {
|
||||||
return nil, fmt.Errorf("%s %w", rel, err)
|
return fmt.Errorf("invalid manifest name: %s", rel)
|
||||||
}
|
}
|
||||||
slog.Warn("bad manifest name", "path", rel)
|
slog.Warn("bad manifest name", "path", rel)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := ParseNamedManifest(n)
|
n = normalizeLogicalName(n)
|
||||||
|
if _, ok := ms[n]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := parseManifestFile(n, match, root)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !continueOnError {
|
if !continueOnError {
|
||||||
return nil, fmt.Errorf("%s %w", n, err)
|
return fmt.Errorf("%s %w", n, err)
|
||||||
}
|
}
|
||||||
slog.Warn("bad manifest", "name", n, "error", err)
|
slog.Warn("bad manifest", "name", n, "error", err)
|
||||||
continue
|
continue
|
||||||
@@ -223,5 +586,5 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ms, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
package manifest
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createManifest(t *testing.T, path, name string) {
|
func createManifestAtRoot(t *testing.T, path, root, name string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
p := filepath.Join(path, "manifests", name)
|
p := filepath.Join(path, root, name)
|
||||||
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -29,6 +33,309 @@ func createManifest(t *testing.T, path, name string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createManifest(t *testing.T, path, name string) {
|
||||||
|
t.Helper()
|
||||||
|
createManifestAtRoot(t, path, "manifests", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteManifestStoresManifestAsBlob(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
config := Layer{
|
||||||
|
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||||
|
Digest: "sha256:" + strings.Repeat("a", 64),
|
||||||
|
Size: 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := WriteManifest(name, config, nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
manifestPath, err := V2PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
manifestData, err := os.ReadFile(manifestPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sha256.Sum256(manifestData)
|
||||||
|
digest := fmt.Sprintf("sha256:%x", sum)
|
||||||
|
blobPath, err := BlobsPath(digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
blobData, err := os.ReadFile(blobPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(blobData, manifestData) {
|
||||||
|
t.Fatal("manifest path and blob content differ")
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := ParseNamedManifest(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got := m.Digest(); got != fmt.Sprintf("%x", sum) {
|
||||||
|
t.Fatalf("digest = %q, want %x", got, sum)
|
||||||
|
}
|
||||||
|
if got := m.BlobDigest(); got != digest {
|
||||||
|
t.Fatalf("blob digest = %q, want %q", got, digest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNamedManifestLeavesLegacyManifestInPlace(t *testing.T) {
|
||||||
|
models := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", models)
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
createManifest(t, models, name.Filepath())
|
||||||
|
|
||||||
|
manifestPath, err := PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := ParseNamedManifest(name); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fi, err := os.Lstat(manifestPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if fi.Mode()&os.ModeSymlink != 0 {
|
||||||
|
t.Fatal("legacy manifest was converted to a symlink while reading")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(manifestPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("legacy manifest read created blob: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateManifestLinks(t *testing.T) {
|
||||||
|
models := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", models)
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
createManifest(t, models, name.Filepath())
|
||||||
|
|
||||||
|
migrated, err := MigrateManifestLinks()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrated != 1 {
|
||||||
|
t.Fatalf("migrated = %d, want 1", migrated)
|
||||||
|
}
|
||||||
|
|
||||||
|
manifestPath, err := V2PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
manifestData, err := os.ReadFile(manifestPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256(manifestData)
|
||||||
|
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
blobData, err := os.ReadFile(blobPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(blobData, manifestData) {
|
||||||
|
t.Fatal("migrated manifest path and blob content differ")
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyPath, err := PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("legacy manifest still exists: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrated, err = MigrateManifestLinks()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if migrated != 0 {
|
||||||
|
t.Fatalf("migrated on second run = %d, want 0", migrated)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := MigrateManifestLinks(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
manifestDataAfter, err := os.ReadFile(manifestPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(manifestDataAfter, manifestData) {
|
||||||
|
t.Fatal("second migration changed manifest content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveLayersRemovesUnreferencedManifestBlob(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
if err := WriteManifest(name, Layer{}, nil); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := ParseNamedManifest(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
blobPath, err := BlobsPath(m.BlobDigest())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(blobPath); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Remove(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := m.RemoveLayers(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("manifest blob still exists: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNamedManifestRejectsUnsafeSymlinks(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
manifestPath, err := PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("non blob basename", func(t *testing.T) {
|
||||||
|
target := filepath.Join(t.TempDir(), "not-a-blob")
|
||||||
|
if err := os.WriteFile(target, []byte(`{"schemaVersion":2}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.Symlink(target, manifestPath); err != nil {
|
||||||
|
t.Skipf("symlink unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ParseNamedManifest(name)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "not a sha256 blob") {
|
||||||
|
t.Fatalf("err = %v, want not a sha256 blob", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blob basename outside blob store", func(t *testing.T) {
|
||||||
|
data := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
target := filepath.Join(t.TempDir(), fmt.Sprintf("sha256-%x", sum))
|
||||||
|
if err := os.WriteFile(target, data, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.Symlink(target, manifestPath); err != nil {
|
||||||
|
t.Skipf("symlink unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ParseNamedManifest(name)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "does not match blob") {
|
||||||
|
t.Fatalf("err = %v, want does not match blob", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNamedManifestPrefersV2(t *testing.T) {
|
||||||
|
models := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", models)
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
|
||||||
|
legacyPath, err := PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := ParseNamedManifest(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if m.MediaType != "v2" {
|
||||||
|
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestsV2ShadowsLegacy(t *testing.T) {
|
||||||
|
models := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", models)
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
createManifest(t, models, name.Filepath())
|
||||||
|
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, err := Manifests(true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ms) != 1 {
|
||||||
|
t.Fatalf("manifest count = %d, want 1", len(ms))
|
||||||
|
}
|
||||||
|
|
||||||
|
var m *Manifest
|
||||||
|
for gotName, gotManifest := range ms {
|
||||||
|
if gotName.EqualFold(model.ParseName("example")) {
|
||||||
|
m = gotManifest
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if m == nil {
|
||||||
|
t.Fatalf("missing v2 manifest for %s", name)
|
||||||
|
}
|
||||||
|
if m.MediaType != "v2" {
|
||||||
|
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifests(t *testing.T) {
|
func TestManifests(t *testing.T) {
|
||||||
cases := map[string]struct {
|
cases := map[string]struct {
|
||||||
ps []string
|
ps []string
|
||||||
|
|||||||
@@ -14,8 +14,23 @@ import (
|
|||||||
|
|
||||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||||
|
|
||||||
|
const (
|
||||||
|
legacyDirName = "manifests"
|
||||||
|
v2DirName = "manifests-v2"
|
||||||
|
defaultPublicHost = "registry.ollama.ai"
|
||||||
|
v2CanonicalHost = "ollama.com"
|
||||||
|
)
|
||||||
|
|
||||||
func Path() (string, error) {
|
func Path() (string, error) {
|
||||||
path := filepath.Join(envconfig.Models(), "manifests")
|
return manifestPath(legacyDirName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func V2Path() (string, error) {
|
||||||
|
return manifestPath(v2DirName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func manifestPath(dir string) (string, error) {
|
||||||
|
path := filepath.Join(envconfig.Models(), dir)
|
||||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||||
}
|
}
|
||||||
@@ -25,6 +40,10 @@ func Path() (string, error) {
|
|||||||
|
|
||||||
// PathForName returns the path to the manifest file for a specific model name.
|
// PathForName returns the path to the manifest file for a specific model name.
|
||||||
func PathForName(n model.Name) (string, error) {
|
func PathForName(n model.Name) (string, error) {
|
||||||
|
return LegacyPathForName(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LegacyPathForName(n model.Name) (string, error) {
|
||||||
if !n.IsValid() {
|
if !n.IsValid() {
|
||||||
return "", os.ErrNotExist
|
return "", os.ErrNotExist
|
||||||
}
|
}
|
||||||
@@ -37,6 +56,162 @@ func PathForName(n model.Name) (string, error) {
|
|||||||
return filepath.Join(manifests, n.Filepath()), nil
|
return filepath.Join(manifests, n.Filepath()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func V2PathForName(n model.Name) (string, error) {
|
||||||
|
if !n.IsValid() {
|
||||||
|
return "", os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
manifests, err := V2Path()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(manifests, canonicalV2Name(n).Filepath()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolvePathForName(n model.Name) (string, error) {
|
||||||
|
path, _, err := resolveManifestPath(n)
|
||||||
|
return path, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveManifestPath(n model.Name) (string, string, error) {
|
||||||
|
if !n.IsValid() {
|
||||||
|
return "", "", os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
v2Path, err := V2PathForName(n)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
if _, err := os.Lstat(v2Path); err == nil {
|
||||||
|
root, err := V2Path()
|
||||||
|
return v2Path, root, err
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyRoot, err := Path()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
for _, legacyName := range legacyNameCandidates(n) {
|
||||||
|
legacyPath := filepath.Join(legacyRoot, legacyName.Filepath())
|
||||||
|
if _, err := os.Lstat(legacyPath); err == nil {
|
||||||
|
return legacyPath, legacyRoot, nil
|
||||||
|
} else if !os.IsNotExist(err) {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeNamedManifestPaths(n model.Name) error {
|
||||||
|
candidates := legacyNameCandidates(n)
|
||||||
|
paths := make([]string, 0, 1+len(candidates))
|
||||||
|
|
||||||
|
v2Path, err := V2PathForName(n)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
paths = append(paths, v2Path)
|
||||||
|
|
||||||
|
for _, legacyName := range candidates {
|
||||||
|
legacyPath, err := LegacyPathForName(legacyName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
paths = append(paths, legacyPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range paths {
|
||||||
|
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pruneManifestRoots()
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeLegacyManifestPaths(n model.Name) error {
|
||||||
|
for _, legacyName := range legacyNameCandidates(n) {
|
||||||
|
legacyPath, err := LegacyPathForName(legacyName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Remove(legacyPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyRoot, err := Path()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := PruneDirectory(legacyRoot); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pruneManifestRoots() error {
|
||||||
|
roots := []func() (string, error){Path, V2Path}
|
||||||
|
for _, rootFn := range roots {
|
||||||
|
root, err := rootFn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := PruneDirectory(root); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeLogicalName maps any public host to the legacy default
|
||||||
|
// so that map keys use a single identity regardless of on-disk host.
|
||||||
|
func normalizeLogicalName(n model.Name) model.Name {
|
||||||
|
if isDefaultPublicHost(n.Host) {
|
||||||
|
n.Host = defaultPublicHost
|
||||||
|
}
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// canonicalV2Name maps any public host to the v2 canonical host
|
||||||
|
// for use in manifests-v2/ on-disk paths.
|
||||||
|
func canonicalV2Name(n model.Name) model.Name {
|
||||||
|
if isDefaultPublicHost(n.Host) {
|
||||||
|
n.Host = v2CanonicalHost
|
||||||
|
}
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func legacyNameCandidates(n model.Name) []model.Name {
|
||||||
|
names := []model.Name{n}
|
||||||
|
if !isDefaultPublicHost(n.Host) {
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
alt := n
|
||||||
|
switch {
|
||||||
|
case strings.EqualFold(n.Host, defaultPublicHost):
|
||||||
|
alt.Host = v2CanonicalHost
|
||||||
|
default:
|
||||||
|
alt.Host = defaultPublicHost
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(names, alt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDefaultPublicHost(host string) bool {
|
||||||
|
return strings.EqualFold(host, defaultPublicHost) || strings.EqualFold(host, v2CanonicalHost)
|
||||||
|
}
|
||||||
|
|
||||||
func BlobsPath(digest string) (string, error) {
|
func BlobsPath(digest string) (string, error) {
|
||||||
// only accept actual sha256 digests
|
// only accept actual sha256 digests
|
||||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import (
|
|||||||
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
||||||
// <|tool_call>/<|tool_response> tags for function calling.
|
// <|tool_call>/<|tool_response> tags for function calling.
|
||||||
type Gemma4Renderer struct {
|
type Gemma4Renderer struct {
|
||||||
useImgTags bool
|
useImgTags bool
|
||||||
|
emptyBlockOnNothink bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
|||||||
// Generation prompt.
|
// Generation prompt.
|
||||||
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
|
||||||
sb.WriteString("<|turn>model\n")
|
sb.WriteString("<|turn>model\n")
|
||||||
|
if r.emptyBlockOnNothink && !hasThink {
|
||||||
|
sb.WriteString("<|channel>thought\n<channel|>")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package renderers
|
|||||||
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
|
||||||
// Gemma 4 reference template.
|
// Gemma 4 reference template.
|
||||||
//
|
//
|
||||||
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
|
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
|
||||||
// reference intentionally uses the shared baseline without an empty generation-time
|
// reference cases below use the small (e2b/e4b-style) baseline, with large
|
||||||
// thought channel until renderer selection is split by size.
|
// (26b/31b-style) checks covered separately in this file.
|
||||||
//
|
//
|
||||||
// To regenerate expected values, save the E2B template to
|
// To regenerate expected values, save the E2B template to
|
||||||
// gemma4_e2b_chat_template.jinja2 and run:
|
// gemma4_e2b_chat_template.jinja2 and run:
|
||||||
@@ -1474,6 +1474,47 @@ Hi<turn|>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
|
||||||
|
messages := []api.Message{{Role: "user", Content: "Hello"}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rendererName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "legacy_alias",
|
||||||
|
rendererName: "gemma4",
|
||||||
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small",
|
||||||
|
rendererName: "gemma4-small",
|
||||||
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large",
|
||||||
|
rendererName: "gemma4-large",
|
||||||
|
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
|
||||||
|
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
|
||||||
|
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
||||||
if os.Getenv("VERIFY_JINJA2") == "" {
|
if os.Getenv("VERIFY_JINJA2") == "" {
|
||||||
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
|
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
|
||||||
@@ -1616,15 +1657,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
variants := []struct {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
name string
|
||||||
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
|
renderer *Gemma4Renderer
|
||||||
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
|
templateRel string
|
||||||
assert.NoError(t, err)
|
}{
|
||||||
|
{
|
||||||
|
name: "small",
|
||||||
|
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
|
||||||
|
templateRel: gemma4E2BTemplate,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large",
|
||||||
|
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
|
||||||
|
templateRel: gemma431BTemplate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
|
for _, variant := range variants {
|
||||||
assert.Equal(t, jinja2Output, got,
|
t.Run(variant.name, func(t *testing.T) {
|
||||||
"renderer output doesn't match Jinja2 template output")
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
|
||||||
|
assert.Equal(t, jinja2Output, got,
|
||||||
|
"renderer output doesn't match Jinja2 template output")
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
|
|||||||
return renderer
|
return renderer
|
||||||
case "nemotron-3-nano":
|
case "nemotron-3-nano":
|
||||||
return &Nemotron3NanoRenderer{}
|
return &Nemotron3NanoRenderer{}
|
||||||
case "gemma4":
|
case "gemma4", "gemma4-small":
|
||||||
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
return &Gemma4Renderer{useImgTags: RenderImgTags}
|
||||||
|
case "gemma4-large":
|
||||||
|
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
|
||||||
case "functiongemma":
|
case "functiongemma":
|
||||||
return &FunctionGemmaRenderer{}
|
return &FunctionGemmaRenderer{}
|
||||||
case "glm-4.7":
|
case "glm-4.7":
|
||||||
|
|||||||
@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
|||||||
arch := layer.GGML.KV().Architecture()
|
arch := layer.GGML.KV().Architecture()
|
||||||
switch arch {
|
switch arch {
|
||||||
case "gemma4":
|
case "gemma4":
|
||||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
|
||||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||||
if _, ok := r.Parameters["stop"]; !ok {
|
if _, ok := r.Parameters["stop"]; !ok {
|
||||||
if r.Parameters == nil {
|
if r.Parameters == nil {
|
||||||
|
|||||||
78
server/gemma4_test.go
Normal file
78
server/gemma4_test.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestResolveGemma4Renderer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model *Model
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil model falls back to legacy alias",
|
||||||
|
model: nil,
|
||||||
|
want: gemma4RendererLegacy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit small passes through",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererSmall),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit large passes through",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLarge),
|
||||||
|
},
|
||||||
|
want: gemma4RendererLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy e4b tag resolves small",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gemma4:e4b",
|
||||||
|
ShortName: "gemma4:e4b",
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy 31b tag resolves large",
|
||||||
|
model: &Model{
|
||||||
|
Name: "gemma4:31b-cloud",
|
||||||
|
ShortName: "gemma4:31b-cloud",
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: gemma4RendererLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy model type resolves small",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy model type resolves large",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||||
|
},
|
||||||
|
want: gemma4RendererLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy unknown defaults small",
|
||||||
|
model: &Model{
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: gemma4RendererSmall,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveGemma4Renderer(tt.model); got != tt.want {
|
||||||
|
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@@ -33,6 +34,10 @@ import (
|
|||||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Blobs newer than this may belong to another process that has not written its
|
||||||
|
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
|
||||||
|
const layerPruneGracePeriod = time.Hour
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errCapabilities = errors.New("does not support")
|
errCapabilities = errors.New("does not support")
|
||||||
errCapabilityCompletion = errors.New("completion")
|
errCapabilityCompletion = errors.New("completion")
|
||||||
@@ -156,7 +161,7 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
|
|
||||||
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
// Temporary workaround — suppress vision/audio for gemma4 MLX models
|
||||||
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
|
||||||
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
|
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
|
||||||
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
|
||||||
return c == model.CapabilityVision || c == "audio"
|
return c == model.CapabilityVision || c == "audio"
|
||||||
})
|
})
|
||||||
@@ -406,31 +411,12 @@ func CopyModel(src, dst model.Name) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := manifest.Path()
|
data, err := manifest.ReadManifestData(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dstpath := filepath.Join(manifests, dst.Filepath())
|
return manifest.WriteManifestData(dst, data)
|
||||||
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
srcpath := filepath.Join(manifests, src.Filepath())
|
|
||||||
srcfile, err := os.Open(srcpath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer srcfile.Close()
|
|
||||||
|
|
||||||
dstfile, err := os.Create(dstpath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer dstfile.Close()
|
|
||||||
|
|
||||||
_, err = io.Copy(dstfile, srcfile)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||||
@@ -441,6 +427,10 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, manifest := range manifests {
|
for _, manifest := range manifests {
|
||||||
|
if manifest.BlobDigest() != "" {
|
||||||
|
delete(deleteMap, manifest.BlobDigest())
|
||||||
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range manifest.Layers {
|
||||||
delete(deleteMap, layer.Digest)
|
delete(deleteMap, layer.Digest)
|
||||||
}
|
}
|
||||||
@@ -478,10 +468,23 @@ func PruneLayers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, blob := range blobs {
|
for _, blob := range blobs {
|
||||||
|
if blob.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := blob.Info()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if time.Since(info.ModTime()) < layerPruneGracePeriod {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
name := blob.Name()
|
name := blob.Name()
|
||||||
name = strings.ReplaceAll(name, "-", ":")
|
name = strings.ReplaceAll(name, "-", ":")
|
||||||
|
|
||||||
_, err := manifest.BlobsPath(name)
|
_, err = manifest.BlobsPath(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||||
// remove invalid blobs (e.g. partial downloads)
|
// remove invalid blobs (e.g. partial downloads)
|
||||||
@@ -531,11 +534,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
// Use fast transfer for models with tensor layers (many small blobs)
|
// Use fast transfer for models with tensor layers (many small blobs)
|
||||||
if hasTensorLayers(layers) {
|
if hasTensorLayers(layers) {
|
||||||
// Read raw manifest JSON to preserve tensor metadata fields
|
// Read raw manifest JSON to preserve tensor metadata fields
|
||||||
manifestPath, err := manifest.PathForName(n)
|
manifestJSON, err := manifest.ReadManifestData(n)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
manifestJSON, err := os.ReadFile(manifestPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -592,6 +591,14 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
if existingMf.Config.Digest != "" {
|
if existingMf.Config.Digest != "" {
|
||||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||||
}
|
}
|
||||||
|
if existingMf.BlobDigest() != "" {
|
||||||
|
digest := existingMf.BlobDigest()
|
||||||
|
if blob, err := manifest.BlobsPath(digest); err == nil {
|
||||||
|
if _, err := os.Stat(blob); err == nil {
|
||||||
|
deleteMap[digest] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||||
@@ -661,21 +668,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
|
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
|
|
||||||
fp, err := manifest.PathForName(n)
|
if err := manifest.WriteManifestData(n, manifestData); err != nil {
|
||||||
if err != nil {
|
slog.Info(fmt.Sprintf("couldn't write manifest for %s", n.DisplayShortest()))
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.WriteFile(fp, manifestData, 0o644)
|
slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
|
||||||
|
|
||||||
if !envconfig.NoPrune() && len(deleteMap) > 0 {
|
if !envconfig.NoPrune() && len(deleteMap) > 0 {
|
||||||
fn(api.ProgressResponse{Status: "removing unused layers"})
|
fn(api.ProgressResponse{Status: "removing unused layers"})
|
||||||
@@ -758,19 +756,11 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
|||||||
// Write manifest
|
// Write manifest
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
|
|
||||||
fp, err := manifest.PathForName(n)
|
if err := manifest.WriteManifestData(n, manifestData); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.WriteFile(fp, manifestData, 0o644); err != nil {
|
slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,58 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/manifest"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
|
||||||
|
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
|
||||||
|
|
||||||
|
for _, digest := range []string{recentDigest, oldDigest} {
|
||||||
|
p, err := manifest.BlobsPath(digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(p, nil, 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
oldPath, err := manifest.BlobsPath(oldDigest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
|
||||||
|
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := PruneLayers(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
recentPath, err := manifest.BlobsPath(recentDigest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(recentPath); err != nil {
|
||||||
|
t.Fatalf("recent orphan was pruned: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("old orphan still exists: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModelCapabilities(t *testing.T) {
|
func TestModelCapabilities(t *testing.T) {
|
||||||
// Create completion model (llama architecture without vision)
|
// Create completion model (llama architecture without vision)
|
||||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||||
@@ -118,6 +162,39 @@ func TestModelCapabilities(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gemma4 small safetensors suppresses vision and audio",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
ModelFormat: "safetensors",
|
||||||
|
Renderer: gemma4RendererSmall,
|
||||||
|
Capabilities: []string{"vision", "audio"},
|
||||||
|
},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemma4 large safetensors suppresses vision and audio",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
ModelFormat: "safetensors",
|
||||||
|
Renderer: gemma4RendererLarge,
|
||||||
|
Capabilities: []string{"vision", "audio"},
|
||||||
|
},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "legacy gemma4 safetensors suppresses vision and audio",
|
||||||
|
model: Model{
|
||||||
|
Config: model.ConfigV2{
|
||||||
|
ModelFormat: "safetensors",
|
||||||
|
Renderer: gemma4RendererLegacy,
|
||||||
|
Capabilities: []string{"vision", "audio"},
|
||||||
|
},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare two slices of model.Capability regardless of order
|
// compare two slices of model.Capability regardless of order
|
||||||
|
|||||||
@@ -116,6 +116,10 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
|||||||
proxied, err := func() (bool, error) {
|
proxied, err := func() (bool, error) {
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/api/delete":
|
case "/api/delete":
|
||||||
|
if s.Fallback != nil {
|
||||||
|
s.Fallback.ServeHTTP(rec, r)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
return false, s.handleDelete(rec, r)
|
return false, s.handleDelete(rec, r)
|
||||||
case "/api/pull":
|
case "/api/pull":
|
||||||
return false, s.handlePull(rec, r)
|
return false, s.handlePull(rec, r)
|
||||||
|
|||||||
@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
|
|
||||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||||
if m.Config.Renderer != "" {
|
if m.Config.Renderer != "" {
|
||||||
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
rendererName := resolveRendererName(m)
|
||||||
|
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,14 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func testConfigWithRenderer(renderer string) model.ConfigV2 {
|
||||||
|
return model.ConfigV2{Renderer: renderer}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
|
||||||
|
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
|
||||||
|
}
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
type expect struct {
|
type expect struct {
|
||||||
prompt string
|
prompt string
|
||||||
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
|||||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||||
|
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model Model
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "small from name",
|
||||||
|
model: Model{
|
||||||
|
Name: "gemma4:e4b",
|
||||||
|
ShortName: "gemma4:e4b",
|
||||||
|
Config: testConfigWithRenderer(gemma4RendererLegacy),
|
||||||
|
},
|
||||||
|
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large from model type",
|
||||||
|
model: Model{
|
||||||
|
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
|
||||||
|
},
|
||||||
|
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := renderPrompt(&tt.model, msgs, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
110
server/renderer_resolution.go
Normal file
110
server/renderer_resolution.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
gemma4RendererLegacy = "gemma4"
|
||||||
|
gemma4RendererSmall = "gemma4-small"
|
||||||
|
gemma4RendererLarge = "gemma4-large"
|
||||||
|
|
||||||
|
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
|
||||||
|
// large template. Default to the small prompt unless the model is clearly in
|
||||||
|
// the large range.
|
||||||
|
gemma4LargeMinParameterCount = 16_000_000_000
|
||||||
|
)
|
||||||
|
|
||||||
|
func resolveRendererName(m *Model) string {
|
||||||
|
if m == nil || m.Config.Renderer == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Config.Renderer {
|
||||||
|
case gemma4RendererLegacy:
|
||||||
|
return resolveGemma4Renderer(m)
|
||||||
|
default:
|
||||||
|
return m.Config.Renderer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveGemma4Renderer(m *Model) string {
|
||||||
|
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
|
||||||
|
if m == nil {
|
||||||
|
return gemma4RendererLegacy
|
||||||
|
}
|
||||||
|
return m.Config.Renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
|
||||||
|
return renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
if renderer, ok := gemma4RendererFromName(m.Name); ok {
|
||||||
|
return renderer
|
||||||
|
}
|
||||||
|
|
||||||
|
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
|
||||||
|
return gemma4RendererForParameterCount(parameterCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return gemma4RendererSmall
|
||||||
|
}
|
||||||
|
|
||||||
|
func gemma4RendererForParameterCount(parameterCount uint64) string {
|
||||||
|
if parameterCount >= gemma4LargeMinParameterCount {
|
||||||
|
return gemma4RendererLarge
|
||||||
|
}
|
||||||
|
|
||||||
|
return gemma4RendererSmall
|
||||||
|
}
|
||||||
|
|
||||||
|
func gemma4RendererFromName(name string) (string, bool) {
|
||||||
|
lower := strings.ToLower(name)
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
|
||||||
|
return gemma4RendererSmall, true
|
||||||
|
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
|
||||||
|
return gemma4RendererLarge, true
|
||||||
|
default:
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHumanParameterCount(s string) (uint64, bool) {
|
||||||
|
if s == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
unit := strings.ToUpper(s[len(s)-1:])
|
||||||
|
var multiplier float64
|
||||||
|
switch unit {
|
||||||
|
case "B":
|
||||||
|
multiplier = float64(format.Billion)
|
||||||
|
case "M":
|
||||||
|
multiplier = float64(format.Million)
|
||||||
|
case "K":
|
||||||
|
multiplier = float64(format.Thousand)
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint64(value * multiplier), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isGemma4Renderer(renderer string) bool {
|
||||||
|
switch renderer {
|
||||||
|
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1770,13 +1770,15 @@ func Serve(ln net.Listener) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
manifestsPath, err := manifest.Path()
|
for _, rootFn := range []func() (string, error){manifest.Path, manifest.V2Path} {
|
||||||
if err != nil {
|
manifestsPath, err := rootFn()
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
if err := manifest.PruneDirectory(manifestsPath); err != nil && !os.IsNotExist(err) {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2408,7 +2410,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
// current approach uses the transition from parsed thinking content to
|
// current approach uses the transition from parsed thinking content to
|
||||||
// parsed non-thinking content as the signal to turn constraining on
|
// parsed non-thinking content as the signal to turn constraining on
|
||||||
|
|
||||||
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
// TODO(parthsareen): temporary fix for https://github.com/ollama/ollama/issues/15260.
|
||||||
|
// To revisit for other models and have a consistent pattern across models through parsers.
|
||||||
|
forceImmediate := m.Config.Parser == "gemma4" && req.Think != nil && !req.Think.Bool()
|
||||||
|
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && !forceImmediate && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
||||||
currentFormat = nil
|
currentFormat = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -109,12 +109,44 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(filepath.ToSlash(p), "/blobs/*") {
|
||||||
|
actual = slices.DeleteFunc(actual, isManifestBlobForTest)
|
||||||
|
}
|
||||||
|
|
||||||
if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" {
|
if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" {
|
||||||
t.Errorf("file exists mismatch (-want +got):\n%s", diff)
|
t.Errorf("file exists mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkManifestFiles(t *testing.T, names ...string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
expect := make([]string, len(names))
|
||||||
|
for i, name := range names {
|
||||||
|
p, err := manifest.V2PathForName(model.ParseName(name))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
expect[i] = p
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(envconfig.Models(), "manifests-v2", "*", "*", "*", "*"), expect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isManifestBlobForTest(path string) bool {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var m manifest.Manifest
|
||||||
|
if err := json.Unmarshal(data, &m); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.SchemaVersion != 0 && m.MediaType != "" && (m.Config.Digest != "" || len(m.Layers) > 0)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreateFromBin(t *testing.T) {
|
func TestCreateFromBin(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@@ -136,9 +168,7 @@ func TestCreateFromBin(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||||
@@ -196,9 +226,7 @@ func TestCreateFromModel(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
Name: "test2",
|
Name: "test2",
|
||||||
@@ -210,10 +238,7 @@ func TestCreateFromModel(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test", "test2")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||||
@@ -306,9 +331,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||||
@@ -327,9 +350,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||||
@@ -357,9 +378,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"),
|
filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"),
|
||||||
@@ -378,9 +397,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||||
@@ -411,9 +428,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||||
@@ -436,10 +451,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test", "test2")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Display contents of each blob in the directory
|
// Display contents of each blob in the directory
|
||||||
blobDir := filepath.Join(p, "blobs")
|
blobDir := filepath.Join(p, "blobs")
|
||||||
@@ -495,10 +507,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test", "test2")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
|
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
|
||||||
@@ -555,9 +564,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||||
@@ -589,10 +596,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test", "test2")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Old layers will not have been pruned
|
// Old layers will not have been pruned
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
@@ -650,9 +654,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"),
|
filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"),
|
||||||
@@ -850,9 +852,7 @@ func TestCreateLicenses(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
|
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
|
||||||
@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
_, digest := createBinFile(t, ggml.KV{
|
||||||
|
"general.architecture": "gemma4",
|
||||||
|
"general.parameter_count": uint64(25_200_000_000),
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Files: map[string]string{"test.gguf": digest},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse manifest: %v", err)
|
||||||
|
}
|
||||||
|
if mf.Config.Digest == "" {
|
||||||
|
t.Fatalf("unexpected empty config digest for manifest")
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath, err := manifest.BlobsPath(mf.Config.Digest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("config blob path: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgFile, err := os.Open(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open config blob: %v", err)
|
||||||
|
}
|
||||||
|
defer cfgFile.Close()
|
||||||
|
|
||||||
|
var cfg model.ConfigV2
|
||||||
|
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||||
|
t.Fatalf("decode config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Renderer != gemma4RendererLegacy {
|
||||||
|
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
|
||||||
|
}
|
||||||
|
if cfg.Parser != "gemma4" {
|
||||||
|
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDetectModelTypeFromFiles(t *testing.T) {
|
func TestDetectModelTypeFromFiles(t *testing.T) {
|
||||||
t.Run("gguf file", func(t *testing.T) {
|
t.Run("gguf file", func(t *testing.T) {
|
||||||
_, digest := createBinFile(t, nil, nil)
|
_, digest := createBinFile(t, nil, nil)
|
||||||
|
|||||||
@@ -42,10 +42,7 @@ func TestDelete(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test", "test2")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||||
@@ -60,9 +57,7 @@ func TestDelete(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "test2")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
|
||||||
})
|
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||||
@@ -76,7 +71,7 @@ func TestDelete(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
checkManifestFiles(t)
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +104,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
|||||||
t.Errorf("expected status code 200, actual %d", w.Code)
|
t.Errorf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
checkManifestFiles(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||||
@@ -129,14 +124,12 @@ func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
|||||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
checkManifestFiles(t, "gpt-oss:20b-cloud")
|
||||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
|
||||||
})
|
|
||||||
|
|
||||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
checkManifestFiles(t)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2108,6 +2108,132 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestChatFormatWithThinkFalse verifies that when a model uses a builtin
|
||||||
|
// parser that supports thinking (e.g. gemma4) and the request explicitly
|
||||||
|
// disables thinking (think=false), the format constraint is passed to the
|
||||||
|
// first and only completion call. Previously, format was deferred for all
|
||||||
|
// thinking-capable parsers and only re-applied after an end-of-thinking
|
||||||
|
// transition — a transition that never fires when thinking is off. See
|
||||||
|
// https://github.com/ollama/ollama/issues/15260.
|
||||||
|
func TestChatFormatWithThinkFalse(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := &mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: llm.DoneReasonStop,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(mock),
|
||||||
|
getGpuFn: getGpuFn,
|
||||||
|
getSystemInfoFn: getSystemInfoFn,
|
||||||
|
waitForRecovery: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{llama: mock}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
|
_, digest := createBinFile(t, ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []*ggml.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use the gemma4 builtin parser — it reports HasThinkingSupport=true, which
|
||||||
|
// adds CapabilityThinking to the model and previously triggered deferral of
|
||||||
|
// the format even when the user passed think=false.
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test-gemma4-parser",
|
||||||
|
Files: map[string]string{"file.gguf": digest},
|
||||||
|
Parser: "gemma4",
|
||||||
|
Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }}{{ end }}`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("create: expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}},"required":["answer"]}`)
|
||||||
|
|
||||||
|
var (
|
||||||
|
requestsMu sync.Mutex
|
||||||
|
requests []llm.CompletionRequest
|
||||||
|
)
|
||||||
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
requestsMu.Lock()
|
||||||
|
requests = append(requests, r)
|
||||||
|
requestsMu.Unlock()
|
||||||
|
|
||||||
|
fn(llm.CompletionResponse{
|
||||||
|
Content: `{"answer":"42"}`,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: llm.DoneReasonStop,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
streamRequest := false
|
||||||
|
think := false
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-gemma4-parser",
|
||||||
|
Messages: []api.Message{{Role: "user", Content: "Respond in JSON."}},
|
||||||
|
Think: &api.ThinkValue{Value: think},
|
||||||
|
Stream: &streamRequest,
|
||||||
|
Format: format,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("chat: expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(requests) != 1 {
|
||||||
|
t.Fatalf("expected a single completion call, got %d", len(requests))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal([]byte(format), []byte(requests[0].Format)) {
|
||||||
|
t.Errorf("expected first completion format to match the request format, got %q", string(requests[0].Format))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateUnload(t *testing.T) {
|
func TestGenerateUnload(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -658,11 +658,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
|||||||
checkManifestList := func() {
|
checkManifestList := func() {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
|
mandir, err := manifest.V2Path()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to resolve v2 manifest path: %v", err)
|
||||||
|
}
|
||||||
var entries []string
|
var entries []string
|
||||||
t.Logf("dir entries:")
|
t.Logf("dir entries:")
|
||||||
fsys := os.DirFS(mandir)
|
fsys := os.DirFS(mandir)
|
||||||
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
err = fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -685,7 +688,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
|||||||
|
|
||||||
g := entries[0] // raw path
|
g := entries[0] // raw path
|
||||||
g = filepath.ToSlash(g)
|
g = filepath.ToSlash(g)
|
||||||
w := model.ParseName(wantStableName).Filepath()
|
wp, err := manifest.V2PathForName(model.ParseName(wantStableName))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to resolve expected manifest path: %v", err)
|
||||||
|
}
|
||||||
|
w, err := filepath.Rel(mandir, wp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to make expected manifest path relative: %v", err)
|
||||||
|
}
|
||||||
w = filepath.ToSlash(w)
|
w = filepath.ToSlash(w)
|
||||||
if g != w {
|
if g != w {
|
||||||
t.Errorf("\ngot: %s\nwant: %s", g, w)
|
t.Errorf("\ngot: %s\nwant: %s", g, w)
|
||||||
|
|||||||
@@ -93,6 +93,13 @@ func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quan
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MoE router logits choose the top-k expert set. Quantization noise here
|
||||||
|
// can flip expert selection, after which downstream activations diverge
|
||||||
|
// sharply. The tensor is small, so leave it in source precision.
|
||||||
|
if isGemma4RouterProjection(name) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// Mixed-precision quantization: sensitive tensors get higher precision.
|
// Mixed-precision quantization: sensitive tensors get higher precision.
|
||||||
//
|
//
|
||||||
// Value projections (v_proj) directly determine attention output quality.
|
// Value projections (v_proj) directly determine attention output quality.
|
||||||
@@ -170,6 +177,12 @@ func isEmbedTokensWeight(name string) bool {
|
|||||||
!strings.Contains(name, "per_layer")
|
!strings.Contains(name, "per_layer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isGemma4RouterProjection(name string) bool {
|
||||||
|
return strings.HasSuffix(name, ".router.proj.weight") &&
|
||||||
|
!strings.Contains(name, "audio_tower") &&
|
||||||
|
!strings.Contains(name, "vision_tower")
|
||||||
|
}
|
||||||
|
|
||||||
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
|
||||||
if td == nil {
|
if td == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
|||||||
@@ -68,6 +68,11 @@ func TestGemma4QuantizationType(t *testing.T) {
|
|||||||
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
|
||||||
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
|
||||||
|
|
||||||
|
// === Router projection: expert selection is sensitive; keep source precision ===
|
||||||
|
{"router proj int4", transform26B, "model.layers.0.router.proj.weight", aligned, "int4", ""},
|
||||||
|
{"router proj nvfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "nvfp4", ""},
|
||||||
|
{"router proj mxfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "mxfp4", ""},
|
||||||
|
|
||||||
// === k_proj: promoted only for 8-expert models ===
|
// === k_proj: promoted only for 8-expert models ===
|
||||||
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
|
||||||
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
rootmanifest "github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ManifestLayer represents a layer in the manifest.
|
// ManifestLayer represents a layer in the manifest.
|
||||||
@@ -49,9 +51,7 @@ func DefaultManifestDir() string {
|
|||||||
// LoadManifest loads a manifest for the given model name.
|
// LoadManifest loads a manifest for the given model name.
|
||||||
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
|
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
|
||||||
func LoadManifest(modelName string) (*ModelManifest, error) {
|
func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||||
manifestPath := resolveManifestPath(modelName)
|
data, err := rootmanifest.ReadManifestData(model.ParseName(modelName))
|
||||||
|
|
||||||
data, err := os.ReadFile(manifestPath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("read manifest: %w", err)
|
return nil, fmt.Errorf("read manifest: %w", err)
|
||||||
}
|
}
|
||||||
@@ -67,36 +67,6 @@ func LoadManifest(modelName string) (*ModelManifest, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveManifestPath converts a model name to a manifest file path.
|
|
||||||
func resolveManifestPath(modelName string) string {
|
|
||||||
// Parse model name into components
|
|
||||||
// Default: registry.ollama.ai/library/<name>/<tag>
|
|
||||||
host := "registry.ollama.ai"
|
|
||||||
namespace := "library"
|
|
||||||
name := modelName
|
|
||||||
tag := "latest"
|
|
||||||
|
|
||||||
// Handle explicit tag
|
|
||||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
|
||||||
tag = name[idx+1:]
|
|
||||||
name = name[:idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle full path like "host/namespace/name"
|
|
||||||
parts := strings.Split(name, "/")
|
|
||||||
switch len(parts) {
|
|
||||||
case 3:
|
|
||||||
host = parts[0]
|
|
||||||
namespace = parts[1]
|
|
||||||
name = parts[2]
|
|
||||||
case 2:
|
|
||||||
namespace = parts[0]
|
|
||||||
name = parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BlobPath returns the full path to a blob given its digest.
|
// BlobPath returns the full path to a blob given its digest.
|
||||||
func (m *ModelManifest) BlobPath(digest string) string {
|
func (m *ModelManifest) BlobPath(digest string) string {
|
||||||
// Convert "sha256:abc123" to "sha256-abc123"
|
// Convert "sha256:abc123" to "sha256-abc123"
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package manifest
|
package manifest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
rootmanifest "github.com/ollama/ollama/manifest"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTotalTensorSize(t *testing.T) {
|
func TestTotalTensorSize(t *testing.T) {
|
||||||
@@ -55,3 +59,39 @@ func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
|||||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadManifestPrefersV2(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
name := model.ParseName("example")
|
||||||
|
|
||||||
|
legacyPath, err := rootmanifest.PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v2Path, err := rootmanifest.V2PathForName(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(v2Path), 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(v2Path, []byte(`{"schemaVersion":2,"mediaType":"v2"}`), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := LoadManifest(name.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if m.Manifest.MediaType != "v2" {
|
||||||
|
t.Fatalf("media type = %q, want %q", m.Manifest.MediaType, "v2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
|
|||||||
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
|
||||||
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
|
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
|
||||||
cmd.Env = os.Environ()
|
cmd.Env = os.Environ()
|
||||||
|
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
|
||||||
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
|
|
||||||
libraryPaths := []string{ml.LibOllamaPath}
|
|
||||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
|
||||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append existing LD_LIBRARY_PATH if set
|
|
||||||
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
|
||||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
|
||||||
|
|
||||||
// Update or add LD_LIBRARY_PATH in cmd.Env
|
|
||||||
found := false
|
|
||||||
for i := range cmd.Env {
|
|
||||||
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
|
|
||||||
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
|
|
||||||
}
|
|
||||||
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cmd = cmd
|
s.cmd = cmd
|
||||||
|
|
||||||
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mlxLibraryPathEnv() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
return "PATH"
|
||||||
|
case "darwin":
|
||||||
|
return "DYLD_LIBRARY_PATH"
|
||||||
|
default:
|
||||||
|
return "LD_LIBRARY_PATH"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
|
||||||
|
if len(libraryPaths) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search order for the imagegen runner is:
|
||||||
|
// 1. bundled lib/ollama root
|
||||||
|
// 2. backend-specific library dirs selected during GPU discovery
|
||||||
|
// 3. any existing caller-provided library path values
|
||||||
|
pathEnv := mlxLibraryPathEnv()
|
||||||
|
pathEnvPaths := append([]string{}, libraryPaths...)
|
||||||
|
if existingPath, ok := os.LookupEnv(pathEnv); ok {
|
||||||
|
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
|
||||||
|
}
|
||||||
|
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||||
|
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
|
||||||
|
|
||||||
|
ollamaLibraryPaths := append([]string{}, libraryPaths...)
|
||||||
|
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
|
||||||
|
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
|
||||||
|
}
|
||||||
|
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||||
|
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
|
||||||
|
for i := range cmd.Env {
|
||||||
|
name, _, ok := strings.Cut(cmd.Env[i], "=")
|
||||||
|
if ok && strings.EqualFold(name, key) {
|
||||||
|
cmd.Env[i] = key + "=" + value
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmd.Env = append(cmd.Env, key+"="+value)
|
||||||
|
}
|
||||||
|
|
||||||
// getLastErr returns the last stderr line.
|
// getLastErr returns the last stderr line.
|
||||||
func (s *Server) getLastErr() string {
|
func (s *Server) getLastErr() string {
|
||||||
s.lastErrLock.Lock()
|
s.lastErrLock.Lock()
|
||||||
|
|||||||
24
x/mlxrunner/cache/cache.go
vendored
24
x/mlxrunner/cache/cache.go
vendored
@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
|
|||||||
mlx.Pin(c.keys, c.values)
|
mlx.Pin(c.keys, c.values)
|
||||||
} else {
|
} else {
|
||||||
if c.idx < c.keys.Dim(2) {
|
if c.idx < c.keys.Dim(2) {
|
||||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
if c.offset <= c.maxSize {
|
||||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
// Not yet wrapped: slots [c.idx, Dim) are grow padding
|
||||||
|
// or stale post-rewind data, not live window content.
|
||||||
|
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||||
|
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||||
|
} else {
|
||||||
|
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
|
||||||
|
// Linearize so the trim + concat below operate on contiguous
|
||||||
|
// positions and preserve the last (maxSize - 1) old tokens.
|
||||||
|
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
|
||||||
|
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
|
||||||
|
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||||
|
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||||
|
c.keys.Set(tailK.Concatenate(2, headK))
|
||||||
|
c.values.Set(tailV.Concatenate(2, headV))
|
||||||
|
c.idx = c.keys.Dim(2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trim to max_size to maintain sliding window
|
// Trim to max_size to maintain sliding window
|
||||||
@@ -322,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
|
|||||||
if c.keys == nil || c.values == nil {
|
if c.keys == nil || c.values == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
liveLen := min(c.offset, c.keys.Dim(2))
|
||||||
return []*mlx.Array{
|
return []*mlx.Array{
|
||||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
|
||||||
|
// tensors whose channel value is the token id, so stateIDs can recover
|
||||||
|
// which ids survived in the cache.
|
||||||
|
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
|
||||||
|
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||||
|
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||||
|
return k, v
|
||||||
|
}
|
||||||
|
|
||||||
|
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
|
||||||
|
data := make([]float32, 0, 2*len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
data = append(data, id, id)
|
||||||
|
}
|
||||||
|
k := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||||
|
v := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||||
|
return k, v
|
||||||
|
}
|
||||||
|
|
||||||
|
// stateIDs returns the ids currently in the cache in slot order (logical
|
||||||
|
// after a concat, physical/rotated after a single-token update).
|
||||||
|
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
|
||||||
|
t.Helper()
|
||||||
|
state := c.State()
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mlx.Eval(state[0])
|
||||||
|
flat := state[0].Floats()
|
||||||
|
n := state[0].Dim(2)
|
||||||
|
out := make([]float32, n)
|
||||||
|
for i := range n {
|
||||||
|
out[i] = flat[i*2]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalSlice(a, b []float32) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i] != b[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
|
||||||
|
ids := make([]float32, n)
|
||||||
|
for i := range ids {
|
||||||
|
ids[i] = startID + float32(i)
|
||||||
|
}
|
||||||
|
k, v := multiTokenKV(ids)
|
||||||
|
c.Update(k, v)
|
||||||
|
return startID + float32(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func feedSingle(c *RotatingKVCache, id float32) {
|
||||||
|
k, v := singleTokenKV(id)
|
||||||
|
c.Update(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
|
||||||
|
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
|
||||||
|
// pre-existing tokens in logical order so the first Q of the new batch
|
||||||
|
// has a full sliding window.
|
||||||
|
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
nextID := feedMulti(c, 1, 3)
|
||||||
|
for range 6 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 9 {
|
||||||
|
t.Fatalf("setup: offset=%d want 9", c.Offset())
|
||||||
|
}
|
||||||
|
if c.idx >= c.maxSize {
|
||||||
|
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 10, 2)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{7, 8, 9, 10, 11}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-concat window=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
if c.Offset() != 11 {
|
||||||
|
t.Fatalf("offset=%d want 11", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
|
||||||
|
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
|
||||||
|
// tokens plus the full new batch. This is the chunked-prefill contract
|
||||||
|
// x/mlxrunner/pipeline.go relies on.
|
||||||
|
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
|
||||||
|
feedMulti(c, 1, 6)
|
||||||
|
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
|
||||||
|
// so the first new Q has its full window in scope for this forward.
|
||||||
|
feedMulti(c, 7, 3)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{4, 5, 6, 7, 8, 9}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The next decode trims oversize back to maxSize; order may be
|
||||||
|
// physical (rotated), so check as a set.
|
||||||
|
feedSingle(c, 10)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
if len(got) != window {
|
||||||
|
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||||
|
}
|
||||||
|
seen := map[float32]bool{}
|
||||||
|
for _, v := range got {
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
for _, w := range []float32{7, 8, 9, 10} {
|
||||||
|
if !seen[w] {
|
||||||
|
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
|
||||||
|
// underlying buffer by `step` slots via mlx.Zeros before writing, so
|
||||||
|
// after one decode on a short prefill c.idx < Dim even though the cache
|
||||||
|
// has not wrapped. Those trailing slots are zero padding and must not
|
||||||
|
// be pulled back into the live window on the next concat.
|
||||||
|
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 512
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 3)
|
||||||
|
feedSingle(c, 4)
|
||||||
|
feedMulti(c, 5, 3)
|
||||||
|
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 3, 4, 5, 6, 7}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("growing-buffer concat=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
|
||||||
|
// Restore(nil, target) between conversation turns to rewind the cache to
|
||||||
|
// the matched prefix. Restore moves c.offset/c.idx without trimming the
|
||||||
|
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
|
||||||
|
// tokens. A subsequent concat must drop those, not treat them as wrapped
|
||||||
|
// window content.
|
||||||
|
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 8
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
// Grow the buffer to exactly maxSize without wrapping.
|
||||||
|
feedMulti(c, 1, 2)
|
||||||
|
for id := float32(3); id <= 8; id++ {
|
||||||
|
feedSingle(c, id)
|
||||||
|
}
|
||||||
|
if c.Offset() != window {
|
||||||
|
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.Restore(nil, 2) {
|
||||||
|
t.Fatalf("live rewind to 2 failed")
|
||||||
|
}
|
||||||
|
if c.Offset() != 2 {
|
||||||
|
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 9, 3)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 9, 10, 11}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("post-rewind concat=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
if c.Offset() != 5 {
|
||||||
|
t.Fatalf("offset=%d want 5", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
|
||||||
|
// formula drops to non-positive and all pre-existing tokens are kept.
|
||||||
|
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 2)
|
||||||
|
feedMulti(c, 3, 2)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{1, 2, 3, 4}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("growing buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
|
||||||
|
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
|
||||||
|
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
|
||||||
|
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
|
||||||
|
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
feedMulti(c, 1, 8)
|
||||||
|
if c.Offset() != 8 {
|
||||||
|
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 9, 8)
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, 17, 4)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
want = []float32{14, 15, 16, 17, 18, 19, 20}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode trims oversize back to maxSize; order may be physical.
|
||||||
|
feedSingle(c, 21)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
if len(got) != window {
|
||||||
|
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||||
|
}
|
||||||
|
seen := map[float32]bool{}
|
||||||
|
for _, v := range got {
|
||||||
|
seen[v] = true
|
||||||
|
}
|
||||||
|
for _, w := range []float32{18, 19, 20, 21} {
|
||||||
|
if !seen[w] {
|
||||||
|
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
|
||||||
|
// prefill sequence and checks that each new prefill retains the last
|
||||||
|
// (maxSize-1) pre-existing tokens in logical order.
|
||||||
|
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
const window = 4
|
||||||
|
c := NewRotatingKVCache(window)
|
||||||
|
|
||||||
|
nextID := feedMulti(c, 1, 2)
|
||||||
|
for range 5 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 7 {
|
||||||
|
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, nextID, 3)
|
||||||
|
nextID += 3
|
||||||
|
got := stateIDs(t, c)
|
||||||
|
want := []float32{5, 6, 7, 8, 9, 10}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
for range 4 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
}
|
||||||
|
if c.Offset() != 14 {
|
||||||
|
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
|
||||||
|
}
|
||||||
|
|
||||||
|
feedMulti(c, nextID, 2)
|
||||||
|
got = stateIDs(t, c)
|
||||||
|
want = []float32{12, 13, 14, 15, 16}
|
||||||
|
if !equalSlice(got, want) {
|
||||||
|
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
|
||||||
|
// token count through any mix of Update() calls — Gemma 4 uses
|
||||||
|
// donorEntry.Offset - L for the consumer's RoPE offset.
|
||||||
|
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
c := NewRotatingKVCache(4)
|
||||||
|
nextID := feedMulti(c, 1, 3)
|
||||||
|
if c.Offset() != 3 {
|
||||||
|
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
|
||||||
|
}
|
||||||
|
for i := range 5 {
|
||||||
|
feedSingle(c, nextID)
|
||||||
|
nextID++
|
||||||
|
if c.Offset() != 3+i+1 {
|
||||||
|
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nextID = feedMulti(c, nextID, 2)
|
||||||
|
if c.Offset() != 10 {
|
||||||
|
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
|
||||||
|
}
|
||||||
|
// L > maxSize concat.
|
||||||
|
feedMulti(c, nextID, 7)
|
||||||
|
if c.Offset() != 17 {
|
||||||
|
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -151,20 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
type CompletionRequest struct {
|
||||||
type completionRequest struct {
|
Prompt string
|
||||||
Prompt string `json:"prompt"`
|
Options api.Options
|
||||||
Options *completionOpts `json:"options,omitempty"`
|
Logprobs bool
|
||||||
}
|
TopLogprobs int
|
||||||
|
|
||||||
type completionOpts struct {
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
TopP float32 `json:"top_p,omitempty"`
|
|
||||||
MinP float32 `json:"min_p,omitempty"`
|
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
|
||||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
|
||||||
NumPredict int `json:"num_predict,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
@@ -177,6 +168,8 @@ type CompletionResponse struct {
|
|||||||
EvalCount int
|
EvalCount int
|
||||||
EvalDuration time.Duration
|
EvalDuration time.Duration
|
||||||
|
|
||||||
|
Logprobs []llm.Logprob
|
||||||
|
|
||||||
Error *api.StatusError
|
Error *api.StatusError
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,19 +194,13 @@ func (c *Client) Close() error {
|
|||||||
|
|
||||||
// Completion implements llm.LlamaServer.
|
// Completion implements llm.LlamaServer.
|
||||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
creq := completionRequest{
|
creq := CompletionRequest{
|
||||||
Prompt: req.Prompt,
|
Prompt: req.Prompt,
|
||||||
|
Logprobs: req.Logprobs,
|
||||||
|
TopLogprobs: req.TopLogprobs,
|
||||||
}
|
}
|
||||||
if req.Options != nil {
|
if req.Options != nil {
|
||||||
creq.Options = &completionOpts{
|
creq.Options = *req.Options
|
||||||
Temperature: req.Options.Temperature,
|
|
||||||
TopP: req.Options.TopP,
|
|
||||||
MinP: req.Options.MinP,
|
|
||||||
TopK: req.Options.TopK,
|
|
||||||
RepeatLastN: req.Options.RepeatLastN,
|
|
||||||
PresencePenalty: req.Options.PresencePenalty,
|
|
||||||
NumPredict: req.Options.NumPredict,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := json.Marshal(creq)
|
body, err := json.Marshal(creq)
|
||||||
@@ -262,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||||||
PromptEvalDuration: raw.PromptEvalDuration,
|
PromptEvalDuration: raw.PromptEvalDuration,
|
||||||
EvalCount: raw.EvalCount,
|
EvalCount: raw.EvalCount,
|
||||||
EvalDuration: raw.EvalDuration,
|
EvalDuration: raw.EvalDuration,
|
||||||
|
Logprobs: raw.Logprobs,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(cresp)
|
fn(cresp)
|
||||||
|
|||||||
@@ -1,62 +1,86 @@
|
|||||||
package mlx
|
package mlx
|
||||||
|
|
||||||
// #include "generated.h"
|
|
||||||
import "C"
|
|
||||||
import "math"
|
import "math"
|
||||||
|
|
||||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||||
|
|
||||||
// GELUApprox matches mlx.nn.gelu_approx:
|
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||||
//
|
// as a fused kernel.
|
||||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
var GELUApprox = Compile1(
|
||||||
func GELUApprox(x *Array) *Array {
|
"GELUApprox",
|
||||||
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs.
|
func(x *Array) *Array {
|
||||||
half := scalarWithDtype(0.5, x)
|
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
||||||
defer C.mlx_array_free(half)
|
dt := x.DType()
|
||||||
coeff := scalarWithDtype(geluCoeff, x)
|
half := FromValue[float32](0.5).AsType(dt)
|
||||||
defer C.mlx_array_free(coeff)
|
coeff := FromValue(geluCoeff).AsType(dt)
|
||||||
c := scalarWithDtype(0.044715, x)
|
c := FromValue[float32](0.044715).AsType(dt)
|
||||||
defer C.mlx_array_free(c)
|
one := FromValue[float32](1.0).AsType(dt)
|
||||||
|
|
||||||
// x^3 via x*x*x (avoids general Power which is slower)
|
// x^3 via x*x*x (avoids general Power which is slower).
|
||||||
x3 := New("GELU_X3")
|
x3 := x.Multiply(x).Multiply(x)
|
||||||
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx)
|
inner := x.Add(c.Multiply(x3))
|
||||||
tmp := New("GELU_X3b")
|
tanh := coeff.Multiply(inner).Tanh()
|
||||||
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx)
|
return half.Multiply(x).Multiply(one.Add(tanh))
|
||||||
x3 = tmp
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// 0.044715 * x^3
|
// SiLU returns a * sigmoid(a) as a fused kernel.
|
||||||
cx3 := New("GELU_CX3")
|
var SiLU = Compile1(
|
||||||
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx)
|
"SiLU",
|
||||||
|
func(a *Array) *Array {
|
||||||
|
return a.Multiply(a.Sigmoid())
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// x + 0.044715 * x^3
|
// SwiGLU returns silu(gate) * up as a fused kernel.
|
||||||
inner := New("GELU_INNER")
|
var SwiGLU = Compile2(
|
||||||
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx)
|
"SwiGLU",
|
||||||
|
func(gate, up *Array) *Array {
|
||||||
|
return SiLU(gate).Multiply(up)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// sqrt(2/pi) * (x + 0.044715 * x^3)
|
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||||
scaled := New("GELU_SCALED")
|
// geglu, used by Gemma-family MLP and MoE paths.
|
||||||
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx)
|
var GeGLU = Compile2(
|
||||||
|
"GeGLU",
|
||||||
|
func(gate, up *Array) *Array {
|
||||||
|
return GELUApprox(gate).Multiply(up)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// tanh(...)
|
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||||
th := New("GELU_TANH")
|
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||||
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx)
|
var LogitSoftcap = Compile2(
|
||||||
|
"LogitSoftcap",
|
||||||
|
func(x, cap *Array) *Array {
|
||||||
|
return x.Divide(cap).Tanh().Multiply(cap)
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// 1 + tanh(...)
|
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
|
||||||
one := scalarWithDtype(1.0, x)
|
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
|
||||||
defer C.mlx_array_free(one)
|
// per-expert scores after top-k) and the post-bias negation (used as the
|
||||||
onePlusTanh := New("GELU_1PT")
|
// argpartition key for top-k) share a single kernel.
|
||||||
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx)
|
var sigmoidRouterFused = Compile(
|
||||||
|
"SigmoidRouter",
|
||||||
|
func(in ...*Array) []*Array {
|
||||||
|
gates, bias := in[0], in[1]
|
||||||
|
orig := gates.Sigmoid()
|
||||||
|
neg := orig.Add(bias).Negative()
|
||||||
|
return []*Array{orig, neg}
|
||||||
|
},
|
||||||
|
Shapeless(),
|
||||||
|
)
|
||||||
|
|
||||||
// 0.5 * x
|
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
|
||||||
halfX := New("GELU_HALFX")
|
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
|
||||||
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx)
|
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
|
||||||
|
out := sigmoidRouterFused(gates, bias)
|
||||||
// 0.5 * x * (1 + tanh(...))
|
return out[0], out[1]
|
||||||
out := New("GELU_APPROX")
|
|
||||||
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func SILU(t *Array) *Array {
|
|
||||||
return t.Multiply(t.Sigmoid()).AsType(t.DType())
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,11 @@ var arrays []*Array
|
|||||||
|
|
||||||
func New(name string) *Array {
|
func New(name string) *Array {
|
||||||
t := &Array{name: name}
|
t := &Array{name: name}
|
||||||
arrays = append(arrays, t)
|
if tracing {
|
||||||
|
traceScratch = append(traceScratch, t)
|
||||||
|
} else {
|
||||||
|
arrays = append(arrays, t)
|
||||||
|
}
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,6 +238,9 @@ func (t Array) Float() float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Array) Ints() []int {
|
func (t Array) Ints() []int {
|
||||||
|
if dt := t.DType(); dt != DTypeInt32 {
|
||||||
|
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
|
||||||
|
}
|
||||||
ints := make([]int, t.Size())
|
ints := make([]int, t.Size())
|
||||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||||
ints[i] = int(f)
|
ints[i] = int(f)
|
||||||
@@ -242,6 +249,9 @@ func (t Array) Ints() []int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Array) Floats() []float32 {
|
func (t Array) Floats() []float32 {
|
||||||
|
if dt := t.DType(); dt != DTypeFloat32 {
|
||||||
|
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
|
||||||
|
}
|
||||||
floats := make([]float32, t.Size())
|
floats := make([]float32, t.Size())
|
||||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||||
floats[i] = float32(f)
|
floats[i] = float32(f)
|
||||||
|
|||||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package mlx
|
||||||
|
|
||||||
|
// #include <stdlib.h>
|
||||||
|
// #include "generated.h"
|
||||||
|
//
|
||||||
|
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||||
|
// extern void closureDestructor(void* payload);
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"runtime/cgo"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompileFunc is the signature of a function that can be compiled.
|
||||||
|
type CompileFunc func(inputs ...*Array) []*Array
|
||||||
|
|
||||||
|
// CompileOption configures Compile behavior.
|
||||||
|
type CompileOption func(*compileConfig)
|
||||||
|
|
||||||
|
type compileConfig struct {
|
||||||
|
shapeless bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||||
|
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||||
|
// on each new (shape, dtype) combination and caches each specialization.
|
||||||
|
func Shapeless() CompileOption {
|
||||||
|
return func(c *compileConfig) { c.shapeless = true }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile returns a compiled version of fn. When called during another
|
||||||
|
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||||
|
// inner ones.
|
||||||
|
//
|
||||||
|
// Compiled functions must not have side effects outside of the function. Do
|
||||||
|
// not access data other than the arguments passed in (either Go data or MLX
|
||||||
|
// arrays) unless it is a constant.
|
||||||
|
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||||
|
var cfg compileConfig
|
||||||
|
for _, o := range opts {
|
||||||
|
o(&cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var closure C.mlx_closure
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
return func(inputs ...*Array) []*Array {
|
||||||
|
if tracing {
|
||||||
|
return fn(inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
once.Do(func() {
|
||||||
|
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||||
|
*payload = cgo.NewHandle(fn)
|
||||||
|
src := C.mlx_closure_new_func_payload(
|
||||||
|
(*[0]byte)(C.closureCallback),
|
||||||
|
unsafe.Pointer(payload),
|
||||||
|
(*[0]byte)(C.closureDestructor),
|
||||||
|
)
|
||||||
|
defer C.mlx_closure_free(src)
|
||||||
|
|
||||||
|
closure = C.mlx_closure_new()
|
||||||
|
mlxCheck(name+": compile failed", func() C.int {
|
||||||
|
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
inVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(inVec)
|
||||||
|
for _, in := range inputs {
|
||||||
|
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
outVec := C.mlx_vector_array_new()
|
||||||
|
defer C.mlx_vector_array_free(outVec)
|
||||||
|
mlxCheck(name+": closure apply failed", func() C.int {
|
||||||
|
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||||
|
})
|
||||||
|
|
||||||
|
n := int(C.mlx_vector_array_size(outVec))
|
||||||
|
outputs := make([]*Array, n)
|
||||||
|
for i := range n {
|
||||||
|
outputs[i] = New(name)
|
||||||
|
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile1 compiles a unary function. See Compile.
|
||||||
|
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a *Array) *Array {
|
||||||
|
return cf(a)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile2 compiles a binary function. See Compile.
|
||||||
|
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0], in[1])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a, b *Array) *Array {
|
||||||
|
return cf(a, b)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile3 compiles a ternary function. See Compile.
|
||||||
|
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||||
|
cf := Compile(name, func(in ...*Array) []*Array {
|
||||||
|
return []*Array{fn(in[0], in[1], in[2])}
|
||||||
|
}, opts...)
|
||||||
|
return func(a, b, c *Array) *Array {
|
||||||
|
return cf(a, b, c)[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tracing is true while a compile callback is running. Since MLX is
|
||||||
|
// single-threaded at this level a plain Go bool suffices.
|
||||||
|
var tracing bool
|
||||||
|
|
||||||
|
// traceScratch collects arrays created during a compile trace so they can be
|
||||||
|
// freed as a group when the callback returns.
|
||||||
|
var traceScratch []*Array
|
||||||
|
|
||||||
|
//export closureCallback
|
||||||
|
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("mlx closure callback panicked", "panic", r)
|
||||||
|
rc = 1
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
handle := *(*cgo.Handle)(payload)
|
||||||
|
fn := handle.Value().(CompileFunc)
|
||||||
|
|
||||||
|
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||||
|
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||||
|
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||||
|
if tracing {
|
||||||
|
panic("mlx: nested compile trace")
|
||||||
|
}
|
||||||
|
tracing = true
|
||||||
|
traceScratch = nil
|
||||||
|
defer func() {
|
||||||
|
for _, a := range traceScratch {
|
||||||
|
if a.pinned > 0 {
|
||||||
|
panic("mlx: traced array was pinned during compilation")
|
||||||
|
}
|
||||||
|
if a.Valid() {
|
||||||
|
C.mlx_array_free(a.ctx)
|
||||||
|
a.ctx.ctx = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing = false
|
||||||
|
traceScratch = nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
n := int(C.mlx_vector_array_size(input))
|
||||||
|
inputs := make([]*Array, n)
|
||||||
|
for i := range n {
|
||||||
|
a := New("")
|
||||||
|
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||||
|
inputs[i] = a
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs := fn(inputs...)
|
||||||
|
|
||||||
|
var arrPtr *C.mlx_array
|
||||||
|
if len(outputs) > 0 {
|
||||||
|
handles := make([]C.mlx_array, len(outputs))
|
||||||
|
for i, out := range outputs {
|
||||||
|
handles[i] = out.ctx
|
||||||
|
}
|
||||||
|
arrPtr = &handles[0]
|
||||||
|
}
|
||||||
|
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
//export closureDestructor
|
||||||
|
func closureDestructor(payload unsafe.Pointer) {
|
||||||
|
handle := *(*cgo.Handle)(payload)
|
||||||
|
handle.Delete()
|
||||||
|
C.free(payload)
|
||||||
|
}
|
||||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package mlx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompileFusion(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Compile fuses the ops inside a function body into a single kernel,
|
||||||
|
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||||
|
// two branches must be materialized simultaneously without fusion,
|
||||||
|
// then compare peak memory against the compiled version which fuses
|
||||||
|
// everything into one kernel with no intermediates.
|
||||||
|
const n = 1024 * 1024 // 4MB per float32 array
|
||||||
|
data := make([]float32, n)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = float32(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||||
|
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||||
|
// With fusion: single kernel, no intermediates.
|
||||||
|
body := func(a, b *Array) *Array {
|
||||||
|
return a.Multiply(b).Multiply(a.Add(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
a := FromValues(data, n)
|
||||||
|
b := FromValues(data, n)
|
||||||
|
Pin(a, b)
|
||||||
|
defer Unpin(a, b)
|
||||||
|
|
||||||
|
// Compiled: ops fused into a single kernel.
|
||||||
|
EnableCompile()
|
||||||
|
fn := Compile2("diamond", body, Shapeless())
|
||||||
|
warm := fn(a, b)
|
||||||
|
Eval(warm)
|
||||||
|
Sweep()
|
||||||
|
ClearCache()
|
||||||
|
ResetPeakMemory()
|
||||||
|
y := fn(a, b)
|
||||||
|
Eval(y)
|
||||||
|
compiledPeak := PeakMemory()
|
||||||
|
Sweep()
|
||||||
|
|
||||||
|
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||||
|
ClearCache()
|
||||||
|
ResetPeakMemory()
|
||||||
|
z := body(a, b)
|
||||||
|
Eval(z)
|
||||||
|
uncompiledPeak := PeakMemory()
|
||||||
|
Sweep()
|
||||||
|
|
||||||
|
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||||
|
t.Skip("peak memory tracking not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||||
|
|
||||||
|
if compiledPeak >= uncompiledPeak {
|
||||||
|
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileNested(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// A compiled function that calls another compiled function should
|
||||||
|
// produce correct results. The inner function inlines via isTracing()
|
||||||
|
// during the outer's trace.
|
||||||
|
inner := Compile1("silu", func(a *Array) *Array {
|
||||||
|
return a.Multiply(a.Sigmoid())
|
||||||
|
}, Shapeless())
|
||||||
|
|
||||||
|
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||||
|
return inner(gate).Multiply(up)
|
||||||
|
}, Shapeless())
|
||||||
|
|
||||||
|
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||||
|
up := FromValues([]float32{1, 1, 1}, 3)
|
||||||
|
Pin(gate, up)
|
||||||
|
defer Unpin(gate, up)
|
||||||
|
|
||||||
|
y := outer(gate, up)
|
||||||
|
Eval(y)
|
||||||
|
|
||||||
|
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||||
|
got := y.Floats()
|
||||||
|
want := []float32{0, 0.7310586, 1.7615942}
|
||||||
|
for i, v := range got {
|
||||||
|
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||||
|
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
boom := Compile1("boom", func(a *Array) *Array {
|
||||||
|
panic("intentional test panic")
|
||||||
|
})
|
||||||
|
|
||||||
|
x := FromValues([]float32{1}, 1)
|
||||||
|
Pin(x)
|
||||||
|
defer Unpin(x)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("expected panic from Call, got none")
|
||||||
|
}
|
||||||
|
if _, ok := r.(string); !ok {
|
||||||
|
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
boom(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||||
|
skipIfNoMLX(t)
|
||||||
|
|
||||||
|
// Repeated invocations of a compiled kernel should not grow the
|
||||||
|
// tracked-arrays list — the callback's traceScratch collects
|
||||||
|
// intermediates during tracing and frees them when the callback returns.
|
||||||
|
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||||
|
return a.Multiply(b).Add(b)
|
||||||
|
})
|
||||||
|
|
||||||
|
a := FromValues([]float32{1, 2}, 2)
|
||||||
|
b := FromValues([]float32{3, 4}, 2)
|
||||||
|
Pin(a, b)
|
||||||
|
defer Unpin(a, b)
|
||||||
|
|
||||||
|
Sweep()
|
||||||
|
before := len(arrays)
|
||||||
|
|
||||||
|
for range 100 {
|
||||||
|
_ = fn(a, b)
|
||||||
|
Sweep()
|
||||||
|
}
|
||||||
|
|
||||||
|
after := len(arrays)
|
||||||
|
if after > before+2 {
|
||||||
|
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,8 +9,8 @@ package mlx
|
|||||||
// #include "generated.h"
|
// #include "generated.h"
|
||||||
// #include <string.h>
|
// #include <string.h>
|
||||||
//
|
//
|
||||||
// static char _mlx_last_error_msg[1024] = {0};
|
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||||
// static int _mlx_last_error_flag = 0;
|
// static __thread int _mlx_last_error_flag = 0;
|
||||||
//
|
//
|
||||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||||
// (void)data;
|
// (void)data;
|
||||||
@@ -30,15 +30,13 @@ package mlx
|
|||||||
// _mlx_last_error_msg[0] = '\0';
|
// _mlx_last_error_msg[0] = '\0';
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// static int mlx_had_last_error(void) {
|
|
||||||
// return _mlx_last_error_flag;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// static const char* mlx_get_last_error(void) {
|
// static const char* mlx_get_last_error(void) {
|
||||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL;
|
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||||
// }
|
// }
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
import "runtime"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// Replace the default exit(-1) error handler with one that captures
|
// Replace the default exit(-1) error handler with one that captures
|
||||||
// the error message so we can surface it in Go.
|
// the error message so we can surface it in Go.
|
||||||
@@ -53,6 +51,24 @@ func Version() string {
|
|||||||
return C.GoString(C.mlx_string_data(str))
|
return C.GoString(C.mlx_string_data(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||||
|
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||||
|
// The thread lock ensures the thread-local error state is read from the same
|
||||||
|
// thread that executed the call.
|
||||||
|
func mlxCheck(fallback string, fn func() C.int) {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
C.mlx_clear_last_error()
|
||||||
|
if fn() != 0 {
|
||||||
|
msg := C.GoString(C.mlx_get_last_error())
|
||||||
|
if msg == "" {
|
||||||
|
msg = fallback
|
||||||
|
}
|
||||||
|
panic("mlx: " + msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func doEval(outputs []*Array, async bool) {
|
func doEval(outputs []*Array, async bool) {
|
||||||
if len(outputs) == 0 {
|
if len(outputs) == 0 {
|
||||||
return
|
return
|
||||||
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
C.mlx_clear_last_error()
|
mlxCheck("eval failed", func() C.int {
|
||||||
var rc C.int
|
if async {
|
||||||
if async {
|
return C.mlx_async_eval(vector)
|
||||||
rc = C.mlx_async_eval(vector)
|
|
||||||
} else {
|
|
||||||
rc = C.mlx_eval(vector)
|
|
||||||
}
|
|
||||||
if rc != 0 {
|
|
||||||
msg := "mlx eval failed"
|
|
||||||
if C.mlx_had_last_error() != 0 {
|
|
||||||
msg = C.GoString(C.mlx_get_last_error())
|
|
||||||
}
|
}
|
||||||
panic("mlx: " + msg)
|
return C.mlx_eval(vector)
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func AsyncEval(outputs ...*Array) {
|
func AsyncEval(outputs ...*Array) {
|
||||||
|
|||||||
@@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
|
||||||
|
out := New("MAX_AXIS")
|
||||||
|
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) Matmul(other *Array) *Array {
|
func (t *Array) Matmul(other *Array) *Array {
|
||||||
out := New("MATMUL")
|
out := New("MATMUL")
|
||||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||||
@@ -169,6 +175,12 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
|
||||||
|
out := New("SCATTER_ADD_AXIS")
|
||||||
|
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Array) Reshape(axes ...int) *Array {
|
func (t *Array) Reshape(axes ...int) *Array {
|
||||||
cAxes := make([]C.int, len(axes))
|
cAxes := make([]C.int, len(axes))
|
||||||
for i := range axes {
|
for i := range axes {
|
||||||
|
|||||||
@@ -404,11 +404,6 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
|
|||||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SiLU(a *Array) *Array {
|
|
||||||
sig := a.Sigmoid()
|
|
||||||
return a.Multiply(sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
|
||||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,11 +7,15 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||||
|
"github.com/ollama/ollama/x/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func prefillChunkSize() int {
|
func prefillChunkSize() int {
|
||||||
@@ -23,28 +27,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
return errors.New("model not loaded")
|
return errors.New("model not loaded")
|
||||||
}
|
}
|
||||||
|
|
||||||
enableCompile := true
|
|
||||||
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
|
|
||||||
enableCompile = modelCompile.EnableCompile()
|
|
||||||
}
|
|
||||||
if enableCompile {
|
|
||||||
mlx.EnableCompile()
|
|
||||||
} else {
|
|
||||||
mlx.DisableCompile()
|
|
||||||
}
|
|
||||||
mlx.ResetPeakMemory()
|
mlx.ResetPeakMemory()
|
||||||
ctx := request.Ctx
|
ctx := request.Ctx
|
||||||
var (
|
var sample, nextSample sampler.Result
|
||||||
sample, logprobs *mlx.Array
|
|
||||||
nextSample, nextLogprobs *mlx.Array
|
|
||||||
)
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if request.Sampler != nil {
|
if request.Sampler != nil {
|
||||||
request.Sampler.Free()
|
request.Sampler.Free()
|
||||||
}
|
}
|
||||||
mlx.Unpin(sample, logprobs)
|
mlx.Unpin(sample.Arrays()...)
|
||||||
mlx.Unpin(nextSample, nextLogprobs)
|
mlx.Unpin(nextSample.Arrays()...)
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
|
|
||||||
@@ -69,10 +61,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
|
|
||||||
// Cap generation to stay within the model's context length
|
// Cap generation to stay within the model's context length
|
||||||
maxGenerate := r.contextLength - len(inputs)
|
maxGenerate := r.contextLength - len(inputs)
|
||||||
if request.Options.MaxTokens <= 0 {
|
if request.Options.NumPredict <= 0 {
|
||||||
request.Options.MaxTokens = maxGenerate
|
request.Options.NumPredict = maxGenerate
|
||||||
} else {
|
} else {
|
||||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Sampler.ResetHistory(inputs)
|
request.Sampler.ResetHistory(inputs)
|
||||||
@@ -144,41 +136,38 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
step := func(token *mlx.Array) sampler.Result {
|
||||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||||
logits := r.Model.Unembed(fwd)
|
logits := r.Model.Unembed(fwd)
|
||||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||||
|
|
||||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
sample := request.Sampler.Sample(logits)
|
||||||
sample := request.Sampler.Sample(logprobs)
|
mlx.Pin(sample.Arrays()...)
|
||||||
|
|
||||||
mlx.Pin(sample, logprobs)
|
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
mlx.AsyncEval(sample, logprobs)
|
mlx.AsyncEval(sample.Arrays()...)
|
||||||
|
return sample
|
||||||
return sample, logprobs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
sample = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||||
|
|
||||||
var b bytes.Buffer
|
dec := decoder{tokenizer: r.Tokenizer}
|
||||||
|
|
||||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
|
||||||
for i := range request.Options.MaxTokens {
|
for i := range request.Options.NumPredict {
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Sampler.AppendToken(sample)
|
request.Sampler.AppendToken(sample.Token)
|
||||||
nextSample, nextLogprobs = step(sample)
|
nextSample = step(sample.Token)
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
mlx.Eval(sample)
|
mlx.Eval(sample.Arrays()...)
|
||||||
final.PromptEvalDuration = time.Since(now)
|
final.PromptEvalDuration = time.Since(now)
|
||||||
now = time.Now()
|
now = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
output := int32(sample.Int())
|
output := int32(sample.Token.Int())
|
||||||
session.outputs = append(session.outputs, output)
|
session.outputs = append(session.outputs, output)
|
||||||
|
|
||||||
if r.Tokenizer.IsEOS(output) {
|
if r.Tokenizer.IsEOS(output) {
|
||||||
@@ -187,17 +176,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
if resp, ok := dec.decode(sample); ok {
|
||||||
case <-ctx.Done():
|
select {
|
||||||
return ctx.Err()
|
case <-ctx.Done():
|
||||||
case request.Responses <- CompletionResponse{
|
return ctx.Err()
|
||||||
Content: r.Decode(output, &b),
|
case request.Responses <- resp:
|
||||||
}:
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mlx.Unpin(sample, logprobs)
|
mlx.Unpin(sample.Arrays()...)
|
||||||
sample, logprobs = nextSample, nextLogprobs
|
sample, nextSample = nextSample, sampler.Result{}
|
||||||
nextSample, nextLogprobs = nil, nil
|
|
||||||
|
|
||||||
if i%256 == 0 {
|
if i%256 == 0 {
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
@@ -213,13 +201,57 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
// decoder serializes sampled tokens into response chunks, holding bytes
|
||||||
token := r.Tokenizer.Decode([]int32{sample})
|
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
|
||||||
|
// with those bytes so Content and Logprobs stay aligned when a chunk does
|
||||||
|
// flush.
|
||||||
|
type decoder struct {
|
||||||
|
tokenizer *tokenizer.Tokenizer
|
||||||
|
buf bytes.Buffer
|
||||||
|
logprobs []llm.Logprob
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := b.WriteString(token); err != nil {
|
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
|
||||||
slog.Error("Failed to write token to buffer", "error", err)
|
output := int32(res.Token.Int())
|
||||||
return ""
|
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
|
||||||
|
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
|
||||||
|
|
||||||
|
content := flushValidUTF8Prefix(&d.buf)
|
||||||
|
if content == "" {
|
||||||
|
return CompletionResponse{}, false
|
||||||
|
}
|
||||||
|
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
|
||||||
|
d.logprobs = nil
|
||||||
|
return resp, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
|
||||||
|
if sample.Logprob == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tok := func(id int32) string { return decode([]int32{id}) }
|
||||||
|
|
||||||
|
out := llm.Logprob{
|
||||||
|
TokenLogprob: llm.TokenLogprob{
|
||||||
|
Token: tok(int32(sample.Token.Int())),
|
||||||
|
Logprob: float64(sample.Logprob.Floats()[0]),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return flushValidUTF8Prefix(b)
|
if sample.TopTokens != nil {
|
||||||
|
ids := sample.TopTokens.Ints()
|
||||||
|
vals := sample.TopLogprobs.Floats()
|
||||||
|
pairs := make([]llm.TokenLogprob, len(ids))
|
||||||
|
for i, id := range ids {
|
||||||
|
pairs[i] = llm.TokenLogprob{
|
||||||
|
Token: tok(int32(id)),
|
||||||
|
Logprob: float64(vals[i]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(pairs, func(i, j int) bool {
|
||||||
|
return pairs[i].Logprob > pairs[j].Logprob
|
||||||
|
})
|
||||||
|
out.TopLogprobs = pairs
|
||||||
|
}
|
||||||
|
return []llm.Logprob{out}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
TextCompletionsRequest
|
CompletionRequest
|
||||||
Responses chan CompletionResponse
|
Responses chan CompletionResponse
|
||||||
Pipeline func(Request) error
|
Pipeline func(Request) error
|
||||||
|
|
||||||
@@ -28,22 +28,6 @@ type Request struct {
|
|||||||
Sampler *sample.Sampler
|
Sampler *sample.Sampler
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextCompletionsRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Options struct {
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
TopP float32 `json:"top_p"`
|
|
||||||
MinP float32 `json:"min_p"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
|
|
||||||
// Deprecated: use MaxTokens instead
|
|
||||||
NumPredict int `json:"num_predict"`
|
|
||||||
} `json:"options"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
Model base.Model
|
Model base.Model
|
||||||
Tokenizer *tokenizer.Tokenizer
|
Tokenizer *tokenizer.Tokenizer
|
||||||
@@ -79,6 +63,8 @@ func (r *Runner) Load(modelName string) error {
|
|||||||
r.Model = m
|
r.Model = m
|
||||||
r.Tokenizer = m.Tokenizer()
|
r.Tokenizer = m.Tokenizer()
|
||||||
r.contextLength = m.MaxContextLength()
|
r.contextLength = m.MaxContextLength()
|
||||||
|
|
||||||
|
mlx.EnableCompile()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
249
x/mlxrunner/sample/logprob_test.go
Normal file
249
x/mlxrunner/sample/logprob_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
//go:build mlx
|
||||||
|
|
||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
// logprobEntry is the (token id, logprob) pair returned by the sampler's
|
||||||
|
// top-K extraction, used after the test-side descending sort.
|
||||||
|
type logprobEntry struct {
|
||||||
|
id int
|
||||||
|
logprob float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
|
||||||
|
// and returns the greedily-sampled token id, its logprob, and the top-K
|
||||||
|
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
|
||||||
|
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
|
||||||
|
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
s := New(Options{Logprobs: true, TopLogprobs: topK})
|
||||||
|
defer func() {
|
||||||
|
s.Free()
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
|
tensor := mlx.FromValues(logits, 1, len(logits))
|
||||||
|
res := s.Sample(tensor)
|
||||||
|
|
||||||
|
mlx.Pin(res.Arrays()...)
|
||||||
|
defer mlx.Unpin(res.Arrays()...)
|
||||||
|
mlx.Sweep()
|
||||||
|
mlx.Eval(res.Arrays()...)
|
||||||
|
|
||||||
|
selected := res.Token.Int()
|
||||||
|
selLP := float64(res.Logprob.Floats()[0])
|
||||||
|
|
||||||
|
var top []logprobEntry
|
||||||
|
if topK > 0 && res.TopTokens != nil {
|
||||||
|
ids := res.TopTokens.Ints()
|
||||||
|
vals := res.TopLogprobs.Floats()
|
||||||
|
top = make([]logprobEntry, len(ids))
|
||||||
|
for i, id := range ids {
|
||||||
|
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
|
||||||
|
}
|
||||||
|
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
|
||||||
|
}
|
||||||
|
return selected, selLP, top
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleLogprobsBasic(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logits []float32
|
||||||
|
topK int
|
||||||
|
wantSelectedID int
|
||||||
|
wantTopLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single token without top logprobs",
|
||||||
|
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
||||||
|
topK: 0,
|
||||||
|
wantSelectedID: 0,
|
||||||
|
wantTopLen: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single token with top logprobs",
|
||||||
|
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
||||||
|
topK: 3,
|
||||||
|
wantSelectedID: 0,
|
||||||
|
wantTopLen: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
|
||||||
|
if selected != tt.wantSelectedID {
|
||||||
|
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
|
||||||
|
}
|
||||||
|
if len(top) != tt.wantTopLen {
|
||||||
|
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleLogprobsNumericalStability(t *testing.T) {
|
||||||
|
logits := []float32{1000.0, 999.0, 998.0}
|
||||||
|
_, selLP, top := runSampleLogprobs(t, logits, 3)
|
||||||
|
|
||||||
|
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
|
||||||
|
t.Errorf("selected logprob is not finite: %f", selLP)
|
||||||
|
}
|
||||||
|
for i, e := range top {
|
||||||
|
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
|
||||||
|
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 1; i < len(top); i++ {
|
||||||
|
if top[i].logprob > top[i-1].logprob {
|
||||||
|
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logits []float32
|
||||||
|
}{
|
||||||
|
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
|
||||||
|
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
|
||||||
|
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
|
||||||
|
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
|
||||||
|
|
||||||
|
if selLP > 0 {
|
||||||
|
t.Errorf("selected logprob should be <= 0, got %f", selLP)
|
||||||
|
}
|
||||||
|
for i, e := range top {
|
||||||
|
if e.logprob > 0 {
|
||||||
|
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.name == "uniform" {
|
||||||
|
want := 1.0 / float64(len(tt.logits))
|
||||||
|
got := math.Exp(selLP)
|
||||||
|
if math.Abs(got-want) > 1e-6 {
|
||||||
|
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(top); i++ {
|
||||||
|
if top[i].logprob > top[i-1].logprob {
|
||||||
|
t.Errorf("top logprobs not descending at %d: %f > %f",
|
||||||
|
i, top[i].logprob, top[i-1].logprob)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, e := range top {
|
||||||
|
if e.id == selected {
|
||||||
|
found = true
|
||||||
|
if math.Abs(e.logprob-selLP) > 1e-6 {
|
||||||
|
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("selected token %d not present in top-K", selected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logits []float32
|
||||||
|
}{
|
||||||
|
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
|
||||||
|
{"large differences", []float32{10.0, 0.0, -10.0}},
|
||||||
|
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
|
||||||
|
{"very large values", []float32{500.0, 499.0, 498.0}},
|
||||||
|
{"very small values", []float32{-500.0, -499.0, -498.0}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
|
||||||
|
if len(top) != len(tt.logits) {
|
||||||
|
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
|
||||||
|
}
|
||||||
|
|
||||||
|
var sum float64
|
||||||
|
for _, e := range top {
|
||||||
|
p := math.Exp(e.logprob)
|
||||||
|
if p < 0 || p > 1 {
|
||||||
|
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
|
||||||
|
}
|
||||||
|
sum += p
|
||||||
|
}
|
||||||
|
|
||||||
|
if math.Abs(sum-1.0) > 1e-5 {
|
||||||
|
t.Errorf("probabilities sum = %f, want 1.0", sum)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
|
||||||
|
logits := []float32{3.0, 1.0, 2.0, 0.5}
|
||||||
|
|
||||||
|
maxIdx := 0
|
||||||
|
for i, v := range logits[1:] {
|
||||||
|
if v > logits[maxIdx] {
|
||||||
|
maxIdx = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
|
||||||
|
|
||||||
|
if selected != maxIdx {
|
||||||
|
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if top[0].id != maxIdx {
|
||||||
|
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
|
||||||
|
}
|
||||||
|
if math.Abs(top[0].logprob-selLP) > 1e-6 {
|
||||||
|
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSampleLogprobsTopKOrdering(t *testing.T) {
|
||||||
|
// Logits chosen so argmax order differs from index order.
|
||||||
|
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
|
||||||
|
wantOrder := []int{1, 3, 4, 0, 2}
|
||||||
|
|
||||||
|
_, _, top := runSampleLogprobs(t, logits, len(logits))
|
||||||
|
|
||||||
|
if len(top) != len(wantOrder) {
|
||||||
|
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
|
||||||
|
}
|
||||||
|
for i, e := range top {
|
||||||
|
if e.id != wantOrder[i] {
|
||||||
|
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 1; i < len(top); i++ {
|
||||||
|
if top[i].logprob > top[i-1].logprob {
|
||||||
|
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
|
||||||
|
i, top[i].logprob, i-1, top[i-1].logprob)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,47 +8,76 @@ import (
|
|||||||
|
|
||||||
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
Temperature float32
|
||||||
|
TopP float32
|
||||||
|
MinP float32
|
||||||
|
TopK int
|
||||||
|
RepeatLastN int
|
||||||
|
RepeatPenalty float32
|
||||||
|
PresencePenalty float32
|
||||||
|
FrequencyPenalty float32
|
||||||
|
|
||||||
|
// Logprobs causes Sample to populate Result.Logprob with the selected
|
||||||
|
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
||||||
|
Logprobs bool
|
||||||
|
TopLogprobs int
|
||||||
|
}
|
||||||
|
|
||||||
type Sampler struct {
|
type Sampler struct {
|
||||||
Temperature float32
|
Options
|
||||||
TopP float32
|
|
||||||
MinP float32
|
|
||||||
TopK int
|
|
||||||
RepeatLastN int
|
|
||||||
PresencePenalty float32
|
|
||||||
|
|
||||||
history *mlx.Array
|
history *mlx.Array
|
||||||
historyLen int
|
historyLen int
|
||||||
transforms []Transform
|
transforms []Transform
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler {
|
// Result bundles the outputs of one decode step. The logprob tensors are
|
||||||
s := &Sampler{
|
// populated only when the sampler is configured to report them.
|
||||||
Temperature: temp,
|
type Result struct {
|
||||||
TopP: top_p,
|
Token *mlx.Array // sampled token id, shape [B]
|
||||||
MinP: min_p,
|
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
|
||||||
TopK: top_k,
|
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
|
||||||
RepeatLastN: repeatLastN,
|
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
|
||||||
PresencePenalty: presencePenalty,
|
}
|
||||||
|
|
||||||
|
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
||||||
|
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
||||||
|
// fields stay nil; the mlx helpers skip them.
|
||||||
|
func (r Result) Arrays() []*mlx.Array {
|
||||||
|
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(opts Options) *Sampler {
|
||||||
|
if opts.RepeatPenalty <= 0 {
|
||||||
|
opts.RepeatPenalty = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s := &Sampler{Options: opts}
|
||||||
|
|
||||||
var transforms []Transform
|
var transforms []Transform
|
||||||
if presencePenalty != 0 {
|
if s.usesHistory() {
|
||||||
transforms = append(transforms, penalty)
|
transforms = append(transforms, penalty)
|
||||||
}
|
}
|
||||||
|
|
||||||
if top_p > 0 && top_p < 1 {
|
hasTopP := opts.TopP > 0 && opts.TopP < 1
|
||||||
transforms = append(transforms, topP)
|
hasTopK := opts.TopK > 0
|
||||||
}
|
switch {
|
||||||
|
case hasTopP:
|
||||||
if min_p != 0 {
|
// topKTopP always does a full descending sort for the top-P
|
||||||
transforms = append(transforms, minP)
|
// cumulative mask and opportunistically masks top-K during the
|
||||||
}
|
// same pass when it is also configured.
|
||||||
|
transforms = append(transforms, topKTopP)
|
||||||
if top_k > 0 {
|
case hasTopK:
|
||||||
|
// Argpartition (partial sort) is cheaper than a full sort.
|
||||||
transforms = append(transforms, topK)
|
transforms = append(transforms, topK)
|
||||||
}
|
}
|
||||||
|
|
||||||
if temp == 0 {
|
if opts.MinP != 0 {
|
||||||
|
transforms = append(transforms, minP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Temperature == 0 {
|
||||||
transforms = append(transforms, greedy)
|
transforms = append(transforms, greedy)
|
||||||
} else {
|
} else {
|
||||||
transforms = append(transforms, temperature)
|
transforms = append(transforms, temperature)
|
||||||
@@ -59,7 +88,7 @@ func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty flo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) usesHistory() bool {
|
func (s *Sampler) usesHistory() bool {
|
||||||
return s.PresencePenalty != 0
|
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
||||||
@@ -115,75 +144,138 @@ func (s *Sampler) Free() {
|
|||||||
s.setHistory(nil, 0)
|
s.setHistory(nil, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array {
|
// Sample runs the configured transform chain on the raw per-token logits
|
||||||
|
// and returns the sampled token id plus, when configured, the reported
|
||||||
|
// log-probability tensors for the selected token and the top-K tokens.
|
||||||
|
func (s *Sampler) Sample(logits *mlx.Array) Result {
|
||||||
|
scores := logits
|
||||||
for _, transform := range s.transforms {
|
for _, transform := range s.transforms {
|
||||||
logits = transform(s, logits)
|
scores = transform(s, scores)
|
||||||
}
|
}
|
||||||
return logits
|
res := Result{Token: scores}
|
||||||
}
|
|
||||||
|
|
||||||
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array {
|
if s.Logprobs {
|
||||||
return logits.Argmax(-1, false)
|
// Compute log_softmax in fp32 and subtract the max before
|
||||||
}
|
// logsumexp so the final subtraction stays on small values.
|
||||||
|
// Otherwise it cancels two large numbers and loses precision.
|
||||||
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
|
lp := logits.AsType(mlx.DTypeFloat32)
|
||||||
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
|
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
||||||
}
|
lp = lp.Subtract(lp.Logsumexp(true))
|
||||||
|
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
|
||||||
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
if k := s.TopLogprobs; k > 0 {
|
||||||
if s.TopP <= 0 || s.TopP >= 1 {
|
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
||||||
return logprobs
|
k = vocab
|
||||||
|
}
|
||||||
|
// Argpartition on the negated values places the K largest
|
||||||
|
// (unsorted) in positions [0:K].
|
||||||
|
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
||||||
|
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
||||||
|
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
order := logprobs.Negative().ArgsortAxis(-1)
|
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
|
||||||
sortedLogprobs := logprobs.TakeAlongAxis(order, -1)
|
return scores.Argmax(-1, false)
|
||||||
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true)
|
}
|
||||||
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
|
||||||
|
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||||
|
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// topKTopP applies top-P in a descending sort pass and, when top-K is also
|
||||||
|
// configured, masks any surviving value below the K-th largest in the same
|
||||||
|
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
|
||||||
|
// case uses a cheaper partial sort via the topK transform.
|
||||||
|
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||||
|
vocab := scores.Dim(scores.NumDims() - 1)
|
||||||
|
applyTopK := s.TopK > 0 && s.TopK < vocab
|
||||||
|
|
||||||
|
order := scores.Negative().ArgsortAxis(-1)
|
||||||
|
sorted := scores.TakeAlongAxis(order, -1)
|
||||||
|
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||||
|
|
||||||
|
// Top-P: in descending order, keep tokens whose exclusive cumulative
|
||||||
|
// probability is still below s.TopP.
|
||||||
|
probs := mlx.SoftmaxAxis(sorted, -1, true)
|
||||||
|
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
|
||||||
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
||||||
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1))))
|
sorted = mlx.Where(keep, sorted, negInf)
|
||||||
return logprobs.PutAlongAxis(order, filtered, -1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
out := scores.PutAlongAxis(order, sorted, -1)
|
||||||
if s.MinP <= 0 || s.MinP > 1 {
|
|
||||||
return logprobs
|
// Top-K: sorted is already in descending order, so positions [K, V)
|
||||||
|
// are the ones to drop. Scatter -inf through their original-layout
|
||||||
|
// indices (order[K:]). Positional (not value-based) so exactly K
|
||||||
|
// tokens survive — ties at the K-th logit get broken by the sort
|
||||||
|
// order rather than promoted through the filter.
|
||||||
|
if applyTopK {
|
||||||
|
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||||
|
out = out.PutAlongAxis(dropOrder, negInf, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1)
|
return out
|
||||||
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP))))
|
}
|
||||||
|
|
||||||
|
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||||
|
if s.MinP <= 0 || s.MinP > 1 {
|
||||||
|
return scores
|
||||||
|
}
|
||||||
|
|
||||||
|
maxScore := scores.MaxAxis(-1, true)
|
||||||
|
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
|
||||||
|
|
||||||
return mlx.Where(
|
return mlx.Where(
|
||||||
logprobs.Less(minLogprobs),
|
scores.Less(threshold),
|
||||||
mlx.FromValue(float32(math.Inf(-1))),
|
mlx.FromValue(float32(math.Inf(-1))),
|
||||||
logprobs,
|
scores,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||||
if s.TopK <= 0 {
|
if s.TopK <= 0 {
|
||||||
return logprobs
|
return scores
|
||||||
}
|
}
|
||||||
|
|
||||||
vocab := logprobs.Dim(logprobs.NumDims() - 1)
|
vocab := scores.Dim(scores.NumDims() - 1)
|
||||||
if s.TopK >= vocab {
|
if s.TopK >= vocab {
|
||||||
return logprobs
|
return scores
|
||||||
}
|
}
|
||||||
|
|
||||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||||
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 {
|
if s.historyLen == 0 {
|
||||||
return logprobs
|
return scores
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenIndices := s.history
|
tokenIndices := s.history
|
||||||
if logprobs.NumDims() > 1 {
|
if scores.NumDims() > 1 {
|
||||||
tokenIndices = tokenIndices.ExpandDims(0)
|
tokenIndices = tokenIndices.ExpandDims(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
selected := logprobs.TakeAlongAxis(tokenIndices, -1)
|
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
|
||||||
adjusted := mlx.AddScalar(selected, -s.PresencePenalty)
|
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
||||||
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1)
|
if s.RepeatPenalty != 1 {
|
||||||
|
factor := mlx.Where(
|
||||||
|
adjusted.Less(mlx.FromValue(float32(0))),
|
||||||
|
mlx.FromValue(s.RepeatPenalty),
|
||||||
|
mlx.FromValue(1/s.RepeatPenalty),
|
||||||
|
)
|
||||||
|
adjusted = adjusted.Multiply(factor)
|
||||||
|
}
|
||||||
|
if s.PresencePenalty != 0 {
|
||||||
|
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
|
||||||
|
}
|
||||||
|
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.FrequencyPenalty != 0 {
|
||||||
|
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return scores
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
||||||
// RepeatLastN = 1, PresencePenalty = 6
|
s := New(Options{RepeatLastN: 1, PresencePenalty: 6})
|
||||||
s := New(0, 0, 0, 0, 1, 6)
|
|
||||||
defer func() {
|
defer func() {
|
||||||
s.Free()
|
s.Free()
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
@@ -20,11 +19,11 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
|||||||
s.ResetHistory([]int32{0})
|
s.ResetHistory([]int32{0})
|
||||||
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
||||||
|
|
||||||
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3)
|
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||||
got := s.Sample(logprobs)
|
got := s.Sample(logits).Token
|
||||||
mlx.Eval(got)
|
mlx.Eval(got)
|
||||||
|
|
||||||
// logprobs will be [0, -1, 4] after the penalty
|
// logits will be [0, -1, 4] after the penalty
|
||||||
// and then (index) 2 after the greedy sampler
|
// and then (index) 2 after the greedy sampler
|
||||||
gotInt := got.Int()
|
gotInt := got.Int()
|
||||||
if gotInt != 2 {
|
if gotInt != 2 {
|
||||||
@@ -32,19 +31,59 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
|
||||||
s := New(0, 0, 0.5, 0, 0, 0)
|
s := New(Options{RepeatLastN: 1, RepeatPenalty: 2})
|
||||||
defer func() {
|
defer func() {
|
||||||
s.Free()
|
s.Free()
|
||||||
mlx.Sweep()
|
mlx.Sweep()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logprobs := mlx.FromValues([]float32{
|
s.ResetHistory([]int32{1})
|
||||||
|
|
||||||
|
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||||
|
got := s.Sample(logits).Token
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
|
||||||
|
gotInt := got.Int()
|
||||||
|
if gotInt != 2 {
|
||||||
|
t.Fatalf("got %d, want 2", gotInt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
|
||||||
|
s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2})
|
||||||
|
defer func() {
|
||||||
|
s.Free()
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.ResetHistory([]int32{1, 1})
|
||||||
|
|
||||||
|
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||||
|
got := s.Sample(logits).Token
|
||||||
|
mlx.Eval(got)
|
||||||
|
|
||||||
|
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
|
||||||
|
gotInt := got.Int()
|
||||||
|
if gotInt != 2 {
|
||||||
|
t.Fatalf("got %d, want 2", gotInt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
||||||
|
s := New(Options{MinP: 0.5})
|
||||||
|
defer func() {
|
||||||
|
s.Free()
|
||||||
|
mlx.Sweep()
|
||||||
|
}()
|
||||||
|
|
||||||
|
logits := mlx.FromValues([]float32{
|
||||||
float32(math.Log(0.5)),
|
float32(math.Log(0.5)),
|
||||||
float32(math.Log(0.3)),
|
float32(math.Log(0.3)),
|
||||||
float32(math.Log(0.2)),
|
float32(math.Log(0.2)),
|
||||||
}, 3)
|
}, 3)
|
||||||
got := minP(s, logprobs)
|
got := minP(s, logits)
|
||||||
mlx.Eval(got)
|
mlx.Eval(got)
|
||||||
|
|
||||||
gotFloats := got.Floats()
|
gotFloats := got.Floats()
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package mlxrunner
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
@@ -87,23 +86,25 @@ func Execute(args []string) error {
|
|||||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||||
request := Request{Responses: make(chan CompletionResponse)}
|
request := Request{Responses: make(chan CompletionResponse)}
|
||||||
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
|
||||||
slog.Error("Failed to decode request", "error", err)
|
slog.Error("Failed to decode request", "error", err)
|
||||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
|
||||||
|
|
||||||
request.Pipeline = runner.TextGenerationPipeline
|
request.Pipeline = runner.TextGenerationPipeline
|
||||||
request.Sampler = sample.New(
|
request.Sampler = sample.New(sample.Options{
|
||||||
request.Options.Temperature,
|
Temperature: request.Options.Temperature,
|
||||||
request.Options.TopP,
|
TopP: request.Options.TopP,
|
||||||
request.Options.MinP,
|
MinP: request.Options.MinP,
|
||||||
request.Options.TopK,
|
TopK: request.Options.TopK,
|
||||||
request.Options.RepeatLastN,
|
RepeatLastN: request.Options.RepeatLastN,
|
||||||
request.Options.PresencePenalty,
|
RepeatPenalty: request.Options.RepeatPenalty,
|
||||||
)
|
PresencePenalty: request.Options.PresencePenalty,
|
||||||
|
FrequencyPenalty: request.Options.FrequencyPenalty,
|
||||||
|
Logprobs: request.Logprobs,
|
||||||
|
TopLogprobs: request.TopLogprobs,
|
||||||
|
})
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ type TextConfig struct {
|
|||||||
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||||
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
|
||||||
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
|
||||||
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
|
|
||||||
|
|
||||||
// KV sharing: maps shared layer index -> donor layer index.
|
// KV sharing: maps shared layer index -> donor layer index.
|
||||||
KVShareMap map[int32]int32 `json:"-"`
|
KVShareMap map[int32]int32 `json:"-"`
|
||||||
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
|
|||||||
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
|
||||||
}
|
}
|
||||||
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
|
||||||
if cfg.FinalLogitSoftcapping > 0 {
|
|
||||||
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute KV sharing map.
|
// Compute KV sharing map.
|
||||||
cfg.KVShareMap = make(map[int32]int32)
|
cfg.KVShareMap = make(map[int32]int32)
|
||||||
@@ -1065,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
var donorKV *sharedKVEntry
|
||||||
|
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
|
||||||
|
|
||||||
// If this layer is a donor, store its cached KV for later shared layers.
|
// If this layer is a donor, store its cached KV for later shared layers.
|
||||||
if layer.IsDonor && c != nil {
|
if layer.IsDonor && donorKV != nil {
|
||||||
state := c.State()
|
sharedKV[layer.LayerIdx] = *donorKV
|
||||||
if len(state) >= 2 && state[0] != nil && state[1] != nil {
|
|
||||||
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1114,9 +1108,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
|
|||||||
logits := m.LMHead.Forward(x)
|
logits := m.LMHead.Forward(x)
|
||||||
|
|
||||||
if m.FinalLogitSoftcapping > 0 {
|
if m.FinalLogitSoftcapping > 0 {
|
||||||
logits = mlx.MulScalar(logits, m.SoftcapInv)
|
cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
|
||||||
logits = logits.Tanh()
|
logits = mlx.LogitSoftcap(logits, cap)
|
||||||
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
@@ -1195,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
|
|||||||
return mlx.Squeeze(sliced, 2)
|
return mlx.Squeeze(sliced, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||||
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||||
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
|
||||||
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||||
h := mlx.Add(x, attnOut)
|
h := mlx.Add(x, attnOut)
|
||||||
|
|
||||||
@@ -1231,8 +1224,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||||||
// PLE injection (after MLP residual).
|
// PLE injection (after MLP residual).
|
||||||
if l.PLE != nil && pleInput != nil {
|
if l.PLE != nil && pleInput != nil {
|
||||||
residual := h
|
residual := h
|
||||||
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h))
|
gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
|
||||||
gated := mlx.Mul(gate, pleInput)
|
|
||||||
projected := l.PLE.Projection.Forward(gated)
|
projected := l.PLE.Projection.Forward(gated)
|
||||||
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
|
||||||
h = mlx.Add(residual, projected)
|
h = mlx.Add(residual, projected)
|
||||||
@@ -1243,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
|
|||||||
h = mlx.Mul(h, l.LayerScalar)
|
h = mlx.Mul(h, l.LayerScalar)
|
||||||
}
|
}
|
||||||
|
|
||||||
return h
|
return h, donorKV
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
|
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
|
||||||
// Determine head dim and scale based on layer type.
|
// Determine head dim and scale based on layer type.
|
||||||
headDim := cfg.HeadDim
|
headDim := cfg.HeadDim
|
||||||
scale := cfg.SlidingScale
|
scale := cfg.SlidingScale
|
||||||
@@ -1280,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
|
||||||
|
|
||||||
var k, v *mlx.Array
|
var k, v *mlx.Array
|
||||||
|
var donorKV *sharedKVEntry
|
||||||
|
|
||||||
if donorEntry != nil {
|
if donorEntry != nil {
|
||||||
// Shared layer: use donor's cached K/V.
|
// Shared layer: use donor's cached K/V.
|
||||||
@@ -1318,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
// Update cache.
|
// Update cache.
|
||||||
if c != nil {
|
if c != nil {
|
||||||
k, v = c.Update(k, v)
|
k, v = c.Update(k, v)
|
||||||
|
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1371,13 +1365,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||||||
// strided views differently. Metal handles them natively.
|
// strided views differently. Metal handles them natively.
|
||||||
out = mlx.Contiguous(out, false)
|
out = mlx.Contiguous(out, false)
|
||||||
}
|
}
|
||||||
return a.OProj.Forward(out)
|
return a.OProj.Forward(out), donorKV
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.GELUApprox(m.GateProj.Forward(x))
|
gate := m.GateProj.Forward(x)
|
||||||
up := m.UpProj.Forward(x)
|
up := m.UpProj.Forward(x)
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
return m.DownProj.Forward(mlx.GeGLU(gate, up))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward runs the router to select top-k experts per token.
|
// Forward runs the router to select top-k experts per token.
|
||||||
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
|||||||
up := mlx.SliceStartStop(gateUp,
|
up := mlx.SliceStartStop(gateUp,
|
||||||
[]int32{0, 0, 0, mid},
|
[]int32{0, 0, 0, mid},
|
||||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
} else {
|
} else {
|
||||||
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
|
||||||
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
|
||||||
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
|
||||||
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
}
|
}
|
||||||
downMode := m.DownQuantMode
|
downMode := m.DownQuantMode
|
||||||
if downMode == "" {
|
if downMode == "" {
|
||||||
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
|
|||||||
up := mlx.SliceStartStop(gateUp,
|
up := mlx.SliceStartStop(gateUp,
|
||||||
[]int32{0, 0, 0, mid},
|
[]int32{0, 0, 0, mid},
|
||||||
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
} else {
|
} else {
|
||||||
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
|
||||||
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
|
||||||
hidden = mlx.Mul(mlx.GELUApprox(gate), up)
|
hidden = mlx.GeGLU(gate, up)
|
||||||
}
|
}
|
||||||
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) {
|
|||||||
|
|
||||||
gotScores, gotInds := r.Forward(x, cfg)
|
gotScores, gotInds := r.Forward(x, cfg)
|
||||||
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
||||||
|
gotInds = gotInds.AsType(mlx.DTypeInt32)
|
||||||
|
wantInds = wantInds.AsType(mlx.DTypeInt32)
|
||||||
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
||||||
|
|
||||||
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
||||||
|
|||||||
@@ -148,9 +148,7 @@ type DenseMLP struct {
|
|||||||
|
|
||||||
// Forward applies the SwiGLU MLP
|
// Forward applies the SwiGLU MLP
|
||||||
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
up := m.UpProj.Forward(x)
|
|
||||||
return m.DownProj.Forward(mlx.Mul(gate, up))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoEGate implements the expert gating mechanism
|
// MoEGate implements the expert gating mechanism
|
||||||
@@ -163,21 +161,21 @@ type MoEGate struct {
|
|||||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||||
gates := g.Gate.Forward(x)
|
gates := g.Gate.Forward(x)
|
||||||
|
|
||||||
scores := mlx.Sigmoid(gates)
|
var origScores, negScores *mlx.Array
|
||||||
origScores := scores
|
|
||||||
|
|
||||||
if g.EScoreCorrectionBias != nil {
|
if g.EScoreCorrectionBias != nil {
|
||||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
|
||||||
|
} else {
|
||||||
|
origScores = mlx.Sigmoid(gates)
|
||||||
|
negScores = mlx.Neg(origScores)
|
||||||
}
|
}
|
||||||
|
|
||||||
topK := cfg.NumExpertsPerTok
|
topK := cfg.NumExpertsPerTok
|
||||||
negScores := mlx.Neg(scores)
|
|
||||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||||
|
|
||||||
dims := inds.Dims()
|
dims := inds.Dims()
|
||||||
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
||||||
|
|
||||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
scores := mlx.TakeAlongAxis(origScores, inds, -1)
|
||||||
|
|
||||||
if topK > 1 && cfg.NormTopKProb {
|
if topK > 1 && cfg.NormTopKProb {
|
||||||
sumScores := mlx.Sum(scores, -1, true)
|
sumScores := mlx.Sum(scores, -1, true)
|
||||||
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
|||||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||||
|
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
|
|
||||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||||
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
|||||||
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||||
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||||
|
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
|
|
||||||
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
@@ -273,9 +271,7 @@ type SharedExperts struct {
|
|||||||
|
|
||||||
// Forward applies the shared expert MLP
|
// Forward applies the shared expert MLP
|
||||||
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
|
||||||
gate := mlx.SiLU(s.GateProj.Forward(x))
|
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
|
||||||
up := s.UpProj.Forward(x)
|
|
||||||
return s.DownProj.Forward(mlx.Mul(gate, up))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MoE implements the full Mixture of Experts layer
|
// MoE implements the full Mixture of Experts layer
|
||||||
|
|||||||
@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
|
|||||||
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
||||||
mlx.Eval(dequantizedWeight)
|
mlx.Eval(dequantizedWeight)
|
||||||
|
|
||||||
qOut := ql.Forward(input)
|
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
|
||||||
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
|
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
|
||||||
mlx.Eval(qOut, dOut)
|
mlx.Eval(qOut, dOut)
|
||||||
|
|
||||||
got := qOut.Floats()
|
got := qOut.Floats()
|
||||||
|
|||||||
@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
|
||||||
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
|
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
|
||||||
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
|
|||||||
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
|
||||||
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
|
||||||
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
|
||||||
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
|
||||||
} else {
|
} else {
|
||||||
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
|
||||||
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
|
||||||
hidden = mlx.Mul(mlx.SiLU(gate), up)
|
hidden = mlx.SwiGLU(gate, up)
|
||||||
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user