Compare commits

...

15 Commits

Author SHA1 Message Date
Parth Sareen
5d1021603a server: apply format when think=false for gemma4 (#15678) 2026-04-20 17:42:29 -07:00
Parth Sareen
8e05d734b9 launch: add kimi cli integration with installer flow (#15723) 2026-04-20 15:33:32 -07:00
Jesse Gross
05e0f21bec mlx: fuse sigmoid router head in glm4_moe_lite
DeepSeek-V2-style aux-loss-free routing computes sigmoid(gates) once but
needs it twice: the raw sigmoid output is gathered after top-k, while the
post-bias negation is the argpartition key. Fuse into a single multi-output
Compiled kernel returning both, saving two launches on the routing path
per token. Exposed as a general SigmoidRouter since the same pattern is
shared across DeepSeek-V2 descendants.

Improves glm4.7 generation performance by approximately 1%.
2026-04-20 15:02:14 -07:00
Daniel Hiltgen
ff23dd343f mlx: apply repeat penalties in sampler (#15631) 2026-04-18 07:49:38 -07:00
Parth Sareen
123b300af6 docs: update hermes (#15655) 2026-04-17 14:20:59 -07:00
Parth Sareen
57653b8e42 cmd/launch: show WSL guidance on Windows instead of handing off (#15637) 2026-04-16 17:18:04 -07:00
Parth Sareen
a50ce61c54 launch: skip unchanged managed-single rewrite (#15633) 2026-04-16 16:20:42 -07:00
Daniel Hiltgen
2bb7ea00d2 create: avoid gc race with create (#15628)
If you have a long running create, and start another ollama server with the
same model dir, the GC algorithm deletes the pending blobs and breaks the
create.  This adds a 1h grace period to avoid deleting in-flight creation
operations.
2026-04-16 13:29:16 -07:00
Daniel Hiltgen
55fa80d07a mlx: additional gemma4 cache fixes (#15607)
Harden additional corner cases
2026-04-16 13:07:19 -07:00
Daniel Hiltgen
b9cb535407 mlx: fix gemma4 cache to use logical view (#15617) 2026-04-16 11:54:30 -07:00
Daniel Hiltgen
031baef094 mlx: fix imagegen lookup (#15588)
* mlx: fix imagegen lookup

Fixes #15533 - imagegen had fallen out of sync with the new layout
for multiple mlx libraries on Metal.

* review comments
2026-04-16 10:39:00 -07:00
Mike Wallio
7d271e6dc9 cmd/launch: add Copilot CLI integration (#15583)
---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: ParthSareen <parth.sareen@ollama.com>
2026-04-15 17:22:53 -07:00
Devon Rifkin
c88dae2d6b Merge pull request #15612 from ollama/drifkin/gemma4-split-templates
gemma4: render differently based on model size
2026-04-15 17:15:35 -07:00
Devon Rifkin
9e3618d663 make empty block conditional 2026-04-15 15:35:25 -07:00
Devon Rifkin
e585ecd11f gemma4: render differently based on model size
Following up on #15560, this change now has e2b/e4b render differently
from 26b/31b.

For backwards compatibility, we take the existing renderer name `gemma4`
and make it do dynamic resolution based on the model name/size, but the
intended use is for the models to be republished with the renderer
variant specified explicitly: `gemma4-small` or `gemma4-large`.
2026-04-15 14:37:16 -07:00
44 changed files with 2315 additions and 664 deletions

View File

@@ -55,7 +55,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
ollama
```
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `claude`, `codex`, `openclaw` and more.
You'll be prompted to run a model or connect Ollama to your existing agents or applications such as `Claude Code`, `OpenClaw`, `OpenCode` , `Codex`, `Copilot`, and more.
### Coding
@@ -65,7 +65,7 @@ To launch a specific integration:
ollama launch claude
```
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
Supported integrations include [Claude Code](https://docs.ollama.com/integrations/claude-code), [Codex](https://docs.ollama.com/integrations/codex), [Copilot CLI](https://docs.ollama.com/integrations/copilot-cli), [Droid](https://docs.ollama.com/integrations/droid), and [OpenCode](https://docs.ollama.com/integrations/opencode).
### AI assistant

View File

@@ -61,6 +61,9 @@ func TestLaunchCmd(t *testing.T) {
if !strings.Contains(cmd.Long, "hermes") {
t.Error("Long description should mention hermes")
}
if !strings.Contains(cmd.Long, "kimi") {
t.Error("Long description should mention kimi")
}
})
t.Run("flags exist", func(t *testing.T) {

76
cmd/launch/copilot.go Normal file
View File

@@ -0,0 +1,76 @@
package launch
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/ollama/ollama/envconfig"
)
// Copilot implements Runner for GitHub Copilot CLI integration.
type Copilot struct{}
func (c *Copilot) String() string { return "Copilot CLI" }
func (c *Copilot) args(model string, extra []string) []string {
var args []string
if model != "" {
args = append(args, "--model", model)
}
args = append(args, extra...)
return args
}
func (c *Copilot) findPath() (string, error) {
if p, err := exec.LookPath("copilot"); err == nil {
return p, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(home, ".local", "bin", name)
if _, err := os.Stat(fallback); err != nil {
return "", err
}
return fallback, nil
}
func (c *Copilot) Run(model string, args []string) error {
copilotPath, err := c.findPath()
if err != nil {
return fmt.Errorf("copilot is not installed, install from https://docs.github.com/en/copilot/how-tos/set-up/install-copilot-cli")
}
cmd := exec.Command(copilotPath, c.args(model, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = append(os.Environ(), c.envVars(model)...)
return cmd.Run()
}
// envVars returns the environment variables that configure Copilot CLI
// to use Ollama as its model provider.
func (c *Copilot) envVars(model string) []string {
env := []string{
"COPILOT_PROVIDER_BASE_URL=" + envconfig.Host().String() + "/v1",
"COPILOT_PROVIDER_API_KEY=",
"COPILOT_PROVIDER_WIRE_API=responses",
}
if model != "" {
env = append(env, "COPILOT_MODEL="+model)
}
return env
}

161
cmd/launch/copilot_test.go Normal file
View File

@@ -0,0 +1,161 @@
package launch
import (
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestCopilotIntegration(t *testing.T) {
c := &Copilot{}
t.Run("String", func(t *testing.T) {
if got := c.String(); got != "Copilot CLI" {
t.Errorf("String() = %q, want %q", got, "Copilot CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = c
})
}
func TestCopilotFindPath(t *testing.T) {
c := &Copilot{}
t.Run("finds copilot in PATH", func(t *testing.T) {
tmpDir := t.TempDir()
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fakeBin := filepath.Join(tmpDir, name)
os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
t.Setenv("PATH", tmpDir)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fakeBin {
t.Errorf("findPath() = %q, want %q", got, fakeBin)
}
})
t.Run("returns error when not in PATH", func(t *testing.T) {
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
t.Run("falls back to ~/.local/bin/copilot", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
name := "copilot"
if runtime.GOOS == "windows" {
name = "copilot.exe"
}
fallback := filepath.Join(tmpDir, ".local", "bin", name)
os.MkdirAll(filepath.Dir(fallback), 0o755)
os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
got, err := c.findPath()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != fallback {
t.Errorf("findPath() = %q, want %q", got, fallback)
}
})
t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
t.Setenv("PATH", t.TempDir()) // empty dir, no copilot binary
_, err := c.findPath()
if err == nil {
t.Fatal("expected error, got nil")
}
})
}
func TestCopilotArgs(t *testing.T) {
c := &Copilot{}
tests := []struct {
name string
model string
args []string
want []string
}{
{"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
{"empty model", "", nil, nil},
{"with model and extra", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
{"empty model with help", "", []string{"--help"}, []string{"--help"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := c.args(tt.model, tt.args)
if !slices.Equal(got, tt.want) {
t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
}
})
}
}
func TestCopilotEnvVars(t *testing.T) {
c := &Copilot{}
envMap := func(envs []string) map[string]string {
m := make(map[string]string)
for _, e := range envs {
k, v, _ := strings.Cut(e, "=")
m[k] = v
}
return m
}
t.Run("sets required provider env vars with model", func(t *testing.T) {
got := envMap(c.envVars("llama3.2"))
if got["COPILOT_PROVIDER_BASE_URL"] == "" {
t.Error("COPILOT_PROVIDER_BASE_URL should be set")
}
if !strings.HasSuffix(got["COPILOT_PROVIDER_BASE_URL"], "/v1") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want /v1 suffix", got["COPILOT_PROVIDER_BASE_URL"])
}
if _, ok := got["COPILOT_PROVIDER_API_KEY"]; !ok {
t.Error("COPILOT_PROVIDER_API_KEY should be set (empty)")
}
if got["COPILOT_PROVIDER_WIRE_API"] != "responses" {
t.Errorf("COPILOT_PROVIDER_WIRE_API = %q, want %q", got["COPILOT_PROVIDER_WIRE_API"], "responses")
}
if got["COPILOT_MODEL"] != "llama3.2" {
t.Errorf("COPILOT_MODEL = %q, want %q", got["COPILOT_MODEL"], "llama3.2")
}
})
t.Run("omits COPILOT_MODEL when model is empty", func(t *testing.T) {
got := envMap(c.envVars(""))
if _, ok := got["COPILOT_MODEL"]; ok {
t.Errorf("COPILOT_MODEL should not be set for empty model, got %q", got["COPILOT_MODEL"])
}
})
t.Run("uses custom OLLAMA_HOST", func(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999")
got := envMap(c.envVars("test"))
if !strings.Contains(got["COPILOT_PROVIDER_BASE_URL"], "myhost:9999") {
t.Errorf("COPILOT_PROVIDER_BASE_URL = %q, want custom host", got["COPILOT_PROVIDER_BASE_URL"])
}
})
}

View File

@@ -4,18 +4,15 @@ import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
pathpkg "path"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"time"
"gopkg.in/yaml.v3"
@@ -66,23 +63,13 @@ var hermesMessagingEnvGroups = [][]string{
// switching UX after startup.
type Hermes struct{}
type hermesConfigBackend struct {
displayPath string
read func() ([]byte, error)
write func([]byte) error
}
func (h *Hermes) String() string { return "Hermes Agent" }
func (h *Hermes) Run(_ string, args []string) error {
// Hermes reads its primary model from config.yaml. launch configures that
// default model ahead of time so we can keep runtime invocation simple and
// still let Hermes discover additional models later via its own UX.
if hermesGOOS == "windows" {
return h.runWindows(args)
}
bin, err := h.findUnixBinary()
bin, err := h.binary()
if err != nil {
return err
}
@@ -95,21 +82,21 @@ func (h *Hermes) Run(_ string, args []string) error {
}
func (h *Hermes) Paths() []string {
backend, err := h.configBackend()
configPath, err := hermesConfigPath()
if err != nil {
return nil
}
return []string{backend.displayPath}
return []string{configPath}
}
func (h *Hermes) Configure(model string) error {
backend, err := h.configBackend()
configPath, err := hermesConfigPath()
if err != nil {
return err
}
cfg := map[string]any{}
if data, err := backend.read(); err == nil {
if data, err := os.ReadFile(configPath); err == nil {
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse hermes config: %w", err)
}
@@ -142,15 +129,18 @@ func (h *Hermes) Configure(model string) error {
if err != nil {
return err
}
return backend.write(data)
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
}
func (h *Hermes) CurrentModel() string {
backend, err := h.configBackend()
configPath, err := hermesConfigPath()
if err != nil {
return ""
}
data, err := backend.read()
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
@@ -188,14 +178,7 @@ func (h *Hermes) RefreshRuntimeAfterConfigure() error {
}
func (h *Hermes) installed() bool {
if hermesGOOS == "windows" {
if _, err := hermesLookPath("hermes"); err == nil {
return true
}
return h.wslHasHermes()
}
_, err := h.findUnixBinary()
_, err := h.binary()
return err == nil
}
@@ -205,7 +188,7 @@ func (h *Hermes) ensureInstalled() error {
}
if hermesGOOS == "windows" {
return h.ensureInstalledWindows()
return hermesWindowsHint()
}
var missing []string
@@ -239,42 +222,6 @@ func (h *Hermes) ensureInstalled() error {
return nil
}
func (h *Hermes) ensureInstalledWindows() error {
// Hermes upstream support is WSL-oriented, so Windows launch uses a hybrid
// WSL handoff that stays on the same install path as upstream Hermes.
if _, err := hermesLookPath("hermes"); err == nil {
return nil
}
if !h.wslAvailable() {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
if h.wslHasHermes() {
return nil
}
ok, err := ConfirmPromptWithOptions("Hermes runs through WSL2 on Windows. Install it in WSL now?", ConfirmOptions{
YesLabel: "Use WSL",
NoLabel: "Show manual steps",
})
if err != nil {
return err
}
if !ok {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
fmt.Fprintf(os.Stderr, "\nInstalling Hermes in WSL...\n")
if err := h.runWSL("bash", "-lc", hermesInstallScript); err != nil {
return hermesWindowsHint(fmt.Errorf("failed to install hermes in WSL: %w", err))
}
if !h.wslHasHermes() {
return hermesWindowsHint(fmt.Errorf("hermes install finished but the WSL binary was not found"))
}
fmt.Fprintf(os.Stderr, "%sHermes installed successfully in WSL%s\n\n", ansiGreen, ansiReset)
return nil
}
func (h *Hermes) listModels(defaultModel string) []string {
client := hermesOllamaClient()
resp, err := client.List(context.Background())
@@ -306,11 +253,15 @@ func (h *Hermes) listModels(defaultModel string) []string {
return models
}
func (h *Hermes) findUnixBinary() (string, error) {
func (h *Hermes) binary() (string, error) {
if path, err := hermesLookPath("hermes"); err == nil {
return path, nil
}
if hermesGOOS == "windows" {
return "", hermesWindowsHint()
}
home, err := hermesUserHome()
if err != nil {
return "", err
@@ -323,70 +274,6 @@ func (h *Hermes) findUnixBinary() (string, error) {
return "", fmt.Errorf("hermes is not installed")
}
func (h *Hermes) runWindows(args []string) error {
if path, err := hermesLookPath("hermes"); err == nil {
if err := h.runGatewaySetupPreflight(args, func() error {
return hermesAttachedCommand(path, "gateway", "setup").Run()
}); err != nil {
return err
}
return hermesAttachedCommand(path, args...).Run()
}
if !h.wslAvailable() {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
if err := h.runGatewaySetupPreflight(args, func() error {
return h.runWSL("hermes", "gateway", "setup")
}); err != nil {
return err
}
if err := h.runWSL(append([]string{"hermes"}, args...)...); err != nil {
return hermesWindowsHint(err)
}
return nil
}
func (h *Hermes) runWSL(args ...string) error {
if !h.wslAvailable() {
return fmt.Errorf("wsl.exe is not available")
}
return hermesAttachedCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).Run()
}
func (h *Hermes) runWSLCombinedOutput(args ...string) ([]byte, error) {
if !h.wslAvailable() {
return nil, fmt.Errorf("wsl.exe is not available")
}
return hermesCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).CombinedOutput()
}
func (h *Hermes) wslAvailable() bool {
_, err := hermesLookPath("wsl.exe")
return err == nil
}
func (h *Hermes) wslHasHermes() bool {
if !h.wslAvailable() {
return false
}
cmd := hermesCommand("wsl.exe", "bash", "-lc", "command -v hermes >/dev/null 2>&1")
return cmd.Run() == nil
}
func (h *Hermes) configBackend() (*hermesConfigBackend, error) {
if hermesGOOS == "windows" {
if _, err := hermesLookPath("hermes"); err == nil {
return hermesLocalConfigBackend()
}
if h.wslAvailable() {
return h.wslConfigBackend()
}
}
return hermesLocalConfigBackend()
}
func hermesConfigPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
@@ -395,110 +282,6 @@ func hermesConfigPath() (string, error) {
return filepath.Join(home, ".hermes", "config.yaml"), nil
}
func hermesLocalConfigBackend() (*hermesConfigBackend, error) {
configPath, err := hermesConfigPath()
if err != nil {
return nil, err
}
return &hermesConfigBackend{
displayPath: configPath,
read: func() ([]byte, error) {
return os.ReadFile(configPath)
},
write: func(data []byte) error {
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
},
}, nil
}
func (h *Hermes) wslConfigBackend() (*hermesConfigBackend, error) {
home, err := h.wslHome()
if err != nil {
return nil, err
}
configPath := pathpkg.Join(home, ".hermes", "config.yaml")
return &hermesConfigBackend{
displayPath: configPath,
read: func() ([]byte, error) {
return h.readWSLFile(configPath)
},
write: func(data []byte) error {
return h.writeWSLConfig(configPath, data)
},
}, nil
}
func (h *Hermes) wslHome() (string, error) {
if !h.wslAvailable() {
return "", fmt.Errorf("wsl.exe is not available")
}
cmd := hermesCommand("wsl.exe", "bash", "-lc", `printf %s "$HOME"`)
out, err := cmd.Output()
if err != nil {
return "", err
}
home := strings.TrimSpace(string(out))
if home == "" {
return "", fmt.Errorf("could not resolve WSL home directory")
}
return home, nil
}
func (h *Hermes) readWSLFile(path string) ([]byte, error) {
pathArg := shellQuoteArgs([]string{path})
cmd := hermesCommand("wsl.exe", "bash", "-lc", fmt.Sprintf("if [ -f %s ]; then cat %s; else exit 42; fi", pathArg, pathArg))
out, err := cmd.Output()
if err == nil {
return out, nil
}
var exitErr *exec.ExitError
if errors.As(err, &exitErr) && exitErr.ExitCode() == 42 {
return nil, os.ErrNotExist
}
return nil, err
}
func (h *Hermes) writeWSLConfig(path string, data []byte) error {
if existing, err := h.readWSLFile(path); err == nil {
if !bytes.Equal(existing, data) {
if err := hermesBackupData(path, existing); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := pathpkg.Dir(path)
dirArg := shellQuoteArgs([]string{dir})
pathArg := shellQuoteArgs([]string{path})
script := fmt.Sprintf(
"dir=%s; path=%s; mkdir -p \"$dir\" && tmp=$(mktemp \"$dir/.tmp-XXXXXX\") && cat > \"$tmp\" && mv \"$tmp\" \"$path\"",
dirArg,
pathArg,
)
cmd := hermesCommand("wsl.exe", "bash", "-lc", script)
cmd.Stdin = bytes.NewReader(data)
if out, err := cmd.CombinedOutput(); err != nil {
if msg := strings.TrimSpace(string(out)); msg != "" {
return fmt.Errorf("%w: %s", err, msg)
}
return err
}
return nil
}
func hermesBackupData(path string, data []byte) error {
if err := os.MkdirAll(fileutil.BackupDir(), 0o755); err != nil {
return err
}
backupPath := filepath.Join(fileutil.BackupDir(), fmt.Sprintf("%s.%d", filepath.Base(path), time.Now().Unix()))
return os.WriteFile(backupPath, data, 0o644)
}
func hermesBaseURL() string {
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
}
@@ -554,8 +337,11 @@ func (h *Hermes) messagingConfigured() bool {
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
envVars := make(map[string]string)
data, err := h.readGatewayEnvFile()
switch {
envFilePath, err := hermesEnvPath()
if err != nil {
return nil, err
}
switch data, err := os.ReadFile(envFilePath); {
case err == nil:
for key, value := range hermesParseEnvFile(data) {
envVars[key] = value
@@ -566,12 +352,10 @@ func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
return nil, err
}
if h.usesLocalRuntimeEnv() {
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
}
}
@@ -579,39 +363,6 @@ func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
return envVars, nil
}
func (h *Hermes) readGatewayEnvFile() ([]byte, error) {
if hermesGOOS == "windows" {
if _, err := hermesLookPath("hermes"); err == nil {
path, err := hermesEnvPath()
if err != nil {
return nil, err
}
return os.ReadFile(path)
}
if h.wslAvailable() {
home, err := h.wslHome()
if err != nil {
return nil, err
}
return h.readWSLFile(pathpkg.Join(home, ".hermes", ".env"))
}
}
path, err := hermesEnvPath()
if err != nil {
return nil, err
}
return os.ReadFile(path)
}
func (h *Hermes) usesLocalRuntimeEnv() bool {
if hermesGOOS != "windows" {
return true
}
_, err := hermesLookPath("hermes")
return err == nil
}
func (h *Hermes) gatewayRunning() (bool, error) {
status, err := h.gatewayStatusOutput()
if err != nil {
@@ -621,19 +372,7 @@ func (h *Hermes) gatewayRunning() (bool, error) {
}
func (h *Hermes) gatewayStatusOutput() (string, error) {
if hermesGOOS == "windows" {
if path, err := hermesLookPath("hermes"); err == nil {
out, err := hermesCommand(path, "gateway", "status").CombinedOutput()
return string(out), err
}
if !h.wslAvailable() {
return "", hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
out, err := h.runWSLCombinedOutput("hermes", "gateway", "status")
return string(out), err
}
bin, err := h.findUnixBinary()
bin, err := h.binary()
if err != nil {
return "", err
}
@@ -642,20 +381,7 @@ func (h *Hermes) gatewayStatusOutput() (string, error) {
}
func (h *Hermes) restartGateway() error {
if hermesGOOS == "windows" {
if path, err := hermesLookPath("hermes"); err == nil {
return hermesAttachedCommand(path, "gateway", "restart").Run()
}
if !h.wslAvailable() {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
if err := h.runWSL("hermes", "gateway", "restart"); err != nil {
return hermesWindowsHint(err)
}
return nil
}
bin, err := h.findUnixBinary()
bin, err := h.binary()
if err != nil {
return err
}
@@ -938,14 +664,6 @@ func mergeHermesToolsets(current any) any {
}
}
func shellQuoteArgs(args []string) string {
quoted := make([]string, 0, len(args))
for _, arg := range args {
quoted = append(quoted, "'"+strings.ReplaceAll(arg, "'", `'\''`)+"'")
}
return strings.Join(quoted, " ")
}
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
cmd := hermesCommand(name, args...)
cmd.Stdin = os.Stdin
@@ -954,9 +672,8 @@ func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
return cmd
}
func hermesWindowsHint(err error) error {
if hermesGOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\nHermes runs on Windows through WSL2.\nQuick setup: wsl --install\nInstaller docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/", err)
func hermesWindowsHint() error {
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
}

View File

@@ -896,64 +896,6 @@ fi
}
}
func TestHermesRefreshRuntimeAfterConfigure_WindowsWSLRestartsRunningGateway(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell test binaries to simulate WSL")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withHermesPlatform(t, "windows")
t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH"))
wslPath := filepath.Join(tmpDir, "wsl.exe")
wslScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/wsl-invocations.log"
exec /bin/sh -lc "$3"
`
if err := os.WriteFile(wslPath, []byte(wslScript), 0o755); err != nil {
t.Fatal(err)
}
hermesBin := filepath.Join(tmpDir, "hermes")
hermesScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/hermes-invocations.log"
if [ "$1" = "gateway" ] && [ "$2" = "status" ]; then
printf '✓ Gateway is running (PID: 321)\n'
fi
`
if err := os.WriteFile(hermesBin, []byte(hermesScript), 0o755); err != nil {
t.Fatal(err)
}
withHermesLookPath(t, func(file string) (string, error) {
if file == "wsl.exe" {
return wslPath, nil
}
return "", os.ErrNotExist
})
h := &Hermes{}
if err := h.RefreshRuntimeAfterConfigure(); err != nil {
t.Fatalf("RefreshRuntimeAfterConfigure returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "hermes-invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 2 {
t.Fatalf("expected WSL status then restart invocations, got %v", lines)
}
if lines[0] != "[gateway status]" {
t.Fatalf("expected WSL gateway status first, got %q", lines[0])
}
if lines[1] != "[gateway restart]" {
t.Fatalf("expected WSL gateway restart second, got %q", lines[1])
}
}
func TestHermesMessagingConfiguredRecognizesSupportedGatewayVars(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
@@ -1002,82 +944,7 @@ func TestHermesMessagingConfiguredRecognizesSupportedGatewayVars(t *testing.T) {
}
}
func TestHermesRunWindowsWSL_UsesGatewaySetupPreflight(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell test binaries to simulate WSL")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, true)
withHermesPlatform(t, "windows")
clearHermesMessagingEnvVars(t)
t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH"))
wslPath := filepath.Join(tmpDir, "wsl.exe")
wslScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/wsl-invocations.log"
exec /bin/sh -lc "$3"
`
if err := os.WriteFile(wslPath, []byte(wslScript), 0o755); err != nil {
t.Fatal(err)
}
hermesBin := filepath.Join(tmpDir, "hermes")
hermesScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/hermes-invocations.log"
if [ "$1" = "gateway" ] && [ "$2" = "setup" ]; then
/bin/mkdir -p "$HOME/.hermes"
printf 'TELEGRAM_BOT_TOKEN=configured\n' > "$HOME/.hermes/.env"
fi
`
if err := os.WriteFile(hermesBin, []byte(hermesScript), 0o755); err != nil {
t.Fatal(err)
}
withHermesLookPath(t, func(file string) (string, error) {
if file == "wsl.exe" {
return wslPath, nil
}
return "", os.ErrNotExist
})
promptCount := 0
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
promptCount++
if prompt != hermesGatewaySetupTitle {
t.Fatalf("unexpected prompt %q", prompt)
}
return true, nil
}
h := &Hermes{}
if err := h.Run("", nil); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if promptCount != 1 {
t.Fatalf("expected one messaging prompt, got %d", promptCount)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "hermes-invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 2 {
t.Fatalf("expected WSL hermes to run setup then launch, got %v", lines)
}
if lines[0] != "[gateway setup]" {
t.Fatalf("expected WSL gateway setup first, got %q", lines[0])
}
if lines[1] != "[]" {
t.Fatalf("expected WSL default hermes launch second, got %q", lines[1])
}
}
func TestHermesEnsureInstalledWindowsWithoutWSLGivesGuidance(t *testing.T) {
func TestHermesEnsureInstalledWindowsShowsWSLGuidance(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withHermesPlatform(t, "windows")
@@ -1086,10 +953,17 @@ func TestHermesEnsureInstalledWindowsWithoutWSLGivesGuidance(t *testing.T) {
h := &Hermes{}
err := h.ensureInstalled()
if err == nil {
t.Fatal("expected missing WSL guidance error")
t.Fatal("expected WSL guidance error")
}
if !strings.Contains(err.Error(), "wsl --install") {
t.Fatalf("expected WSL guidance, got %v", err)
msg := err.Error()
if !strings.Contains(msg, "wsl --install") {
t.Fatalf("expected install command in guidance, got %v", err)
}
if !strings.Contains(msg, "hermes-agent.nousresearch.com") {
t.Fatalf("expected docs link in guidance, got %v", err)
}
if strings.Contains(msg, "hermes is not installed") {
t.Fatalf("guidance should not lead with 'hermes is not installed', got %v", err)
}
}

View File

@@ -54,6 +54,7 @@ func TestIntegrationLookup(t *testing.T) {
{"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"},
{"kimi", "kimi", true, "Kimi Code CLI"},
{"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""},
@@ -74,7 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode", "hermes"}
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
@@ -89,6 +90,15 @@ func TestIntegrationRegistry(t *testing.T) {
}
}
func TestHiddenIntegrationsExcludedFromVisibleLists(t *testing.T) {
for _, info := range ListIntegrationInfos() {
switch info.Name {
case "cline", "vscode", "kimi":
t.Fatalf("hidden integration %q should not appear in ListIntegrationInfos", info.Name)
}
}
}
func TestHasLocalModel(t *testing.T) {
tests := []struct {
name string

315
cmd/launch/kimi.go Normal file
View File

@@ -0,0 +1,315 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Kimi implements Runner for Kimi Code CLI integration.
type Kimi struct{}
const (
kimiDefaultModelAlias = "ollama"
kimiDefaultMaxContextSize = 32768
)
var (
kimiGOOS = runtime.GOOS
kimiModelShowTimeout = 5 * time.Second
)
func (k *Kimi) String() string { return "Kimi Code CLI" }
func (k *Kimi) args(config string, extra []string) []string {
args := []string{"--config", config}
args = append(args, extra...)
return args
}
func (k *Kimi) Run(model string, args []string) error {
if strings.TrimSpace(model) == "" {
return fmt.Errorf("model is required")
}
if err := validateKimiPassthroughArgs(args); err != nil {
return err
}
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
if err != nil {
return fmt.Errorf("failed to build kimi config: %w", err)
}
bin, err := ensureKimiInstalled()
if err != nil {
return err
}
cmd := exec.Command(bin, k.args(config, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func findKimiBinary() (string, error) {
if path, err := exec.LookPath("kimi"); err == nil {
return path, nil
}
home, _ := os.UserHomeDir()
var candidates []string
switch kimiGOOS {
case "windows":
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
default:
candidates = append(candidates,
filepath.Join(home, ".local", "bin", "kimi"),
filepath.Join(home, "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
)
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
candidates = append(candidates,
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
)
}
// WSL users can inherit Windows env vars while launching from Linux shells.
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
}
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
}
for _, candidate := range candidates {
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate, nil
}
}
return "", fmt.Errorf("kimi binary not found")
}
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
if strings.TrimSpace(dir) == "" {
return candidates
}
return append(candidates,
filepath.Join(dir, "kimi.exe"),
filepath.Join(dir, "kimi.cmd"),
filepath.Join(dir, "kimi.bat"),
)
}
func windowsPathToWSL(path string) string {
trimmed := strings.TrimSpace(path)
if len(trimmed) < 3 || trimmed[1] != ':' {
return ""
}
drive := strings.ToLower(string(trimmed[0]))
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
rest = strings.TrimPrefix(rest, "/")
if rest == "" {
return filepath.Join("/mnt", drive)
}
return filepath.Join("/mnt", drive, rest)
}
func validateKimiPassthroughArgs(args []string) error {
for _, arg := range args {
switch {
case arg == "--config", strings.HasPrefix(arg, "--config="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
case arg == "--model", strings.HasPrefix(arg, "--model="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
case arg == "-m", strings.HasPrefix(arg, "-m="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
}
}
return nil
}
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
cfg := map[string]any{
"default_model": kimiDefaultModelAlias,
"providers": map[string]any{
kimiDefaultModelAlias: map[string]any{
"type": "openai_legacy",
"base_url": envconfig.ConnectableHost().String() + "/v1",
"api_key": "ollama",
},
},
"models": map[string]any{
kimiDefaultModelAlias: map[string]any{
"provider": kimiDefaultModelAlias,
"model": model,
"max_context_size": maxContextSize,
},
},
}
data, err := json.Marshal(cfg)
if err != nil {
return "", err
}
return string(data), nil
}
func resolveKimiMaxContextSize(model string) int {
if l, ok := lookupCloudModelLimit(model); ok {
return l.Context
}
client, err := api.ClientFromEnvironment()
if err != nil {
return kimiDefaultMaxContextSize
}
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
defer cancel()
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
return kimiDefaultMaxContextSize
}
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
return n
}
return kimiDefaultMaxContextSize
}
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
for key, val := range modelInfo {
if !strings.HasSuffix(key, ".context_length") {
continue
}
switch v := val.(type) {
case float64:
if v > 0 {
return int(v), true
}
case int:
if v > 0 {
return v, true
}
case int64:
if v > 0 {
return int(v), true
}
}
}
return 0, false
}
func ensureKimiInstalled() (string, error) {
if path, err := findKimiBinary(); err == nil {
return path, nil
}
if err := checkKimiInstallerDependencies(); err != nil {
return "", err
}
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("kimi installation cancelled")
}
bin, args, err := kimiInstallerCommand(kimiGOOS)
if err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install kimi: %w", err)
}
path, err := findKimiBinary()
if err != nil {
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
return path, nil
}
func checkKimiInstallerDependencies() error {
switch kimiGOOS {
case "windows":
if _, err := exec.LookPath("powershell"); err != nil {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
}
default:
var missing []string
if _, err := exec.LookPath("curl"); err != nil {
missing = append(missing, "curl: https://curl.se/")
}
if _, err := exec.LookPath("bash"); err != nil {
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
}
if len(missing) > 0 {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
}
}
return nil
}
func kimiInstallerCommand(goos string) (string, []string, error) {
switch goos {
case "windows":
return "powershell", []string{
"-NoProfile",
"-ExecutionPolicy",
"Bypass",
"-Command",
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
}, nil
case "darwin", "linux":
return "bash", []string{
"-c",
"curl -LsSf https://code.kimi.com/install.sh | bash",
}, nil
default:
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
}
}

636
cmd/launch/kimi_test.go Normal file
View File

@@ -0,0 +1,636 @@
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func assertKimiBinPath(t *testing.T, bin string) {
t.Helper()
base := strings.ToLower(filepath.Base(bin))
if !strings.HasPrefix(base, "kimi") {
t.Fatalf("bin = %q, want path to kimi executable", bin)
}
}
func TestKimiIntegration(t *testing.T) {
k := &Kimi{}
t.Run("String", func(t *testing.T) {
if got := k.String(); got != "Kimi Code CLI" {
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = k
})
}
func TestKimiArgs(t *testing.T) {
k := &Kimi{}
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
if !slices.Equal(got, want) {
t.Fatalf("args() = %v, want %v", got, want)
}
}
func TestWindowsPathToWSL(t *testing.T) {
tests := []struct {
name string
in string
want string
valid bool
}{
{
name: "user profile path",
in: `C:\Users\parth`,
want: filepath.Join("/mnt", "c", "Users", "parth"),
valid: true,
},
{
name: "path with trailing slash",
in: `D:\tools\bin\`,
want: filepath.Join("/mnt", "d", "tools", "bin"),
valid: true,
},
{
name: "non windows path",
in: "/home/parth",
valid: false,
},
{
name: "empty",
in: "",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := windowsPathToWSL(tt.in)
if !tt.valid {
if got != "" {
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
}
return
}
if got != tt.want {
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestFindKimiBinaryFallbacks(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
homeDir := t.TempDir()
setTestHome(t, homeDir)
t.Setenv("PATH", t.TempDir())
kimiGOOS = "linux"
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
t.Run("windows appdata uv bin", func(t *testing.T) {
setTestHome(t, t.TempDir())
t.Setenv("PATH", t.TempDir())
kimiGOOS = "windows"
appDataDir := t.TempDir()
t.Setenv("APPDATA", appDataDir)
t.Setenv("LOCALAPPDATA", "")
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
}
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
tests := []struct {
name string
args []string
want string
}{
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateKimiPassthroughArgs(tt.args)
if err == nil {
t.Fatalf("expected error for args %v", tt.args)
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
}
})
}
}
func TestBuildKimiInlineConfig(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
if parsed["default_model"] != "ollama" {
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
}
if ollamaProvider["api_key"] != "ollama" {
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
}
models, ok := parsed["models"].(map[string]any)
if !ok {
t.Fatalf("models missing or wrong type: %T", parsed["models"])
}
ollamaModel, ok := models["ollama"].(map[string]any)
if !ok {
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
}
if ollamaModel["provider"] != "ollama" {
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
}
if ollamaModel["model"] != "llama3.2" {
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
}
if ollamaModel["max_context_size"] != float64(65536) {
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
}
}
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
}
}
func TestResolveKimiMaxContextSize(t *testing.T) {
t.Run("uses cloud limit when known", func(t *testing.T) {
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
if got != 262_144 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
}
})
t.Run("uses model show context length for local models", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
http.NotFound(w, r)
return
}
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
got := resolveKimiMaxContextSize("llama3.2")
if got != 131_072 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
}
})
t.Run("falls back to default when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
oldTimeout := kimiModelShowTimeout
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
got := resolveKimiMaxContextSize("llama3.2")
if got != kimiDefaultMaxContextSize {
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
}
})
}
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
k := &Kimi{}
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect install prompt, got %q", prompt)
return false, nil
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
err := k.Run("llama3.2", []string{"--model", "other"})
if err == nil || !strings.Contains(err.Error(), "--model") {
t.Fatalf("expected conflict error mentioning --model, got %v", err)
}
}
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
logPath := filepath.Join(tmpDir, "kimi-args.log")
script := fmt.Sprintf(`#!/bin/sh
for arg in "$@"; do
printf "%%s\n" "$arg" >> %q
done
exit 0
`, logPath)
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake kimi: %v", err)
}
t.Setenv("PATH", tmpDir)
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
k := &Kimi{}
if err := k.Run("llama3.2", []string{"--quiet", "--print"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read args log: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) < 4 {
t.Fatalf("expected at least 4 args, got %v", lines)
}
if lines[0] != "--config" {
t.Fatalf("first arg = %q, want --config", lines[0])
}
var cfg map[string]any
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
t.Fatalf("config arg is not valid JSON: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollamaProvider := providers["ollama"].(map[string]any)
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if lines[2] != "--quiet" || lines[3] != "--print" {
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
}
}
func TestEnsureKimiInstalled(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return fn(prompt)
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
t.Run("already installed", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "kimi")
kimiGOOS = runtime.GOOS
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
})
t.Run("missing dependencies", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
t.Fatalf("expected missing dependency error, got %v", err)
}
})
t.Run("missing and user declines install", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "curl")
writeFakeBinary(t, tmpDir, "bash")
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
if !strings.Contains(prompt, "Kimi is not installed.") {
t.Fatalf("unexpected prompt: %q", prompt)
}
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
t.Fatalf("expected cancellation error, got %v", err)
}
})
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
installLog := filepath.Join(tmpDir, "bash.log")
kimiPath := filepath.Join(tmpDir, "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "-c" ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, installLog, kimiPath, kimiPath)
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
logData, err := os.ReadFile(installLog)
if err != nil {
t.Fatalf("failed to read install log: %v", err)
}
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
}
})
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
homeDir := t.TempDir()
setTestHome(t, homeDir)
tmpBin := t.TempDir()
t.Setenv("PATH", tmpBin)
kimiGOOS = "linux"
writeFakeBinary(t, tmpBin, "curl")
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
if [ "$1" = "-c" ]; then
/bin/mkdir -p %q
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
if bin != installedKimi {
t.Fatalf("bin = %q, want %q", bin, installedKimi)
}
})
t.Run("install command fails", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
t.Fatalf("expected install failure error, got %v", err)
}
})
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
t.Fatalf("expected PATH guidance error, got %v", err)
}
})
}
func TestKimiInstallerCommand(t *testing.T) {
tests := []struct {
name string
goos string
wantBin string
wantParts []string
wantErr bool
}{
{
name: "linux",
goos: "linux",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "darwin",
goos: "darwin",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "windows",
goos: "windows",
wantBin: "powershell",
wantParts: []string{"-Command", "install.ps1"},
},
{
name: "unsupported",
goos: "freebsd",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bin, args, err := kimiInstallerCommand(tt.goos)
if tt.wantErr {
if err == nil {
t.Fatal("expected error")
}
return
}
if err != nil {
t.Fatalf("kimiInstallerCommand() error = %v", err)
}
if bin != tt.wantBin {
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
}
joined := strings.Join(args, " ")
for _, part := range tt.wantParts {
if !strings.Contains(joined, part) {
t.Fatalf("args %q missing %q", joined, part)
}
}
})
}
}

View File

@@ -206,8 +206,10 @@ Supported integrations:
claude Claude Code
cline Cline
codex Codex
copilot Copilot CLI (aliases: copilot-cli)
droid Droid
hermes Hermes Agent
kimi Kimi Code CLI
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
@@ -586,7 +588,7 @@ func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, nam
return nil
}
if current == "" || needsConfigure || req.ModelOverride != "" || target != current {
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
return err
}

View File

@@ -409,7 +409,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *te
}
}
func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsingSavedModel(t *testing.T) {
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
@@ -436,29 +436,30 @@ func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsing
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model is reused for repair")
t.Fatal("selector should not be called when saved model matches target")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
t.Fatal("confirm prompt should not run when saved model matches target")
return false, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("expected missing live config to be rewritten from saved model: %s", diff)
if len(runner.configured) != 0 {
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected repaired config to refresh runtime once, got %d", runner.refreshCalls)
if runner.refreshCalls != 0 {
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to use repaired saved model, got %q", runner.ranModel)
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLiveConfig(t *testing.T) {
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
@@ -466,6 +467,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
@@ -475,7 +478,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
@@ -483,7 +486,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model is reused for repair")
t.Fatal("selector should not be called when model override is provided")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@@ -492,19 +495,19 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ConfigureOnly: true,
ModelOverride: "gemma4",
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("expected configure-only flow to rewrite missing live config: %s", diff)
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected configure-only repair to refresh runtime once, got %d", runner.refreshCalls)
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
}
if runner.ranModel != "" {
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
}

View File

@@ -33,7 +33,7 @@ type IntegrationInfo struct {
Description string
}
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "droid", "pi"}
var launcherIntegrationOrder = []string{"openclaw", "claude", "opencode", "hermes", "codex", "copilot", "droid", "pi"}
var integrationSpecs = []*IntegrationSpec{
{
@@ -74,6 +74,36 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "kimi",
Runner: &Kimi{},
Description: "Moonshot's coding agent for terminal and IDEs",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("kimi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensureKimiInstalled()
return err
},
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
},
},
{
Name: "copilot",
Runner: &Copilot{},
Aliases: []string{"copilot-cli"},
Description: "GitHub's AI coding agent for the terminal",
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := (&Copilot{}).findPath()
return err == nil
},
URL: "https://github.com/features/copilot/cli/",
},
},
{
Name: "droid",
Runner: &Droid{},

View File

@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
{
name: "kimi",
binary: "kimi",
runner: &Kimi{},
checkPath: func(home string) string {
return filepath.Join(home, ".kimi", "config.toml")
},
},
}
for _, tt := range tests {
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
if tt.name == "pi" {
writeFakeBinary(t, binDir, "npm")
}
if tt.name == "kimi" {
writeFakeBinary(t, binDir, "curl")
writeFakeBinary(t, binDir, "bash")
}
t.Setenv("PATH", binDir)
configPath := tt.checkPath(home)

View File

@@ -120,6 +120,7 @@
"pages": [
"/integrations/claude-code",
"/integrations/codex",
"/integrations/copilot-cli",
"/integrations/opencode",
"/integrations/droid",
"/integrations/goose",

BIN
docs/images/hermes.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@@ -0,0 +1,93 @@
---
title: Copilot CLI
---
GitHub Copilot CLI is GitHub's AI coding agent for the terminal. It can understand your codebase, make edits, run commands, and help you build software faster.
Open models can be used with Copilot CLI through Ollama, enabling you to use models such as `qwen3.5`, `glm-5.1:cloud`, `kimi-k2.5:cloud`.
## Install
Install [Copilot CLI](https://github.com/features/copilot/cli/):
<CodeGroup>
```shell macOS / Linux (Homebrew)
brew install copilot-cli
```
```shell npm (all platforms)
npm install -g @github/copilot
```
```shell macOS / Linux (script)
curl -fsSL https://gh.io/copilot-install | bash
```
```powershell Windows (WinGet)
winget install GitHub.Copilot
```
</CodeGroup>
## Usage with Ollama
### Quick setup
```shell
ollama launch copilot
```
### Run directly with a model
```shell
ollama launch copilot --model kimi-k2.5:cloud
```
## Recommended Models
- `kimi-k2.5:cloud`
- `glm-5:cloud`
- `minimax-m2.7:cloud`
- `qwen3.5:cloud`
- `glm-4.7-flash`
- `qwen3.5`
Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
## Non-interactive (headless) mode
Run Copilot CLI without interaction for use in Docker, CI/CD, or scripts:
```shell
ollama launch copilot --model kimi-k2.5:cloud --yes -- -p "how does this repository work?"
```
The `--yes` flag auto-pulls the model, skips selectors, and requires `--model` to be specified. Arguments after `--` are passed directly to Copilot CLI.
## Manual setup
Copilot CLI connects to Ollama using the OpenAI-compatible API via environment variables.
1. Set the environment variables:
```shell
export COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1
export COPILOT_PROVIDER_API_KEY=
export COPILOT_PROVIDER_WIRE_API=responses
export COPILOT_MODEL=qwen3.5
```
1. Run Copilot CLI:
```shell
copilot
```
Or run with environment variables inline:
```shell
COPILOT_PROVIDER_BASE_URL=http://localhost:11434/v1 COPILOT_PROVIDER_API_KEY= COPILOT_PROVIDER_WIRE_API=responses COPILOT_MODEL=glm-5:cloud copilot
```
**Note:** Copilot requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.

View File

@@ -2,7 +2,9 @@
title: Hermes Agent
---
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway.
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
![Hermes Agent with Ollama](/images/hermes.png)
## Quick start
@@ -10,25 +12,56 @@ Hermes Agent is a self-improving AI agent built by Nous Research. It features au
ollama launch hermes
```
### Pull a model
Ollama handles everything automatically:
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
2. **Model** — Pick a model from the selector (local or cloud)
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `glm-5.1:cloud` — Reasoning and code generation
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
```bash
ollama pull kimi-k2.5:cloud
hermes gateway setup
```
See [Recommended models](#recommended-models) for more options.
## Reconfigure
### Install
Re-run the full setup wizard at any time:
```bash
hermes setup
```
## Manual setup
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
```bash
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
```
### Set up
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
Hermes launches the setup wizard automatically. Choose **Quick setup**:
```
How would you like to set up Hermes?
@@ -84,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
Launch hermes chat now? [Y/n]: Y
```
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `glm-5.1:cloud` — Reasoning and code generation
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
More models at [ollama.com/search](https://ollama.com/models).
## Configure later
Re-run the setup wizard at any time:
```bash
hermes setup
```
To configure just messaging:
```bash
hermes setup gateway
```

View File

@@ -10,6 +10,7 @@ Coding assistants that can read, modify, and execute code in your projects.
- [Claude Code](/integrations/claude-code)
- [Codex](/integrations/codex)
- [Copilot CLI](/integrations/copilot-cli)
- [OpenCode](/integrations/opencode)
- [Droid](/integrations/droid)
- [Goose](/integrations/goose)

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"os"
"time"
)
type Layer struct {
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
return Layer{}, err
}
}
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{
MediaType: mediatype,
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
if err != nil {
return Layer{}, err
}
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{
MediaType: mediatype,
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
}, nil
}
func touchLayer(path string) error {
now := time.Now()
return os.Chtimes(path, now, now)
}
func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" {
return nil, errors.New("opening layer with empty digest")

View File

@@ -12,7 +12,8 @@ import (
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
// <|tool_call>/<|tool_response> tags for function calling.
type Gemma4Renderer struct {
useImgTags bool
useImgTags bool
emptyBlockOnNothink bool
}
const (
@@ -124,6 +125,9 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
// Generation prompt.
if prevMessageType != "tool_response" && prevMessageType != "tool_call" {
sb.WriteString("<|turn>model\n")
if r.emptyBlockOnNothink && !hasThink {
sb.WriteString("<|channel>thought\n<channel|>")
}
}
return sb.String(), nil

View File

@@ -3,9 +3,9 @@ package renderers
// TestGemma4RendererMatchesReference verifies our renderer matches the checked-in
// Gemma 4 reference template.
//
// Current upstream Gemma 4 chat templates differ by model size, so the checked-in
// reference intentionally uses the shared baseline without an empty generation-time
// thought channel until renderer selection is split by size.
// Current upstream Gemma 4 chat templates differ by model size. The checked-in
// reference cases below use the small (e2b/e4b-style) baseline, with large
// (26b/31b-style) checks covered separately in this file.
//
// To regenerate expected values, save the E2B template to
// gemma4_e2b_chat_template.jinja2 and run:
@@ -1474,6 +1474,47 @@ Hi<turn|>
}
}
func TestGemma4RendererVariantsMatchExpectedGenerationPrompt(t *testing.T) {
messages := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
rendererName string
expected string
}{
{
name: "legacy_alias",
rendererName: "gemma4",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "small",
rendererName: "gemma4-small",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large",
rendererName: "gemma4-large",
expected: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := RenderWithRenderer(tt.rendererName, messages, nil, nil)
assert.NoError(t, err)
assert.Equal(t, tt.expected, got)
})
}
}
func TestGemma4LargeRendererOmitsEmptyThoughtBlockWhenThinkingEnabled(t *testing.T) {
got, err := RenderWithRenderer("gemma4-large", []api.Message{{Role: "user", Content: "Hello"}}, nil, thinkTrue())
assert.NoError(t, err)
assert.Equal(t, "<bos><|turn>system\n<|think|>\n<turn|>\n<|turn>user\nHello<turn|>\n<|turn>model\n", got)
assert.NotContains(t, got, "<|channel>thought\n<channel|>")
}
func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
if os.Getenv("VERIFY_JINJA2") == "" {
t.Skip("set VERIFY_JINJA2=1 to run expanded Jinja2 parity checks")
@@ -1616,15 +1657,35 @@ func TestGemma4RendererMatchesJinja2ExpandedParity(t *testing.T) {
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &Gemma4Renderer{useImgTags: RenderImgTags}
got, err := renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
variants := []struct {
name string
renderer *Gemma4Renderer
templateRel string
}{
{
name: "small",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags},
templateRel: gemma4E2BTemplate,
},
{
name: "large",
renderer: &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true},
templateRel: gemma431BTemplate,
},
}
jinja2Output := renderWithJinja2(t, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
for _, variant := range variants {
t.Run(variant.name, func(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := variant.renderer.Render(tt.messages, tt.tools, tt.think)
assert.NoError(t, err)
jinja2Output := renderWithJinja2Template(t, variant.templateRel, tt.messages, tt.tools, tt.think)
assert.Equal(t, jinja2Output, got,
"renderer output doesn't match Jinja2 template output")
})
}
})
}
}

View File

@@ -81,8 +81,10 @@ func rendererForName(name string) Renderer {
return renderer
case "nemotron-3-nano":
return &Nemotron3NanoRenderer{}
case "gemma4":
case "gemma4", "gemma4-small":
return &Gemma4Renderer{useImgTags: RenderImgTags}
case "gemma4-large":
return &Gemma4Renderer{useImgTags: RenderImgTags, emptyBlockOnNothink: true}
case "functiongemma":
return &FunctionGemmaRenderer{}
case "glm-4.7":

View File

@@ -523,7 +523,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
arch := layer.GGML.KV().Architecture()
switch arch {
case "gemma4":
config.Renderer = cmp.Or(config.Renderer, "gemma4")
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
config.Parser = cmp.Or(config.Parser, "gemma4")
if _, ok := r.Parameters["stop"]; !ok {
if r.Parameters == nil {

78
server/gemma4_test.go Normal file
View File

@@ -0,0 +1,78 @@
package server
import "testing"
func TestResolveGemma4Renderer(t *testing.T) {
tests := []struct {
name string
model *Model
want string
}{
{
name: "nil model falls back to legacy alias",
model: nil,
want: gemma4RendererLegacy,
},
{
name: "explicit small passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererSmall),
},
want: gemma4RendererSmall,
},
{
name: "explicit large passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLarge),
},
want: gemma4RendererLarge,
},
{
name: "legacy e4b tag resolves small",
model: &Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
{
name: "legacy 31b tag resolves large",
model: &Model{
Name: "gemma4:31b-cloud",
ShortName: "gemma4:31b-cloud",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererLarge,
},
{
name: "legacy model type resolves small",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
},
want: gemma4RendererSmall,
},
{
name: "legacy model type resolves large",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: gemma4RendererLarge,
},
{
name: "legacy unknown defaults small",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveGemma4Renderer(tt.model); got != tt.want {
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -19,6 +19,7 @@ import (
"slices"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@@ -33,6 +34,10 @@ import (
"github.com/ollama/ollama/x/imagegen/transfer"
)
// Blobs newer than this may belong to another process that has not written its
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
const layerPruneGracePeriod = time.Hour
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
@@ -156,7 +161,7 @@ func (m *Model) Capabilities() []model.Capability {
// Temporary workaround — suppress vision/audio for gemma4 MLX models
// until multimodal runtime pipeline lands. Remove when imageproc.go is wired up.
if m.Config.ModelFormat == "safetensors" && m.Config.Renderer == "gemma4" {
if m.Config.ModelFormat == "safetensors" && isGemma4Renderer(m.Config.Renderer) {
capabilities = slices.DeleteFunc(capabilities, func(c model.Capability) bool {
return c == model.CapabilityVision || c == "audio"
})
@@ -478,10 +483,23 @@ func PruneLayers() error {
}
for _, blob := range blobs {
if blob.IsDir() {
continue
}
info, err := blob.Info()
if err != nil {
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
continue
}
if time.Since(info.ModTime()) < layerPruneGracePeriod {
continue
}
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := manifest.BlobsPath(name)
_, err = manifest.BlobsPath(name)
if err != nil {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)

View File

@@ -5,14 +5,58 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
for _, digest := range []string{recentDigest, oldDigest} {
p, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(p, nil, 0o644); err != nil {
t.Fatal(err)
}
}
oldPath, err := manifest.BlobsPath(oldDigest)
if err != nil {
t.Fatal(err)
}
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
if err := PruneLayers(); err != nil {
t.Fatal(err)
}
recentPath, err := manifest.BlobsPath(recentDigest)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(recentPath); err != nil {
t.Fatalf("recent orphan was pruned: %v", err)
}
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
t.Fatalf("old orphan still exists: %v", err)
}
}
func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{
@@ -118,6 +162,39 @@ func TestModelCapabilities(t *testing.T) {
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "gemma4 small safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererSmall,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "gemma4 large safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLarge,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "legacy gemma4 safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLegacy,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
}
// compare two slices of model.Capability regardless of order

View File

@@ -115,7 +115,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" {
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
rendererName := resolveRendererName(m)
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
if err != nil {
return "", err
}

View File

@@ -13,6 +13,14 @@ import (
"github.com/ollama/ollama/types/model"
)
func testConfigWithRenderer(renderer string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer}
}
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
}
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
@@ -397,3 +405,43 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
}
}
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
msgs := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
model Model
want string
}{
{
name: "small from name",
model: Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large from model type",
model: Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderPrompt(&tt.model, msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -0,0 +1,110 @@
package server
import (
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
gemma4RendererLegacy = "gemma4"
gemma4RendererSmall = "gemma4-small"
gemma4RendererLarge = "gemma4-large"
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
// large template. Default to the small prompt unless the model is clearly in
// the large range.
gemma4LargeMinParameterCount = 16_000_000_000
)
func resolveRendererName(m *Model) string {
if m == nil || m.Config.Renderer == "" {
return ""
}
switch m.Config.Renderer {
case gemma4RendererLegacy:
return resolveGemma4Renderer(m)
default:
return m.Config.Renderer
}
}
func resolveGemma4Renderer(m *Model) string {
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
if m == nil {
return gemma4RendererLegacy
}
return m.Config.Renderer
}
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
return renderer
}
if renderer, ok := gemma4RendererFromName(m.Name); ok {
return renderer
}
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
return gemma4RendererForParameterCount(parameterCount)
}
return gemma4RendererSmall
}
func gemma4RendererForParameterCount(parameterCount uint64) string {
if parameterCount >= gemma4LargeMinParameterCount {
return gemma4RendererLarge
}
return gemma4RendererSmall
}
func gemma4RendererFromName(name string) (string, bool) {
lower := strings.ToLower(name)
switch {
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
return gemma4RendererSmall, true
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
return gemma4RendererLarge, true
default:
return "", false
}
}
func parseHumanParameterCount(s string) (uint64, bool) {
if s == "" {
return 0, false
}
unit := strings.ToUpper(s[len(s)-1:])
var multiplier float64
switch unit {
case "B":
multiplier = float64(format.Billion)
case "M":
multiplier = float64(format.Million)
case "K":
multiplier = float64(format.Thousand)
default:
return 0, false
}
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
if err != nil {
return 0, false
}
return uint64(value * multiplier), true
}
func isGemma4Renderer(renderer string) bool {
switch renderer {
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
return true
default:
return false
}
}

View File

@@ -2408,7 +2408,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
// TODO(parthsareen): temporary fix for https://github.com/ollama/ollama/issues/15260.
// To revisit for other models and have a consistent pattern across models through parsers.
forceImmediate := m.Config.Parser == "gemma4" && req.Think != nil && !req.Think.Bool()
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && !forceImmediate && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil
}

View File

@@ -928,6 +928,59 @@ func TestCreateDetectTemplate(t *testing.T) {
})
}
func TestCreateGemma4KeepsDynamicRendererAlias(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "gemma4",
"general.parameter_count": uint64(25_200_000_000),
}, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
mf, err := manifest.ParseNamedManifest(model.ParseName("test"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for manifest")
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
cfgFile, err := os.Open(configPath)
if err != nil {
t.Fatalf("open config blob: %v", err)
}
defer cfgFile.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
t.Fatalf("decode config: %v", err)
}
if cfg.Renderer != gemma4RendererLegacy {
t.Fatalf("expected renderer %q, got %q", gemma4RendererLegacy, cfg.Renderer)
}
if cfg.Parser != "gemma4" {
t.Fatalf("expected parser %q, got %q", "gemma4", cfg.Parser)
}
}
func TestDetectModelTypeFromFiles(t *testing.T) {
t.Run("gguf file", func(t *testing.T) {
_, digest := createBinFile(t, nil, nil)

View File

@@ -2108,6 +2108,132 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
})
}
// TestChatFormatWithThinkFalse verifies that when a model uses a builtin
// parser that supports thinking (e.g. gemma4) and the request explicitly
// disables thinking (think=false), the format constraint is passed to the
// first and only completion call. Previously, format was deferred for all
// thinking-capable parsers and only re-applied after an end-of-thinking
// transition — a transition that never fires when thinking is off. See
// https://github.com/ollama/ollama/issues/15260.
func TestChatFormatWithThinkFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := &mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Use the gemma4 builtin parser — it reports HasThinkingSupport=true, which
// adds CapabilityThinking to the model and previously triggered deferral of
// the format even when the user passed think=false.
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-gemma4-parser",
Files: map[string]string{"file.gguf": digest},
Parser: "gemma4",
Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }}{{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("create: expected status 200, got %d: %s", w.Code, w.Body.String())
}
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}},"required":["answer"]}`)
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
requestsMu.Lock()
requests = append(requests, r)
requestsMu.Unlock()
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
}
streamRequest := false
think := false
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-gemma4-parser",
Messages: []api.Message{{Role: "user", Content: "Respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
if w.Code != http.StatusOK {
t.Fatalf("chat: expected status 200, got %d: %s", w.Code, w.Body.String())
}
if len(requests) != 1 {
t.Fatalf("expected a single completion call, got %d", len(requests))
}
if !bytes.Equal([]byte(format), []byte(requests[0].Format)) {
t.Errorf("expected first completion format to match the request format, got %q", string(requests[0].Format))
}
}
func TestGenerateUnload(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -115,36 +115,7 @@ func (s *Server) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo
// Spawn subprocess: ollama runner --imagegen-engine --model <path> --port <port>
cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", s.modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
if runtime.GOOS == "linux" {
// Build library paths: start with LibOllamaPath, then add any mlx_* subdirectories
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
// Append existing LD_LIBRARY_PATH if set
if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add LD_LIBRARY_PATH in cmd.Env
found := false
for i := range cmd.Env {
if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
configureMLXSubprocessEnv(cmd, ml.LibraryPaths(gpus))
s.cmd = cmd
@@ -200,6 +171,53 @@ func (s *Server) Ping(ctx context.Context) error {
return nil
}
func mlxLibraryPathEnv() string {
switch runtime.GOOS {
case "windows":
return "PATH"
case "darwin":
return "DYLD_LIBRARY_PATH"
default:
return "LD_LIBRARY_PATH"
}
}
func configureMLXSubprocessEnv(cmd *exec.Cmd, libraryPaths []string) {
if len(libraryPaths) == 0 {
return
}
// Search order for the imagegen runner is:
// 1. bundled lib/ollama root
// 2. backend-specific library dirs selected during GPU discovery
// 3. any existing caller-provided library path values
pathEnv := mlxLibraryPathEnv()
pathEnvPaths := append([]string{}, libraryPaths...)
if existingPath, ok := os.LookupEnv(pathEnv); ok {
pathEnvPaths = append(pathEnvPaths, filepath.SplitList(existingPath)...)
}
setSubprocessEnv(cmd, pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
slog.Debug("mlx subprocess library path", pathEnv, strings.Join(pathEnvPaths, string(filepath.ListSeparator)))
ollamaLibraryPaths := append([]string{}, libraryPaths...)
if existingPath, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
ollamaLibraryPaths = append(ollamaLibraryPaths, filepath.SplitList(existingPath)...)
}
setSubprocessEnv(cmd, "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
slog.Debug("mlx subprocess library path", "OLLAMA_LIBRARY_PATH", strings.Join(ollamaLibraryPaths, string(filepath.ListSeparator)))
}
func setSubprocessEnv(cmd *exec.Cmd, key, value string) {
for i := range cmd.Env {
name, _, ok := strings.Cut(cmd.Env[i], "=")
if ok && strings.EqualFold(name, key) {
cmd.Env[i] = key + "=" + value
return
}
}
cmd.Env = append(cmd.Env, key+"="+value)
}
// getLastErr returns the last stderr line.
func (s *Server) getLastErr() string {
s.lastErrLock.Lock()

View File

@@ -337,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
liveLen := min(c.offset, c.keys.Dim(2))
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
}
}

View File

@@ -158,13 +158,15 @@ type completionRequest struct {
}
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
}
type CompletionResponse struct {
@@ -206,13 +208,15 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
if req.Options != nil {
creq.Options = &completionOpts{
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
RepeatLastN: req.Options.RepeatLastN,
PresencePenalty: req.Options.PresencePenalty,
NumPredict: req.Options.NumPredict,
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
RepeatLastN: req.Options.RepeatLastN,
RepeatPenalty: req.Options.RepeatPenalty,
PresencePenalty: req.Options.PresencePenalty,
FrequencyPenalty: req.Options.FrequencyPenalty,
NumPredict: req.Options.NumPredict,
}
}

View File

@@ -62,3 +62,25 @@ var LogitSoftcap = Compile2(
},
Shapeless(),
)
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
// per-expert scores after top-k) and the post-bias negation (used as the
// argpartition key for top-k) share a single kernel.
var sigmoidRouterFused = Compile(
"SigmoidRouter",
func(in ...*Array) []*Array {
gates, bias := in[0], in[1]
orig := gates.Sigmoid()
neg := orig.Add(bias).Negative()
return []*Array{orig, neg}
},
Shapeless(),
)
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
out := sigmoidRouterFused(gates, bias)
return out[0], out[1]
}

View File

@@ -169,6 +169,12 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
return out
}
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
out := New("SCATTER_ADD_AXIS")
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {

View File

@@ -26,16 +26,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ResetPeakMemory()
ctx := request.Ctx
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
sample *mlx.Array
nextSample *mlx.Array
)
defer func() {
if request.Sampler != nil {
request.Sampler.Free()
}
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
mlx.Unpin(sample)
mlx.Unpin(nextSample)
mlx.Sweep()
mlx.ClearCache()
@@ -135,22 +135,21 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
step := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sampler.Sample(logprobs)
sample := request.Sampler.Sample(logits)
mlx.Pin(sample, logprobs)
mlx.Pin(sample)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
mlx.AsyncEval(sample)
return sample, logprobs
return sample
}
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
sample = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer
@@ -161,7 +160,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
request.Sampler.AppendToken(sample)
nextSample, nextLogprobs = step(sample)
nextSample = step(sample)
if i == 0 {
mlx.Eval(sample)
@@ -186,9 +185,9 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}:
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
mlx.Unpin(sample)
sample = nextSample
nextSample = nil
if i%256 == 0 {
mlx.ClearCache()

View File

@@ -31,13 +31,15 @@ type Request struct {
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n"`
PresencePenalty float32 `json:"presence_penalty"`
MaxTokens int `json:"max_tokens"`
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`

View File

@@ -9,30 +9,38 @@ import (
type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Sampler struct {
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
PresencePenalty float32
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
history *mlx.Array
historyLen int
transforms []Transform
}
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler {
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) *Sampler {
if repeatPenalty <= 0 {
repeatPenalty = 1
}
s := &Sampler{
Temperature: temp,
TopP: top_p,
MinP: min_p,
TopK: top_k,
RepeatLastN: repeatLastN,
PresencePenalty: presencePenalty,
Temperature: temp,
TopP: top_p,
MinP: min_p,
TopK: top_k,
RepeatLastN: repeatLastN,
RepeatPenalty: repeatPenalty,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
}
var transforms []Transform
if presencePenalty != 0 {
if s.usesHistory() {
transforms = append(transforms, penalty)
}
@@ -59,7 +67,7 @@ func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty flo
}
func (s *Sampler) usesHistory() bool {
return s.PresencePenalty != 0
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
}
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
@@ -130,60 +138,78 @@ func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
}
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
func topP(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.TopP <= 0 || s.TopP >= 1 {
return logprobs
return logits
}
order := logprobs.Negative().ArgsortAxis(-1)
sortedLogprobs := logprobs.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true)
order := logits.Negative().ArgsortAxis(-1)
sortedLogits := logits.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(sortedLogits, -1, true)
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1))))
return logprobs.PutAlongAxis(order, filtered, -1)
filtered := mlx.Where(keep, sortedLogits, mlx.FromValue(float32(math.Inf(-1))))
return logits.PutAlongAxis(order, filtered, -1)
}
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
func minP(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return logprobs
return logits
}
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1)
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP))))
maxLogits := logits.TakeAlongAxis(logits.Argmax(-1, true), -1)
minLogits := mlx.AddScalar(maxLogits, float32(math.Log(float64(s.MinP))))
return mlx.Where(
logprobs.Less(minLogprobs),
logits.Less(minLogits),
mlx.FromValue(float32(math.Inf(-1))),
logprobs,
logits,
)
}
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
func topK(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.TopK <= 0 {
return logprobs
return logits
}
vocab := logprobs.Dim(logprobs.NumDims() - 1)
vocab := logits.Dim(logits.NumDims() - 1)
if s.TopK >= vocab {
return logprobs
return logits
}
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
mask := logits.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logits.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 {
return logprobs
func penalty(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.historyLen == 0 {
return logits
}
tokenIndices := s.history
if logprobs.NumDims() > 1 {
if logits.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0)
}
selected := logprobs.TakeAlongAxis(tokenIndices, -1)
adjusted := mlx.AddScalar(selected, -s.PresencePenalty)
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1)
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
adjusted := logits.TakeAlongAxis(tokenIndices, -1)
if s.RepeatPenalty != 1 {
factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))),
mlx.FromValue(s.RepeatPenalty),
mlx.FromValue(1/s.RepeatPenalty),
)
adjusted = adjusted.Multiply(factor)
}
if s.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
}
logits = logits.PutAlongAxis(tokenIndices, adjusted, -1)
}
if s.FrequencyPenalty != 0 {
logits = logits.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
}
return logits
}

View File

@@ -11,7 +11,7 @@ import (
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
// RepeatLastN = 1, PresencePenalty = 6
s := New(0, 0, 0, 0, 1, 6)
s := New(0, 0, 0, 0, 1, 1, 6, 0)
defer func() {
s.Free()
mlx.Sweep()
@@ -20,11 +20,11 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
s.ResetHistory([]int32{0})
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logprobs)
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits)
mlx.Eval(got)
// logprobs will be [0, -1, 4] after the penalty
// logits will be [0, -1, 4] after the penalty
// and then (index) 2 after the greedy sampler
gotInt := got.Int()
if gotInt != 2 {
@@ -32,19 +32,59 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
}
}
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
s := New(0, 0, 0.5, 0, 0, 0)
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
s := New(0, 0, 0, 0, 1, 2, 0, 0)
defer func() {
s.Free()
mlx.Sweep()
}()
logprobs := mlx.FromValues([]float32{
s.ResetHistory([]int32{1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits)
mlx.Eval(got)
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
gotInt := got.Int()
if gotInt != 2 {
t.Fatalf("got %d, want 2", gotInt)
}
}
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
s := New(0, 0, 0, 0, 4, 1, 0, 2)
defer func() {
s.Free()
mlx.Sweep()
}()
s.ResetHistory([]int32{1, 1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits)
mlx.Eval(got)
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
gotInt := got.Int()
if gotInt != 2 {
t.Fatalf("got %d, want 2", gotInt)
}
}
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
s := New(0, 0, 0.5, 0, 0, 1, 0, 0)
defer func() {
s.Free()
mlx.Sweep()
}()
logits := mlx.FromValues([]float32{
float32(math.Log(0.5)),
float32(math.Log(0.3)),
float32(math.Log(0.2)),
}, 3)
got := minP(s, logprobs)
got := minP(s, logits)
mlx.Eval(got)
gotFloats := got.Floats()

View File

@@ -102,7 +102,9 @@ func Execute(args []string) error {
request.Options.MinP,
request.Options.TopK,
request.Options.RepeatLastN,
request.Options.RepeatPenalty,
request.Options.PresencePenalty,
request.Options.FrequencyPenalty,
)
var cancel context.CancelFunc

View File

@@ -1061,14 +1061,12 @@ func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
}
}
h = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
var donorKV *sharedKVEntry
h, donorKV = layer.Forward(h, c, B, L, m.TextConfig, pleInput, donorEntry, smc)
// If this layer is a donor, store its cached KV for later shared layers.
if layer.IsDonor && c != nil {
state := c.State()
if len(state) >= 2 && state[0] != nil && state[1] != nil {
sharedKV[layer.LayerIdx] = sharedKVEntry{K: state[0], V: state[1], Offset: c.Offset()}
}
if layer.IsDonor && donorKV != nil {
sharedKV[layer.LayerIdx] = *donorKV
}
}
@@ -1190,9 +1188,9 @@ func sliceLayerDim(combined *mlx.Array, layerIdx, B, L, pleDim int32) *mlx.Array
return mlx.Squeeze(sliced, 2)
}
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig, pleInput *mlx.Array, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut, donorKV := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg, donorEntry, slidingMaskCache)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
@@ -1237,10 +1235,10 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
h = mlx.Mul(h, l.LayerScalar)
}
return h
return h, donorKV
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) *mlx.Array {
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig, donorEntry *sharedKVEntry, slidingMaskCache *slidingMaskCache) (*mlx.Array, *sharedKVEntry) {
// Determine head dim and scale based on layer type.
headDim := cfg.HeadDim
scale := cfg.SlidingScale
@@ -1274,6 +1272,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, offset, ropeFreqs)
var k, v *mlx.Array
var donorKV *sharedKVEntry
if donorEntry != nil {
// Shared layer: use donor's cached K/V.
@@ -1312,6 +1311,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// Update cache.
if c != nil {
k, v = c.Update(k, v)
donorKV = &sharedKVEntry{K: k, V: v, Offset: c.Offset()}
}
}
@@ -1365,7 +1365,7 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
// strided views differently. Metal handles them natively.
out = mlx.Contiguous(out, false)
}
return a.OProj.Forward(out)
return a.OProj.Forward(out), donorKV
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {

View File

@@ -161,21 +161,21 @@ type MoEGate struct {
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
gates := g.Gate.Forward(x)
scores := mlx.Sigmoid(gates)
origScores := scores
var origScores, negScores *mlx.Array
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
} else {
origScores = mlx.Sigmoid(gates)
negScores = mlx.Neg(origScores)
}
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
dims := inds.Dims()
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
scores = mlx.TakeAlongAxis(origScores, inds, -1)
scores := mlx.TakeAlongAxis(origScores, inds, -1)
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)