Compare commits

...

27 Commits

Author SHA1 Message Date
Patrick Devine
7bcdb250b9 fix failing client2 unit tests 2026-04-21 13:56:39 -07:00
Patrick Devine
7bbcd2e6be server: add v2 manifest path
This change adds a new manifest-v2/ path for new models created with the
create/pull/copy commands. Under manifest-v2, manifests are now just blobs which are
content addressable similar to tensors/config files. The named tags instead
will symlink/hard link/contain a copy depending on what the file system supports.

Downgrades to older versions of ollama are still possible, but any create/pull/copy
done with the newer version will potentially have its blobs pruned by the older
version.

manifest-v2 also changes the default registry name to `ollama.com` instead of
`registry.ollama.ai`.
2026-04-21 12:05:54 -07:00
Jesse Gross
22d6c817f8 mlxrunner: fuse top-P and top-K into a single sort pass
When both filters are active, avoid paying for a full sort in top-P
and a partial sort in top-K. Single-filter paths are unchanged.
Improves generation throughput on gemma4:e4b by 1.5%.
2026-04-20 17:43:00 -07:00
Jesse Gross
ca01373b28 mlxrunner: use MaxAxis in the min-P sampler
One reduction op instead of Argmax + TakeAlongAxis.
2026-04-20 17:43:00 -07:00
Jesse Gross
24e038d56a mlxrunner: add logprobs support
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax
with the top-K ranked by probability. Skipped on the GPU when the request
doesn't ask for logprobs so decode doesn't pay for it otherwise.
2026-04-20 17:43:00 -07:00
Parth Sareen
5d1021603a server: apply format when think=false for gemma4 (#15678) 2026-04-20 17:42:29 -07:00
Parth Sareen
8e05d734b9 launch: add kimi cli integration with installer flow (#15723) 2026-04-20 15:33:32 -07:00
Jesse Gross
05e0f21bec mlx: fuse sigmoid router head in glm4_moe_lite
DeepSeek-V2-style aux-loss-free routing computes sigmoid(gates) once but
needs it twice: the raw sigmoid output is gathered after top-k, while the
post-bias negation is the argpartition key. Fuse into a single multi-output
Compiled kernel returning both, saving two launches on the routing path
per token. Exposed as a general SigmoidRouter since the same pattern is
shared across DeepSeek-V2 descendants.

Improves glm4.7 generation performance by approximately 1%.
2026-04-20 15:02:14 -07:00
Daniel Hiltgen
ff23dd343f mlx: apply repeat penalties in sampler (#15631) 2026-04-18 07:49:38 -07:00
Parth Sareen
123b300af6 docs: update hermes (#15655) 2026-04-17 14:20:59 -07:00
Parth Sareen
57653b8e42 cmd/launch: show WSL guidance on Windows instead of handing off (#15637) 2026-04-16 17:18:04 -07:00
Parth Sareen
a50ce61c54 launch: skip unchanged managed-single rewrite (#15633) 2026-04-16 16:20:42 -07:00
Daniel Hiltgen
2bb7ea00d2 create: avoid gc race with create (#15628)
If you have a long running create, and start another ollama server with the
same model dir, the GC algorithm deletes the pending blobs and breaks the
create.  This adds a 1h grace period to avoid deleting in-flight creation
operations.
2026-04-16 13:29:16 -07:00
Daniel Hiltgen
55fa80d07a mlx: additional gemma4 cache fixes (#15607)
Harden additional corner cases
2026-04-16 13:07:19 -07:00
Daniel Hiltgen
b9cb535407 mlx: fix gemma4 cache to use logical view (#15617) 2026-04-16 11:54:30 -07:00
Daniel Hiltgen
031baef094 mlx: fix imagegen lookup (#15588)
* mlx: fix imagegen lookup

Fixes #15533 - imagegen had fallen out of sync with the new layout
for multiple mlx libraries on Metal.

* review comments
2026-04-16 10:39:00 -07:00
Mike Wallio
7d271e6dc9 cmd/launch: add Copilot CLI integration (#15583)
---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-04-15 17:22:53 -07:00
Devon Rifkin
c88dae2d6b Merge pull request #15612 from ollama/drifkin/gemma4-split-templates
gemma4: render differently based on model size
2026-04-15 17:15:35 -07:00
Devon Rifkin
9e3618d663 make empty block conditional 2026-04-15 15:35:25 -07:00
Daniel Hiltgen
5d920cc6bc Keep Gemma4 router projection in source precision (#15613) 2026-04-15 15:04:23 -07:00
Devon Rifkin
e585ecd11f gemma4: render differently based on model size
Following up on #15560, this change now has e2b/e4b render differently
from 26b/31b.

For backwards compatibility, we take the existing renderer name `gemma4`
and make it do dynamic resolution based on the model name/size, but the
intended use is for the models to be republished with the renderer
variant specified explicitly: `gemma4-small` or `gemma4-large`.
2026-04-15 14:37:16 -07:00
Eva H
cdddea0592 launch: always list cloud recommendations first (#15593) 2026-04-15 13:17:35 -07:00
Parth Sareen
43f90def04 launch: add hermes (#15569) 2026-04-15 12:00:23 -07:00
Daniel Hiltgen
06ae6367bd mlx: fix RotatingKVCache.concat() dropping context on mid-rotation (#15591)
After the rotating buffer has wrapped (c.offset > c.maxSize) a subsequent
L>1 Update() went through a slice-to-[0, c.idx) path that discarded all
slots in [c.idx, Dim), losing the older-but-still-in-window tokens the
first Q of the new batch needs for its sliding-window attention.

Linearize the circular buffer to logical order in that wrapped case so
the existing trim + concat preserves the last (maxSize - 1) old tokens.
When the buffer has not yet wrapped (c.offset <= c.maxSize), slots
[c.idx, Dim) are grow padding or stale post-rewind data, so keep
dropping them.
2026-04-14 18:29:06 -07:00
Daniel Hiltgen
48ad7085c4 mlx: Improve gemma4 performance with fused operations (#15587)
* mlx: Improve gemma4 performance with fused operations

* review comments
2026-04-14 18:04:04 -07:00
Jesse Gross
e1e3cec8d0 models: fuse MLP activation functions via mlx_compile
Converts SiLU/GELUApprox to compiled kernels and adds SwiGLU,
matching upstream mlx/mlx_lm's activations pattern. Routes llama,
qwen3, qwen3_5 (dense + MoE), and glm4_moe_lite MLP paths through
mlx.SwiGLU so each MLP invocation runs as one fused Metal/CUDA
kernel rather than a chain of per-op launches.
2026-04-14 16:38:32 -07:00
Jesse Gross
d3e67e305c mlx: add compiled closure support
Wraps MLX's mlx_compile API so Go functions can be traced into fused
kernels. Contiguous elementwise chains collapse into a single
Metal/CUDA kernel instead of launching one per op.

Exposes Compile plus arity helpers (Compile1/2/3) that mirror Python's
@mx.compile decorator shape, lazily building the closure on first call
so package-level declarations work before the MLX dylib loads.
2026-04-14 16:38:32 -07:00
71 changed files with 7170 additions and 735 deletions

View File

@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
ollama ollama
``` ```
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more. You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
### Coding ### Coding
@@ -65,7 +65,7 @@ To launch a specific integration:
ollama launch claude ollama launch claude
``` ```
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode). Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
### AI assistant ### AI assistant

View File

@@ -58,6 +58,12 @@ func TestLaunchCmd(t *testing.T) {
if cmd.Long == "" { if cmd.Long == "" {
t.Error("Long description should not be empty") t.Error("Long description should not be empty")
} }
if !strings.Contains(cmd.Long, "hermes") {
t.Error("Long description should mention hermes")
}
if !strings.Contains(cmd.Long, "kimi") {
t.Error("Long description should mention kimi")
}
}) })
t.Run("flags exist", func(t *testing.T) { t.Run("flags exist", func(t *testing.T) {

76
cmd/launch/copilot.go Normal file
View File

@@ -0,0 +1,76 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Copilot implements Runner for GitHub Copilot CLI integration.
type Copilot struct{}
func (c *Copilot) String() string { return "Copilot CLI" }
func (c *Copilot) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Copilot) findPath() (string, error) {
if p, err := exec.LookPath("copilot"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(home, ".local", "bin", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Copilot) Run(model string, args []string) error {
copilotPath, err := c.findPath()
if err != nil {
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
}
cmd := exec.Command(copilotPath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), c.envVars(model)...)
return cmd.Run()
}
// envVars returns the environment variables that configure Copilot CLI
// to use Ollama as its model provider.
func (c *Copilot) envVars(model string) []string {
env := []string{
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
"COPILOT_PROVIDER_API_KEY=",
"COPILOT_PROVIDER_WIRE_API=responses",
}
if model != "" {
env = append(env, "COPILOT_MODEL="+model)
}
return env
}

161
cmd/launch/copilot_test.go Normal file
View File

@@ -0,0 +1,161 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestCopilotIntegration(t *testing.T) {
c := &Copilot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Copilot CLI" {
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestCopilotFindPath(t *testing.T) {
c := &Copilot{}
t.Run("finds copilot in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("returns error when not in PATH", func(t *testing.T) {
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(tmpDir, ".local", "bin", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestCopilotArgs(t *testing.T) {
c := &Copilot{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestCopilotEnvVars(t *testing.T) {
c := &Copilot{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("sets required provider env vars with model", func(t *testing.T) {
got := envMap(c.envVars("llama3.2"))
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
}
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
}
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
}
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
}
if got["COPILOT_MODEL"] != "llama3.2" {
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
}
})
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
got := envMap(c.envVars(""))
if _, ok := got["COPILOT_MODEL"]; ok {
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
got := envMap(c.envVars("test"))
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
}
})
}

679
cmd/launch/hermes.go Normal file
View File

@@ -0,0 +1,679 @@
package launch
import (
"bufio"
"bytes"
"context"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"gopkg.in/yaml.v3"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/config"
"github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig"
)
const (
hermesInstallScript = "curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash -s -- --skip-setup"
hermesProviderName = "Ollama"
hermesProviderKey = "ollama-launch"
hermesLegacyKey = "ollama"
hermesPlaceholderKey = "ollama"
hermesGatewaySetupHint = "hermes gateway setup"
hermesGatewaySetupTitle = "Connect a messaging app now?"
)
var (
hermesGOOS = runtime.GOOS
hermesLookPath = exec.LookPath
hermesCommand = exec.Command
hermesUserHome = os.UserHomeDir
hermesOllamaURL = envconfig.ConnectableHost
)
var hermesMessagingEnvGroups = [][]string{
{"TELEGRAM_BOT_TOKEN"},
{"DISCORD_BOT_TOKEN"},
{"SLACK_BOT_TOKEN"},
{"SIGNAL_ACCOUNT"},
{"EMAIL_ADDRESS"},
{"TWILIO_ACCOUNT_SID"},
{"MATRIX_ACCESS_TOKEN", "MATRIX_PASSWORD"},
{"MATTERMOST_TOKEN"},
{"WHATSAPP_PHONE_NUMBER_ID"},
{"DINGTALK_CLIENT_ID"},
{"FEISHU_APP_ID"},
{"WECOM_BOT_ID"},
{"WEIXIN_ACCOUNT_ID"},
{"BLUEBUBBLES_SERVER_URL"},
{"WEBHOOK_ENABLED"},
}
// Hermes is intentionally not an Editor integration: launch owns one primary
// model and the local Ollama endpoint, while Hermes keeps its own discovery and
// switching UX after startup.
type Hermes struct{}
func (h *Hermes) String() string { return "Hermes Agent" }
func (h *Hermes) Run(_ string, args []string) error {
// Hermes reads its primary model from config.yaml. launch configures that
// default model ahead of time so we can keep runtime invocation simple and
// still let Hermes discover additional models later via its own UX.
bin, err := h.binary()
if err != nil {
return err
}
if err := h.runGatewaySetupPreflight(args, func() error {
return hermesAttachedCommand(bin, "gateway", "setup").Run()
}); err != nil {
return err
}
return hermesAttachedCommand(bin, args...).Run()
}
func (h *Hermes) Paths() []string {
configPath, err := hermesConfigPath()
if err != nil {
return nil
}
return []string{configPath}
}
func (h *Hermes) Configure(model string) error {
configPath, err := hermesConfigPath()
if err != nil {
return err
}
cfg := map[string]any{}
if data, err := os.ReadFile(configPath); err == nil {
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse hermes config: %w", err)
}
} else if !os.IsNotExist(err) {
return err
}
modelSection, _ := cfg["model"].(map[string]any)
if modelSection == nil {
modelSection = make(map[string]any)
}
models := h.listModels(model)
applyHermesManagedProviders(cfg, hermesBaseURL(), model, models)
// launch writes the minimum provider/default-model settings needed to
// bootstrap Hermes against Ollama. The active provider stays on a
// launch-owned key so /model stays aligned with the launcher-managed entry,
// and the Ollama endpoint lives in providers: so the picker shows one row.
modelSection["provider"] = hermesProviderKey
modelSection["default"] = model
modelSection["base_url"] = hermesBaseURL()
modelSection["api_key"] = hermesPlaceholderKey
cfg["model"] = modelSection
// use Hermes' built-in web toolset for now.
// TODO(parthsareen): move this to using Ollama web search
cfg["toolsets"] = mergeHermesToolsets(cfg["toolsets"])
data, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
}
func (h *Hermes) CurrentModel() string {
configPath, err := hermesConfigPath()
if err != nil {
return ""
}
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
cfg := map[string]any{}
if yaml.Unmarshal(data, &cfg) != nil {
return ""
}
return hermesManagedCurrentModel(cfg, hermesBaseURL())
}
func (h *Hermes) Onboard() error {
return config.MarkIntegrationOnboarded("hermes")
}
func (h *Hermes) RequiresInteractiveOnboarding() bool {
return false
}
func (h *Hermes) RefreshRuntimeAfterConfigure() error {
running, err := h.gatewayRunning()
if err != nil {
return fmt.Errorf("check Hermes gateway status: %w", err)
}
if !running {
return nil
}
fmt.Fprintf(os.Stderr, "%sRefreshing Hermes messaging gateway...%s\n", ansiGray, ansiReset)
if err := h.restartGateway(); err != nil {
return fmt.Errorf("restart Hermes gateway: %w", err)
}
fmt.Fprintln(os.Stderr)
return nil
}
func (h *Hermes) installed() bool {
_, err := h.binary()
return err == nil
}
func (h *Hermes) ensureInstalled() error {
if h.installed() {
return nil
}
if hermesGOOS == "windows" {
return hermesWindowsHint()
}
var missing []string
for _, dep := range []string{"bash", "curl", "git"} {
if _, err := hermesLookPath(dep); err != nil {
missing = append(missing, dep)
}
}
if len(missing) > 0 {
return fmt.Errorf("Hermes is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch hermes", strings.Join(missing, "\n "))
}
ok, err := ConfirmPrompt("Hermes is not installed. Install now?")
if err != nil {
return err
}
if !ok {
return fmt.Errorf("hermes installation cancelled")
}
fmt.Fprintf(os.Stderr, "\nInstalling Hermes...\n")
if err := hermesAttachedCommand("bash", "-lc", hermesInstallScript).Run(); err != nil {
return fmt.Errorf("failed to install hermes: %w", err)
}
if !h.installed() {
return fmt.Errorf("hermes was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sHermes installed successfully%s\n\n", ansiGreen, ansiReset)
return nil
}
func (h *Hermes) listModels(defaultModel string) []string {
client := hermesOllamaClient()
resp, err := client.List(context.Background())
if err != nil {
return []string{defaultModel}
}
models := make([]string, 0, len(resp.Models)+1)
seen := make(map[string]struct{}, len(resp.Models)+1)
add := func(name string) {
name = strings.TrimSpace(name)
if name == "" {
return
}
if _, ok := seen[name]; ok {
return
}
seen[name] = struct{}{}
models = append(models, name)
}
add(defaultModel)
for _, entry := range resp.Models {
add(entry.Name)
}
if len(models) == 0 {
return []string{defaultModel}
}
return models
}
func (h *Hermes) binary() (string, error) {
if path, err := hermesLookPath("hermes"); err == nil {
return path, nil
}
if hermesGOOS == "windows" {
return "", hermesWindowsHint()
}
home, err := hermesUserHome()
if err != nil {
return "", err
}
fallback := filepath.Join(home, ".local", "bin", "hermes")
if _, err := os.Stat(fallback); err == nil {
return fallback, nil
}
return "", fmt.Errorf("hermes is not installed")
}
func hermesConfigPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", "config.yaml"), nil
}
func hermesBaseURL() string {
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
}
func hermesEnvPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
return "", err
}
return filepath.Join(home, ".hermes", ".env"), nil
}
func (h *Hermes) runGatewaySetupPreflight(args []string, runSetup func() error) error {
if len(args) > 0 || !isInteractiveSession() || currentLaunchConfirmPolicy.yes || currentLaunchConfirmPolicy.requireYesMessage {
return nil
}
if h.messagingConfigured() {
return nil
}
fmt.Fprintf(os.Stderr, "\nHermes can message you on Telegram, Discord, Slack, and more.\n\n")
ok, err := ConfirmPromptWithOptions(hermesGatewaySetupTitle, ConfirmOptions{
YesLabel: "Yes",
NoLabel: "Set up later",
})
if err != nil {
return err
}
if !ok {
return nil
}
if err := runSetup(); err != nil {
return fmt.Errorf("hermes messaging setup failed: %w\n\nTry running: %s", err, hermesGatewaySetupHint)
}
return nil
}
func (h *Hermes) messagingConfigured() bool {
envVars, err := h.gatewayEnvVars()
if err != nil {
return false
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if strings.TrimSpace(envVars[key]) != "" {
return true
}
}
}
return false
}
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
envVars := make(map[string]string)
envFilePath, err := hermesEnvPath()
if err != nil {
return nil, err
}
switch data, err := os.ReadFile(envFilePath); {
case err == nil:
for key, value := range hermesParseEnvFile(data) {
envVars[key] = value
}
case os.IsNotExist(err):
// nothing persisted yet
default:
return nil, err
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
}
}
return envVars, nil
}
func (h *Hermes) gatewayRunning() (bool, error) {
status, err := h.gatewayStatusOutput()
if err != nil {
return false, err
}
return hermesGatewayStatusRunning(status), nil
}
func (h *Hermes) gatewayStatusOutput() (string, error) {
bin, err := h.binary()
if err != nil {
return "", err
}
out, err := hermesCommand(bin, "gateway", "status").CombinedOutput()
return string(out), err
}
func (h *Hermes) restartGateway() error {
bin, err := h.binary()
if err != nil {
return err
}
return hermesAttachedCommand(bin, "gateway", "restart").Run()
}
func hermesGatewayStatusRunning(output string) bool {
status := strings.ToLower(output)
switch {
case strings.Contains(status, "gateway is not running"):
return false
case strings.Contains(status, "gateway service is stopped"):
return false
case strings.Contains(status, "gateway service is not loaded"):
return false
case strings.Contains(status, "gateway is running"):
return true
case strings.Contains(status, "gateway service is running"):
return true
case strings.Contains(status, "gateway service is loaded"):
return true
default:
return false
}
}
func hermesParseEnvFile(data []byte) map[string]string {
out := make(map[string]string)
scanner := bufio.NewScanner(bytes.NewReader(data))
for scanner.Scan() {
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\ufeff"))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "export ") {
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
if key == "" {
continue
}
value = strings.TrimSpace(value)
if len(value) >= 2 {
switch {
case value[0] == '"' && value[len(value)-1] == '"':
if unquoted, err := strconv.Unquote(value); err == nil {
value = unquoted
}
case value[0] == '\'' && value[len(value)-1] == '\'':
value = value[1 : len(value)-1]
}
}
out[key] = value
}
return out
}
func hermesOllamaClient() *api.Client {
// Hermes queries the same launch-resolved Ollama host that launch writes
// into config, so model discovery follows the configured endpoint.
return api.NewClient(hermesOllamaURL(), http.DefaultClient)
}
func applyHermesManagedProviders(cfg map[string]any, baseURL string, model string, models []string) {
providers := hermesUserProviders(cfg["providers"])
entry := hermesManagedProviderEntry(providers)
if entry == nil {
entry = make(map[string]any)
}
entry["name"] = hermesProviderName
entry["api"] = baseURL
entry["default_model"] = model
entry["models"] = hermesStringListAny(models)
providers[hermesProviderKey] = entry
delete(providers, hermesLegacyKey)
cfg["providers"] = providers
customProviders := hermesWithoutManagedCustomProviders(cfg["custom_providers"])
if len(customProviders) == 0 {
delete(cfg, "custom_providers")
return
}
cfg["custom_providers"] = customProviders
}
func hermesManagedCurrentModel(cfg map[string]any, baseURL string) string {
modelCfg, _ := cfg["model"].(map[string]any)
if modelCfg == nil {
return ""
}
provider, _ := modelCfg["provider"].(string)
if strings.TrimSpace(strings.ToLower(provider)) != hermesProviderKey {
return ""
}
configBaseURL, _ := modelCfg["base_url"].(string)
if hermesNormalizeURL(configBaseURL) != hermesNormalizeURL(baseURL) {
return ""
}
current, _ := modelCfg["default"].(string)
current = strings.TrimSpace(current)
if current == "" {
return ""
}
providers := hermesUserProviders(cfg["providers"])
entry, _ := providers[hermesProviderKey].(map[string]any)
if entry == nil {
return ""
}
if hermesHasManagedCustomProvider(cfg["custom_providers"]) {
return ""
}
apiURL, _ := entry["api"].(string)
if hermesNormalizeURL(apiURL) != hermesNormalizeURL(baseURL) {
return ""
}
defaultModel, _ := entry["default_model"].(string)
if strings.TrimSpace(defaultModel) != current {
return ""
}
return current
}
func hermesUserProviders(current any) map[string]any {
switch existing := current.(type) {
case map[string]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
out[key] = value
}
return out
case map[any]any:
out := make(map[string]any, len(existing))
for key, value := range existing {
if s, ok := key.(string); ok {
out[s] = value
}
}
return out
default:
return make(map[string]any)
}
}
func hermesCustomProviders(current any) []any {
switch existing := current.(type) {
case []any:
return append([]any(nil), existing...)
case []map[string]any:
out := make([]any, 0, len(existing))
for _, entry := range existing {
out = append(out, entry)
}
return out
default:
return nil
}
}
func hermesManagedProviderEntry(providers map[string]any) map[string]any {
for _, key := range []string{hermesProviderKey, hermesLegacyKey} {
if entry, _ := providers[key].(map[string]any); entry != nil {
return entry
}
}
return nil
}
func hermesWithoutManagedCustomProviders(current any) []any {
customProviders := hermesCustomProviders(current)
preserved := make([]any, 0, len(customProviders))
for _, item := range customProviders {
entry, _ := item.(map[string]any)
if entry == nil {
preserved = append(preserved, item)
continue
}
if hermesManagedCustomProvider(entry) {
continue
}
preserved = append(preserved, entry)
}
return preserved
}
func hermesHasManagedCustomProvider(current any) bool {
for _, item := range hermesCustomProviders(current) {
entry, _ := item.(map[string]any)
if entry != nil && hermesManagedCustomProvider(entry) {
return true
}
}
return false
}
func hermesManagedCustomProvider(entry map[string]any) bool {
name, _ := entry["name"].(string)
return strings.EqualFold(strings.TrimSpace(name), hermesProviderName)
}
func hermesNormalizeURL(raw string) string {
return strings.TrimRight(strings.TrimSpace(raw), "/")
}
func hermesStringListAny(models []string) []any {
out := make([]any, 0, len(models))
for _, model := range dedupeModelList(models) {
model = strings.TrimSpace(model)
if model == "" {
continue
}
out = append(out, model)
}
return out
}
func mergeHermesToolsets(current any) any {
added := false
switch existing := current.(type) {
case []any:
out := make([]any, 0, len(existing)+1)
for _, item := range existing {
out = append(out, item)
if s, _ := item.(string); s == "web" {
added = true
}
}
if !added {
out = append(out, "web")
}
return out
case []string:
out := append([]string(nil), existing...)
if !slices.Contains(out, "web") {
out = append(out, "web")
}
asAny := make([]any, 0, len(out))
for _, item := range out {
asAny = append(asAny, item)
}
return asAny
case string:
if strings.TrimSpace(existing) == "" {
return []any{"hermes-cli", "web"}
}
parts := strings.Split(existing, ",")
out := make([]any, 0, len(parts)+1)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if part == "web" {
added = true
}
out = append(out, part)
}
if !added {
out = append(out, "web")
}
return out
default:
return []any{"hermes-cli", "web"}
}
}
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
cmd := hermesCommand(name, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd
}
func hermesWindowsHint() error {
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
}

1110
cmd/launch/hermes_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -54,6 +54,7 @@ func TestIntegrationLookup(t *testing.T) {
{"claude uppercase", "CLAUDE", true, "Claude Code"}, {"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"}, {"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"}, {"codex", "codex", true, "Codex"},
{"kimi", "kimi", true, "Kimi Code CLI"},
{"droid", "droid", true, "Droid"}, {"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"}, {"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""}, {"unknown integration", "unknown", false, ""},
@@ -74,7 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
} }
func TestIntegrationRegistry(t *testing.T) { func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode"} expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
for _, name := range expectedIntegrations { for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
@@ -89,6 +90,15 @@ func TestIntegrationRegistry(t *testing.T) {
} }
} }
func TestHiddenIntegrationsExcludedFromVisibleLists(t *testing.T) {
for _, info := range ListIntegrationInfos() {
switch info.Name {
case "cline", "vscode", "kimi":
t.Fatalf("hidden integration %q should not appear in ListIntegrationInfos", info.Name)
}
}
}
func TestHasLocalModel(t *testing.T) { func TestHasLocalModel(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -329,7 +339,7 @@ func TestBuildModelList_NoExistingModels(t *testing.T) {
} }
} }
func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) { func TestBuildModelList_OnlyLocalModels_CloudRecsStillFirst(t *testing.T) {
existing := []modelInfo{ existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false}, {Name: "llama3.2:latest", Remote: false},
{Name: "qwen2.5:latest", Remote: false}, {Name: "qwen2.5:latest", Remote: false},
@@ -338,10 +348,11 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
items, _, _, _ := buildModelList(existing, nil, "") items, _, _, _ := buildModelList(existing, nil, "")
got := names(items) got := names(items)
// Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs // Cloud recs always come first among recommended, regardless of installed inventory.
want := []string{"gemma4", "qwen3.5", "kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "llama3.2", "qwen2.5"} // Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems.
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2", "qwen2.5"}
if diff := cmp.Diff(want, got); diff != "" { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff) t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff)
} }
} }
@@ -588,7 +599,7 @@ func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
} }
} }
func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) { func TestBuildModelList_OnlyLocal_CloudRecsStillFirst(t *testing.T) {
existing := []modelInfo{ existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false}, {Name: "llama3.2:latest", Remote: false},
} }
@@ -596,11 +607,11 @@ func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
items, _, _, _ := buildModelList(existing, nil, "") items, _, _, _ := buildModelList(existing, nil, "")
got := names(items) got := names(items)
// Local recs should sort before cloud recs in only-local case // Cloud recs sort before local recs regardless of installed inventory.
localIdx := slices.Index(got, "gemma4") localIdx := slices.Index(got, "gemma4")
cloudIdx := slices.Index(got, "glm-5.1:cloud") cloudIdx := slices.Index(got, "glm-5.1:cloud")
if localIdx > cloudIdx { if cloudIdx > localIdx {
t.Errorf("local recs should be before cloud recs in only-local case, got %v", got) t.Errorf("cloud recs should be before local recs even when only local models installed, got %v", got)
} }
} }
@@ -1509,27 +1520,13 @@ func TestListIntegrationInfos(t *testing.T) {
} }
}) })
t.Run("sorted with custom order at end", func(t *testing.T) { t.Run("follows launcher order", func(t *testing.T) {
// integrationOrder entries (cline, opencode) should appear last, in that order. got := make([]string, 0, len(infos))
// All other entries should be sorted alphabetically before them. for _, info := range infos {
orderRank := make(map[string]int) got = append(got, info.Name)
for i, name := range integrationOrder {
orderRank[name] = i + 1
} }
for i := 1; i < len(infos); i++ { if diff := compareStrings(got, integrationOrder); diff != "" {
aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name] t.Fatalf("launcher integration order mismatch: %s", diff)
switch {
case aRank == 0 && bRank == 0:
if infos[i-1].Name >= infos[i].Name {
t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
}
case aRank > 0 && bRank == 0:
t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
case aRank > 0 && bRank > 0:
if aRank >= bRank {
t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
}
}
} }
}) })
@@ -1557,6 +1554,28 @@ func TestListIntegrationInfos(t *testing.T) {
} }
} }
}) })
t.Run("includes hermes", func(t *testing.T) {
for _, info := range infos {
if info.Name == "hermes" {
return
}
}
t.Fatal("expected hermes to be included in ListIntegrationInfos")
})
t.Run("hermes still resolves explicitly", func(t *testing.T) {
name, runner, err := LookupIntegration("hermes")
if err != nil {
t.Fatalf("expected explicit hermes integration lookup to work, got %v", err)
}
if name != "hermes" {
t.Fatalf("expected canonical name hermes, got %q", name)
}
if runner.String() == "" {
t.Fatal("expected hermes integration runner to be present")
}
})
} }
func TestBuildModelList_Descriptions(t *testing.T) { func TestBuildModelList_Descriptions(t *testing.T) {
@@ -1645,6 +1664,7 @@ func TestIntegration_AutoInstallable(t *testing.T) {
}{ }{
{"openclaw", true}, {"openclaw", true},
{"pi", true}, {"pi", true},
{"hermes", true},
{"claude", false}, {"claude", false},
{"codex", false}, {"codex", false},
{"opencode", false}, {"opencode", false},

315
cmd/launch/kimi.go Normal file
View File

@@ -0,0 +1,315 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Kimi implements Runner for Kimi Code CLI integration.
type Kimi struct{}
const (
kimiDefaultModelAlias = "ollama"
kimiDefaultMaxContextSize = 32768
)
var (
kimiGOOS = runtime.GOOS
kimiModelShowTimeout = 5 * time.Second
)
func (k *Kimi) String() string { return "Kimi Code CLI" }
func (k *Kimi) args(config string, extra []string) []string {
args := []string{"--config", config}
args = append(args, extra...)
return args
}
func (k *Kimi) Run(model string, args []string) error {
if strings.TrimSpace(model) == "" {
return fmt.Errorf("model is required")
}
if err := validateKimiPassthroughArgs(args); err != nil {
return err
}
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
if err != nil {
return fmt.Errorf("failed to build kimi config: %w", err)
}
bin, err := ensureKimiInstalled()
if err != nil {
return err
}
cmd := exec.Command(bin, k.args(config, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func findKimiBinary() (string, error) {
if path, err := exec.LookPath("kimi"); err == nil {
return path, nil
}
home, _ := os.UserHomeDir()
var candidates []string
switch kimiGOOS {
case "windows":
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
default:
candidates = append(candidates,
filepath.Join(home, ".local", "bin", "kimi"),
filepath.Join(home, "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
)
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
candidates = append(candidates,
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
)
}
// WSL users can inherit Windows env vars while launching from Linux shells.
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
}
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
}
for _, candidate := range candidates {
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate, nil
}
}
return "", fmt.Errorf("kimi binary not found")
}
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
if strings.TrimSpace(dir) == "" {
return candidates
}
return append(candidates,
filepath.Join(dir, "kimi.exe"),
filepath.Join(dir, "kimi.cmd"),
filepath.Join(dir, "kimi.bat"),
)
}
func windowsPathToWSL(path string) string {
trimmed := strings.TrimSpace(path)
if len(trimmed) < 3 || trimmed[1] != ':' {
return ""
}
drive := strings.ToLower(string(trimmed[0]))
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
rest = strings.TrimPrefix(rest, "/")
if rest == "" {
return filepath.Join("/mnt", drive)
}
return filepath.Join("/mnt", drive, rest)
}
func validateKimiPassthroughArgs(args []string) error {
for _, arg := range args {
switch {
case arg == "--config", strings.HasPrefix(arg, "--config="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
case arg == "--model", strings.HasPrefix(arg, "--model="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
case arg == "-m", strings.HasPrefix(arg, "-m="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
}
}
return nil
}
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
cfg := map[string]any{
"default_model": kimiDefaultModelAlias,
"providers": map[string]any{
kimiDefaultModelAlias: map[string]any{
"type": "openai_legacy",
"base_url": envconfig.ConnectableHost().String() + "/v1",
"api_key": "ollama",
},
},
"models": map[string]any{
kimiDefaultModelAlias: map[string]any{
"provider": kimiDefaultModelAlias,
"model": model,
"max_context_size": maxContextSize,
},
},
}
data, err := json.Marshal(cfg)
if err != nil {
return "", err
}
return string(data), nil
}
func resolveKimiMaxContextSize(model string) int {
if l, ok := lookupCloudModelLimit(model); ok {
return l.Context
}
client, err := api.ClientFromEnvironment()
if err != nil {
return kimiDefaultMaxContextSize
}
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
defer cancel()
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
return kimiDefaultMaxContextSize
}
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
return n
}
return kimiDefaultMaxContextSize
}
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
for key, val := range modelInfo {
if !strings.HasSuffix(key, ".context_length") {
continue
}
switch v := val.(type) {
case float64:
if v > 0 {
return int(v), true
}
case int:
if v > 0 {
return v, true
}
case int64:
if v > 0 {
return int(v), true
}
}
}
return 0, false
}
func ensureKimiInstalled() (string, error) {
if path, err := findKimiBinary(); err == nil {
return path, nil
}
if err := checkKimiInstallerDependencies(); err != nil {
return "", err
}
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("kimi installation cancelled")
}
bin, args, err := kimiInstallerCommand(kimiGOOS)
if err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install kimi: %w", err)
}
path, err := findKimiBinary()
if err != nil {
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
return path, nil
}
func checkKimiInstallerDependencies() error {
switch kimiGOOS {
case "windows":
if _, err := exec.LookPath("powershell"); err != nil {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
}
default:
var missing []string
if _, err := exec.LookPath("curl"); err != nil {
missing = append(missing, "curl: https://curl.se/")
}
if _, err := exec.LookPath("bash"); err != nil {
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
}
if len(missing) > 0 {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
}
}
return nil
}
func kimiInstallerCommand(goos string) (string, []string, error) {
switch goos {
case "windows":
return "powershell", []string{
"-NoProfile",
"-ExecutionPolicy",
"Bypass",
"-Command",
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
}, nil
case "darwin", "linux":
return "bash", []string{
"-c",
"curl -LsSf https://code.kimi.com/install.sh | bash",
}, nil
default:
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
}
}

636
cmd/launch/kimi_test.go Normal file
View File

@@ -0,0 +1,636 @@
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func assertKimiBinPath(t *testing.T, bin string) {
t.Helper()
base := strings.ToLower(filepath.Base(bin))
if !strings.HasPrefix(base, "kimi") {
t.Fatalf("bin = %q, want path to kimi executable", bin)
}
}
func TestKimiIntegration(t *testing.T) {
k := &Kimi{}
t.Run("String", func(t *testing.T) {
if got := k.String(); got != "Kimi Code CLI" {
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = k
})
}
func TestKimiArgs(t *testing.T) {
k := &Kimi{}
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
if !slices.Equal(got, want) {
t.Fatalf("args() = %v, want %v", got, want)
}
}
func TestWindowsPathToWSL(t *testing.T) {
tests := []struct {
name string
in string
want string
valid bool
}{
{
name: "user profile path",
in: `C:\Users\parth`,
want: filepath.Join("/mnt", "c", "Users", "parth"),
valid: true,
},
{
name: "path with trailing slash",
in: `D:\tools\bin\`,
want: filepath.Join("/mnt", "d", "tools", "bin"),
valid: true,
},
{
name: "non windows path",
in: "/home/parth",
valid: false,
},
{
name: "empty",
in: "",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := windowsPathToWSL(tt.in)
if !tt.valid {
if got != "" {
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
}
return
}
if got != tt.want {
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestFindKimiBinaryFallbacks(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
homeDir := t.TempDir()
setTestHome(t, homeDir)
t.Setenv("PATH", t.TempDir())
kimiGOOS = "linux"
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
t.Run("windows appdata uv bin", func(t *testing.T) {
setTestHome(t, t.TempDir())
t.Setenv("PATH", t.TempDir())
kimiGOOS = "windows"
appDataDir := t.TempDir()
t.Setenv("APPDATA", appDataDir)
t.Setenv("LOCALAPPDATA", "")
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
}
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
tests := []struct {
name string
args []string
want string
}{
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateKimiPassthroughArgs(tt.args)
if err == nil {
t.Fatalf("expected error for args %v", tt.args)
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
}
})
}
}
func TestBuildKimiInlineConfig(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
if parsed["default_model"] != "ollama" {
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
}
if ollamaProvider["api_key"] != "ollama" {
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
}
models, ok := parsed["models"].(map[string]any)
if !ok {
t.Fatalf("models missing or wrong type: %T", parsed["models"])
}
ollamaModel, ok := models["ollama"].(map[string]any)
if !ok {
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
}
if ollamaModel["provider"] != "ollama" {
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
}
if ollamaModel["model"] != "llama3.2" {
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
}
if ollamaModel["max_context_size"] != float64(65536) {
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
}
}
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
}
}
func TestResolveKimiMaxContextSize(t *testing.T) {
t.Run("uses cloud limit when known", func(t *testing.T) {
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
if got != 262_144 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
}
})
t.Run("uses model show context length for local models", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
http.NotFound(w, r)
return
}
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
got := resolveKimiMaxContextSize("llama3.2")
if got != 131_072 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
}
})
t.Run("falls back to default when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
oldTimeout := kimiModelShowTimeout
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
got := resolveKimiMaxContextSize("llama3.2")
if got != kimiDefaultMaxContextSize {
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
}
})
}
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
k := &Kimi{}
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect install prompt, got %q", prompt)
return false, nil
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
err := k.Run("llama3.2", []string{"--model", "other"})
if err == nil || !strings.Contains(err.Error(), "--model") {
t.Fatalf("expected conflict error mentioning --model, got %v", err)
}
}
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
logPath := filepath.Join(tmpDir, "kimi-args.log")
script := fmt.Sprintf(`#!/bin/sh
for arg in "$@"; do
printf "%%s\n" "$arg" >> %q
done
exit 0
`, logPath)
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake kimi: %v", err)
}
t.Setenv("PATH", tmpDir)
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
k := &Kimi{}
if err := k.Run("llama3.2", []string{"--quiet", "--print"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read args log: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) < 4 {
t.Fatalf("expected at least 4 args, got %v", lines)
}
if lines[0] != "--config" {
t.Fatalf("first arg = %q, want --config", lines[0])
}
var cfg map[string]any
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
t.Fatalf("config arg is not valid JSON: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollamaProvider := providers["ollama"].(map[string]any)
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if lines[2] != "--quiet" || lines[3] != "--print" {
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
}
}
func TestEnsureKimiInstalled(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return fn(prompt)
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
t.Run("already installed", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "kimi")
kimiGOOS = runtime.GOOS
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
})
t.Run("missing dependencies", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
t.Fatalf("expected missing dependency error, got %v", err)
}
})
t.Run("missing and user declines install", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "curl")
writeFakeBinary(t, tmpDir, "bash")
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
if !strings.Contains(prompt, "Kimi is not installed.") {
t.Fatalf("unexpected prompt: %q", prompt)
}
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
t.Fatalf("expected cancellation error, got %v", err)
}
})
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
installLog := filepath.Join(tmpDir, "bash.log")
kimiPath := filepath.Join(tmpDir, "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "-c" ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, installLog, kimiPath, kimiPath)
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
logData, err := os.ReadFile(installLog)
if err != nil {
t.Fatalf("failed to read install log: %v", err)
}
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
}
})
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
homeDir := t.TempDir()
setTestHome(t, homeDir)
tmpBin := t.TempDir()
t.Setenv("PATH", tmpBin)
kimiGOOS = "linux"
writeFakeBinary(t, tmpBin, "curl")
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
if [ "$1" = "-c" ]; then
/bin/mkdir -p %q
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
if bin != installedKimi {
t.Fatalf("bin = %q, want %q", bin, installedKimi)
}
})
t.Run("install command fails", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
t.Fatalf("expected install failure error, got %v", err)
}
})
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
t.Fatalf("expected PATH guidance error, got %v", err)
}
})
}
func TestKimiInstallerCommand(t *testing.T) {
tests := []struct {
name string
goos string
wantBin string
wantParts []string
wantErr bool
}{
{
name: "linux",
goos: "linux",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "darwin",
goos: "darwin",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "windows",
goos: "windows",
wantBin: "powershell",
wantParts: []string{"-Command", "install.ps1"},
},
{
name: "unsupported",
goos: "freebsd",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bin, args, err := kimiInstallerCommand(tt.goos)
if tt.wantErr {
if err == nil {
t.Fatal("expected error")
}
return
}
if err != nil {
t.Fatalf("kimiInstallerCommand() error = %v", err)
}
if bin != tt.wantBin {
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
}
joined := strings.Join(args, " ")
for _, part := range tt.wantParts {
if !strings.Contains(joined, part) {
t.Fatalf("args %q missing %q", joined, part)
}
}
})
}
}

View File

@@ -141,6 +141,36 @@ type Editor interface {
Models() []string Models() []string
} }
// ManagedSingleModel is the narrow launch-owned config path for integrations
// like Hermes that have one primary model selected by launcher, need launcher
// to persist minimal config, and still keep their own model discovery and
// onboarding UX. This stays separate from Runner-only integrations and the
// multi-model Editor flow so Hermes-specific behavior stays scoped to one path.
type ManagedSingleModel interface {
Paths() []string
Configure(model string) error
CurrentModel() string
Onboard() error
}
// ManagedRuntimeRefresher lets managed integrations refresh any long-lived
// background runtime after launch rewrites their config.
type ManagedRuntimeRefresher interface {
RefreshRuntimeAfterConfigure() error
}
// ManagedOnboardingValidator lets managed integrations re-check saved
// onboarding state when launcher needs a stronger live readiness signal.
type ManagedOnboardingValidator interface {
OnboardingComplete() bool
}
// ManagedInteractiveOnboarding lets a managed integration declare whether its
// onboarding step really requires an interactive terminal. Hermes does not.
type ManagedInteractiveOnboarding interface {
RequiresInteractiveOnboarding() bool
}
type modelInfo struct { type modelInfo struct {
Name string Name string
Remote bool Remote bool
@@ -176,7 +206,10 @@ Supported integrations:
claude Claude Code claude Claude Code
cline Cline cline Cline
codex Codex codex Codex
copilot Copilot CLI (aliases: copilot-cli)
droid Droid droid Droid
hermes Hermes Agent
kimi Kimi Code CLI
opencode OpenCode opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot) openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi pi Pi
@@ -186,6 +219,7 @@ Examples:
ollama launch ollama launch
ollama launch claude ollama launch claude
ollama launch claude --model <model> ollama launch claude --model <model>
ollama launch hermes
ollama launch droid --config (does not auto-launch) ollama launch droid --config (does not auto-launch)
ollama launch codex -- -p myprofile (pass extra args to integration) ollama launch codex -- -p myprofile (pass extra args to integration)
ollama launch codex -- --sandbox workspace-write`, ollama launch codex -- --sandbox workspace-write`,
@@ -308,36 +342,54 @@ func LaunchIntegration(ctx context.Context, req IntegrationLaunchRequest) error
if err != nil { if err != nil {
return err return err
} }
policy := launchIntegrationPolicy(req)
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
}
launchClient, saved, err := prepareIntegrationLaunch(name, policy)
if err != nil {
return err
}
if managed, ok := runner.(ManagedSingleModel); ok {
if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err
}
return launchClient.launchManagedSingleIntegration(ctx, name, runner, managed, saved, req)
}
if !req.ConfigureOnly { if !req.ConfigureOnly {
if err := EnsureIntegrationInstalled(name, runner); err != nil { if err := EnsureIntegrationInstalled(name, runner); err != nil {
return err return err
} }
} }
var policy LaunchPolicy
// TUI does not set a policy, whereas ollama launch <app> does as it can have flags which change the behavior
if req.Policy == nil {
policy = defaultLaunchPolicy(isInteractiveSession(), false)
} else {
policy = *req.Policy
}
launchClient, err := newLauncherClient(policy)
if err != nil {
return err
}
saved, _ := loadStoredIntegrationConfig(name)
// In headless --yes mode we cannot prompt, so require an explicit --model.
if policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() && req.ModelOverride == "" {
return fmt.Errorf("headless --yes launch for %s requires --model <model>", name)
}
if editor, ok := runner.(Editor); ok { if editor, ok := runner.(Editor); ok {
return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req) return launchClient.launchEditorIntegration(ctx, name, runner, editor, saved, req)
} }
return launchClient.launchSingleIntegration(ctx, name, runner, saved, req) return launchClient.launchSingleIntegration(ctx, name, runner, saved, req)
} }
func launchIntegrationPolicy(req IntegrationLaunchRequest) LaunchPolicy {
// TUI does not set a policy, whereas ollama launch <app> does as it can
// have flags which change the behavior.
if req.Policy != nil {
return *req.Policy
}
return defaultLaunchPolicy(isInteractiveSession(), false)
}
func prepareIntegrationLaunch(name string, policy LaunchPolicy) (*launcherClient, *config.IntegrationConfig, error) {
launchClient, err := newLauncherClient(policy)
if err != nil {
return nil, nil, err
}
saved, _ := loadStoredIntegrationConfig(name)
return launchClient, saved, nil
}
func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) { func (c *launcherClient) buildLauncherState(ctx context.Context) (*LauncherState, error) {
_ = c.loadModelInventoryOnce(ctx) _ = c.loadModelInventoryOnce(ctx)
@@ -368,9 +420,18 @@ func (c *launcherClient) buildLauncherIntegrationState(ctx context.Context, info
if err != nil { if err != nil {
return LauncherIntegrationState{}, err return LauncherIntegrationState{}, err
} }
currentModel, usable, err := c.launcherModelState(ctx, info.Name, integration.editor) var currentModel string
if err != nil { var usable bool
return LauncherIntegrationState{}, err if managed, ok := integration.spec.Runner.(ManagedSingleModel); ok {
currentModel, usable, err = c.launcherManagedModelState(ctx, info.Name, managed)
if err != nil {
return LauncherIntegrationState{}, err
}
} else {
currentModel, usable, err = c.launcherModelState(ctx, info.Name, integration.editor)
if err != nil {
return LauncherIntegrationState{}, err
}
} }
return LauncherIntegrationState{ return LauncherIntegrationState{
@@ -408,6 +469,28 @@ func (c *launcherClient) launcherModelState(ctx context.Context, name string, is
return model, usableErr == nil && usable, nil return model, usableErr == nil && usable, nil
} }
func (c *launcherClient) launcherManagedModelState(ctx context.Context, name string, managed ManagedSingleModel) (string, bool, error) {
current := managed.CurrentModel()
if current == "" {
cfg, loadErr := loadStoredIntegrationConfig(name)
if loadErr == nil {
current = primaryModelFromConfig(cfg)
}
if current != "" {
return current, false, nil
}
}
if current == "" {
return "", false, nil
}
usable, err := c.savedModelUsable(ctx, current)
if err != nil {
return current, false, err
}
return current, usable, nil
}
func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) { func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelRequest) (string, error) {
current := config.LastModel() current := config.LastModel()
if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() { if !req.ForcePicker && current != "" && c.policy.Confirm == LaunchConfirmAutoApprove && !isInteractiveSession() {
@@ -444,35 +527,15 @@ func (c *launcherClient) resolveRunModel(ctx context.Context, req RunModelReques
} }
func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error { func (c *launcherClient) launchSingleIntegration(ctx context.Context, name string, runner Runner, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
current := primaryModelFromConfig(saved) target, _, err := c.resolveSingleIntegrationTarget(ctx, runner, primaryModelFromConfig(saved), req)
target := req.ModelOverride if err != nil {
needsConfigure := req.ForceConfigure
if target == "" {
target = current
usable, err := c.savedModelUsable(ctx, target)
if err != nil {
return err
}
if !usable {
needsConfigure = true
}
}
if needsConfigure {
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
if err != nil {
return err
}
target = selected
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
return err return err
} }
if target == "" { if target == "" {
return nil return nil
} }
current := primaryModelFromConfig(saved)
if target != current { if target != current {
if err := config.SaveIntegration(name, []string{target}); err != nil { if err := config.SaveIntegration(name, []string{target}); err != nil {
return fmt.Errorf("failed to save: %w", err) return fmt.Errorf("failed to save: %w", err)
@@ -510,6 +573,102 @@ func (c *launcherClient) launchEditorIntegration(ctx context.Context, name strin
return launchAfterConfiguration(name, runner, models[0], req) return launchAfterConfiguration(name, runner, models[0], req)
} }
func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, name string, runner Runner, managed ManagedSingleModel, saved *config.IntegrationConfig, req IntegrationLaunchRequest) error {
current := managed.CurrentModel()
selectionCurrent := current
if selectionCurrent == "" {
selectionCurrent = primaryModelFromConfig(saved)
}
target, needsConfigure, err := c.resolveSingleIntegrationTarget(ctx, runner, selectionCurrent, req)
if err != nil {
return err
}
if target == "" {
return nil
}
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
return err
}
if refresher, ok := managed.(ManagedRuntimeRefresher); ok {
if err := refresher.RefreshRuntimeAfterConfigure(); err != nil {
return err
}
}
}
if !managedIntegrationOnboarded(saved, managed) {
if !isInteractiveSession() && managedRequiresInteractiveOnboarding(managed) {
return fmt.Errorf("%s still needs interactive gateway setup; run 'ollama launch %s' in a terminal to finish onboarding", runner, name)
}
if err := managed.Onboard(); err != nil {
return err
}
}
if req.ConfigureOnly {
return nil
}
return runIntegration(runner, target, req.ExtraArgs)
}
func (c *launcherClient) resolveSingleIntegrationTarget(ctx context.Context, runner Runner, current string, req IntegrationLaunchRequest) (string, bool, error) {
target := req.ModelOverride
needsConfigure := req.ForceConfigure
if target == "" {
target = current
usable, err := c.savedModelUsable(ctx, target)
if err != nil {
return "", false, err
}
if !usable {
needsConfigure = true
}
}
if needsConfigure {
selected, err := c.selectSingleModelWithSelector(ctx, fmt.Sprintf("Select model for %s:", runner), target, DefaultSingleSelector)
if err != nil {
return "", false, err
}
target = selected
} else if err := c.ensureModelsReady(ctx, []string{target}); err != nil {
return "", false, err
}
return target, needsConfigure, nil
}
func savedIntegrationOnboarded(saved *config.IntegrationConfig) bool {
return saved != nil && saved.Onboarded
}
func managedIntegrationOnboarded(saved *config.IntegrationConfig, managed ManagedSingleModel) bool {
if !savedIntegrationOnboarded(saved) {
return false
}
validator, ok := managed.(ManagedOnboardingValidator)
if !ok {
return true
}
return validator.OnboardingComplete()
}
// Most managed integrations treat onboarding as an interactive terminal step.
// Hermes opts out because its launch-owned onboarding is just bookkeeping, so
// headless launches should not be blocked once config is already prepared.
func managedRequiresInteractiveOnboarding(managed ManagedSingleModel) bool {
onboarding, ok := managed.(ManagedInteractiveOnboarding)
if !ok {
return true
}
return onboarding.RequiresInteractiveOnboarding()
}
func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) { func (c *launcherClient) selectSingleModelWithSelector(ctx context.Context, title, current string, selector SingleSelector) (string, error) {
if selector == nil { if selector == nil {
return "", fmt.Errorf("no selector configured") return "", fmt.Errorf("no selector configured")

View File

@@ -49,6 +49,55 @@ func (r *launcherSingleRunner) Run(model string, args []string) error {
func (r *launcherSingleRunner) String() string { return "StubSingle" } func (r *launcherSingleRunner) String() string { return "StubSingle" }
type launcherManagedRunner struct {
paths []string
currentModel string
configured []string
ranModel string
onboarded bool
onboardCalls int
onboardingComplete bool
refreshCalls int
refreshErr error
}
func (r *launcherManagedRunner) Run(model string, args []string) error {
r.ranModel = model
return nil
}
func (r *launcherManagedRunner) String() string { return "StubManaged" }
func (r *launcherManagedRunner) Paths() []string { return r.paths }
func (r *launcherManagedRunner) Configure(model string) error {
r.configured = append(r.configured, model)
r.currentModel = model
return nil
}
func (r *launcherManagedRunner) CurrentModel() string { return r.currentModel }
func (r *launcherManagedRunner) Onboard() error {
r.onboardCalls++
r.onboarded = true
r.onboardingComplete = true
return nil
}
func (r *launcherManagedRunner) OnboardingComplete() bool { return r.onboardingComplete }
func (r *launcherManagedRunner) RefreshRuntimeAfterConfigure() error {
r.refreshCalls++
return r.refreshErr
}
type launcherHeadlessManagedRunner struct {
launcherManagedRunner
}
func (r *launcherHeadlessManagedRunner) RequiresInteractiveOnboarding() bool { return false }
func setLaunchTestHome(t *testing.T, dir string) { func setLaunchTestHome(t *testing.T, dir string) {
t.Helper() t.Helper()
t.Setenv("HOME", dir) t.Setenv("HOME", dir)
@@ -141,6 +190,451 @@ func TestDefaultLaunchPolicy(t *testing.T) {
} }
} }
func TestBuildLauncherState_ManagedSingleIntegrationUsesCurrentModel(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{currentModel: "gemma4"}
withIntegrationOverride(t, "pi", runner)
state, err := BuildLauncherState(context.Background())
if err != nil {
t.Fatalf("BuildLauncherState returned error: %v", err)
}
if state.Integrations["pi"].CurrentModel != "gemma4" {
t.Fatalf("expected managed current model from integration config, got %q", state.Integrations["pi"].CurrentModel)
}
if !state.Integrations["pi"].ModelUsable {
t.Fatal("expected managed current model to be usable")
}
}
func TestBuildLauncherState_ManagedSingleIntegrationShowsSavedModelWhenLiveConfigMissing(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("pi", []string{"gemma4"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "pi", runner)
state, err := BuildLauncherState(context.Background())
if err != nil {
t.Fatalf("BuildLauncherState returned error: %v", err)
}
if state.Integrations["pi"].CurrentModel != "gemma4" {
t.Fatalf("expected saved model to remain visible, got %q", state.Integrations["pi"].CurrentModel)
}
if state.Integrations["pi"].ModelUsable {
t.Fatal("expected missing live config to mark managed model unusable")
}
}
func TestLaunchIntegration_ManagedSingleIntegrationConfiguresOnboardsAndRuns(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
paths: nil,
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
return "gemma4", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("configured models mismatch: %s", diff)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
saved, err := config.LoadIntegration("stubmanaged")
if err != nil {
t.Fatalf("failed to reload managed integration config: %v", err)
}
if diff := compareStrings(saved.Models, []string{"gemma4"}); diff != "" {
t.Fatalf("saved models mismatch: %s", diff)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationReOnboardsWhenSavedFlagIsStale(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
currentModel: "gemma4",
onboardingComplete: false,
}
withIntegrationOverride(t, "stubmanaged", runner)
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
if err := config.MarkIntegrationOnboarded("stubmanaged"); err != nil {
t.Fatalf("failed to mark managed integration onboarded: %v", err)
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected stale onboarded flag to trigger onboarding, got %d calls", runner.onboardCalls)
}
if runner.refreshCalls != 0 {
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run saved model after onboarding repair, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
paths: nil,
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
ConfigureOnly: true,
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if runner.ranModel != "" {
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected configure-only flow to refresh runtime once, got %d", runner.refreshCalls)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected configure-only flow to onboard once, got %d", runner.onboardCalls)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model matches target")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatal("confirm prompt should not run when saved model matches target")
return false, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if len(runner.configured) != 0 {
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
}
if runner.refreshCalls != 0 {
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
runner := &launcherManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when model override is provided")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationStopsWhenRuntimeRefreshFails(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
refreshErr: fmt.Errorf("boom"),
}
withIntegrationOverride(t, "stubmanaged", runner)
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
}
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
})
if err == nil || !strings.Contains(err.Error(), "boom") {
t.Fatalf("expected runtime refresh error, got %v", err)
}
if runner.ranModel != "" {
t.Fatalf("expected final launch to stop on runtime refresh failure, got %q", runner.ranModel)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected one runtime refresh attempt, got %d", runner.refreshCalls)
}
if runner.onboardCalls != 0 {
t.Fatalf("expected onboarding to stop after refresh failure, got %d", runner.onboardCalls)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessNeedsInteractiveOnboarding(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, false)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherManagedRunner{
paths: nil,
}
withIntegrationOverride(t, "stubmanaged", runner)
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
})
if err == nil {
t.Fatal("expected headless onboarding requirement to fail")
}
if !strings.Contains(err.Error(), "interactive gateway setup") {
t.Fatalf("expected interactive onboarding guidance, got %v", err)
}
if runner.ranModel != "" {
t.Fatalf("expected no final launch when onboarding is still required, got %q", runner.ranModel)
}
if runner.onboardCalls != 0 {
t.Fatalf("expected no onboarding attempts in headless mode, got %d", runner.onboardCalls)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationHeadlessAllowsNonInteractiveOnboarding(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, false)
withLauncherHooks(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
runner := &launcherHeadlessManagedRunner{}
withIntegrationOverride(t, "stubmanaged", runner)
err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ModelOverride: "gemma4",
Policy: &LaunchPolicy{Confirm: LaunchConfirmAutoApprove, MissingModel: LaunchMissingModelAutoPull},
})
if err != nil {
t.Fatalf("expected non-interactive onboarding to succeed headlessly, got %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("configured models mismatch: %s", diff)
}
if runner.onboardCalls != 1 {
t.Fatalf("expected onboarding to run once, got %d", runner.onboardCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
}
func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) { func TestBuildLauncherState_InstalledAndCloudDisabled(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir) setLaunchTestHome(t, tmpDir)

View File

@@ -230,7 +230,7 @@ func pullMissingModel(ctx context.Context, client *api.Client, model string) err
// prepareEditorIntegration persists models and applies editor-managed config files. // prepareEditorIntegration persists models and applies editor-managed config files.
func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error { func prepareEditorIntegration(name string, runner Runner, editor Editor, models []string) error {
if ok, err := confirmEditorEdit(runner, editor); err != nil { if ok, err := confirmConfigEdit(runner, editor.Paths()); err != nil {
return err return err
} else if !ok { } else if !ok {
return errCancelled return errCancelled
@@ -244,8 +244,22 @@ func prepareEditorIntegration(name string, runner Runner, editor Editor, models
return nil return nil
} }
func confirmEditorEdit(runner Runner, editor Editor) (bool, error) { func prepareManagedSingleIntegration(name string, runner Runner, managed ManagedSingleModel, model string) error {
paths := editor.Paths() if ok, err := confirmConfigEdit(runner, managed.Paths()); err != nil {
return err
} else if !ok {
return errCancelled
}
if err := managed.Configure(model); err != nil {
return fmt.Errorf("setup failed: %w", err)
}
if err := config.SaveIntegration(name, []string{model}); err != nil {
return fmt.Errorf("failed to save: %w", err)
}
return nil
}
func confirmConfigEdit(runner Runner, paths []string) (bool, error) {
if len(paths) == 0 { if len(paths) == 0 {
return true, nil return true, nil
} }
@@ -345,8 +359,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
recRank[rec.Name] = i + 1 recRank[rec.Name] = i + 1
} }
onlyLocal := hasLocalModel && !hasCloudModel
if hasLocalModel || hasCloudModel { if hasLocalModel || hasCloudModel {
slices.SortStableFunc(items, func(a, b ModelItem) int { slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name] ac, bc := checked[a.Name], checked[b.Name]
@@ -368,12 +380,6 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
} }
if aRec && bRec { if aRec && bRec {
if aCloud != bCloud { if aCloud != bCloud {
if onlyLocal {
if aCloud {
return 1
}
return -1
}
if aCloud { if aCloud {
return -1 return -1
} }

View File

@@ -33,7 +33,7 @@ type IntegrationInfo struct {
Description string Description string
} }
var launcherIntegrationOrder = []string{"opencode", "droid", "pi"} var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
var integrationSpecs = []*IntegrationSpec{ var integrationSpecs = []*IntegrationSpec{
{ {
@@ -74,6 +74,36 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@openai/codex"}, Command: []string{"npm", "install", "-g", "@openai/codex"},
}, },
}, },
{
Name: "kimi",
Runner: &Kimi{},
Description: "Moonshot's coding agent for terminal and IDEs",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("kimi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensureKimiInstalled()
return err
},
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
},
},
{
Name: "copilot",
Runner: &Copilot{},
Aliases: []string{"copilot-cli"},
Description: "GitHub's AI coding agent for the terminal",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Copilot{}).findPath()
return err == nil
},
URL: "https://github.com/features/copilot/cli/",
},
},
{ {
Name: "droid", Name: "droid",
Runner: &Droid{}, Runner: &Droid{},
@@ -136,6 +166,20 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"}, Command: []string{"npm", "install", "-g", "@mariozechner/pi-coding-agent@latest"},
}, },
}, },
{
Name: "hermes",
Runner: &Hermes{},
Description: "Self-improving AI agent built by Nous Research",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
return (&Hermes{}).installed()
},
EnsureInstalled: func() error {
return (&Hermes{}).ensureInstalled()
},
URL: "https://hermes-agent.nousresearch.com/docs/getting-started/installation/",
},
},
{ {
Name: "vscode", Name: "vscode",
Runner: &VSCode{}, Runner: &VSCode{},
@@ -255,10 +299,10 @@ func ListVisibleIntegrationSpecs() []IntegrationSpec {
return aRank - bRank return aRank - bRank
} }
if aRank > 0 { if aRank > 0 {
return 1 return -1
} }
if bRank > 0 { if bRank > 0 {
return -1 return 1
} }
return strings.Compare(a.Name, b.Name) return strings.Compare(a.Name, b.Name)
}) })

View File

@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
return filepath.Join(home, ".pi", "agent", "models.json") return filepath.Join(home, ".pi", "agent", "models.json")
}, },
}, },
{
name: "kimi",
binary: "kimi",
runner: &Kimi{},
checkPath: func(home string) string {
return filepath.Join(home, ".kimi", "config.toml")
},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
if tt.name == "pi" { if tt.name == "pi" {
writeFakeBinary(t, binDir, "npm") writeFakeBinary(t, binDir, "npm")
} }
if tt.name == "kimi" {
writeFakeBinary(t, binDir, "curl")
writeFakeBinary(t, binDir, "bash")
}
t.Setenv("PATH", binDir) t.Setenv("PATH", binDir)
configPath := tt.checkPath(home) configPath := tt.checkPath(home)

View File

@@ -45,21 +45,12 @@ type menuItem struct {
isOthers bool isOthers bool
} }
var mainMenuItems = []menuItem{ const pinnedIntegrationCount = 3
{
title: "Chat with a model", var runModelMenuItem = menuItem{
description: "Start an interactive chat with a model", title: "Chat with a model",
isRunModel: true, description: "Start an interactive chat with a model",
}, isRunModel: true,
{
integration: "openclaw",
},
{
integration: "claude",
},
{
integration: "opencode",
},
} }
var othersMenuItem = menuItem{ var othersMenuItem = menuItem{
@@ -102,20 +93,14 @@ func shouldExpandOthers(state *launch.LauncherState) bool {
} }
func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem { func buildMenuItems(state *launch.LauncherState, showOthers bool) []menuItem {
items := make([]menuItem, 0, len(mainMenuItems)+1) items := []menuItem{runModelMenuItem}
for _, item := range mainMenuItems { items = append(items, pinnedIntegrationItems(state)...)
if item.integration == "" {
items = append(items, item)
continue
}
if integrationState, ok := state.Integrations[item.integration]; ok {
items = append(items, integrationMenuItem(integrationState))
}
}
if showOthers { otherItems := otherIntegrationItems(state)
items = append(items, otherIntegrationItems(state)...) switch {
} else { case showOthers:
items = append(items, otherItems...)
case len(otherItems) > 0:
items = append(items, othersMenuItem) items = append(items, othersMenuItem)
} }
@@ -135,17 +120,28 @@ func integrationMenuItem(state launch.LauncherIntegrationState) menuItem {
} }
func otherIntegrationItems(state *launch.LauncherState) []menuItem { func otherIntegrationItems(state *launch.LauncherState) []menuItem {
pinned := map[string]bool{ ordered := orderedIntegrationItems(state)
"openclaw": true, if len(ordered) <= pinnedIntegrationCount {
"claude": true, return nil
"opencode": true, }
return ordered[pinnedIntegrationCount:]
}
func pinnedIntegrationItems(state *launch.LauncherState) []menuItem {
ordered := orderedIntegrationItems(state)
if len(ordered) <= pinnedIntegrationCount {
return ordered
}
return ordered[:pinnedIntegrationCount]
}
func orderedIntegrationItems(state *launch.LauncherState) []menuItem {
if state == nil {
return nil
} }
var items []menuItem items := make([]menuItem, 0, len(state.Integrations))
for _, info := range launch.ListIntegrationInfos() { for _, info := range launch.ListIntegrationInfos() {
if pinned[info.Name] {
continue
}
integrationState, ok := state.Integrations[info.Name] integrationState, ok := state.Integrations[info.Name]
if !ok { if !ok {
continue continue
@@ -155,6 +151,10 @@ func otherIntegrationItems(state *launch.LauncherState) []menuItem {
return items return items
} }
func primaryMenuItemCount(state *launch.LauncherState) int {
return 1 + len(pinnedIntegrationItems(state))
}
func initialCursor(state *launch.LauncherState, items []menuItem) int { func initialCursor(state *launch.LauncherState, items []menuItem) int {
if state == nil || state.LastSelection == "" { if state == nil || state.LastSelection == "" {
return 0 return 0
@@ -190,7 +190,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.cursor > 0 { if m.cursor > 0 {
m.cursor-- m.cursor--
} }
if m.showOthers && m.cursor < len(mainMenuItems) { if m.showOthers && m.cursor < primaryMenuItemCount(m.state) {
m.showOthers = false m.showOthers = false
m.items = buildMenuItems(m.state, false) m.items = buildMenuItems(m.state, false)
m.cursor = min(m.cursor, len(m.items)-1) m.cursor = min(m.cursor, len(m.items)-1)

View File

@@ -5,6 +5,7 @@ import (
"testing" "testing"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/launch" "github.com/ollama/ollama/cmd/launch"
) )
@@ -43,6 +44,13 @@ func launcherTestState() *launch.LauncherState {
Selectable: true, Selectable: true,
Changeable: true, Changeable: true,
}, },
"hermes": {
Name: "hermes",
DisplayName: "Hermes Agent",
Description: "Self-improving AI agent built by Nous Research",
Selectable: true,
Changeable: true,
},
"droid": { "droid": {
Name: "droid", Name: "droid",
DisplayName: "Droid", DisplayName: "Droid",
@@ -70,8 +78,28 @@ func findMenuCursorByIntegration(items []menuItem, name string) int {
return -1 return -1
} }
func integrationSequence(items []menuItem) []string {
sequence := make([]string, 0, len(items))
for _, item := range items {
switch {
case item.isRunModel:
sequence = append(sequence, "run")
case item.isOthers:
sequence = append(sequence, "more")
case item.integration != "":
sequence = append(sequence, item.integration)
}
}
return sequence
}
func compareStrings(got, want []string) string {
return cmp.Diff(want, got)
}
func TestMenuRendersPinnedItemsAndMore(t *testing.T) { func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
view := newModel(launcherTestState()).View() menu := newModel(launcherTestState())
view := menu.View()
for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} { for _, want := range []string{"Chat with a model", "Launch OpenClaw", "Launch Claude Code", "Launch OpenCode", "More..."} {
if !strings.Contains(view, want) { if !strings.Contains(view, want) {
t.Fatalf("expected menu view to contain %q\n%s", want, view) t.Fatalf("expected menu view to contain %q\n%s", want, view)
@@ -80,23 +108,31 @@ func TestMenuRendersPinnedItemsAndMore(t *testing.T) {
if strings.Contains(view, "Launch Codex") { if strings.Contains(view, "Launch Codex") {
t.Fatalf("expected Codex to be under More, not pinned\n%s", view) t.Fatalf("expected Codex to be under More, not pinned\n%s", view)
} }
wantOrder := []string{"run", "openclaw", "claude", "opencode", "more"}
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
t.Fatalf("unexpected pinned order: %s", diff)
}
} }
func TestMenuExpandsOthersFromLastSelection(t *testing.T) { func TestMenuExpandsOthersFromLastSelection(t *testing.T) {
state := launcherTestState() state := launcherTestState()
state.LastSelection = "pi" state.LastSelection = "codex"
menu := newModel(state) menu := newModel(state)
if !menu.showOthers { if !menu.showOthers {
t.Fatal("expected others section to expand when last selection is in the overflow list") t.Fatal("expected others section to expand when last selection is in the overflow list")
} }
view := menu.View() view := menu.View()
if !strings.Contains(view, "Launch Pi") { if !strings.Contains(view, "Launch Codex") {
t.Fatalf("expected expanded view to contain overflow integration\n%s", view) t.Fatalf("expected expanded view to contain overflow integration\n%s", view)
} }
if strings.Contains(view, "More...") { if strings.Contains(view, "More...") {
t.Fatalf("expected expanded view to replace More... item\n%s", view) t.Fatalf("expected expanded view to replace More... item\n%s", view)
} }
wantOrder := []string{"run", "openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
if diff := compareStrings(integrationSequence(menu.items), wantOrder); diff != "" {
t.Fatalf("unexpected expanded order: %s", diff)
}
} }
func TestMenuEnterOnRunSelectsRun(t *testing.T) { func TestMenuEnterOnRunSelectsRun(t *testing.T) {

View File

@@ -120,6 +120,7 @@
"pages": [ "pages": [
"/integrations/claude-code", "/integrations/claude-code",
"/integrations/codex", "/integrations/codex",
"/integrations/copilot-cli",
"/integrations/opencode", "/integrations/opencode",
"/integrations/droid", "/integrations/droid",
"/integrations/goose", "/integrations/goose",

BIN
docs/images/hermes.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@@ -0,0 +1,93 @@
---
title: Copilot CLI
---
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
## Install
Install [Copilot CLI](https://github.com/features/copilot/cli/):
<CodeGroup>
```shell macOS / Linux (Homebrew)
brew install copilot-cli
```
```shell npm (all platforms)
npm install -g @github/copilot
```
```shell macOS / Linux (script)
curl -fsSL https://gh.io/copilot-install | bash
```
```powershell Windows (WinGet)
winget install GitHub.Copilot
```
</CodeGroup>
## Usage with Ollama
### Quick setup
```shell
ollama launch copilot
```
### Run directly with a model
```shell
ollama launch copilot --model kimi-k2.5:cloud
```
## Recommended Models
- `kimi-k2.5:cloud`
- `glm-5:cloud`
- `minimax-m2.7:cloud`
- `qwen3.5:cloud`
- `glm-4.7-flash`
- `qwen3.5`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
## Non-interactive (headless) mode
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
```shell
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
```
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
## Manual setup
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
1. Set the environment variables:
```shell
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
export COPILOT_PROVIDER_API_KEY=
export COPILOT_PROVIDER_WIRE_API=responses
export COPILOT_MODEL=qwen3.5
```
1. Run Copilot CLI:
```shell
copilot
```
Or run with environment variables inline:
```shell
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
```
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.

View File

@@ -2,29 +2,66 @@
title: Hermes Agent title: Hermes Agent
--- ---
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway. Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
![Hermes Agent with Ollama](/images/hermes.png)
## Quick start ## Quick start
### Pull a model
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
```bash ```bash
ollama pull kimi-k2.5:cloud ollama launch hermes
``` ```
See [Recommended models](#recommended-models) for more options. Ollama handles everything automatically:
### Install 1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
2. **Model** — Pick a model from the selector (local or cloud)
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `glm-5.1:cloud` — Reasoning and code generation
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
```bash
hermes gateway setup
```
## Reconfigure
Re-run the full setup wizard at any time:
```bash
hermes setup
```
## Manual setup
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
```bash ```bash
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
``` ```
### Set up Hermes launches the setup wizard automatically. Choose **Quick setup**:
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
``` ```
How would you like to set up Hermes? How would you like to set up Hermes?
@@ -80,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
Launch hermes chat now? [Y/n]: Y Launch hermes chat now? [Y/n]: Y
``` ```
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `glm-5.1:cloud` — Reasoning and code generation
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
More models at [ollama.com/search](https://ollama.com/models).
## Configure later
Re-run the setup wizard at any time:
```bash
hermes setup
```
To configure just messaging:
```bash
hermes setup gateway
```

View File

@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [Claude Code](/integrations/claude-code) - [Claude Code](/integrations/claude-code)
- [Codex](/integrations/codex) - [Codex](/integrations/codex)
- [Copilot CLI](/integrations/copilot-cli)
- [OpenCode](/integrations/opencode) - [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid) - [Droid](/integrations/droid)
- [Goose](/integrations/goose) - [Goose](/integrations/goose)

2
go.mod
View File

@@ -106,5 +106,5 @@ require (
golang.org/x/term v0.36.0 golang.org/x/term v0.36.0
golang.org/x/text v0.30.0 golang.org/x/text v0.30.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1
) )

View File

@@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) {
} }
func TestAPIGenerateLogprobs(t *testing.T) { func TestAPIGenerateLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
@@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
} }
func TestAPIChatLogprobs(t *testing.T) { func TestAPIChatLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"time"
) )
type Layer struct { type Layer struct {
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
return Layer{}, err return Layer{}, err
} }
} }
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{ return Layer{
MediaType: mediatype, MediaType: mediatype,
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
if err != nil { if err != nil {
return Layer{}, err return Layer{}, err
} }
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{ return Layer{
MediaType: mediatype, MediaType: mediatype,
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
}, nil }, nil
} }
func touchLayer(path string) error {
now := time.Now()
return os.Chtimes(path, now, now)
}
func (l *Layer) Open() (io.ReadSeekCloser, error) { func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" { if l.Digest == "" {
return nil, errors.New("opening layer with empty digest") return nil, errors.New("opening layer with empty digest")

View File

@@ -1,18 +1,23 @@
package manifest package manifest
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
var blobFilenamePattern = regexp.MustCompile(`^sha256-[0-9a-fA-F]{64}$`)
type Manifest struct { type Manifest struct {
SchemaVersion int `json:"schemaVersion"` SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
@@ -22,6 +27,7 @@ type Manifest struct {
filepath string filepath string
fi os.FileInfo fi os.FileInfo
digest string digest string
name model.Name
} }
func (m *Manifest) Size() (size int64) { func (m *Manifest) Size() (size int64) {
@@ -36,6 +42,14 @@ func (m *Manifest) Digest() string {
return m.digest return m.digest
} }
func (m *Manifest) BlobDigest() string {
if m.digest == "" {
return ""
}
return "sha256:" + m.digest
}
func (m *Manifest) FileInfo() os.FileInfo { func (m *Manifest) FileInfo() os.FileInfo {
return m.fi return m.fi
} }
@@ -59,16 +73,7 @@ func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
} }
func (m *Manifest) Remove() error { func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil { return removeNamedManifestPaths(m.name)
return err
}
manifests, err := Path()
if err != nil {
return err
}
return PruneDirectory(manifests)
} }
func (m *Manifest) RemoveLayers() error { func (m *Manifest) RemoveLayers() error {
@@ -80,6 +85,9 @@ func (m *Manifest) RemoveLayers() error {
// Build set of digests still in use by other manifests // Build set of digests still in use by other manifests
inUse := make(map[string]struct{}) inUse := make(map[string]struct{})
for _, other := range ms { for _, other := range ms {
if other.BlobDigest() != "" {
inUse[other.BlobDigest()] = struct{}{}
}
for _, layer := range append(other.Layers, other.Config) { for _, layer := range append(other.Layers, other.Config) {
if layer.Digest != "" { if layer.Digest != "" {
inUse[layer.Digest] = struct{}{} inUse[layer.Digest] = struct{}{}
@@ -87,20 +95,27 @@ func (m *Manifest) RemoveLayers() error {
} }
} }
// Remove layers not used by any other manifest digests := make([]string, 0, len(m.Layers)+2)
for _, layer := range append(m.Layers, m.Config) { digests = append(digests, m.BlobDigest())
if layer.Digest == "" { for _, layer := range m.Layers {
digests = append(digests, layer.Digest)
}
digests = append(digests, m.Config.Digest)
// Remove manifest and layer blobs not used by any other manifest
for _, digest := range digests {
if digest == "" {
continue continue
} }
if _, used := inUse[layer.Digest]; used { if _, used := inUse[digest]; used {
continue continue
} }
blob, err := BlobsPath(layer.Digest) blob, err := BlobsPath(digest)
if err != nil { if err != nil {
return err return err
} }
if err := os.Remove(blob); os.IsNotExist(err) { if err := os.Remove(blob); os.IsNotExist(err) {
slog.Debug("layer does not exist", "digest", layer.Digest) slog.Debug("blob does not exist", "digest", digest)
} else if err != nil { } else if err != nil {
return err return err
} }
@@ -114,15 +129,36 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n) return nil, model.Unqualified(n)
} }
manifests, err := Path() p, root, err := resolveManifestPath(n)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p := filepath.Join(manifests, n.Filepath()) return parseManifestFile(normalizeLogicalName(n), p, root)
}
func ReadManifestData(n model.Name) ([]byte, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
}
p, root, err := resolveManifestPath(n)
if err != nil {
return nil, err
}
f, _, err := OpenVerifiedManifest(p, root)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
func parseManifestFile(name model.Name, path, root string) (*Manifest, error) {
var m Manifest var m Manifest
f, err := os.Open(p) f, digest, err := OpenVerifiedManifest(path, root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -133,35 +169,19 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err return nil, err
} }
sha256sum := sha256.New() if err := json.NewDecoder(f).Decode(&m); err != nil {
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err return nil, err
} }
m.filepath = p m.filepath = path
m.fi = fi m.fi = fi
m.digest = hex.EncodeToString(sha256sum.Sum(nil)) m.digest = digest
m.name = name
return &m, nil return &m, nil
} }
func WriteManifest(name model.Name, config Layer, layers []Layer) error { func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := Path()
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := Manifest{ m := Manifest{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
@@ -169,33 +189,371 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
Layers: layers, Layers: layers,
} }
return json.NewEncoder(f).Encode(m) var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(m); err != nil {
return err
}
return WriteManifestData(name, b.Bytes())
} }
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) { // WriteManifestData stores raw manifest bytes as a content-addressed blob and
// updates the v2 named manifest path to reference that blob. Any legacy named
// manifest for the same model is removed after the v2 write succeeds.
func WriteManifestData(name model.Name, data []byte) error {
if !name.IsFullyQualified() {
return model.Unqualified(name)
}
digest, err := writeManifestBlob(data)
if err != nil {
return err
}
if err := LinkManifest(name, digest); err != nil {
return err
}
return removeLegacyManifestPaths(name)
}
// LinkManifest updates the v2 named manifest path to reference an existing
// manifest blob. It prefers symlinks, then hardlinks, then a byte-for-byte copy
// for filesystems that do not support links.
func LinkManifest(name model.Name, digest string) error {
if !name.IsFullyQualified() {
return model.Unqualified(name)
}
manifestPath, err := V2PathForName(name)
if err != nil {
return err
}
blobPath, err := BlobsPath(digest)
if err != nil {
return err
}
if _, err := os.Stat(blobPath); err != nil {
return err
}
if err := checkBlobDigest(blobPath, digest); err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
return err
}
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
return err
}
if rel, err := filepath.Rel(filepath.Dir(manifestPath), blobPath); err == nil {
if err := os.Symlink(rel, manifestPath); err == nil {
return nil
}
}
if err := os.Link(blobPath, manifestPath); err == nil {
return nil
}
return copyManifestFile(blobPath, manifestPath)
}
func writeManifestBlob(data []byte) (string, error) {
sum := sha256.Sum256(data)
digest := fmt.Sprintf("sha256:%x", sum)
blobPath, err := BlobsPath(digest)
if err != nil {
return "", err
}
if existing, err := os.ReadFile(blobPath); err == nil && bytes.Equal(existing, data) {
return digest, nil
}
blobs, err := BlobsPath("")
if err != nil {
return "", err
}
temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil {
return "", err
}
tempName := temp.Name()
defer os.Remove(tempName)
if _, err := temp.Write(data); err != nil {
temp.Close()
return "", err
}
if err := temp.Close(); err != nil {
return "", err
}
if err := os.Chmod(tempName, 0o644); err != nil {
return "", err
}
if err := os.Rename(tempName, blobPath); err != nil {
if err := os.Remove(blobPath); err != nil && !os.IsNotExist(err) {
return "", err
}
if err := os.Rename(tempName, blobPath); err != nil {
return "", err
}
}
return digest, nil
}
func copyManifestFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
temp, err := os.CreateTemp(filepath.Dir(dst), ".manifest-*")
if err != nil {
return err
}
tempName := temp.Name()
defer os.Remove(tempName)
if _, err := io.Copy(temp, in); err != nil {
temp.Close()
return err
}
if err := temp.Close(); err != nil {
return err
}
if err := os.Chmod(tempName, 0o644); err != nil {
return err
}
return os.Rename(tempName, dst)
}
// OpenVerifiedManifest opens a named manifest path rooted under root. Symlinks must resolve to a
// blob whose basename is sha256-<hex> and whose bytes hash to that digest.
// Regular-file manifests are treated as legacy/copy fallback manifests and are
// opened without mutating the local store.
func OpenVerifiedManifest(path, root string) (*os.File, string, error) {
resolvedRoot, err := filepath.EvalSymlinks(root)
if err != nil {
return nil, "", err
}
info, err := os.Lstat(path)
if err != nil {
return nil, "", err
}
target, err := evalAbs(path)
if err != nil {
return nil, "", err
}
if info.Mode()&os.ModeSymlink != 0 {
base := filepath.Base(target)
if !blobFilenamePattern.MatchString(base) {
return nil, "", fmt.Errorf("manifest symlink target %q is not a sha256 blob", target)
}
digest := strings.ToLower(strings.TrimPrefix(base, "sha256-"))
blobPath, err := BlobsPath("sha256:" + digest)
if err != nil {
return nil, "", err
}
if !sameFile(target, blobPath) {
return nil, "", fmt.Errorf("manifest symlink target %q does not match blob %q", target, blobPath)
}
f, err := os.Open(path)
if err != nil {
return nil, "", err
}
if err := checkBlobDigestReader(f, "sha256:"+digest); err != nil {
f.Close()
return nil, "", err
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
f.Close()
return nil, "", err
}
return f, digest, nil
}
if !pathWithin(target, resolvedRoot) {
return nil, "", fmt.Errorf("manifest path %q resolves outside manifest directory", path)
}
f, err := os.Open(path)
if err != nil {
return nil, "", err
}
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
f.Close()
return nil, "", err
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
f.Close()
return nil, "", err
}
digest := fmt.Sprintf("%x", h.Sum(nil))
return f, digest, nil
}
// MigrateManifestLinks moves legacy named manifests into manifests-v2. This is currently unwired but
// will be added in the future.
func MigrateManifestLinks() (int, error) {
manifests, err := Path() manifests, err := Path()
if err != nil { if err != nil {
return nil, err return 0, err
} }
// TODO(mxyng): use something less brittle // TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*")) matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
if err != nil { if err != nil {
return nil, err return 0, err
} }
ms := make(map[model.Name]*Manifest) var migrated int
for _, match := range matches { for _, match := range matches {
fi, err := os.Stat(match) fi, err := os.Stat(match)
if err != nil { if err != nil {
return nil, err return migrated, err
}
if fi.IsDir() {
continue
}
rel, err := filepath.Rel(manifests, match)
if err != nil {
return migrated, fmt.Errorf("%s %w", match, err)
}
n := model.ParseNameFromFilepath(rel)
if !n.IsFullyQualified() {
slog.Warn("bad manifest name", "path", rel)
continue
}
data, err := readManifestPath(match, manifests)
if err != nil {
return migrated, err
}
if err := WriteManifestData(normalizeLogicalName(n), data); err != nil {
return migrated, err
}
migrated++
}
return migrated, nil
}
func readManifestPath(path, root string) ([]byte, error) {
f, _, err := OpenVerifiedManifest(path, root)
if err != nil {
return nil, err
}
defer f.Close()
return io.ReadAll(f)
}
func pathWithin(path, root string) bool {
rel, err := filepath.Rel(root, path)
return err == nil && rel != "." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
}
func evalAbs(path string) (string, error) {
abs, err := filepath.Abs(path)
if err != nil {
return "", err
}
return filepath.EvalSymlinks(abs)
}
func sameFile(a, b string) bool {
ai, err := os.Stat(a)
if err != nil {
return false
}
bi, err := os.Stat(b)
if err != nil {
return false
}
return os.SameFile(ai, bi)
}
func checkBlobDigest(path, digest string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
return checkBlobDigestReader(f, digest)
}
func checkBlobDigestReader(r io.Reader, digest string) error {
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
return err
}
got := fmt.Sprintf("sha256:%x", h.Sum(nil))
if got != strings.ToLower(strings.Replace(digest, "-", ":", 1)) {
return errors.New("digest mismatch")
}
return nil
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
ms := make(map[model.Name]*Manifest)
manifestsV2, err := V2Path()
if err != nil {
return nil, err
}
if err := collectManifests(ms, manifestsV2, continueOnError); err != nil {
return nil, err
}
manifests, err := Path()
if err != nil {
return nil, err
}
if err := collectManifests(ms, manifests, continueOnError); err != nil {
return nil, err
}
return ms, nil
}
func collectManifests(ms map[model.Name]*Manifest, root string, continueOnError bool) error {
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(root, "*", "*", "*", "*"))
if err != nil {
return err
}
for _, match := range matches {
fi, err := os.Lstat(match)
if err != nil {
return err
} }
if !fi.IsDir() { if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match) rel, err := filepath.Rel(root, match)
if err != nil { if err != nil {
if !continueOnError { if !continueOnError {
return nil, fmt.Errorf("%s %w", match, err) return fmt.Errorf("%s %w", match, err)
} }
slog.Warn("bad filepath", "path", match, "error", err) slog.Warn("bad filepath", "path", match, "error", err)
continue continue
@@ -204,16 +562,21 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
n := model.ParseNameFromFilepath(rel) n := model.ParseNameFromFilepath(rel)
if !n.IsValid() { if !n.IsValid() {
if !continueOnError { if !continueOnError {
return nil, fmt.Errorf("%s %w", rel, err) return fmt.Errorf("invalid manifest name: %s", rel)
} }
slog.Warn("bad manifest name", "path", rel) slog.Warn("bad manifest name", "path", rel)
continue continue
} }
m, err := ParseNamedManifest(n) n = normalizeLogicalName(n)
if _, ok := ms[n]; ok {
continue
}
m, err := parseManifestFile(n, match, root)
if err != nil { if err != nil {
if !continueOnError { if !continueOnError {
return nil, fmt.Errorf("%s %w", n, err) return fmt.Errorf("%s %w", n, err)
} }
slog.Warn("bad manifest", "name", n, "error", err) slog.Warn("bad manifest", "name", n, "error", err)
continue continue
@@ -223,5 +586,5 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
} }
} }
return ms, nil return nil
} }

View File

@@ -1,19 +1,23 @@
package manifest package manifest
import ( import (
"bytes"
"crypto/sha256"
"encoding/json" "encoding/json"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strings"
"testing" "testing"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func createManifest(t *testing.T, path, name string) { func createManifestAtRoot(t *testing.T, path, root, name string) {
t.Helper() t.Helper()
p := filepath.Join(path, "manifests", name) p := filepath.Join(path, root, name)
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil { if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -29,6 +33,309 @@ func createManifest(t *testing.T, path, name string) {
} }
} }
func createManifest(t *testing.T, path, name string) {
t.Helper()
createManifestAtRoot(t, path, "manifests", name)
}
func TestWriteManifestStoresManifestAsBlob(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
config := Layer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: "sha256:" + strings.Repeat("a", 64),
Size: 12,
}
if err := WriteManifest(name, config, nil); err != nil {
t.Fatal(err)
}
manifestPath, err := V2PathForName(name)
if err != nil {
t.Fatal(err)
}
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(manifestData)
digest := fmt.Sprintf("sha256:%x", sum)
blobPath, err := BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
blobData, err := os.ReadFile(blobPath)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(blobData, manifestData) {
t.Fatal("manifest path and blob content differ")
}
m, err := ParseNamedManifest(name)
if err != nil {
t.Fatal(err)
}
if got := m.Digest(); got != fmt.Sprintf("%x", sum) {
t.Fatalf("digest = %q, want %x", got, sum)
}
if got := m.BlobDigest(); got != digest {
t.Fatalf("blob digest = %q, want %q", got, digest)
}
}
func TestParseNamedManifestLeavesLegacyManifestInPlace(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
createManifest(t, models, name.Filepath())
manifestPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if _, err := ParseNamedManifest(name); err != nil {
t.Fatal(err)
}
fi, err := os.Lstat(manifestPath)
if err != nil {
t.Fatal(err)
}
if fi.Mode()&os.ModeSymlink != 0 {
t.Fatal("legacy manifest was converted to a symlink while reading")
}
data, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(data)
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
t.Fatalf("legacy manifest read created blob: %v", err)
}
}
func TestMigrateManifestLinks(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
createManifest(t, models, name.Filepath())
migrated, err := MigrateManifestLinks()
if err != nil {
t.Fatal(err)
}
if migrated != 1 {
t.Fatalf("migrated = %d, want 1", migrated)
}
manifestPath, err := V2PathForName(name)
if err != nil {
t.Fatal(err)
}
manifestData, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(manifestData)
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
if err != nil {
t.Fatal(err)
}
blobData, err := os.ReadFile(blobPath)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(blobData, manifestData) {
t.Fatal("migrated manifest path and blob content differ")
}
legacyPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
t.Fatalf("legacy manifest still exists: %v", err)
}
migrated, err = MigrateManifestLinks()
if err != nil {
t.Fatal(err)
}
if migrated != 0 {
t.Fatalf("migrated on second run = %d, want 0", migrated)
}
if _, err := MigrateManifestLinks(); err != nil {
t.Fatal(err)
}
manifestDataAfter, err := os.ReadFile(manifestPath)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(manifestDataAfter, manifestData) {
t.Fatal("second migration changed manifest content")
}
}
func TestRemoveLayersRemovesUnreferencedManifestBlob(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
if err := WriteManifest(name, Layer{}, nil); err != nil {
t.Fatal(err)
}
m, err := ParseNamedManifest(name)
if err != nil {
t.Fatal(err)
}
blobPath, err := BlobsPath(m.BlobDigest())
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(blobPath); err != nil {
t.Fatal(err)
}
if err := m.Remove(); err != nil {
t.Fatal(err)
}
if err := m.RemoveLayers(); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
t.Fatalf("manifest blob still exists: %v", err)
}
}
func TestParseNamedManifestRejectsUnsafeSymlinks(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
manifestPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
t.Fatal(err)
}
t.Run("non blob basename", func(t *testing.T) {
target := filepath.Join(t.TempDir(), "not-a-blob")
if err := os.WriteFile(target, []byte(`{"schemaVersion":2}`), 0o644); err != nil {
t.Fatal(err)
}
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err)
}
if err := os.Symlink(target, manifestPath); err != nil {
t.Skipf("symlink unavailable: %v", err)
}
_, err := ParseNamedManifest(name)
if err == nil || !strings.Contains(err.Error(), "not a sha256 blob") {
t.Fatalf("err = %v, want not a sha256 blob", err)
}
})
t.Run("blob basename outside blob store", func(t *testing.T) {
data := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
sum := sha256.Sum256(data)
target := filepath.Join(t.TempDir(), fmt.Sprintf("sha256-%x", sum))
if err := os.WriteFile(target, data, 0o644); err != nil {
t.Fatal(err)
}
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err)
}
if err := os.Symlink(target, manifestPath); err != nil {
t.Skipf("symlink unavailable: %v", err)
}
_, err := ParseNamedManifest(name)
if err == nil || !strings.Contains(err.Error(), "does not match blob") {
t.Fatalf("err = %v, want does not match blob", err)
}
})
}
func TestParseNamedManifestPrefersV2(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
legacyPath, err := PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
t.Fatal(err)
}
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
t.Fatal(err)
}
m, err := ParseNamedManifest(name)
if err != nil {
t.Fatal(err)
}
if m.MediaType != "v2" {
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
}
}
func TestManifestsV2ShadowsLegacy(t *testing.T) {
models := t.TempDir()
t.Setenv("OLLAMA_MODELS", models)
name := model.ParseName("example")
createManifest(t, models, name.Filepath())
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
t.Fatal(err)
}
ms, err := Manifests(true)
if err != nil {
t.Fatal(err)
}
if len(ms) != 1 {
t.Fatalf("manifest count = %d, want 1", len(ms))
}
var m *Manifest
for gotName, gotManifest := range ms {
if gotName.EqualFold(model.ParseName("example")) {
m = gotManifest
break
}
}
if m == nil {
t.Fatalf("missing v2 manifest for %s", name)
}
if m.MediaType != "v2" {
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
}
}
func TestManifests(t *testing.T) { func TestManifests(t *testing.T) {
cases := map[string]struct { cases := map[string]struct {
ps []string ps []string

View File

@@ -14,8 +14,23 @@ import (
var ErrInvalidDigestFormat = errors.New("invalid digest format") var ErrInvalidDigestFormat = errors.New("invalid digest format")
const (
legacyDirName = "manifests"
v2DirName = "manifests-v2"
defaultPublicHost = "registry.ollama.ai"
v2CanonicalHost = "ollama.com"
)
func Path() (string, error) { func Path() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests") return manifestPath(legacyDirName)
}
func V2Path() (string, error) {
return manifestPath(v2DirName)
}
func manifestPath(dir string) (string, error) {
path := filepath.Join(envconfig.Models(), dir)
if err := os.MkdirAll(path, 0o755); err != nil { if err := os.MkdirAll(path, 0o755); err != nil {
return "", fmt.Errorf("%w: ensure path elements are traversable", err) return "", fmt.Errorf("%w: ensure path elements are traversable", err)
} }
@@ -25,6 +40,10 @@ func Path() (string, error) {
// PathForName returns the path to the manifest file for a specific model name. // PathForName returns the path to the manifest file for a specific model name.
func PathForName(n model.Name) (string, error) { func PathForName(n model.Name) (string, error) {
return LegacyPathForName(n)
}
func LegacyPathForName(n model.Name) (string, error) {
if !n.IsValid() { if !n.IsValid() {
return "", os.ErrNotExist return "", os.ErrNotExist
} }
@@ -37,6 +56,162 @@ func PathForName(n model.Name) (string, error) {
return filepath.Join(manifests, n.Filepath()), nil return filepath.Join(manifests, n.Filepath()), nil
} }
func V2PathForName(n model.Name) (string, error) {
if !n.IsValid() {
return "", os.ErrNotExist
}
manifests, err := V2Path()
if err != nil {
return "", err
}
return filepath.Join(manifests, canonicalV2Name(n).Filepath()), nil
}
func ResolvePathForName(n model.Name) (string, error) {
path, _, err := resolveManifestPath(n)
return path, err
}
func resolveManifestPath(n model.Name) (string, string, error) {
if !n.IsValid() {
return "", "", os.ErrNotExist
}
v2Path, err := V2PathForName(n)
if err != nil {
return "", "", err
}
if _, err := os.Lstat(v2Path); err == nil {
root, err := V2Path()
return v2Path, root, err
} else if !os.IsNotExist(err) {
return "", "", err
}
legacyRoot, err := Path()
if err != nil {
return "", "", err
}
for _, legacyName := range legacyNameCandidates(n) {
legacyPath := filepath.Join(legacyRoot, legacyName.Filepath())
if _, err := os.Lstat(legacyPath); err == nil {
return legacyPath, legacyRoot, nil
} else if !os.IsNotExist(err) {
return "", "", err
}
}
return "", "", os.ErrNotExist
}
func removeNamedManifestPaths(n model.Name) error {
candidates := legacyNameCandidates(n)
paths := make([]string, 0, 1+len(candidates))
v2Path, err := V2PathForName(n)
if err != nil {
return err
}
paths = append(paths, v2Path)
for _, legacyName := range candidates {
legacyPath, err := LegacyPathForName(legacyName)
if err != nil {
return err
}
paths = append(paths, legacyPath)
}
for _, path := range paths {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
}
return pruneManifestRoots()
}
func removeLegacyManifestPaths(n model.Name) error {
for _, legacyName := range legacyNameCandidates(n) {
legacyPath, err := LegacyPathForName(legacyName)
if err != nil {
return err
}
if err := os.Remove(legacyPath); err != nil && !os.IsNotExist(err) {
return err
}
}
legacyRoot, err := Path()
if err != nil {
return err
}
if err := PruneDirectory(legacyRoot); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func pruneManifestRoots() error {
roots := []func() (string, error){Path, V2Path}
for _, rootFn := range roots {
root, err := rootFn()
if err != nil {
return err
}
if err := PruneDirectory(root); err != nil && !os.IsNotExist(err) {
return err
}
}
return nil
}
// normalizeLogicalName maps any public host to the legacy default
// so that map keys use a single identity regardless of on-disk host.
func normalizeLogicalName(n model.Name) model.Name {
if isDefaultPublicHost(n.Host) {
n.Host = defaultPublicHost
}
return n
}
// canonicalV2Name maps any public host to the v2 canonical host
// for use in manifests-v2/ on-disk paths.
func canonicalV2Name(n model.Name) model.Name {
if isDefaultPublicHost(n.Host) {
n.Host = v2CanonicalHost
}
return n
}
func legacyNameCandidates(n model.Name) []model.Name {
names := []model.Name{n}
if !isDefaultPublicHost(n.Host) {
return names
}
alt := n
switch {
case strings.EqualFold(n.Host, defaultPublicHost):
alt.Host = v2CanonicalHost
default:
alt.Host = defaultPublicHost
}
return append(names, alt)
}
func isDefaultPublicHost(host string) bool {
return strings.EqualFold(host, defaultPublicHost) || strings.EqualFold(host, v2CanonicalHost)
}
func BlobsPath(digest string) (string, error) { func BlobsPath(digest string) (string, error) {
// only accept actual sha256 digests // only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$" pattern := "^sha256[:-][0-9a-fA-F]{64}$"

View File

@@ -12,7 +12,8 @@ import (
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/ // <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
// <|tool_call>/<|tool_response> tags for function calling. // <|tool_call>/<|tool_response> tags for function calling.
type Gemma4Renderer struct { type Gemma4Renderer struct {
useImgTags bool useImgTags bool
emptyBlockOnNothink bool
} }
const ( const (
@@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
// Generation prompt. // Generation prompt.
if prevMessageType != "tool_response" && prevMessageType != "tool_call" { if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
sb.WriteString("<|turn>model\n") sb.WriteString("<|turn>model\n")
if r.emptyBlockOnNothink && !hasThink {
sb.WriteString("<|channel>thought\n<channel|>")
}
} }
return sb.String(), nil return sb.String(), nil

View File

@@ -3,9 +3,9 @@ package renderers
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in // TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
// Gemma 4 reference template. // Gemma 4 reference template.
// //
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in // Current upstream Gemma 4 chat templates differ by model size. The checked-in
// reference intentionally uses the shared baseline without an empty generation-time // reference cases below use the small (e2b/e4b-style) baseline, with large
// thought channel until renderer selection is split by size. // (26b/31b-style) checks covered separately in this file.
// //
// To regenerate expected values, save the E2B template to // To regenerate expected values, save the E2B template to
// gemma4_e2b_chat_template.jinja2 and run: // gemma4_e2b_chat_template.jinja2 and run:
@@ -1474,6 +1474,47 @@ Hi<turn|>
} }
} }
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
messages := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
rendererName string
expected string
}{
{
name: "legacy_alias",
rendererName: "gemma4",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "small",
rendererName: "gemma4-small",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large",
rendererName: "gemma4-large",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
assert.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
assert.NoError(t, err)
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
}
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) { func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
if os.Getenv("VERIFY_JINJA2") == "" { if os.Getenv("VERIFY_JINJA2") == "" {
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks") t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
@@ -1616,15 +1657,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
}, },
} }
for _, tt := range tests { variants := []struct {
t.Run(tt.name, func(t *testing.T) { name string
renderer := &Gemma4Renderer{useImgTags: RenderImgTags} renderer *Gemma4Renderer
got, err := renderer.Render(tt.messages, tt.tools, tt.think) templateRel string
assert.NoError(t, err) }{
{
name: "small",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
templateRel: gemma4E2BTemplate,
},
{
name: "large",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
templateRel: gemma431BTemplate,
},
}
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think) for _, variant := range variants {
assert.Equal(t, jinja2Output, got, t.Run(variant.name, func(t *testing.T) {
"renderer output doesn't match Jinja2 template output") for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
})
}
}) })
} }
} }

View File

@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
return renderer return renderer
case "nemotron-3-nano": case "nemotron-3-nano":
return &Nemotron3NanoRenderer{} return &Nemotron3NanoRenderer{}
case "gemma4": case "gemma4", "gemma4-small":
return &Gemma4Renderer{useImgTags: RenderImgTags} return &Gemma4Renderer{useImgTags: RenderImgTags}
case "gemma4-large":
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
case "functiongemma": case "functiongemma":
return &FunctionGemmaRenderer{} return &FunctionGemmaRenderer{}
case "glm-4.7": case "glm-4.7":

View File

@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
arch := layer.GGML.KV().Architecture() arch := layer.GGML.KV().Architecture()
switch arch { switch arch {
case "gemma4": case "gemma4":
config.Renderer = cmp.Or(config.Renderer, "gemma4") config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
config.Parser = cmp.Or(config.Parser, "gemma4") config.Parser = cmp.Or(config.Parser, "gemma4")
if _, ok := r.Parameters["stop"]; !ok { if _, ok := r.Parameters["stop"]; !ok {
if r.Parameters == nil { if r.Parameters == nil {

78
server/gemma4_test.go Normal file
View File

@@ -0,0 +1,78 @@
package server
import "testing"
func TestResolveGemma4Renderer(t *testing.T) {
tests := []struct {
name string
model *Model
want string
}{
{
name: "nil model falls back to legacy alias",
model: nil,
want: gemma4RendererLegacy,
},
{
name: "explicit small passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererSmall),
},
want: gemma4RendererSmall,
},
{
name: "explicit large passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLarge),
},
want: gemma4RendererLarge,
},
{
name: "legacy e4b tag resolves small",
model: &Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
{
name: "legacy 31b tag resolves large",
model: &Model{
Name: "gemma4:31b-cloud",
ShortName: "gemma4:31b-cloud",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererLarge,
},
{
name: "legacy model type resolves small",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
},
want: gemma4RendererSmall,
},
{
name: "legacy model type resolves large",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: gemma4RendererLarge,
},
{
name: "legacy unknown defaults small",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveGemma4Renderer(tt.model); got != tt.want {
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -19,6 +19,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -33,6 +34,10 @@ import (
"github.com/ollama/ollama/x/imagegen/transfer" "github.com/ollama/ollama/x/imagegen/transfer"
) )
// Blobs newer than this may belong to another process that has not written its
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
const layerPruneGracePeriod = time.Hour
var ( var (
errCapabilities = errors.New("does not support") errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion") errCapabilityCompletion = errors.New("completion")
@@ -156,7 +161,7 @@ func (m *Model) Capabilities() []model.Capability {
// Temporary workaround — suppress vision/audio for gemma4 MLX models // Temporary workaround — suppress vision/audio for gemma4 MLX models
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up. // until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" { if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool { capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
return c == model.CapabilityVision || c == "audio" return c == model.CapabilityVision || c == "audio"
}) })
@@ -406,31 +411,12 @@ func CopyModel(src, dst model.Name) error {
return nil return nil
} }
manifests, err := manifest.Path() data, err := manifest.ReadManifestData(src)
if err != nil { if err != nil {
return err return err
} }
dstpath := filepath.Join(manifests, dst.Filepath()) return manifest.WriteManifestData(dst, data)
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
srcpath := filepath.Join(manifests, src.Filepath())
srcfile, err := os.Open(srcpath)
if err != nil {
return err
}
defer srcfile.Close()
dstfile, err := os.Create(dstpath)
if err != nil {
return err
}
defer dstfile.Close()
_, err = io.Copy(dstfile, srcfile)
return err
} }
func deleteUnusedLayers(deleteMap map[string]struct{}) error { func deleteUnusedLayers(deleteMap map[string]struct{}) error {
@@ -441,6 +427,10 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
} }
for _, manifest := range manifests { for _, manifest := range manifests {
if manifest.BlobDigest() != "" {
delete(deleteMap, manifest.BlobDigest())
}
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest) delete(deleteMap, layer.Digest)
} }
@@ -478,10 +468,23 @@ func PruneLayers() error {
} }
for _, blob := range blobs { for _, blob := range blobs {
if blob.IsDir() {
continue
}
info, err := blob.Info()
if err != nil {
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
continue
}
if time.Since(info.ModTime()) < layerPruneGracePeriod {
continue
}
name := blob.Name() name := blob.Name()
name = strings.ReplaceAll(name, "-", ":") name = strings.ReplaceAll(name, "-", ":")
_, err := manifest.BlobsPath(name) _, err = manifest.BlobsPath(name)
if err != nil { if err != nil {
if errors.Is(err, manifest.ErrInvalidDigestFormat) { if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads) // remove invalid blobs (e.g. partial downloads)
@@ -531,11 +534,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
// Use fast transfer for models with tensor layers (many small blobs) // Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) { if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields // Read raw manifest JSON to preserve tensor metadata fields
manifestPath, err := manifest.PathForName(n) manifestJSON, err := manifest.ReadManifestData(n)
if err != nil {
return err
}
manifestJSON, err := os.ReadFile(manifestPath)
if err != nil { if err != nil {
return err return err
} }
@@ -592,6 +591,14 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if existingMf.Config.Digest != "" { if existingMf.Config.Digest != "" {
deleteMap[existingMf.Config.Digest] = struct{}{} deleteMap[existingMf.Config.Digest] = struct{}{}
} }
if existingMf.BlobDigest() != "" {
digest := existingMf.BlobDigest()
if blob, err := manifest.BlobsPath(digest); err == nil {
if _, err := os.Stat(blob); err == nil {
deleteMap[digest] = struct{}{}
}
}
}
} }
if n.ProtocolScheme == "http" && !regOpts.Insecure { if n.ProtocolScheme == "http" && !regOpts.Insecure {
@@ -661,21 +668,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
fp, err := manifest.PathForName(n) if err := manifest.WriteManifestData(n, manifestData); err != nil {
if err != nil { slog.Info(fmt.Sprintf("couldn't write manifest for %s", n.DisplayShortest()))
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err return err
} }
err = os.WriteFile(fp, manifestData, 0o644) slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
if !envconfig.NoPrune() && len(deleteMap) > 0 { if !envconfig.NoPrune() && len(deleteMap) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"}) fn(api.ProgressResponse{Status: "removing unused layers"})
@@ -758,19 +756,11 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
// Write manifest // Write manifest
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
fp, err := manifest.PathForName(n) if err := manifest.WriteManifestData(n, manifestData); err != nil {
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err return err
} }
if err := os.WriteFile(fp, manifestData, 0o644); err != nil { slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
return err
}
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
return nil return nil
} }

View File

@@ -5,14 +5,58 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strings" "strings"
"testing" "testing"
"time"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
for _, digest := range []string{recentDigest, oldDigest} {
p, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(p, nil, 0o644); err != nil {
t.Fatal(err)
}
}
oldPath, err := manifest.BlobsPath(oldDigest)
if err != nil {
t.Fatal(err)
}
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
if err := PruneLayers(); err != nil {
t.Fatal(err)
}
recentPath, err := manifest.BlobsPath(recentDigest)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(recentPath); err != nil {
t.Fatalf("recent orphan was pruned: %v", err)
}
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
t.Fatalf("old orphan still exists: %v", err)
}
}
func TestModelCapabilities(t *testing.T) { func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision) // Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{ completionModelPath, _ := createBinFile(t, ggml.KV{
@@ -118,6 +162,39 @@ func TestModelCapabilities(t *testing.T) {
}, },
expectedCaps: []model.Capability{model.CapabilityEmbedding}, expectedCaps: []model.Capability{model.CapabilityEmbedding},
}, },
{
name: "gemma4 small safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererSmall,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "gemma4 large safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLarge,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "legacy gemma4 safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLegacy,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
} }
// compare two slices of model.Capability regardless of order // compare two slices of model.Capability regardless of order

View File

@@ -116,6 +116,10 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
proxied, err := func() (bool, error) { proxied, err := func() (bool, error) {
switch r.URL.Path { switch r.URL.Path {
case "/api/delete": case "/api/delete":
if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r)
return true, nil
}
return false, s.handleDelete(rec, r) return false, s.handleDelete(rec, r)
case "/api/pull": case "/api/pull":
return false, s.handlePull(rec, r) return false, s.handlePull(rec, r)

View File

@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" { if m.Config.Renderer != "" {
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) rendererName := resolveRendererName(m)
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -13,6 +13,14 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func testConfigWithRenderer(renderer string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer}
}
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
}
func TestChatPrompt(t *testing.T) { func TestChatPrompt(t *testing.T) {
type expect struct { type expect struct {
prompt string prompt string
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt) t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
} }
} }
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
msgs := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
model Model
want string
}{
{
name: "small from name",
model: Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large from model type",
model: Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderPrompt(&tt.model, msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -0,0 +1,110 @@
package server
import (
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
gemma4RendererLegacy = "gemma4"
gemma4RendererSmall = "gemma4-small"
gemma4RendererLarge = "gemma4-large"
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
// large template. Default to the small prompt unless the model is clearly in
// the large range.
gemma4LargeMinParameterCount = 16_000_000_000
)
func resolveRendererName(m *Model) string {
if m == nil || m.Config.Renderer == "" {
return ""
}
switch m.Config.Renderer {
case gemma4RendererLegacy:
return resolveGemma4Renderer(m)
default:
return m.Config.Renderer
}
}
func resolveGemma4Renderer(m *Model) string {
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
if m == nil {
return gemma4RendererLegacy
}
return m.Config.Renderer
}
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
return renderer
}
if renderer, ok := gemma4RendererFromName(m.Name); ok {
return renderer
}
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
return gemma4RendererForParameterCount(parameterCount)
}
return gemma4RendererSmall
}
func gemma4RendererForParameterCount(parameterCount uint64) string {
if parameterCount >= gemma4LargeMinParameterCount {
return gemma4RendererLarge
}
return gemma4RendererSmall
}
func gemma4RendererFromName(name string) (string, bool) {
lower := strings.ToLower(name)
switch {
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
return gemma4RendererSmall, true
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
return gemma4RendererLarge, true
default:
return "", false
}
}
func parseHumanParameterCount(s string) (uint64, bool) {
if s == "" {
return 0, false
}
unit := strings.ToUpper(s[len(s)-1:])
var multiplier float64
switch unit {
case "B":
multiplier = float64(format.Billion)
case "M":
multiplier = float64(format.Million)
case "K":
multiplier = float64(format.Thousand)
default:
return 0, false
}
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
if err != nil {
return 0, false
}
return uint64(value * multiplier), true
}
func isGemma4Renderer(renderer string) bool {
switch renderer {
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
return true
default:
return false
}
}

View File

@@ -1770,13 +1770,15 @@ func Serve(ln net.Listener) error {
return err return err
} }
manifestsPath, err := manifest.Path() for _, rootFn := range []func() (string, error){manifest.Path, manifest.V2Path} {
if err != nil { manifestsPath, err := rootFn()
return err if err != nil {
} return err
}
if err := manifest.PruneDirectory(manifestsPath); err != nil { if err := manifest.PruneDirectory(manifestsPath); err != nil && !os.IsNotExist(err) {
return err return err
}
} }
} }
} }
@@ -2408,7 +2410,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// current approach uses the transition from parsed thinking content to // current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on // parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) { // TODO(parthsareen): temporary fix for https://github.com/ollama/ollama/issues/15260.
// To revisit for other models and have a consistent pattern across models through parsers.
forceImmediate := m.Config.Parser == "gemma4" && req.Think != nil && !req.Think.Bool()
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && !forceImmediate && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil currentFormat = nil
} }

View File

@@ -109,12 +109,44 @@ func checkFileExists(t *testing.T, p string, expect []string) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if strings.HasSuffix(filepath.ToSlash(p), "/blobs/*") {
actual = slices.DeleteFunc(actual, isManifestBlobForTest)
}
if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" { if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" {
t.Errorf("file exists mismatch (-want +got):\n%s", diff) t.Errorf("file exists mismatch (-want +got):\n%s", diff)
} }
} }
func checkManifestFiles(t *testing.T, names ...string) {
t.Helper()
expect := make([]string, len(names))
for i, name := range names {
p, err := manifest.V2PathForName(model.ParseName(name))
if err != nil {
t.Fatal(err)
}
expect[i] = p
}
checkFileExists(t, filepath.Join(envconfig.Models(), "manifests-v2", "*", "*", "*", "*"), expect)
}
func isManifestBlobForTest(path string) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
var m manifest.Manifest
if err := json.Unmarshal(data, &m); err != nil {
return false
}
return m.SchemaVersion != 0 && m.MediaType != "" && (m.Config.Digest != "" || len(m.Layers) > 0)
}
func TestCreateFromBin(t *testing.T) { func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
@@ -136,9 +168,7 @@ func TestCreateFromBin(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"), filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
@@ -196,9 +226,7 @@ func TestCreateFromModel(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
w = createRequest(t, s.CreateHandler, api.CreateRequest{ w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2", Name: "test2",
@@ -210,10 +238,7 @@ func TestCreateFromModel(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test", "test2")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"), filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
@@ -306,9 +331,7 @@ func TestCreateRemovesLayers(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"), filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
@@ -327,9 +350,7 @@ func TestCreateRemovesLayers(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"), filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
@@ -357,9 +378,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"), filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"),
@@ -378,9 +397,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"), filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
@@ -411,9 +428,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"), filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
@@ -436,10 +451,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test", "test2")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
// Display contents of each blob in the directory // Display contents of each blob in the directory
blobDir := filepath.Join(p, "blobs") blobDir := filepath.Join(p, "blobs")
@@ -495,10 +507,7 @@ func TestCreateMergeParameters(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test", "test2")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"), filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
@@ -555,9 +564,7 @@ func TestCreateReplacesMessages(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"), filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
@@ -589,10 +596,7 @@ func TestCreateReplacesMessages(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test", "test2")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
// Old layers will not have been pruned // Old layers will not have been pruned
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
@@ -650,9 +654,7 @@ func TestCreateTemplateSystem(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"), filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"),
@@ -850,9 +852,7 @@ func TestCreateLicenses(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"), filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
}) })
} }
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "gemma4",
"general.parameter_count": uint64(25_200_000_000),
}, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
if cfg.Renderer != gemma4RendererLegacy {
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
}
if cfg.Parser != "gemma4" {
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
}
}
func TestDetectModelTypeFromFiles(t *testing.T) { func TestDetectModelTypeFromFiles(t *testing.T) {
t.Run("gguf file", func(t *testing.T) { t.Run("gguf file", func(t *testing.T) {
_, digest := createBinFile(t, nil, nil) _, digest := createBinFile(t, nil, nil)

View File

@@ -42,10 +42,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test", "test2")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"), filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
@@ -60,9 +57,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "test2")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"), filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
@@ -76,7 +71,7 @@ func TestDelete(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) checkManifestFiles(t)
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{}) checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
} }
@@ -109,7 +104,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Errorf("expected status code 200, actual %d", w.Code) t.Errorf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) checkManifestFiles(t)
} }
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) { func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
@@ -129,14 +124,12 @@ func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ checkManifestFiles(t, "gpt-oss:20b-cloud")
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
})
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"}) w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String()) t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) checkManifestFiles(t)
} }

View File

@@ -2108,6 +2108,132 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}) })
} }
// TestChatFormatWithThinkFalse verifies that when a model uses a builtin
// parser that supports thinking (e.g. gemma4) and the request explicitly
// disables thinking (think=false), the format constraint is passed to the
// first and only completion call. Previously, format was deferred for all
// thinking-capable parsers and only re-applied after an end-of-thinking
// transition — a transition that never fires when thinking is off. See
// https://github.com/ollama/ollama/issues/15260.
func TestChatFormatWithThinkFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := &mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Use the gemma4 builtin parser — it reports HasThinkingSupport=true, which
// adds CapabilityThinking to the model and previously triggered deferral of
// the format even when the user passed think=false.
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-gemma4-parser",
Files: map[string]string{"file.gguf": digest},
Parser: "gemma4",
Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }}{{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("create: expected status 200, got %d: %s", w.Code, w.Body.String())
}
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}},"required":["answer"]}`)
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
requestsMu.Lock()
requests = append(requests, r)
requestsMu.Unlock()
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
}
streamRequest := false
think := false
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-gemma4-parser",
Messages: []api.Message{{Role: "user", Content: "Respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
if w.Code != http.StatusOK {
t.Fatalf("chat: expected status 200, got %d: %s", w.Code, w.Body.String())
}
if len(requests) != 1 {
t.Fatalf("expected a single completion call, got %d", len(requests))
}
if !bytes.Equal([]byte(format), []byte(requests[0].Format)) {
t.Errorf("expected first completion format to match the request format, got %q", string(requests[0].Format))
}
}
func TestGenerateUnload(t *testing.T) { func TestGenerateUnload(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -658,11 +658,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
checkManifestList := func() { checkManifestList := func() {
t.Helper() t.Helper()
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/") mandir, err := manifest.V2Path()
if err != nil {
t.Fatalf("failed to resolve v2 manifest path: %v", err)
}
var entries []string var entries []string
t.Logf("dir entries:") t.Logf("dir entries:")
fsys := os.DirFS(mandir) fsys := os.DirFS(mandir)
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error { err = fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
if err != nil { if err != nil {
return err return err
} }
@@ -685,7 +688,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
g := entries[0] // raw path g := entries[0] // raw path
g = filepath.ToSlash(g) g = filepath.ToSlash(g)
w := model.ParseName(wantStableName).Filepath() wp, err := manifest.V2PathForName(model.ParseName(wantStableName))
if err != nil {
t.Fatalf("failed to resolve expected manifest path: %v", err)
}
w, err := filepath.Rel(mandir, wp)
if err != nil {
t.Fatalf("failed to make expected manifest path relative: %v", err)
}
w = filepath.ToSlash(w) w = filepath.ToSlash(w)
if g != w { if g != w {
t.Errorf("\ngot: %s\nwant: %s", g, w) t.Errorf("\ngot: %s\nwant: %s", g, w)

View File

@@ -93,6 +93,13 @@ func (t gemma4ImportTransform) quantizationType(name string, shape []int32, quan
return "" return ""
} }
// MoE router logits choose the top-k expert set. Quantization noise here
// can flip expert selection, after which downstream activations diverge
// sharply. The tensor is small, so leave it in source precision.
if isGemma4RouterProjection(name) {
return ""
}
// Mixed-precision quantization: sensitive tensors get higher precision. // Mixed-precision quantization: sensitive tensors get higher precision.
// //
// Value projections (v_proj) directly determine attention output quality. // Value projections (v_proj) directly determine attention output quality.
@@ -170,6 +177,12 @@ func isEmbedTokensWeight(name string) bool {
!strings.Contains(name, "per_layer") !strings.Contains(name, "per_layer")
} }
func isGemma4RouterProjection(name string) bool {
return strings.HasSuffix(name, ".router.proj.weight") &&
!strings.Contains(name, "audio_tower") &&
!strings.Contains(name, "vision_tower")
}
func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) { func (t gemma4ImportTransform) transformTensor(td *safetensors.TensorData) ([]*safetensors.TensorData, error) {
if td == nil { if td == nil {
return nil, nil return nil, nil

View File

@@ -68,6 +68,11 @@ func TestGemma4QuantizationType(t *testing.T) {
{"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"}, {"expert gate_up nvfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "nvfp4", "nvfp4"},
{"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"}, {"expert gate_up mxfp4", transform26B, "model.layers.0.moe.experts.42.gate_up_proj.weight", aligned, "mxfp4", "mxfp4"},
// === Router projection: expert selection is sensitive; keep source precision ===
{"router proj int4", transform26B, "model.layers.0.router.proj.weight", aligned, "int4", ""},
{"router proj nvfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "nvfp4", ""},
{"router proj mxfp4", transform26B, "model.layers.0.router.proj.weight", aligned, "mxfp4", ""},
// === k_proj: promoted only for 8-expert models === // === k_proj: promoted only for 8-expert models ===
{"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"}, {"k_proj 128 experts int4", transform26B, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int4"},
{"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"}, {"k_proj 8 experts int4", transform8E, "model.layers.0.self_attn.k_proj.weight", aligned, "int4", "int8"},

View File

@@ -11,6 +11,8 @@ import (
"strings" "strings"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
rootmanifest "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
) )
// ManifestLayer represents a layer in the manifest. // ManifestLayer represents a layer in the manifest.
@@ -49,9 +51,7 @@ func DefaultManifestDir() string {
// LoadManifest loads a manifest for the given model name. // LoadManifest loads a manifest for the given model name.
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag" // Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
func LoadManifest(modelName string) (*ModelManifest, error) { func LoadManifest(modelName string) (*ModelManifest, error) {
manifestPath := resolveManifestPath(modelName) data, err := rootmanifest.ReadManifestData(model.ParseName(modelName))
data, err := os.ReadFile(manifestPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("read manifest: %w", err) return nil, fmt.Errorf("read manifest: %w", err)
} }
@@ -67,36 +67,6 @@ func LoadManifest(modelName string) (*ModelManifest, error) {
}, nil }, nil
} }
// resolveManifestPath converts a model name to a manifest file path.
func resolveManifestPath(modelName string) string {
// Parse model name into components
// Default: registry.ollama.ai/library/<name>/<tag>
host := "registry.ollama.ai"
namespace := "library"
name := modelName
tag := "latest"
// Handle explicit tag
if idx := strings.LastIndex(name, ":"); idx != -1 {
tag = name[idx+1:]
name = name[:idx]
}
// Handle full path like "host/namespace/name"
parts := strings.Split(name, "/")
switch len(parts) {
case 3:
host = parts[0]
namespace = parts[1]
name = parts[2]
case 2:
namespace = parts[0]
name = parts[1]
}
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
}
// BlobPath returns the full path to a blob given its digest. // BlobPath returns the full path to a blob given its digest.
func (m *ModelManifest) BlobPath(digest string) string { func (m *ModelManifest) BlobPath(digest string) string {
// Convert "sha256:abc123" to "sha256-abc123" // Convert "sha256:abc123" to "sha256-abc123"

View File

@@ -1,8 +1,12 @@
package manifest package manifest
import ( import (
"os"
"path/filepath" "path/filepath"
"testing" "testing"
rootmanifest "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
) )
func TestTotalTensorSize(t *testing.T) { func TestTotalTensorSize(t *testing.T) {
@@ -55,3 +59,39 @@ func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs) t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
} }
} }
func TestLoadManifestPrefersV2(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
name := model.ParseName("example")
legacyPath, err := rootmanifest.PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
t.Fatal(err)
}
v2Path, err := rootmanifest.V2PathForName(name)
if err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(filepath.Dir(v2Path), 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(v2Path, []byte(`{"schemaVersion":2,"mediaType":"v2"}`), 0o644); err != nil {
t.Fatal(err)
}
m, err := LoadManifest(name.String())
if err != nil {
t.Fatal(err)
}
if m.Manifest.MediaType != "v2" {
t.Fatalf("media type = %q, want %q", m.Manifest.MediaType, "v2")
}
}

View File

@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port> // Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port)) cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ() cmd.Env = os.Environ()
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
s.cmd = cmd s.cmd = cmd
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
return nil return nil
} }
func mlxLibraryPathEnv() string {
switch runtime.GOOS {
case "windows":
return "PATH"
case "darwin":
return "DYLD_LIBRARY_PATH"
default:
return "LD_LIBRARY_PATH"
}
}
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
if len(libraryPaths) == 0 {
return
}
// Search order for the imagegen runner is:
// 1. bundled lib/ollama root
// 2. backend-specific library dirs selected during GPU discovery
// 3. any existing caller-provided library path values
pathEnv := mlxLibraryPathEnv()
pathEnvPaths := append([]string{}, libraryPaths...)
if existingPath, ok := os.LookupEnv(pathEnv); ok {
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
}
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
ollamaLibraryPaths := append([]string{}, libraryPaths...)
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
}
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
}
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
for i := range cmd.Env {
name, _, ok := strings.Cut(cmd.Env[i], "=")
if ok && strings.EqualFold(name, key) {
cmd.Env[i] = key + "=" + value
return
}
}
cmd.Env = append(cmd.Env, key+"="+value)
}
// getLastErr returns the last stderr line. // getLastErr returns the last stderr line.
func (s *Server) getLastErr() string { func (s *Server) getLastErr() string {
s.lastErrLock.Lock() s.lastErrLock.Lock()

View File

@@ -254,8 +254,23 @@ func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV
mlx.Pin(c.keys, c.values) mlx.Pin(c.keys, c.values)
} else { } else {
if c.idx < c.keys.Dim(2) { if c.idx < c.keys.Dim(2) {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) if c.offset <= c.maxSize {
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())) // Not yet wrapped: slots [c.idx, Dim) are grow padding
// or stale post-rewind data, not live window content.
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
} else {
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
// Linearize so the trim + concat below operate on contiguous
// positions and preserve the last (maxSize - 1) old tokens.
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
c.keys.Set(tailK.Concatenate(2, headK))
c.values.Set(tailV.Concatenate(2, headV))
c.idx = c.keys.Dim(2)
}
} }
// Trim to max_size to maintain sliding window // Trim to max_size to maintain sliding window
@@ -322,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil { if c.keys == nil || c.values == nil {
return nil return nil
} }
liveLen := min(c.offset, c.keys.Dim(2))
return []*mlx.Array{ return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()), c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
} }
} }

View File

@@ -0,0 +1,338 @@
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
// tensors whose channel value is the token id, so stateIDs can recover
// which ids survived in the cache.
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
return k, v
}
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
data := make([]float32, 0, 2*len(ids))
for _, id := range ids {
data = append(data, id, id)
}
k := mlx.FromValues(data, 1, 1, len(ids), 2)
v := mlx.FromValues(data, 1, 1, len(ids), 2)
return k, v
}
// stateIDs returns the ids currently in the cache in slot order (logical
// after a concat, physical/rotated after a single-token update).
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
t.Helper()
state := c.State()
if state == nil {
return nil
}
mlx.Eval(state[0])
flat := state[0].Floats()
n := state[0].Dim(2)
out := make([]float32, n)
for i := range n {
out[i] = flat[i*2]
}
return out
}
func equalSlice(a, b []float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
ids := make([]float32, n)
for i := range ids {
ids[i] = startID + float32(i)
}
k, v := multiTokenKV(ids)
c.Update(k, v)
return startID + float32(n)
}
func feedSingle(c *RotatingKVCache, id float32) {
k, v := singleTokenKV(id)
c.Update(k, v)
}
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
// pre-existing tokens in logical order so the first Q of the new batch
// has a full sliding window.
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
nextID := feedMulti(c, 1, 3)
for range 6 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 9 {
t.Fatalf("setup: offset=%d want 9", c.Offset())
}
if c.idx >= c.maxSize {
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
}
feedMulti(c, 10, 2)
got := stateIDs(t, c)
want := []float32{7, 8, 9, 10, 11}
if !equalSlice(got, want) {
t.Fatalf("post-concat window=%v want %v", got, want)
}
if c.Offset() != 11 {
t.Fatalf("offset=%d want 11", c.Offset())
}
}
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
// tokens plus the full new batch. This is the chunked-prefill contract
// x/mlxrunner/pipeline.go relies on.
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
feedMulti(c, 1, 6)
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
// so the first new Q has its full window in scope for this forward.
feedMulti(c, 7, 3)
got := stateIDs(t, c)
want := []float32{4, 5, 6, 7, 8, 9}
if !equalSlice(got, want) {
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
}
// The next decode trims oversize back to maxSize; order may be
// physical (rotated), so check as a set.
feedSingle(c, 10)
got = stateIDs(t, c)
if len(got) != window {
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
}
seen := map[float32]bool{}
for _, v := range got {
seen[v] = true
}
for _, w := range []float32{7, 8, 9, 10} {
if !seen[w] {
t.Fatalf("post-decode window missing %v (got %v)", w, got)
}
}
}
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
// underlying buffer by `step` slots via mlx.Zeros before writing, so
// after one decode on a short prefill c.idx < Dim even though the cache
// has not wrapped. Those trailing slots are zero padding and must not
// be pulled back into the live window on the next concat.
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
skipIfNoMLX(t)
const window = 512
c := NewRotatingKVCache(window)
feedMulti(c, 1, 3)
feedSingle(c, 4)
feedMulti(c, 5, 3)
got := stateIDs(t, c)
want := []float32{1, 2, 3, 4, 5, 6, 7}
if !equalSlice(got, want) {
t.Fatalf("growing-buffer concat=%v want %v", got, want)
}
}
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
// Restore(nil, target) between conversation turns to rewind the cache to
// the matched prefix. Restore moves c.offset/c.idx without trimming the
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
// tokens. A subsequent concat must drop those, not treat them as wrapped
// window content.
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
skipIfNoMLX(t)
const window = 8
c := NewRotatingKVCache(window)
// Grow the buffer to exactly maxSize without wrapping.
feedMulti(c, 1, 2)
for id := float32(3); id <= 8; id++ {
feedSingle(c, id)
}
if c.Offset() != window {
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
}
if !c.Restore(nil, 2) {
t.Fatalf("live rewind to 2 failed")
}
if c.Offset() != 2 {
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
}
feedMulti(c, 9, 3)
got := stateIDs(t, c)
want := []float32{1, 2, 9, 10, 11}
if !equalSlice(got, want) {
t.Fatalf("post-rewind concat=%v want %v", got, want)
}
if c.Offset() != 5 {
t.Fatalf("offset=%d want 5", c.Offset())
}
}
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
// formula drops to non-positive and all pre-existing tokens are kept.
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
feedMulti(c, 1, 2)
feedMulti(c, 3, 2)
got := stateIDs(t, c)
want := []float32{1, 2, 3, 4}
if !equalSlice(got, want) {
t.Fatalf("growing buffer=%v want %v", got, want)
}
}
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
feedMulti(c, 1, 8)
if c.Offset() != 8 {
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
}
feedMulti(c, 9, 8)
got := stateIDs(t, c)
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
if !equalSlice(got, want) {
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
}
feedMulti(c, 17, 4)
got = stateIDs(t, c)
want = []float32{14, 15, 16, 17, 18, 19, 20}
if !equalSlice(got, want) {
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
}
// Decode trims oversize back to maxSize; order may be physical.
feedSingle(c, 21)
got = stateIDs(t, c)
if len(got) != window {
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
}
seen := map[float32]bool{}
for _, v := range got {
seen[v] = true
}
for _, w := range []float32{18, 19, 20, 21} {
if !seen[w] {
t.Fatalf("post-decode window missing %v (got %v)", w, got)
}
}
}
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
// prefill sequence and checks that each new prefill retains the last
// (maxSize-1) pre-existing tokens in logical order.
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
nextID := feedMulti(c, 1, 2)
for range 5 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 7 {
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
}
feedMulti(c, nextID, 3)
nextID += 3
got := stateIDs(t, c)
want := []float32{5, 6, 7, 8, 9, 10}
if !equalSlice(got, want) {
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
}
for range 4 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 14 {
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
}
feedMulti(c, nextID, 2)
got = stateIDs(t, c)
want = []float32{12, 13, 14, 15, 16}
if !equalSlice(got, want) {
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
}
}
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
// token count through any mix of Update() calls — Gemma 4 uses
// donorEntry.Offset - L for the consumer's RoPE offset.
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
nextID := feedMulti(c, 1, 3)
if c.Offset() != 3 {
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
}
for i := range 5 {
feedSingle(c, nextID)
nextID++
if c.Offset() != 3+i+1 {
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
}
}
nextID = feedMulti(c, nextID, 2)
if c.Offset() != 10 {
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
}
// L > maxSize concat.
feedMulti(c, nextID, 7)
if c.Offset() != 17 {
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
}
}

View File

@@ -151,20 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
} }
} }
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. type CompletionRequest struct {
type completionRequest struct { Prompt string
Prompt string `json:"prompt"` Options api.Options
Options *completionOpts `json:"options,omitempty"` Logprobs bool
} TopLogprobs int
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
} }
type CompletionResponse struct { type CompletionResponse struct {
@@ -177,6 +168,8 @@ type CompletionResponse struct {
EvalCount int EvalCount int
EvalDuration time.Duration EvalDuration time.Duration
Logprobs []llm.Logprob
Error *api.StatusError Error *api.StatusError
} }
@@ -201,19 +194,13 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer. // Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := completionRequest{ creq := CompletionRequest{
Prompt: req.Prompt, Prompt: req.Prompt,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
} }
if req.Options != nil { if req.Options != nil {
creq.Options = &completionOpts{ creq.Options = *req.Options
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
RepeatLastN: req.Options.RepeatLastN,
PresencePenalty: req.Options.PresencePenalty,
NumPredict: req.Options.NumPredict,
}
} }
body, err := json.Marshal(creq) body, err := json.Marshal(creq)
@@ -262,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
PromptEvalDuration: raw.PromptEvalDuration, PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration, EvalDuration: raw.EvalDuration,
Logprobs: raw.Logprobs,
} }
fn(cresp) fn(cresp)

View File

@@ -1,62 +1,86 @@
package mlx package mlx
// #include "generated.h"
import "C"
import "math" import "math"
var geluCoeff = float32(math.Sqrt(2 / math.Pi)) var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox matches mlx.nn.gelu_approx: // GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// // as a fused kernel.
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) var GELUApprox = Compile1(
func GELUApprox(x *Array) *Array { "GELUApprox",
// Use dtype-matched scalars to avoid implicit upcasts on bf16 inputs. func(x *Array) *Array {
half := scalarWithDtype(0.5, x) // Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
defer C.mlx_array_free(half) dt := x.DType()
coeff := scalarWithDtype(geluCoeff, x) half := FromValue[float32](0.5).AsType(dt)
defer C.mlx_array_free(coeff) coeff := FromValue(geluCoeff).AsType(dt)
c := scalarWithDtype(0.044715, x) c := FromValue[float32](0.044715).AsType(dt)
defer C.mlx_array_free(c) one := FromValue[float32](1.0).AsType(dt)
// x^3 via x*x*x (avoids general Power which is slower) // x^3 via x*x*x (avoids general Power which is slower).
x3 := New("GELU_X3") x3 := x.Multiply(x).Multiply(x)
C.mlx_multiply(&x3.ctx, x.ctx, x.ctx, DefaultStream().ctx) inner := x.Add(c.Multiply(x3))
tmp := New("GELU_X3b") tanh := coeff.Multiply(inner).Tanh()
C.mlx_multiply(&tmp.ctx, x3.ctx, x.ctx, DefaultStream().ctx) return half.Multiply(x).Multiply(one.Add(tanh))
x3 = tmp },
Shapeless(),
)
// 0.044715 * x^3 // SiLU returns a * sigmoid(a) as a fused kernel.
cx3 := New("GELU_CX3") var SiLU = Compile1(
C.mlx_multiply(&cx3.ctx, c, x3.ctx, DefaultStream().ctx) "SiLU",
func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
},
Shapeless(),
)
// x + 0.044715 * x^3 // SwiGLU returns silu(gate) * up as a fused kernel.
inner := New("GELU_INNER") var SwiGLU = Compile2(
C.mlx_add(&inner.ctx, x.ctx, cx3.ctx, DefaultStream().ctx) "SwiGLU",
func(gate, up *Array) *Array {
return SiLU(gate).Multiply(up)
},
Shapeless(),
)
// sqrt(2/pi) * (x + 0.044715 * x^3) // GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
scaled := New("GELU_SCALED") // geglu, used by Gemma-family MLP and MoE paths.
C.mlx_multiply(&scaled.ctx, coeff, inner.ctx, DefaultStream().ctx) var GeGLU = Compile2(
"GeGLU",
func(gate, up *Array) *Array {
return GELUApprox(gate).Multiply(up)
},
Shapeless(),
)
// tanh(...) // LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
th := New("GELU_TANH") // mlx_lm's logit_softcap. cap must have the same dtype as x.
C.mlx_tanh(&th.ctx, scaled.ctx, DefaultStream().ctx) var LogitSoftcap = Compile2(
"LogitSoftcap",
func(x, cap *Array) *Array {
return x.Divide(cap).Tanh().Multiply(cap)
},
Shapeless(),
)
// 1 + tanh(...) // sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
one := scalarWithDtype(1.0, x) // head. Two outputs are returned so the pre-bias sigmoid (used to gather
defer C.mlx_array_free(one) // per-expert scores after top-k) and the post-bias negation (used as the
onePlusTanh := New("GELU_1PT") // argpartition key for top-k) share a single kernel.
C.mlx_add(&onePlusTanh.ctx, one, th.ctx, DefaultStream().ctx) var sigmoidRouterFused = Compile(
"SigmoidRouter",
func(in ...*Array) []*Array {
gates, bias := in[0], in[1]
orig := gates.Sigmoid()
neg := orig.Add(bias).Negative()
return []*Array{orig, neg}
},
Shapeless(),
)
// 0.5 * x // SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
halfX := New("GELU_HALFX") // kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
C.mlx_multiply(&halfX.ctx, half, x.ctx, DefaultStream().ctx) func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
out := sigmoidRouterFused(gates, bias)
// 0.5 * x * (1 + tanh(...)) return out[0], out[1]
out := New("GELU_APPROX")
C.mlx_multiply(&out.ctx, halfX.ctx, onePlusTanh.ctx, DefaultStream().ctx)
return out
}
func SILU(t *Array) *Array {
return t.Multiply(t.Sigmoid()).AsType(t.DType())
} }

View File

@@ -27,7 +27,11 @@ var arrays []*Array
func New(name string) *Array { func New(name string) *Array {
t := &Array{name: name} t := &Array{name: name}
arrays = append(arrays, t) if tracing {
traceScratch = append(traceScratch, t)
} else {
arrays = append(arrays, t)
}
return t return t
} }
@@ -234,6 +238,9 @@ func (t Array) Float() float64 {
} }
func (t Array) Ints() []int { func (t Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
ints := make([]int, t.Size()) ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f) ints[i] = int(f)
@@ -242,6 +249,9 @@ func (t Array) Ints() []int {
} }
func (t Array) Floats() []float32 { func (t Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
floats := make([]float32, t.Size()) floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f) floats[i] = float32(f)

192
x/mlxrunner/mlx/compile.go Normal file
View File

@@ -0,0 +1,192 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
//
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// extern void closureDestructor(void* payload);
import "C"
import (
"log/slog"
"runtime/cgo"
"sync"
"unsafe"
)
// CompileFunc is the signature of a function that can be compiled.
type CompileFunc func(inputs ...*Array) []*Array
// CompileOption configures Compile behavior.
type CompileOption func(*compileConfig)
type compileConfig struct {
shapeless bool
}
// Shapeless traces the function once against symbolic shapes so the compiled
// graph accepts any input shape afterwards. Without this option, MLX re-traces
// on each new (shape, dtype) combination and caches each specialization.
func Shapeless() CompileOption {
return func(c *compileConfig) { c.shapeless = true }
}
// Compile returns a compiled version of fn. When called during another
// compile's trace, fn is inlined directly so outer compiles can fuse through
// inner ones.
//
// Compiled functions must not have side effects outside of the function. Do
// not access data other than the arguments passed in (either Go data or MLX
// arrays) unless it is a constant.
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
var cfg compileConfig
for _, o := range opts {
o(&cfg)
}
var closure C.mlx_closure
var once sync.Once
return func(inputs ...*Array) []*Array {
if tracing {
return fn(inputs...)
}
once.Do(func() {
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
*payload = cgo.NewHandle(fn)
src := C.mlx_closure_new_func_payload(
(*[0]byte)(C.closureCallback),
unsafe.Pointer(payload),
(*[0]byte)(C.closureDestructor),
)
defer C.mlx_closure_free(src)
closure = C.mlx_closure_new()
mlxCheck(name+": compile failed", func() C.int {
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
})
})
inVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
mlxCheck(name+": closure apply failed", func() C.int {
return C.mlx_closure_apply(&outVec, closure, inVec)
})
n := int(C.mlx_vector_array_size(outVec))
outputs := make([]*Array, n)
for i := range n {
outputs[i] = New(name)
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
}
return outputs
}
}
// Compile1 compiles a unary function. See Compile.
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0])}
}, opts...)
return func(a *Array) *Array {
return cf(a)[0]
}
}
// Compile2 compiles a binary function. See Compile.
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1])}
}, opts...)
return func(a, b *Array) *Array {
return cf(a, b)[0]
}
}
// Compile3 compiles a ternary function. See Compile.
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1], in[2])}
}, opts...)
return func(a, b, c *Array) *Array {
return cf(a, b, c)[0]
}
}
// tracing is true while a compile callback is running. Since MLX is
// single-threaded at this level a plain Go bool suffices.
var tracing bool
// traceScratch collects arrays created during a compile trace so they can be
// freed as a group when the callback returns.
var traceScratch []*Array
//export closureCallback
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
defer func() {
if r := recover(); r != nil {
slog.Error("mlx closure callback panicked", "panic", r)
rc = 1
}
}()
handle := *(*cgo.Handle)(payload)
fn := handle.Value().(CompileFunc)
// When tracing, we track all of the intermediates that are created and free them separately at the end of
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
if tracing {
panic("mlx: nested compile trace")
}
tracing = true
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {
C.mlx_array_free(a.ctx)
a.ctx.ctx = nil
}
}
tracing = false
traceScratch = nil
}()
n := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, n)
for i := range n {
a := New("")
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
inputs[i] = a
}
outputs := fn(inputs...)
var arrPtr *C.mlx_array
if len(outputs) > 0 {
handles := make([]C.mlx_array, len(outputs))
for i, out := range outputs {
handles[i] = out.ctx
}
arrPtr = &handles[0]
}
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
return 0
}
//export closureDestructor
func closureDestructor(payload unsafe.Pointer) {
handle := *(*cgo.Handle)(payload)
handle.Delete()
C.free(payload)
}

View File

@@ -0,0 +1,147 @@
package mlx
import (
"testing"
)
func TestCompileFusion(t *testing.T) {
skipIfNoMLX(t)
// Compile fuses the ops inside a function body into a single kernel,
// eliminating intermediate buffers. Use a diamond-shaped graph where
// two branches must be materialized simultaneously without fusion,
// then compare peak memory against the compiled version which fuses
// everything into one kernel with no intermediates.
const n = 1024 * 1024 // 4MB per float32 array
data := make([]float32, n)
for i := range data {
data[i] = float32(i + 1)
}
// Diamond: both a*b and a+b must be live for the final multiply.
// Without fusion: peak includes both intermediates (~8MB extra).
// With fusion: single kernel, no intermediates.
body := func(a, b *Array) *Array {
return a.Multiply(b).Multiply(a.Add(b))
}
a := FromValues(data, n)
b := FromValues(data, n)
Pin(a, b)
defer Unpin(a, b)
// Compiled: ops fused into a single kernel.
EnableCompile()
fn := Compile2("diamond", body, Shapeless())
warm := fn(a, b)
Eval(warm)
Sweep()
ClearCache()
ResetPeakMemory()
y := fn(a, b)
Eval(y)
compiledPeak := PeakMemory()
Sweep()
// Uncompiled: ops evaluated individually, intermediates materialized.
ClearCache()
ResetPeakMemory()
z := body(a, b)
Eval(z)
uncompiledPeak := PeakMemory()
Sweep()
if compiledPeak == 0 && uncompiledPeak == 0 {
t.Skip("peak memory tracking not available")
}
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
if compiledPeak >= uncompiledPeak {
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
}
}
func TestCompileNested(t *testing.T) {
skipIfNoMLX(t)
// A compiled function that calls another compiled function should
// produce correct results. The inner function inlines via isTracing()
// during the outer's trace.
inner := Compile1("silu", func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
}, Shapeless())
outer := Compile2("swiglu", func(gate, up *Array) *Array {
return inner(gate).Multiply(up)
}, Shapeless())
gate := FromValues([]float32{0, 1, 2}, 3)
up := FromValues([]float32{1, 1, 1}, 3)
Pin(gate, up)
defer Unpin(gate, up)
y := outer(gate, up)
Eval(y)
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
got := y.Floats()
want := []float32{0, 0.7310586, 1.7615942}
for i, v := range got {
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileCallbackPanicRecovers(t *testing.T) {
skipIfNoMLX(t)
boom := Compile1("boom", func(a *Array) *Array {
panic("intentional test panic")
})
x := FromValues([]float32{1}, 1)
Pin(x)
defer Unpin(x)
defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from Call, got none")
}
if _, ok := r.(string); !ok {
t.Fatalf("expected string panic, got %T: %v", r, r)
}
}()
boom(x)
}
func TestCompileNoTrackingGrowth(t *testing.T) {
skipIfNoMLX(t)
// Repeated invocations of a compiled kernel should not grow the
// tracked-arrays list — the callback's traceScratch collects
// intermediates during tracing and frees them when the callback returns.
fn := Compile2("mul_add", func(a, b *Array) *Array {
return a.Multiply(b).Add(b)
})
a := FromValues([]float32{1, 2}, 2)
b := FromValues([]float32{3, 4}, 2)
Pin(a, b)
defer Unpin(a, b)
Sweep()
before := len(arrays)
for range 100 {
_ = fn(a, b)
Sweep()
}
after := len(arrays)
if after > before+2 {
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
}
}

View File

@@ -9,8 +9,8 @@ package mlx
// #include "generated.h" // #include "generated.h"
// #include <string.h> // #include <string.h>
// //
// static char _mlx_last_error_msg[1024] = {0}; // static __thread char _mlx_last_error_msg[1024] = {0};
// static int _mlx_last_error_flag = 0; // static __thread int _mlx_last_error_flag = 0;
// //
// static void _mlx_capture_error_handler(const char* msg, void* data) { // static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data; // (void)data;
@@ -30,15 +30,13 @@ package mlx
// _mlx_last_error_msg[0] = '\0'; // _mlx_last_error_msg[0] = '\0';
// } // }
// //
// static int mlx_had_last_error(void) {
// return _mlx_last_error_flag;
// }
//
// static const char* mlx_get_last_error(void) { // static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : NULL; // return _mlx_last_error_flag ? _mlx_last_error_msg : "";
// } // }
import "C" import "C"
import "runtime"
func init() { func init() {
// Replace the default exit(-1) error handler with one that captures // Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go. // the error message so we can surface it in Go.
@@ -53,6 +51,24 @@ func Version() string {
return C.GoString(C.mlx_string_data(str)) return C.GoString(C.mlx_string_data(str))
} }
// mlxCheck locks the goroutine to its OS thread, clears the captured error
// state, calls fn, and panics with the captured message if fn returns non-zero.
// The thread lock ensures the thread-local error state is read from the same
// thread that executed the call.
func mlxCheck(fallback string, fn func() C.int) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.mlx_clear_last_error()
if fn() != 0 {
msg := C.GoString(C.mlx_get_last_error())
if msg == "" {
msg = fallback
}
panic("mlx: " + msg)
}
}
func doEval(outputs []*Array, async bool) { func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 { if len(outputs) == 0 {
return return
@@ -67,20 +83,12 @@ func doEval(outputs []*Array, async bool) {
} }
} }
C.mlx_clear_last_error() mlxCheck("eval failed", func() C.int {
var rc C.int if async {
if async { return C.mlx_async_eval(vector)
rc = C.mlx_async_eval(vector)
} else {
rc = C.mlx_eval(vector)
}
if rc != 0 {
msg := "mlx eval failed"
if C.mlx_had_last_error() != 0 {
msg = C.GoString(C.mlx_get_last_error())
} }
panic("mlx: " + msg) return C.mlx_eval(vector)
} })
} }
func AsyncEval(outputs ...*Array) { func AsyncEval(outputs ...*Array) {

View File

@@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
return out return out
} }
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array { func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL") out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
@@ -169,6 +175,12 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
return out return out
} }
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
out := New("SCATTER_ADD_AXIS")
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array { func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes)) cAxes := make([]C.int, len(axes))
for i := range axes { for i := range axes {

View File

@@ -404,11 +404,6 @@ func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *A
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices) return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
} }
func SiLU(a *Array) *Array {
sig := a.Sigmoid()
return a.Multiply(sig)
}
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array { func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil) return RoPEWithFreqs(x, dims, traditional, base, scale, offset, nil)
} }

View File

@@ -7,11 +7,15 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
"sort"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
) )
func prefillChunkSize() int { func prefillChunkSize() int {
@@ -23,28 +27,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
return errors.New("model not loaded") return errors.New("model not loaded")
} }
enableCompile := true
if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
enableCompile = modelCompile.EnableCompile()
}
if enableCompile {
mlx.EnableCompile()
} else {
mlx.DisableCompile()
}
mlx.ResetPeakMemory() mlx.ResetPeakMemory()
ctx := request.Ctx ctx := request.Ctx
var ( var sample, nextSample sampler.Result
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
defer func() { defer func() {
if request.Sampler != nil { if request.Sampler != nil {
request.Sampler.Free() request.Sampler.Free()
} }
mlx.Unpin(sample, logprobs) mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample, nextLogprobs) mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep() mlx.Sweep()
mlx.ClearCache() mlx.ClearCache()
@@ -69,10 +61,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
// Cap generation to stay within the model's context length // Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs) maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 { if request.Options.NumPredict <= 0 {
request.Options.MaxTokens = maxGenerate request.Options.NumPredict = maxGenerate
} else { } else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate) request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
} }
request.Sampler.ResetHistory(inputs) request.Sampler.ResetHistory(inputs)
@@ -144,41 +136,38 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ClearCache() mlx.ClearCache()
} }
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token.ExpandDims(0), caches) fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd) logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true)) sample := request.Sampler.Sample(logits)
sample := request.Sampler.Sample(logprobs) mlx.Pin(sample.Arrays()...)
mlx.Pin(sample, logprobs)
mlx.Sweep() mlx.Sweep()
mlx.AsyncEval(sample, logprobs) mlx.AsyncEval(sample.Arrays()...)
return sample
return sample, logprobs
} }
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed)) sample = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1} final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.MaxTokens { for i := range request.Options.NumPredict {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
request.Sampler.AppendToken(sample) request.Sampler.AppendToken(sample.Token)
nextSample, nextLogprobs = step(sample) nextSample = step(sample.Token)
if i == 0 { if i == 0 {
mlx.Eval(sample) mlx.Eval(sample.Arrays()...)
final.PromptEvalDuration = time.Since(now) final.PromptEvalDuration = time.Since(now)
now = time.Now() now = time.Now()
} }
output := int32(sample.Int()) output := int32(sample.Token.Int())
session.outputs = append(session.outputs, output) session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) { if r.Tokenizer.IsEOS(output) {
@@ -187,17 +176,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
break break
} }
select { if resp, ok := dec.decode(sample); ok {
case <-ctx.Done(): select {
return ctx.Err() case <-ctx.Done():
case request.Responses <- CompletionResponse{ return ctx.Err()
Content: r.Decode(output, &b), case request.Responses <- resp:
}: }
} }
mlx.Unpin(sample, logprobs) mlx.Unpin(sample.Arrays()...)
sample, logprobs = nextSample, nextLogprobs sample, nextSample = nextSample, sampler.Result{}
nextSample, nextLogprobs = nil, nil
if i%256 == 0 { if i%256 == 0 {
mlx.ClearCache() mlx.ClearCache()
@@ -213,13 +201,57 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
func (r Runner) Decode(sample int32, b *bytes.Buffer) string { // decoder serializes sampled tokens into response chunks, holding bytes
token := r.Tokenizer.Decode([]int32{sample}) // whose UTF-8 sequence hasn't completed yet and the logprobs that belong
// with those bytes so Content and Logprobs stay aligned when a chunk does
// flush.
type decoder struct {
tokenizer *tokenizer.Tokenizer
buf bytes.Buffer
logprobs []llm.Logprob
}
if _, err := b.WriteString(token); err != nil { func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
slog.Error("Failed to write token to buffer", "error", err) output := int32(res.Token.Int())
return "" d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf)
if content == "" {
return CompletionResponse{}, false
}
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
d.logprobs = nil
return resp, true
}
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
if sample.Logprob == nil {
return nil
}
tok := func(id int32) string { return decode([]int32{id}) }
out := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: tok(int32(sample.Token.Int())),
Logprob: float64(sample.Logprob.Floats()[0]),
},
} }
return flushValidUTF8Prefix(b) if sample.TopTokens != nil {
ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids))
for i, id := range ids {
pairs[i] = llm.TokenLogprob{
Token: tok(int32(id)),
Logprob: float64(vals[i]),
}
}
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob
})
out.TopLogprobs = pairs
}
return []llm.Logprob{out}
} }

View File

@@ -19,7 +19,7 @@ import (
) )
type Request struct { type Request struct {
TextCompletionsRequest CompletionRequest
Responses chan CompletionResponse Responses chan CompletionResponse
Pipeline func(Request) error Pipeline func(Request) error
@@ -28,22 +28,6 @@ type Request struct {
Sampler *sample.Sampler Sampler *sample.Sampler
} }
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n"`
PresencePenalty float32 `json:"presence_penalty"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
@@ -79,6 +63,8 @@ func (r *Runner) Load(modelName string) error {
r.Model = m r.Model = m
r.Tokenizer = m.Tokenizer() r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength() r.contextLength = m.MaxContextLength()
mlx.EnableCompile()
return nil return nil
} }

View File

@@ -0,0 +1,249 @@
//go:build mlx
package sample
import (
"math"
"sort"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// logprobEntry is the (token id, logprob) pair returned by the sampler's
// top-K extraction, used after the test-side descending sort.
type logprobEntry struct {
id int
logprob float64
}
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
// and returns the greedily-sampled token id, its logprob, and the top-K
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
t.Helper()
s := New(Options{Logprobs: true, TopLogprobs: topK})
defer func() {
s.Free()
mlx.Sweep()
}()
tensor := mlx.FromValues(logits, 1, len(logits))
res := s.Sample(tensor)
mlx.Pin(res.Arrays()...)
defer mlx.Unpin(res.Arrays()...)
mlx.Sweep()
mlx.Eval(res.Arrays()...)
selected := res.Token.Int()
selLP := float64(res.Logprob.Floats()[0])
var top []logprobEntry
if topK > 0 && res.TopTokens != nil {
ids := res.TopTokens.Ints()
vals := res.TopLogprobs.Floats()
top = make([]logprobEntry, len(ids))
for i, id := range ids {
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
}
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
}
return selected, selLP, top
}
func TestSampleLogprobsBasic(t *testing.T) {
tests := []struct {
name string
logits []float32
topK int
wantSelectedID int
wantTopLen int
}{
{
name: "single token without top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 0,
wantSelectedID: 0,
wantTopLen: 0,
},
{
name: "single token with top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 3,
wantSelectedID: 0,
wantTopLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
if selected != tt.wantSelectedID {
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
}
if len(top) != tt.wantTopLen {
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
}
})
}
}
func TestSampleLogprobsNumericalStability(t *testing.T) {
logits := []float32{1000.0, 999.0, 998.0}
_, selLP, top := runSampleLogprobs(t, logits, 3)
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
t.Errorf("selected logprob is not finite: %f", selLP)
}
for i, e := range top {
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
}
}
}
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if selLP > 0 {
t.Errorf("selected logprob should be <= 0, got %f", selLP)
}
for i, e := range top {
if e.logprob > 0 {
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
}
}
if tt.name == "uniform" {
want := 1.0 / float64(len(tt.logits))
got := math.Exp(selLP)
if math.Abs(got-want) > 1e-6 {
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending at %d: %f > %f",
i, top[i].logprob, top[i-1].logprob)
}
}
found := false
for _, e := range top {
if e.id == selected {
found = true
if math.Abs(e.logprob-selLP) > 1e-6 {
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
}
break
}
}
if !found {
t.Errorf("selected token %d not present in top-K", selected)
}
})
}
}
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
{"large differences", []float32{10.0, 0.0, -10.0}},
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
{"very large values", []float32{500.0, 499.0, 498.0}},
{"very small values", []float32{-500.0, -499.0, -498.0}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if len(top) != len(tt.logits) {
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
}
var sum float64
for _, e := range top {
p := math.Exp(e.logprob)
if p < 0 || p > 1 {
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
}
sum += p
}
if math.Abs(sum-1.0) > 1e-5 {
t.Errorf("probabilities sum = %f, want 1.0", sum)
}
})
}
}
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
logits := []float32{3.0, 1.0, 2.0, 0.5}
maxIdx := 0
for i, v := range logits[1:] {
if v > logits[maxIdx] {
maxIdx = i + 1
}
}
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
if selected != maxIdx {
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
}
if top[0].id != maxIdx {
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
}
if math.Abs(top[0].logprob-selLP) > 1e-6 {
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
}
}
func TestSampleLogprobsTopKOrdering(t *testing.T) {
// Logits chosen so argmax order differs from index order.
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
wantOrder := []int{1, 3, 4, 0, 2}
_, _, top := runSampleLogprobs(t, logits, len(logits))
if len(top) != len(wantOrder) {
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
}
for i, e := range top {
if e.id != wantOrder[i] {
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
i, top[i].logprob, i-1, top[i-1].logprob)
}
}
}

View File

@@ -8,47 +8,76 @@ import (
type Transform func(*Sampler, *mlx.Array) *mlx.Array type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Options struct {
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
// Logprobs causes Sample to populate Result.Logprob with the selected
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
Logprobs bool
TopLogprobs int
}
type Sampler struct { type Sampler struct {
Temperature float32 Options
TopP float32
MinP float32
TopK int
RepeatLastN int
PresencePenalty float32
history *mlx.Array history *mlx.Array
historyLen int historyLen int
transforms []Transform transforms []Transform
} }
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler { // Result bundles the outputs of one decode step. The logprob tensors are
s := &Sampler{ // populated only when the sampler is configured to report them.
Temperature: temp, type Result struct {
TopP: top_p, Token *mlx.Array // sampled token id, shape [B]
MinP: min_p, Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
TopK: top_k, TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
RepeatLastN: repeatLastN, TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
PresencePenalty: presencePenalty, }
// Arrays returns the tensor fields as a slice so callers can drive the mlx
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
// fields stay nil; the mlx helpers skip them.
func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
}
func New(opts Options) *Sampler {
if opts.RepeatPenalty <= 0 {
opts.RepeatPenalty = 1
} }
s := &Sampler{Options: opts}
var transforms []Transform var transforms []Transform
if presencePenalty != 0 { if s.usesHistory() {
transforms = append(transforms, penalty) transforms = append(transforms, penalty)
} }
if top_p > 0 && top_p < 1 { hasTopP := opts.TopP > 0 && opts.TopP < 1
transforms = append(transforms, topP) hasTopK := opts.TopK > 0
} switch {
case hasTopP:
if min_p != 0 { // topKTopP always does a full descending sort for the top-P
transforms = append(transforms, minP) // cumulative mask and opportunistically masks top-K during the
} // same pass when it is also configured.
transforms = append(transforms, topKTopP)
if top_k > 0 { case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort.
transforms = append(transforms, topK) transforms = append(transforms, topK)
} }
if temp == 0 { if opts.MinP != 0 {
transforms = append(transforms, minP)
}
if opts.Temperature == 0 {
transforms = append(transforms, greedy) transforms = append(transforms, greedy)
} else { } else {
transforms = append(transforms, temperature) transforms = append(transforms, temperature)
@@ -59,7 +88,7 @@ func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty flo
} }
func (s *Sampler) usesHistory() bool { func (s *Sampler) usesHistory() bool {
return s.PresencePenalty != 0 return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
} }
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) { func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
@@ -115,75 +144,138 @@ func (s *Sampler) Free() {
s.setHistory(nil, 0) s.setHistory(nil, 0)
} }
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array { // Sample runs the configured transform chain on the raw per-token logits
// and returns the sampled token id plus, when configured, the reported
// log-probability tensors for the selected token and the top-K tokens.
func (s *Sampler) Sample(logits *mlx.Array) Result {
scores := logits
for _, transform := range s.transforms { for _, transform := range s.transforms {
logits = transform(s, logits) scores = transform(s, scores)
} }
return logits res := Result{Token: scores}
}
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array { if s.Logprobs {
return logits.Argmax(-1, false) // Compute log_softmax in fp32 and subtract the max before
} // logsumexp so the final subtraction stays on small values.
// Otherwise it cancels two large numbers and loses precision.
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array { lp := logits.AsType(mlx.DTypeFloat32)
return mlx.DivScalar(logits, s.Temperature).Categorical(-1) lp = lp.Subtract(lp.MaxAxis(-1, true))
} lp = lp.Subtract(lp.Logsumexp(true))
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array { if k := s.TopLogprobs; k > 0 {
if s.TopP <= 0 || s.TopP >= 1 { if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
return logprobs k = vocab
}
// Argpartition on the negated values places the K largest
// (unsorted) in positions [0:K].
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
res.TopTokens = idx.AsType(mlx.DTypeInt32)
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
}
} }
return res
}
order := logprobs.Negative().ArgsortAxis(-1) func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
sortedLogprobs := logprobs.TakeAlongAxis(order, -1) return scores.Argmax(-1, false)
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true) }
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
}
// topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
// case uses a cheaper partial sort via the topK transform.
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := s.TopK > 0 && s.TopK < vocab
order := scores.Negative().ArgsortAxis(-1)
sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below s.TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP)) keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1)))) sorted = mlx.Where(keep, sorted, negInf)
return logprobs.PutAlongAxis(order, filtered, -1)
}
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array { out := scores.PutAlongAxis(order, sorted, -1)
if s.MinP <= 0 || s.MinP > 1 {
return logprobs // Top-K: sorted is already in descending order, so positions [K, V)
// are the ones to drop. Scatter -inf through their original-layout
// indices (order[K:]). Positional (not value-based) so exactly K
// tokens survive — ties at the K-th logit get broken by the sort
// order rather than promoted through the filter.
if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1)
} }
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1) return out
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP)))) }
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return scores
}
maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
return mlx.Where( return mlx.Where(
logprobs.Less(minLogprobs), scores.Less(threshold),
mlx.FromValue(float32(math.Inf(-1))), mlx.FromValue(float32(math.Inf(-1))),
logprobs, scores,
) )
} }
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array { func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.TopK <= 0 { if s.TopK <= 0 {
return logprobs return scores
} }
vocab := logprobs.Dim(logprobs.NumDims() - 1) vocab := scores.Dim(scores.NumDims() - 1)
if s.TopK >= vocab { if s.TopK >= vocab {
return logprobs return scores
} }
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
} }
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array { func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 { if s.historyLen == 0 {
return logprobs return scores
} }
tokenIndices := s.history tokenIndices := s.history
if logprobs.NumDims() > 1 { if scores.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0) tokenIndices = tokenIndices.ExpandDims(0)
} }
selected := logprobs.TakeAlongAxis(tokenIndices, -1) if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
adjusted := mlx.AddScalar(selected, -s.PresencePenalty) adjusted := scores.TakeAlongAxis(tokenIndices, -1)
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1) if s.RepeatPenalty != 1 {
factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))),
mlx.FromValue(s.RepeatPenalty),
mlx.FromValue(1/s.RepeatPenalty),
)
adjusted = adjusted.Multiply(factor)
}
if s.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
}
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
}
if s.FrequencyPenalty != 0 {
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
}
return scores
} }

View File

@@ -10,8 +10,7 @@ import (
) )
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
// RepeatLastN = 1, PresencePenalty = 6 s := New(Options{RepeatLastN: 1, PresencePenalty: 6})
s := New(0, 0, 0, 0, 1, 6)
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
@@ -20,11 +19,11 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
s.ResetHistory([]int32{0}) s.ResetHistory([]int32{0})
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1})) s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3) logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logprobs) got := s.Sample(logits).Token
mlx.Eval(got) mlx.Eval(got)
// logprobs will be [0, -1, 4] after the penalty // logits will be [0, -1, 4] after the penalty
// and then (index) 2 after the greedy sampler // and then (index) 2 after the greedy sampler
gotInt := got.Int() gotInt := got.Int()
if gotInt != 2 { if gotInt != 2 {
@@ -32,19 +31,59 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
} }
} }
func TestMinPMasksTokensBelowThreshold(t *testing.T) { func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
s := New(0, 0, 0.5, 0, 0, 0) s := New(Options{RepeatLastN: 1, RepeatPenalty: 2})
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
}() }()
logprobs := mlx.FromValues([]float32{ s.ResetHistory([]int32{1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits).Token
mlx.Eval(got)
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
gotInt := got.Int()
if gotInt != 2 {
t.Fatalf("got %d, want 2", gotInt)
}
}
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2})
defer func() {
s.Free()
mlx.Sweep()
}()
s.ResetHistory([]int32{1, 1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits).Token
mlx.Eval(got)
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
gotInt := got.Int()
if gotInt != 2 {
t.Fatalf("got %d, want 2", gotInt)
}
}
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
s := New(Options{MinP: 0.5})
defer func() {
s.Free()
mlx.Sweep()
}()
logits := mlx.FromValues([]float32{
float32(math.Log(0.5)), float32(math.Log(0.5)),
float32(math.Log(0.3)), float32(math.Log(0.3)),
float32(math.Log(0.2)), float32(math.Log(0.2)),
}, 3) }, 3)
got := minP(s, logprobs) got := minP(s, logits)
mlx.Eval(got) mlx.Eval(got)
gotFloats := got.Floats() gotFloats := got.Floats()

View File

@@ -2,7 +2,6 @@ package mlxrunner
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"encoding/json" "encoding/json"
"flag" "flag"
@@ -87,23 +86,25 @@ func Execute(args []string) error {
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)} request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
slog.Error("Failed to decode request", "error", err) slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
} }
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
request.Pipeline = runner.TextGenerationPipeline request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New( request.Sampler = sample.New(sample.Options{
request.Options.Temperature, Temperature: request.Options.Temperature,
request.Options.TopP, TopP: request.Options.TopP,
request.Options.MinP, MinP: request.Options.MinP,
request.Options.TopK, TopK: request.Options.TopK,
request.Options.RepeatLastN, RepeatLastN: request.Options.RepeatLastN,
request.Options.PresencePenalty, RepeatPenalty: request.Options.RepeatPenalty,
) PresencePenalty: request.Options.PresencePenalty,
FrequencyPenalty: request.Options.FrequencyPenalty,
Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs,
})
var cancel context.CancelFunc var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context()) request.Ctx, cancel = context.WithCancel(r.Context())

View File

@@ -80,7 +80,6 @@ type TextConfig struct {
PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size) PLEProjScale float32 `json:"-"` // 1/sqrt(hidden_size)
PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071... PLECombineScale float32 `json:"-"` // 2^(-0.5) = 0.7071...
RouterScale float32 `json:"-"` // 1/sqrt(hidden_size) RouterScale float32 `json:"-"` // 1/sqrt(hidden_size)
SoftcapInv float32 `json:"-"` // 1/final_logit_softcapping
// KV sharing: maps shared layer index -> donor layer index. // KV sharing: maps shared layer index -> donor layer index.
KVShareMap map[int32]int32 `json:"-"` KVShareMap map[int32]int32 `json:"-"`
@@ -455,9 +454,6 @@ func parseTextConfig(configData []byte) (TextConfig, error) {
cfg.PLECombineScale = float32(math.Pow(2.0, -0.5)) cfg.PLECombineScale = float32(math.Pow(2.0, -0.5))
} }
cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize))) cfg.RouterScale = float32(1.0 / math.Sqrt(float64(cfg.HiddenSize)))
if cfg.FinalLogitSoftcapping > 0 {
cfg.SoftcapInv = 1.0 / cfg.FinalLogitSoftcapping
}
// Compute KV sharing map. // Compute KV sharing map.
cfg.KVShareMap = make(map[int32]int32) cfg.KVShareMap = make(map[int32]int32)
@@ -1065,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
} }
} }
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc) var donorKV *sharedKVEntry
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
// If this layer is a donor, store its cached KV for later shared layers. // If this layer is a donor, store its cached KV for later shared layers.
if layer.IsDonor && c != nil { if layer.IsDonor && donorKV != nil {
state := c.State() sharedKV[layer.LayerIdx] = *donorKV
if len(state) >= 2 && state[0] != nil && state[1] != nil {
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
}
} }
} }
@@ -1114,9 +1108,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
logits := m.LMHead.Forward(x) logits := m.LMHead.Forward(x)
if m.FinalLogitSoftcapping > 0 { if m.FinalLogitSoftcapping > 0 {
logits = mlx.MulScalar(logits, m.SoftcapInv) cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
logits = logits.Tanh() logits = mlx.LogitSoftcap(logits, cap)
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
} }
return logits return logits
@@ -1195,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
return mlx.Squeeze(sliced, 2) return mlx.Squeeze(sliced, 2)
} }
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array { func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps) normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache) attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut) h := mlx.Add(x, attnOut)
@@ -1231,8 +1224,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
// PLE injection (after MLP residual). // PLE injection (after MLP residual).
if l.PLE != nil && pleInput != nil { if l.PLE != nil && pleInput != nil {
residual := h residual := h
gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h)) gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
gated := mlx.Mul(gate, pleInput)
projected := l.PLE.Projection.Forward(gated) projected := l.PLE.Projection.Forward(gated)
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps) projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
h = mlx.Add(residual, projected) h = mlx.Add(residual, projected)
@@ -1243,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
h = mlx.Mul(h, l.LayerScalar) h = mlx.Mul(h, l.LayerScalar)
} }
return h return h, donorKV
} }
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array { func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
// Determine head dim and scale based on layer type. // Determine head dim and scale based on layer type.
headDim := cfg.HeadDim headDim := cfg.HeadDim
scale := cfg.SlidingScale scale := cfg.SlidingScale
@@ -1280,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs) q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
var k, v *mlx.Array var k, v *mlx.Array
var donorKV *sharedKVEntry
if donorEntry != nil { if donorEntry != nil {
// Shared layer: use donor's cached K/V. // Shared layer: use donor's cached K/V.
@@ -1318,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// Update cache. // Update cache.
if c != nil { if c != nil {
k, v = c.Update(k, v) k, v = c.Update(k, v)
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
} }
} }
@@ -1371,13 +1365,13 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// strided views differently. Metal handles them natively. // strided views differently. Metal handles them natively.
out = mlx.Contiguous(out, false) out = mlx.Contiguous(out, false)
} }
return a.OProj.Forward(out) return a.OProj.Forward(out), donorKV
} }
func (m *MLP) Forward(x *mlx.Array) *mlx.Array { func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x)) gate := m.GateProj.Forward(x)
up := m.UpProj.Forward(x) up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up)) return m.DownProj.Forward(mlx.GeGLU(gate, up))
} }
// Forward runs the router to select top-k experts per token. // Forward runs the router to select top-k experts per token.
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp, up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid}, []int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} else { } else {
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases, gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort) nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases, up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort) nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} }
downMode := m.DownQuantMode downMode := m.DownQuantMode
if downMode == "" { if downMode == "" {
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp, up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid}, []int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])}) []int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} else { } else {
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort) gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort) up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.GELUApprox(gate), up) hidden = mlx.GeGLU(gate, up)
} }
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort) down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
} }

View File

@@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) {
gotScores, gotInds := r.Forward(x, cfg) gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg) wantScores, wantInds := legacyRouterForward(r, x, cfg)
gotInds = gotInds.AsType(mlx.DTypeInt32)
wantInds = wantInds.AsType(mlx.DTypeInt32)
mlx.Eval(gotScores, gotInds, wantScores, wantInds) mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) { if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {

View File

@@ -148,9 +148,7 @@ type DenseMLP struct {
// Forward applies the SwiGLU MLP // Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array { func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x)) return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
} }
// MoEGate implements the expert gating mechanism // MoEGate implements the expert gating mechanism
@@ -163,21 +161,21 @@ type MoEGate struct {
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
gates := g.Gate.Forward(x) gates := g.Gate.Forward(x)
scores := mlx.Sigmoid(gates) var origScores, negScores *mlx.Array
origScores := scores
if g.EScoreCorrectionBias != nil { if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias) origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
} else {
origScores = mlx.Sigmoid(gates)
negScores = mlx.Neg(origScores)
} }
topK := cfg.NumExpertsPerTok topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1) inds := mlx.Argpartition(negScores, int(topK)-1, -1)
dims := inds.Dims() dims := inds.Dims()
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK}) inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
scores = mlx.TakeAlongAxis(origScores, inds, -1) scores := mlx.TakeAlongAxis(origScores, inds, -1)
if topK > 1 && cfg.NormTopKProb { if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true) sumScores := mlx.Sum(scores, -1, true)
@@ -242,7 +240,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up) hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
@@ -250,7 +248,7 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort) gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort) up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up) hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort) down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
} }
@@ -273,9 +271,7 @@ type SharedExperts struct {
// Forward applies the shared expert MLP // Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array { func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(s.GateProj.Forward(x)) return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
up := s.UpProj.Forward(x)
return s.DownProj.Forward(mlx.Mul(gate, up))
} }
// MoE implements the full Mixture of Experts layer // MoE implements the full Mixture of Experts layer

View File

@@ -314,5 +314,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
} }
func (m *MLP) Forward(x *mlx.Array) *mlx.Array { func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
} }

View File

@@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4") dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight) mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input) qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
dOut := NewLinear(dequantizedWeight, nil).Forward(input) dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
mlx.Eval(qOut, dOut) mlx.Eval(qOut, dOut)
got := qOut.Floats() got := qOut.Floats()

View File

@@ -333,5 +333,5 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
} }
func (m *MLP) Forward(x *mlx.Array) *mlx.Array { func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
} }

View File

@@ -1253,7 +1253,7 @@ func (g *GatedDeltaNet) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Co
} }
func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array { func (m *DenseMLP) Forward(x *mlx.Array, _ *Config) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x))) return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
} }
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array { func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
@@ -1283,13 +1283,13 @@ func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort) nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases, up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort) nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up) hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases, down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort) nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else { } else {
gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort) gate = mlx.GatherMM(xFlat, s.GateWeight, nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort) up = mlx.GatherMM(xFlat, s.UpWeight, nil, idxFlat, doSort)
hidden = mlx.Mul(mlx.SiLU(gate), up) hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort) down = mlx.GatherMM(hidden, s.DownWeight, nil, idxFlat, doSort)
} }