mirror of
https://github.com/ollama/ollama.git
synced 2026-04-22 00:36:11 +02:00
Compare commits
14 Commits
v0.21.0-rc
...
pdevine/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bcdb250b9 | ||
|
|
7bbcd2e6be | ||
|
|
22d6c817f8 | ||
|
|
ca01373b28 | ||
|
|
24e038d56a | ||
|
|
5d1021603a | ||
|
|
8e05d734b9 | ||
|
|
05e0f21bec | ||
|
|
ff23dd343f | ||
|
|
123b300af6 | ||
|
|
57653b8e42 | ||
|
|
a50ce61c54 | ||
|
|
2bb7ea00d2 | ||
|
|
55fa80d07a |
@@ -61,6 +61,9 @@ func TestLaunchCmd(t *testing.T) {
|
||||
if !strings.Contains(cmd.Long, "hermes") {
|
||||
t.Error("Long description should mention hermes")
|
||||
}
|
||||
if !strings.Contains(cmd.Long, "kimi") {
|
||||
t.Error("Long description should mention kimi")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("flags exist", func(t *testing.T) {
|
||||
|
||||
@@ -4,18 +4,15 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
pathpkg "path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
@@ -66,23 +63,13 @@ var hermesMessagingEnvGroups = [][]string{
|
||||
// switching UX after startup.
|
||||
type Hermes struct{}
|
||||
|
||||
type hermesConfigBackend struct {
|
||||
displayPath string
|
||||
read func() ([]byte, error)
|
||||
write func([]byte) error
|
||||
}
|
||||
|
||||
func (h *Hermes) String() string { return "Hermes Agent" }
|
||||
|
||||
func (h *Hermes) Run(_ string, args []string) error {
|
||||
// Hermes reads its primary model from config.yaml. launch configures that
|
||||
// default model ahead of time so we can keep runtime invocation simple and
|
||||
// still let Hermes discover additional models later via its own UX.
|
||||
if hermesGOOS == "windows" {
|
||||
return h.runWindows(args)
|
||||
}
|
||||
|
||||
bin, err := h.findUnixBinary()
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -95,21 +82,21 @@ func (h *Hermes) Run(_ string, args []string) error {
|
||||
}
|
||||
|
||||
func (h *Hermes) Paths() []string {
|
||||
backend, err := h.configBackend()
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return []string{backend.displayPath}
|
||||
return []string{configPath}
|
||||
}
|
||||
|
||||
func (h *Hermes) Configure(model string) error {
|
||||
backend, err := h.configBackend()
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := map[string]any{}
|
||||
if data, err := backend.read(); err == nil {
|
||||
if data, err := os.ReadFile(configPath); err == nil {
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse hermes config: %w", err)
|
||||
}
|
||||
@@ -142,15 +129,18 @@ func (h *Hermes) Configure(model string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return backend.write(data)
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(configPath, data)
|
||||
}
|
||||
|
||||
func (h *Hermes) CurrentModel() string {
|
||||
backend, err := h.configBackend()
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
data, err := backend.read()
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -188,14 +178,7 @@ func (h *Hermes) RefreshRuntimeAfterConfigure() error {
|
||||
}
|
||||
|
||||
func (h *Hermes) installed() bool {
|
||||
if hermesGOOS == "windows" {
|
||||
if _, err := hermesLookPath("hermes"); err == nil {
|
||||
return true
|
||||
}
|
||||
return h.wslHasHermes()
|
||||
}
|
||||
|
||||
_, err := h.findUnixBinary()
|
||||
_, err := h.binary()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -205,7 +188,7 @@ func (h *Hermes) ensureInstalled() error {
|
||||
}
|
||||
|
||||
if hermesGOOS == "windows" {
|
||||
return h.ensureInstalledWindows()
|
||||
return hermesWindowsHint()
|
||||
}
|
||||
|
||||
var missing []string
|
||||
@@ -239,42 +222,6 @@ func (h *Hermes) ensureInstalled() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) ensureInstalledWindows() error {
|
||||
// Hermes upstream support is WSL-oriented, so Windows launch uses a hybrid
|
||||
// WSL handoff that stays on the same install path as upstream Hermes.
|
||||
if _, err := hermesLookPath("hermes"); err == nil {
|
||||
return nil
|
||||
}
|
||||
if !h.wslAvailable() {
|
||||
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||
}
|
||||
if h.wslHasHermes() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err := ConfirmPromptWithOptions("Hermes runs through WSL2 on Windows. Install it in WSL now?", ConfirmOptions{
|
||||
YesLabel: "Use WSL",
|
||||
NoLabel: "Show manual steps",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Hermes in WSL...\n")
|
||||
if err := h.runWSL("bash", "-lc", hermesInstallScript); err != nil {
|
||||
return hermesWindowsHint(fmt.Errorf("failed to install hermes in WSL: %w", err))
|
||||
}
|
||||
if !h.wslHasHermes() {
|
||||
return hermesWindowsHint(fmt.Errorf("hermes install finished but the WSL binary was not found"))
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sHermes installed successfully in WSL%s\n\n", ansiGreen, ansiReset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) listModels(defaultModel string) []string {
|
||||
client := hermesOllamaClient()
|
||||
resp, err := client.List(context.Background())
|
||||
@@ -306,11 +253,15 @@ func (h *Hermes) listModels(defaultModel string) []string {
|
||||
return models
|
||||
}
|
||||
|
||||
func (h *Hermes) findUnixBinary() (string, error) {
|
||||
func (h *Hermes) binary() (string, error) {
|
||||
if path, err := hermesLookPath("hermes"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if hermesGOOS == "windows" {
|
||||
return "", hermesWindowsHint()
|
||||
}
|
||||
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -323,70 +274,6 @@ func (h *Hermes) findUnixBinary() (string, error) {
|
||||
return "", fmt.Errorf("hermes is not installed")
|
||||
}
|
||||
|
||||
func (h *Hermes) runWindows(args []string) error {
|
||||
if path, err := hermesLookPath("hermes"); err == nil {
|
||||
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||
return hermesAttachedCommand(path, "gateway", "setup").Run()
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return hermesAttachedCommand(path, args...).Run()
|
||||
}
|
||||
if !h.wslAvailable() {
|
||||
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||
}
|
||||
if err := h.runGatewaySetupPreflight(args, func() error {
|
||||
return h.runWSL("hermes", "gateway", "setup")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := h.runWSL(append([]string{"hermes"}, args...)...); err != nil {
|
||||
return hermesWindowsHint(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hermes) runWSL(args ...string) error {
|
||||
if !h.wslAvailable() {
|
||||
return fmt.Errorf("wsl.exe is not available")
|
||||
}
|
||||
|
||||
return hermesAttachedCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).Run()
|
||||
}
|
||||
|
||||
func (h *Hermes) runWSLCombinedOutput(args ...string) ([]byte, error) {
|
||||
if !h.wslAvailable() {
|
||||
return nil, fmt.Errorf("wsl.exe is not available")
|
||||
}
|
||||
|
||||
return hermesCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).CombinedOutput()
|
||||
}
|
||||
|
||||
func (h *Hermes) wslAvailable() bool {
|
||||
_, err := hermesLookPath("wsl.exe")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (h *Hermes) wslHasHermes() bool {
|
||||
if !h.wslAvailable() {
|
||||
return false
|
||||
}
|
||||
cmd := hermesCommand("wsl.exe", "bash", "-lc", "command -v hermes >/dev/null 2>&1")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func (h *Hermes) configBackend() (*hermesConfigBackend, error) {
|
||||
if hermesGOOS == "windows" {
|
||||
if _, err := hermesLookPath("hermes"); err == nil {
|
||||
return hermesLocalConfigBackend()
|
||||
}
|
||||
if h.wslAvailable() {
|
||||
return h.wslConfigBackend()
|
||||
}
|
||||
}
|
||||
return hermesLocalConfigBackend()
|
||||
}
|
||||
|
||||
func hermesConfigPath() (string, error) {
|
||||
home, err := hermesUserHome()
|
||||
if err != nil {
|
||||
@@ -395,110 +282,6 @@ func hermesConfigPath() (string, error) {
|
||||
return filepath.Join(home, ".hermes", "config.yaml"), nil
|
||||
}
|
||||
|
||||
func hermesLocalConfigBackend() (*hermesConfigBackend, error) {
|
||||
configPath, err := hermesConfigPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &hermesConfigBackend{
|
||||
displayPath: configPath,
|
||||
read: func() ([]byte, error) {
|
||||
return os.ReadFile(configPath)
|
||||
},
|
||||
write: func(data []byte) error {
|
||||
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteWithBackup(configPath, data)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Hermes) wslConfigBackend() (*hermesConfigBackend, error) {
|
||||
home, err := h.wslHome()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configPath := pathpkg.Join(home, ".hermes", "config.yaml")
|
||||
return &hermesConfigBackend{
|
||||
displayPath: configPath,
|
||||
read: func() ([]byte, error) {
|
||||
return h.readWSLFile(configPath)
|
||||
},
|
||||
write: func(data []byte) error {
|
||||
return h.writeWSLConfig(configPath, data)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Hermes) wslHome() (string, error) {
|
||||
if !h.wslAvailable() {
|
||||
return "", fmt.Errorf("wsl.exe is not available")
|
||||
}
|
||||
cmd := hermesCommand("wsl.exe", "bash", "-lc", `printf %s "$HOME"`)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
home := strings.TrimSpace(string(out))
|
||||
if home == "" {
|
||||
return "", fmt.Errorf("could not resolve WSL home directory")
|
||||
}
|
||||
return home, nil
|
||||
}
|
||||
|
||||
func (h *Hermes) readWSLFile(path string) ([]byte, error) {
|
||||
pathArg := shellQuoteArgs([]string{path})
|
||||
cmd := hermesCommand("wsl.exe", "bash", "-lc", fmt.Sprintf("if [ -f %s ]; then cat %s; else exit 42; fi", pathArg, pathArg))
|
||||
out, err := cmd.Output()
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) && exitErr.ExitCode() == 42 {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (h *Hermes) writeWSLConfig(path string, data []byte) error {
|
||||
if existing, err := h.readWSLFile(path); err == nil {
|
||||
if !bytes.Equal(existing, data) {
|
||||
if err := hermesBackupData(path, existing); err != nil {
|
||||
return fmt.Errorf("backup failed: %w", err)
|
||||
}
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("read existing file: %w", err)
|
||||
}
|
||||
|
||||
dir := pathpkg.Dir(path)
|
||||
dirArg := shellQuoteArgs([]string{dir})
|
||||
pathArg := shellQuoteArgs([]string{path})
|
||||
script := fmt.Sprintf(
|
||||
"dir=%s; path=%s; mkdir -p \"$dir\" && tmp=$(mktemp \"$dir/.tmp-XXXXXX\") && cat > \"$tmp\" && mv \"$tmp\" \"$path\"",
|
||||
dirArg,
|
||||
pathArg,
|
||||
)
|
||||
cmd := hermesCommand("wsl.exe", "bash", "-lc", script)
|
||||
cmd.Stdin = bytes.NewReader(data)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
if msg := strings.TrimSpace(string(out)); msg != "" {
|
||||
return fmt.Errorf("%w: %s", err, msg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hermesBackupData(path string, data []byte) error {
|
||||
if err := os.MkdirAll(fileutil.BackupDir(), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
backupPath := filepath.Join(fileutil.BackupDir(), fmt.Sprintf("%s.%d", filepath.Base(path), time.Now().Unix()))
|
||||
return os.WriteFile(backupPath, data, 0o644)
|
||||
}
|
||||
|
||||
func hermesBaseURL() string {
|
||||
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
|
||||
}
|
||||
@@ -554,8 +337,11 @@ func (h *Hermes) messagingConfigured() bool {
|
||||
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||
envVars := make(map[string]string)
|
||||
|
||||
data, err := h.readGatewayEnvFile()
|
||||
switch {
|
||||
envFilePath, err := hermesEnvPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch data, err := os.ReadFile(envFilePath); {
|
||||
case err == nil:
|
||||
for key, value := range hermesParseEnvFile(data) {
|
||||
envVars[key] = value
|
||||
@@ -566,12 +352,10 @@ func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.usesLocalRuntimeEnv() {
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
for _, key := range group {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
envVars[key] = value
|
||||
}
|
||||
for _, group := range hermesMessagingEnvGroups {
|
||||
for _, key := range group {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
envVars[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -579,39 +363,6 @@ func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
|
||||
return envVars, nil
|
||||
}
|
||||
|
||||
func (h *Hermes) readGatewayEnvFile() ([]byte, error) {
|
||||
if hermesGOOS == "windows" {
|
||||
if _, err := hermesLookPath("hermes"); err == nil {
|
||||
path, err := hermesEnvPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
if h.wslAvailable() {
|
||||
home, err := h.wslHome()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.readWSLFile(pathpkg.Join(home, ".hermes", ".env"))
|
||||
}
|
||||
}
|
||||
|
||||
path, err := hermesEnvPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func (h *Hermes) usesLocalRuntimeEnv() bool {
|
||||
if hermesGOOS != "windows" {
|
||||
return true
|
||||
}
|
||||
_, err := hermesLookPath("hermes")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayRunning() (bool, error) {
|
||||
status, err := h.gatewayStatusOutput()
|
||||
if err != nil {
|
||||
@@ -621,19 +372,7 @@ func (h *Hermes) gatewayRunning() (bool, error) {
|
||||
}
|
||||
|
||||
func (h *Hermes) gatewayStatusOutput() (string, error) {
|
||||
if hermesGOOS == "windows" {
|
||||
if path, err := hermesLookPath("hermes"); err == nil {
|
||||
out, err := hermesCommand(path, "gateway", "status").CombinedOutput()
|
||||
return string(out), err
|
||||
}
|
||||
if !h.wslAvailable() {
|
||||
return "", hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||
}
|
||||
out, err := h.runWSLCombinedOutput("hermes", "gateway", "status")
|
||||
return string(out), err
|
||||
}
|
||||
|
||||
bin, err := h.findUnixBinary()
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -642,20 +381,7 @@ func (h *Hermes) gatewayStatusOutput() (string, error) {
|
||||
}
|
||||
|
||||
func (h *Hermes) restartGateway() error {
|
||||
if hermesGOOS == "windows" {
|
||||
if path, err := hermesLookPath("hermes"); err == nil {
|
||||
return hermesAttachedCommand(path, "gateway", "restart").Run()
|
||||
}
|
||||
if !h.wslAvailable() {
|
||||
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
|
||||
}
|
||||
if err := h.runWSL("hermes", "gateway", "restart"); err != nil {
|
||||
return hermesWindowsHint(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bin, err := h.findUnixBinary()
|
||||
bin, err := h.binary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -938,14 +664,6 @@ func mergeHermesToolsets(current any) any {
|
||||
}
|
||||
}
|
||||
|
||||
func shellQuoteArgs(args []string) string {
|
||||
quoted := make([]string, 0, len(args))
|
||||
for _, arg := range args {
|
||||
quoted = append(quoted, "'"+strings.ReplaceAll(arg, "'", `'\''`)+"'")
|
||||
}
|
||||
return strings.Join(quoted, " ")
|
||||
}
|
||||
|
||||
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
|
||||
cmd := hermesCommand(name, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
@@ -954,9 +672,8 @@ func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func hermesWindowsHint(err error) error {
|
||||
if hermesGOOS != "windows" {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("%w\n\nHermes runs on Windows through WSL2.\nQuick setup: wsl --install\nInstaller docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/", err)
|
||||
func hermesWindowsHint() error {
|
||||
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
|
||||
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
|
||||
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
|
||||
}
|
||||
|
||||
@@ -896,64 +896,6 @@ fi
|
||||
}
|
||||
}
|
||||
|
||||
func TestHermesRefreshRuntimeAfterConfigure_WindowsWSLRestartsRunningGateway(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell test binaries to simulate WSL")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withHermesPlatform(t, "windows")
|
||||
t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
wslPath := filepath.Join(tmpDir, "wsl.exe")
|
||||
wslScript := `#!/bin/sh
|
||||
printf '[%s]\n' "$*" >> "$HOME/wsl-invocations.log"
|
||||
exec /bin/sh -lc "$3"
|
||||
`
|
||||
if err := os.WriteFile(wslPath, []byte(wslScript), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hermesBin := filepath.Join(tmpDir, "hermes")
|
||||
hermesScript := `#!/bin/sh
|
||||
printf '[%s]\n' "$*" >> "$HOME/hermes-invocations.log"
|
||||
if [ "$1" = "gateway" ] && [ "$2" = "status" ]; then
|
||||
printf '✓ Gateway is running (PID: 321)\n'
|
||||
fi
|
||||
`
|
||||
if err := os.WriteFile(hermesBin, []byte(hermesScript), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
withHermesLookPath(t, func(file string) (string, error) {
|
||||
if file == "wsl.exe" {
|
||||
return wslPath, nil
|
||||
}
|
||||
return "", os.ErrNotExist
|
||||
})
|
||||
|
||||
h := &Hermes{}
|
||||
if err := h.RefreshRuntimeAfterConfigure(); err != nil {
|
||||
t.Fatalf("RefreshRuntimeAfterConfigure returned error: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "hermes-invocations.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("expected WSL status then restart invocations, got %v", lines)
|
||||
}
|
||||
if lines[0] != "[gateway status]" {
|
||||
t.Fatalf("expected WSL gateway status first, got %q", lines[0])
|
||||
}
|
||||
if lines[1] != "[gateway restart]" {
|
||||
t.Fatalf("expected WSL gateway restart second, got %q", lines[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHermesMessagingConfiguredRecognizesSupportedGatewayVars(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
@@ -1002,82 +944,7 @@ func TestHermesMessagingConfiguredRecognizesSupportedGatewayVars(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHermesRunWindowsWSL_UsesGatewaySetupPreflight(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell test binaries to simulate WSL")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withLauncherHooks(t)
|
||||
withInteractiveSession(t, true)
|
||||
withHermesPlatform(t, "windows")
|
||||
clearHermesMessagingEnvVars(t)
|
||||
t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
|
||||
wslPath := filepath.Join(tmpDir, "wsl.exe")
|
||||
wslScript := `#!/bin/sh
|
||||
printf '[%s]\n' "$*" >> "$HOME/wsl-invocations.log"
|
||||
exec /bin/sh -lc "$3"
|
||||
`
|
||||
if err := os.WriteFile(wslPath, []byte(wslScript), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hermesBin := filepath.Join(tmpDir, "hermes")
|
||||
hermesScript := `#!/bin/sh
|
||||
printf '[%s]\n' "$*" >> "$HOME/hermes-invocations.log"
|
||||
if [ "$1" = "gateway" ] && [ "$2" = "setup" ]; then
|
||||
/bin/mkdir -p "$HOME/.hermes"
|
||||
printf 'TELEGRAM_BOT_TOKEN=configured\n' > "$HOME/.hermes/.env"
|
||||
fi
|
||||
`
|
||||
if err := os.WriteFile(hermesBin, []byte(hermesScript), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
withHermesLookPath(t, func(file string) (string, error) {
|
||||
if file == "wsl.exe" {
|
||||
return wslPath, nil
|
||||
}
|
||||
return "", os.ErrNotExist
|
||||
})
|
||||
|
||||
promptCount := 0
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
promptCount++
|
||||
if prompt != hermesGatewaySetupTitle {
|
||||
t.Fatalf("unexpected prompt %q", prompt)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
h := &Hermes{}
|
||||
if err := h.Run("", nil); err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
|
||||
if promptCount != 1 {
|
||||
t.Fatalf("expected one messaging prompt, got %d", promptCount)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(tmpDir, "hermes-invocations.log"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("expected WSL hermes to run setup then launch, got %v", lines)
|
||||
}
|
||||
if lines[0] != "[gateway setup]" {
|
||||
t.Fatalf("expected WSL gateway setup first, got %q", lines[0])
|
||||
}
|
||||
if lines[1] != "[]" {
|
||||
t.Fatalf("expected WSL default hermes launch second, got %q", lines[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHermesEnsureInstalledWindowsWithoutWSLGivesGuidance(t *testing.T) {
|
||||
func TestHermesEnsureInstalledWindowsShowsWSLGuidance(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
withHermesPlatform(t, "windows")
|
||||
@@ -1086,10 +953,17 @@ func TestHermesEnsureInstalledWindowsWithoutWSLGivesGuidance(t *testing.T) {
|
||||
h := &Hermes{}
|
||||
err := h.ensureInstalled()
|
||||
if err == nil {
|
||||
t.Fatal("expected missing WSL guidance error")
|
||||
t.Fatal("expected WSL guidance error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "wsl --install") {
|
||||
t.Fatalf("expected WSL guidance, got %v", err)
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "wsl --install") {
|
||||
t.Fatalf("expected install command in guidance, got %v", err)
|
||||
}
|
||||
if !strings.Contains(msg, "hermes-agent.nousresearch.com") {
|
||||
t.Fatalf("expected docs link in guidance, got %v", err)
|
||||
}
|
||||
if strings.Contains(msg, "hermes is not installed") {
|
||||
t.Fatalf("guidance should not lead with 'hermes is not installed', got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ func TestIntegrationLookup(t *testing.T) {
|
||||
{"claude uppercase", "CLAUDE", true, "Claude Code"},
|
||||
{"claude mixed case", "Claude", true, "Claude Code"},
|
||||
{"codex", "codex", true, "Codex"},
|
||||
{"kimi", "kimi", true, "Kimi Code CLI"},
|
||||
{"droid", "droid", true, "Droid"},
|
||||
{"opencode", "opencode", true, "OpenCode"},
|
||||
{"unknown integration", "unknown", false, ""},
|
||||
@@ -74,7 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIntegrationRegistry(t *testing.T) {
|
||||
expectedIntegrations := []string{"claude", "codex", "droid", "opencode", "hermes"}
|
||||
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
|
||||
|
||||
for _, name := range expectedIntegrations {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
@@ -89,6 +90,15 @@ func TestIntegrationRegistry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHiddenIntegrationsExcludedFromVisibleLists(t *testing.T) {
|
||||
for _, info := range ListIntegrationInfos() {
|
||||
switch info.Name {
|
||||
case "cline", "vscode", "kimi":
|
||||
t.Fatalf("hidden integration %q should not appear in ListIntegrationInfos", info.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasLocalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
315
cmd/launch/kimi.go
Normal file
315
cmd/launch/kimi.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// Kimi implements Runner for Kimi Code CLI integration.
|
||||
type Kimi struct{}
|
||||
|
||||
const (
|
||||
kimiDefaultModelAlias = "ollama"
|
||||
kimiDefaultMaxContextSize = 32768
|
||||
)
|
||||
|
||||
var (
|
||||
kimiGOOS = runtime.GOOS
|
||||
kimiModelShowTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func (k *Kimi) String() string { return "Kimi Code CLI" }
|
||||
|
||||
func (k *Kimi) args(config string, extra []string) []string {
|
||||
args := []string{"--config", config}
|
||||
args = append(args, extra...)
|
||||
return args
|
||||
}
|
||||
|
||||
func (k *Kimi) Run(model string, args []string) error {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return fmt.Errorf("model is required")
|
||||
}
|
||||
if err := validateKimiPassthroughArgs(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build kimi config: %w", err)
|
||||
}
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command(bin, k.args(config, args)...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func findKimiBinary() (string, error) {
|
||||
if path, err := exec.LookPath("kimi"); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
home, _ := os.UserHomeDir()
|
||||
|
||||
var candidates []string
|
||||
switch kimiGOOS {
|
||||
case "windows":
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
|
||||
|
||||
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
|
||||
}
|
||||
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
|
||||
}
|
||||
default:
|
||||
candidates = append(candidates,
|
||||
filepath.Join(home, ".local", "bin", "kimi"),
|
||||
filepath.Join(home, "bin", "kimi"),
|
||||
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
|
||||
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
|
||||
)
|
||||
|
||||
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
|
||||
candidates = append(candidates,
|
||||
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
|
||||
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
|
||||
)
|
||||
}
|
||||
|
||||
// WSL users can inherit Windows env vars while launching from Linux shells.
|
||||
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
|
||||
}
|
||||
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
|
||||
}
|
||||
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
|
||||
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
|
||||
}
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("kimi binary not found")
|
||||
}
|
||||
|
||||
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
return candidates
|
||||
}
|
||||
|
||||
return append(candidates,
|
||||
filepath.Join(dir, "kimi.exe"),
|
||||
filepath.Join(dir, "kimi.cmd"),
|
||||
filepath.Join(dir, "kimi.bat"),
|
||||
)
|
||||
}
|
||||
|
||||
func windowsPathToWSL(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if len(trimmed) < 3 || trimmed[1] != ':' {
|
||||
return ""
|
||||
}
|
||||
|
||||
drive := strings.ToLower(string(trimmed[0]))
|
||||
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
|
||||
rest = strings.TrimPrefix(rest, "/")
|
||||
if rest == "" {
|
||||
return filepath.Join("/mnt", drive)
|
||||
}
|
||||
|
||||
return filepath.Join("/mnt", drive, rest)
|
||||
}
|
||||
|
||||
func validateKimiPassthroughArgs(args []string) error {
|
||||
for _, arg := range args {
|
||||
switch {
|
||||
case arg == "--config", strings.HasPrefix(arg, "--config="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
|
||||
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
|
||||
case arg == "--model", strings.HasPrefix(arg, "--model="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
|
||||
case arg == "-m", strings.HasPrefix(arg, "-m="):
|
||||
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
|
||||
cfg := map[string]any{
|
||||
"default_model": kimiDefaultModelAlias,
|
||||
"providers": map[string]any{
|
||||
kimiDefaultModelAlias: map[string]any{
|
||||
"type": "openai_legacy",
|
||||
"base_url": envconfig.ConnectableHost().String() + "/v1",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
},
|
||||
"models": map[string]any{
|
||||
kimiDefaultModelAlias: map[string]any{
|
||||
"provider": kimiDefaultModelAlias,
|
||||
"model": model,
|
||||
"max_context_size": maxContextSize,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func resolveKimiMaxContextSize(model string) int {
|
||||
if l, ok := lookupCloudModelLimit(model); ok {
|
||||
return l.Context
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return kimiDefaultMaxContextSize
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
|
||||
defer cancel()
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
return kimiDefaultMaxContextSize
|
||||
}
|
||||
|
||||
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
|
||||
return n
|
||||
}
|
||||
|
||||
return kimiDefaultMaxContextSize
|
||||
}
|
||||
|
||||
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
|
||||
for key, val := range modelInfo {
|
||||
if !strings.HasSuffix(key, ".context_length") {
|
||||
continue
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case float64:
|
||||
if v > 0 {
|
||||
return int(v), true
|
||||
}
|
||||
case int:
|
||||
if v > 0 {
|
||||
return v, true
|
||||
}
|
||||
case int64:
|
||||
if v > 0 {
|
||||
return int(v), true
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func ensureKimiInstalled() (string, error) {
|
||||
if path, err := findKimiBinary(); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
if err := checkKimiInstallerDependencies(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", fmt.Errorf("kimi installation cancelled")
|
||||
}
|
||||
|
||||
bin, args, err := kimiInstallerCommand(kimiGOOS)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("failed to install kimi: %w", err)
|
||||
}
|
||||
|
||||
path, err := findKimiBinary()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func checkKimiInstallerDependencies() error {
|
||||
switch kimiGOOS {
|
||||
case "windows":
|
||||
if _, err := exec.LookPath("powershell"); err != nil {
|
||||
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
|
||||
}
|
||||
default:
|
||||
var missing []string
|
||||
if _, err := exec.LookPath("curl"); err != nil {
|
||||
missing = append(missing, "curl: https://curl.se/")
|
||||
}
|
||||
if _, err := exec.LookPath("bash"); err != nil {
|
||||
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kimiInstallerCommand(goos string) (string, []string, error) {
|
||||
switch goos {
|
||||
case "windows":
|
||||
return "powershell", []string{
|
||||
"-NoProfile",
|
||||
"-ExecutionPolicy",
|
||||
"Bypass",
|
||||
"-Command",
|
||||
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
|
||||
}, nil
|
||||
case "darwin", "linux":
|
||||
return "bash", []string{
|
||||
"-c",
|
||||
"curl -LsSf https://code.kimi.com/install.sh | bash",
|
||||
}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
|
||||
}
|
||||
}
|
||||
636
cmd/launch/kimi_test.go
Normal file
636
cmd/launch/kimi_test.go
Normal file
@@ -0,0 +1,636 @@
|
||||
package launch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertKimiBinPath(t *testing.T, bin string) {
|
||||
t.Helper()
|
||||
base := strings.ToLower(filepath.Base(bin))
|
||||
if !strings.HasPrefix(base, "kimi") {
|
||||
t.Fatalf("bin = %q, want path to kimi executable", bin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKimiIntegration(t *testing.T) {
|
||||
k := &Kimi{}
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
if got := k.String(); got != "Kimi Code CLI" {
|
||||
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("implements Runner", func(t *testing.T) {
|
||||
var _ Runner = k
|
||||
})
|
||||
}
|
||||
|
||||
func TestKimiArgs(t *testing.T) {
|
||||
k := &Kimi{}
|
||||
|
||||
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
|
||||
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
|
||||
if !slices.Equal(got, want) {
|
||||
t.Fatalf("args() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsPathToWSL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "user profile path",
|
||||
in: `C:\Users\parth`,
|
||||
want: filepath.Join("/mnt", "c", "Users", "parth"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "path with trailing slash",
|
||||
in: `D:\tools\bin\`,
|
||||
want: filepath.Join("/mnt", "d", "tools", "bin"),
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "non windows path",
|
||||
in: "/home/parth",
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := windowsPathToWSL(tt.in)
|
||||
if !tt.valid {
|
||||
if got != "" {
|
||||
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindKimiBinaryFallbacks(t *testing.T) {
|
||||
oldGOOS := kimiGOOS
|
||||
t.Cleanup(func() { kimiGOOS = oldGOOS })
|
||||
|
||||
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
|
||||
homeDir := t.TempDir()
|
||||
setTestHome(t, homeDir)
|
||||
t.Setenv("PATH", t.TempDir())
|
||||
kimiGOOS = "linux"
|
||||
|
||||
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
t.Fatalf("failed to create candidate dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write kimi candidate: %v", err)
|
||||
}
|
||||
|
||||
got, err := findKimiBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("findKimiBinary() error = %v", err)
|
||||
}
|
||||
if got != target {
|
||||
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("windows appdata uv bin", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
t.Setenv("PATH", t.TempDir())
|
||||
kimiGOOS = "windows"
|
||||
|
||||
appDataDir := t.TempDir()
|
||||
t.Setenv("APPDATA", appDataDir)
|
||||
t.Setenv("LOCALAPPDATA", "")
|
||||
|
||||
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
|
||||
t.Fatalf("failed to create candidate dir: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write kimi candidate: %v", err)
|
||||
}
|
||||
|
||||
got, err := findKimiBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("findKimiBinary() error = %v", err)
|
||||
}
|
||||
if got != target {
|
||||
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
want string
|
||||
}{
|
||||
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
|
||||
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
|
||||
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
|
||||
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
|
||||
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
|
||||
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
|
||||
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
|
||||
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateKimiPassthroughArgs(tt.args)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for args %v", tt.args)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.want) {
|
||||
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKimiInlineConfig(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
|
||||
|
||||
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
|
||||
if err != nil {
|
||||
t.Fatalf("buildKimiInlineConfig() error = %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("config is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsed["default_model"] != "ollama" {
|
||||
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
|
||||
}
|
||||
|
||||
providers, ok := parsed["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
|
||||
}
|
||||
ollamaProvider, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
|
||||
}
|
||||
if ollamaProvider["type"] != "openai_legacy" {
|
||||
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
|
||||
}
|
||||
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
|
||||
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
|
||||
}
|
||||
if ollamaProvider["api_key"] != "ollama" {
|
||||
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
|
||||
}
|
||||
|
||||
models, ok := parsed["models"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("models missing or wrong type: %T", parsed["models"])
|
||||
}
|
||||
ollamaModel, ok := models["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
|
||||
}
|
||||
if ollamaModel["provider"] != "ollama" {
|
||||
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
|
||||
}
|
||||
if ollamaModel["model"] != "llama3.2" {
|
||||
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
|
||||
}
|
||||
if ollamaModel["max_context_size"] != float64(65536) {
|
||||
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
|
||||
|
||||
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
|
||||
if err != nil {
|
||||
t.Fatalf("buildKimiInlineConfig() error = %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
|
||||
t.Fatalf("config is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
providers, ok := parsed["providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
|
||||
}
|
||||
|
||||
ollamaProvider, ok := providers["ollama"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
|
||||
}
|
||||
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
|
||||
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveKimiMaxContextSize(t *testing.T) {
|
||||
t.Run("uses cloud limit when known", func(t *testing.T) {
|
||||
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
|
||||
if got != 262_144 {
|
||||
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses model show context length for local models", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
|
||||
}))
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
got := resolveKimiMaxContextSize("llama3.2")
|
||||
if got != 131_072 {
|
||||
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back to default when show fails", func(t *testing.T) {
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
oldTimeout := kimiModelShowTimeout
|
||||
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
|
||||
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
|
||||
|
||||
got := resolveKimiMaxContextSize("llama3.2")
|
||||
if got != kimiDefaultMaxContextSize {
|
||||
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
|
||||
k := &Kimi{}
|
||||
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
t.Fatalf("did not expect install prompt, got %q", prompt)
|
||||
return false, nil
|
||||
}
|
||||
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||
|
||||
err := k.Run("llama3.2", []string{"--model", "other"})
|
||||
if err == nil || !strings.Contains(err.Error(), "--model") {
|
||||
t.Fatalf("expected conflict error mentioning --model, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binary")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
setTestHome(t, tmpDir)
|
||||
logPath := filepath.Join(tmpDir, "kimi-args.log")
|
||||
script := fmt.Sprintf(`#!/bin/sh
|
||||
for arg in "$@"; do
|
||||
printf "%%s\n" "$arg" >> %q
|
||||
done
|
||||
exit 0
|
||||
`, logPath)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake kimi: %v", err)
|
||||
}
|
||||
t.Setenv("PATH", tmpDir)
|
||||
|
||||
srv := httptest.NewServer(http.NotFoundHandler())
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
k := &Kimi{}
|
||||
if err := k.Run("llama3.2", []string{"--quiet", "--print"}); err != nil {
|
||||
t.Fatalf("Run() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read args log: %v", err)
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
|
||||
if len(lines) < 4 {
|
||||
t.Fatalf("expected at least 4 args, got %v", lines)
|
||||
}
|
||||
if lines[0] != "--config" {
|
||||
t.Fatalf("first arg = %q, want --config", lines[0])
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
|
||||
t.Fatalf("config arg is not valid JSON: %v", err)
|
||||
}
|
||||
providers := cfg["providers"].(map[string]any)
|
||||
ollamaProvider := providers["ollama"].(map[string]any)
|
||||
if ollamaProvider["type"] != "openai_legacy" {
|
||||
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
|
||||
}
|
||||
|
||||
if lines[2] != "--quiet" || lines[3] != "--print" {
|
||||
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureKimiInstalled(t *testing.T) {
|
||||
oldGOOS := kimiGOOS
|
||||
t.Cleanup(func() { kimiGOOS = oldGOOS })
|
||||
|
||||
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
|
||||
t.Helper()
|
||||
oldConfirm := DefaultConfirmPrompt
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return fn(prompt)
|
||||
}
|
||||
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
|
||||
}
|
||||
|
||||
t.Run("already installed", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
writeFakeBinary(t, tmpDir, "kimi")
|
||||
kimiGOOS = runtime.GOOS
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
t.Fatalf("did not expect prompt, got %q", prompt)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||
}
|
||||
assertKimiBinPath(t, bin)
|
||||
})
|
||||
|
||||
t.Run("missing dependencies", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
t.Fatalf("did not expect prompt, got %q", prompt)
|
||||
return false, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
|
||||
t.Fatalf("expected missing dependency error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing and user declines install", func(t *testing.T) {
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
writeFakeBinary(t, tmpDir, "bash")
|
||||
kimiGOOS = "linux"
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
if !strings.Contains(prompt, "Kimi is not installed.") {
|
||||
t.Fatalf("unexpected prompt: %q", prompt)
|
||||
}
|
||||
return false, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
|
||||
t.Fatalf("expected cancellation error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
|
||||
installLog := filepath.Join(tmpDir, "bash.log")
|
||||
kimiPath := filepath.Join(tmpDir, "kimi")
|
||||
bashScript := fmt.Sprintf(`#!/bin/sh
|
||||
echo "$@" >> %q
|
||||
if [ "$1" = "-c" ]; then
|
||||
/bin/cat > %q <<'EOS'
|
||||
#!/bin/sh
|
||||
exit 0
|
||||
EOS
|
||||
/bin/chmod +x %q
|
||||
fi
|
||||
exit 0
|
||||
`, installLog, kimiPath, kimiPath)
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||
}
|
||||
assertKimiBinPath(t, bin)
|
||||
|
||||
logData, err := os.ReadFile(installLog)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read install log: %v", err)
|
||||
}
|
||||
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
|
||||
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
homeDir := t.TempDir()
|
||||
setTestHome(t, homeDir)
|
||||
|
||||
tmpBin := t.TempDir()
|
||||
t.Setenv("PATH", tmpBin)
|
||||
kimiGOOS = "linux"
|
||||
writeFakeBinary(t, tmpBin, "curl")
|
||||
|
||||
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
|
||||
bashScript := fmt.Sprintf(`#!/bin/sh
|
||||
if [ "$1" = "-c" ]; then
|
||||
/bin/mkdir -p %q
|
||||
/bin/cat > %q <<'EOS'
|
||||
#!/bin/sh
|
||||
exit 0
|
||||
EOS
|
||||
/bin/chmod +x %q
|
||||
fi
|
||||
exit 0
|
||||
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
|
||||
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
bin, err := ensureKimiInstalled()
|
||||
if err != nil {
|
||||
t.Fatalf("ensureKimiInstalled() error = %v", err)
|
||||
}
|
||||
if bin != installedKimi {
|
||||
t.Fatalf("bin = %q, want %q", bin, installedKimi)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("install command fails", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
|
||||
t.Fatalf("expected install failure error, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("uses POSIX shell fake binaries")
|
||||
}
|
||||
|
||||
setTestHome(t, t.TempDir())
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("PATH", tmpDir)
|
||||
kimiGOOS = "linux"
|
||||
writeFakeBinary(t, tmpDir, "curl")
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("failed to write fake bash: %v", err)
|
||||
}
|
||||
|
||||
withConfirm(t, func(prompt string) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
_, err := ensureKimiInstalled()
|
||||
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
|
||||
t.Fatalf("expected PATH guidance error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKimiInstallerCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
goos string
|
||||
wantBin string
|
||||
wantParts []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "linux",
|
||||
goos: "linux",
|
||||
wantBin: "bash",
|
||||
wantParts: []string{"-c", "install.sh"},
|
||||
},
|
||||
{
|
||||
name: "darwin",
|
||||
goos: "darwin",
|
||||
wantBin: "bash",
|
||||
wantParts: []string{"-c", "install.sh"},
|
||||
},
|
||||
{
|
||||
name: "windows",
|
||||
goos: "windows",
|
||||
wantBin: "powershell",
|
||||
wantParts: []string{"-Command", "install.ps1"},
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
goos: "freebsd",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bin, args, err := kimiInstallerCommand(tt.goos)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("kimiInstallerCommand() error = %v", err)
|
||||
}
|
||||
if bin != tt.wantBin {
|
||||
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
|
||||
}
|
||||
joined := strings.Join(args, " ")
|
||||
for _, part := range tt.wantParts {
|
||||
if !strings.Contains(joined, part) {
|
||||
t.Fatalf("args %q missing %q", joined, part)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -209,6 +209,7 @@ Supported integrations:
|
||||
copilot Copilot CLI (aliases: copilot-cli)
|
||||
droid Droid
|
||||
hermes Hermes Agent
|
||||
kimi Kimi Code CLI
|
||||
opencode OpenCode
|
||||
openclaw OpenClaw (aliases: clawdbot, moltbot)
|
||||
pi Pi
|
||||
@@ -587,7 +588,7 @@ func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, nam
|
||||
return nil
|
||||
}
|
||||
|
||||
if current == "" || needsConfigure || req.ModelOverride != "" || target != current {
|
||||
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
|
||||
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -409,7 +409,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *te
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsingSavedModel(t *testing.T) {
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
@@ -436,29 +436,30 @@ func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsing
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when saved model is reused for repair")
|
||||
t.Fatal("selector should not be called when saved model matches target")
|
||||
return "", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
return true, nil
|
||||
t.Fatal("confirm prompt should not run when saved model matches target")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||
t.Fatalf("expected missing live config to be rewritten from saved model: %s", diff)
|
||||
if len(runner.configured) != 0 {
|
||||
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
|
||||
}
|
||||
if runner.refreshCalls != 1 {
|
||||
t.Fatalf("expected repaired config to refresh runtime once, got %d", runner.refreshCalls)
|
||||
if runner.refreshCalls != 0 {
|
||||
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to use repaired saved model, got %q", runner.ranModel)
|
||||
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLiveConfig(t *testing.T) {
|
||||
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
setLaunchTestHome(t, tmpDir)
|
||||
withInteractiveSession(t, true)
|
||||
@@ -466,6 +467,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tags":
|
||||
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
|
||||
case "/api/show":
|
||||
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
|
||||
default:
|
||||
@@ -475,7 +478,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
|
||||
defer srv.Close()
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
|
||||
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
|
||||
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
|
||||
t.Fatalf("failed to save managed integration config: %v", err)
|
||||
}
|
||||
|
||||
@@ -483,7 +486,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
|
||||
withIntegrationOverride(t, "stubmanaged", runner)
|
||||
|
||||
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
|
||||
t.Fatal("selector should not be called when saved model is reused for repair")
|
||||
t.Fatal("selector should not be called when model override is provided")
|
||||
return "", nil
|
||||
}
|
||||
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
|
||||
@@ -492,19 +495,19 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
|
||||
|
||||
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
|
||||
Name: "stubmanaged",
|
||||
ConfigureOnly: true,
|
||||
ModelOverride: "gemma4",
|
||||
}); err != nil {
|
||||
t.Fatalf("LaunchIntegration returned error: %v", err)
|
||||
}
|
||||
|
||||
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
|
||||
t.Fatalf("expected configure-only flow to rewrite missing live config: %s", diff)
|
||||
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
|
||||
}
|
||||
if runner.refreshCalls != 1 {
|
||||
t.Fatalf("expected configure-only repair to refresh runtime once, got %d", runner.refreshCalls)
|
||||
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
|
||||
}
|
||||
if runner.ranModel != "" {
|
||||
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
|
||||
if runner.ranModel != "gemma4" {
|
||||
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,23 @@ var integrationSpecs = []*IntegrationSpec{
|
||||
Command: []string{"npm", "install", "-g", "@openai/codex"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "kimi",
|
||||
Runner: &Kimi{},
|
||||
Description: "Moonshot's coding agent for terminal and IDEs",
|
||||
Hidden: true,
|
||||
Install: IntegrationInstallSpec{
|
||||
CheckInstalled: func() bool {
|
||||
_, err := exec.LookPath("kimi")
|
||||
return err == nil
|
||||
},
|
||||
EnsureInstalled: func() error {
|
||||
_, err := ensureKimiInstalled()
|
||||
return err
|
||||
},
|
||||
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "copilot",
|
||||
Runner: &Copilot{},
|
||||
|
||||
@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
return filepath.Join(home, ".pi", "agent", "models.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "kimi",
|
||||
binary: "kimi",
|
||||
runner: &Kimi{},
|
||||
checkPath: func(home string) string {
|
||||
return filepath.Join(home, ".kimi", "config.toml")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
|
||||
if tt.name == "pi" {
|
||||
writeFakeBinary(t, binDir, "npm")
|
||||
}
|
||||
if tt.name == "kimi" {
|
||||
writeFakeBinary(t, binDir, "curl")
|
||||
writeFakeBinary(t, binDir, "bash")
|
||||
}
|
||||
t.Setenv("PATH", binDir)
|
||||
|
||||
configPath := tt.checkPath(home)
|
||||
|
||||
BIN
docs/images/hermes.png
Normal file
BIN
docs/images/hermes.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
@@ -2,7 +2,9 @@
|
||||
title: Hermes Agent
|
||||
---
|
||||
|
||||
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway.
|
||||
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
|
||||
|
||||

|
||||
|
||||
## Quick start
|
||||
|
||||
@@ -10,25 +12,56 @@ Hermes Agent is a self-improving AI agent built by Nous Research. It features au
|
||||
ollama launch hermes
|
||||
```
|
||||
|
||||
### Pull a model
|
||||
Ollama handles everything automatically:
|
||||
|
||||
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
|
||||
1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
|
||||
2. **Model** — Pick a model from the selector (local or cloud)
|
||||
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
|
||||
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
|
||||
|
||||
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `glm-5.1:cloud` — Reasoning and code generation
|
||||
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
|
||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
|
||||
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
|
||||
|
||||
## Connect messaging apps
|
||||
|
||||
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
|
||||
|
||||
```bash
|
||||
ollama pull kimi-k2.5:cloud
|
||||
hermes gateway setup
|
||||
```
|
||||
|
||||
See [Recommended models](#recommended-models) for more options.
|
||||
## Reconfigure
|
||||
|
||||
### Install
|
||||
Re-run the full setup wizard at any time:
|
||||
|
||||
```bash
|
||||
hermes setup
|
||||
```
|
||||
|
||||
## Manual setup
|
||||
|
||||
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
### Set up
|
||||
|
||||
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
|
||||
Hermes launches the setup wizard automatically. Choose **Quick setup**:
|
||||
|
||||
```
|
||||
How would you like to set up Hermes?
|
||||
@@ -84,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
|
||||
Launch hermes chat now? [Y/n]: Y
|
||||
```
|
||||
|
||||
## Recommended models
|
||||
|
||||
**Cloud models**:
|
||||
|
||||
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
|
||||
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
|
||||
- `glm-5.1:cloud` — Reasoning and code generation
|
||||
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
|
||||
|
||||
**Local models:**
|
||||
|
||||
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
|
||||
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
|
||||
|
||||
More models at [ollama.com/search](https://ollama.com/models).
|
||||
|
||||
## Configure later
|
||||
|
||||
Re-run the setup wizard at any time:
|
||||
|
||||
```bash
|
||||
hermes setup
|
||||
```
|
||||
|
||||
To configure just messaging:
|
||||
|
||||
```bash
|
||||
hermes setup gateway
|
||||
```
|
||||
|
||||
@@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
if testModel != "" {
|
||||
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
|
||||
t.Skip("logprobs not supported by all runners")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAPIChatLogprobs(t *testing.T) {
|
||||
if testModel != "" {
|
||||
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
|
||||
t.Skip("logprobs not supported by all runners")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Layer struct {
|
||||
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
return Layer{}, err
|
||||
}
|
||||
}
|
||||
if err := touchLayer(blob); err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
return Layer{
|
||||
MediaType: mediatype,
|
||||
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
if err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
if err := touchLayer(blob); err != nil {
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
return Layer{
|
||||
MediaType: mediatype,
|
||||
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func touchLayer(path string) error {
|
||||
now := time.Now()
|
||||
return os.Chtimes(path, now, now)
|
||||
}
|
||||
|
||||
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||
if l.Digest == "" {
|
||||
return nil, errors.New("opening layer with empty digest")
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var blobFilenamePattern = regexp.MustCompile(`^sha256-[0-9a-fA-F]{64}$`)
|
||||
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
@@ -22,6 +27,7 @@ type Manifest struct {
|
||||
filepath string
|
||||
fi os.FileInfo
|
||||
digest string
|
||||
name model.Name
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() (size int64) {
|
||||
@@ -36,6 +42,14 @@ func (m *Manifest) Digest() string {
|
||||
return m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) BlobDigest() string {
|
||||
if m.digest == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "sha256:" + m.digest
|
||||
}
|
||||
|
||||
func (m *Manifest) FileInfo() os.FileInfo {
|
||||
return m.fi
|
||||
}
|
||||
@@ -59,16 +73,7 @@ func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
|
||||
}
|
||||
|
||||
func (m *Manifest) Remove() error {
|
||||
if err := os.Remove(m.filepath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return PruneDirectory(manifests)
|
||||
return removeNamedManifestPaths(m.name)
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
@@ -80,6 +85,9 @@ func (m *Manifest) RemoveLayers() error {
|
||||
// Build set of digests still in use by other manifests
|
||||
inUse := make(map[string]struct{})
|
||||
for _, other := range ms {
|
||||
if other.BlobDigest() != "" {
|
||||
inUse[other.BlobDigest()] = struct{}{}
|
||||
}
|
||||
for _, layer := range append(other.Layers, other.Config) {
|
||||
if layer.Digest != "" {
|
||||
inUse[layer.Digest] = struct{}{}
|
||||
@@ -87,20 +95,27 @@ func (m *Manifest) RemoveLayers() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Remove layers not used by any other manifest
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == "" {
|
||||
digests := make([]string, 0, len(m.Layers)+2)
|
||||
digests = append(digests, m.BlobDigest())
|
||||
for _, layer := range m.Layers {
|
||||
digests = append(digests, layer.Digest)
|
||||
}
|
||||
digests = append(digests, m.Config.Digest)
|
||||
|
||||
// Remove manifest and layer blobs not used by any other manifest
|
||||
for _, digest := range digests {
|
||||
if digest == "" {
|
||||
continue
|
||||
}
|
||||
if _, used := inUse[layer.Digest]; used {
|
||||
if _, used := inUse[digest]; used {
|
||||
continue
|
||||
}
|
||||
blob, err := BlobsPath(layer.Digest)
|
||||
blob, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(blob); os.IsNotExist(err) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
slog.Debug("blob does not exist", "digest", digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -114,15 +129,36 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
p, root, err := resolveManifestPath(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := filepath.Join(manifests, n.Filepath())
|
||||
return parseManifestFile(normalizeLogicalName(n), p, root)
|
||||
}
|
||||
|
||||
func ReadManifestData(n model.Name) ([]byte, error) {
|
||||
if !n.IsFullyQualified() {
|
||||
return nil, model.Unqualified(n)
|
||||
}
|
||||
|
||||
p, root, err := resolveManifestPath(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, _, err := OpenVerifiedManifest(p, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
func parseManifestFile(name model.Name, path, root string) (*Manifest, error) {
|
||||
var m Manifest
|
||||
f, err := os.Open(p)
|
||||
f, digest, err := OpenVerifiedManifest(path, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -133,35 +169,19 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sha256sum := sha256.New()
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.filepath = p
|
||||
m.filepath = path
|
||||
m.fi = fi
|
||||
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
|
||||
m.digest = digest
|
||||
m.name = name
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := filepath.Join(manifests, name.Filepath())
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
m := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
@@ -169,33 +189,371 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
return json.NewEncoder(f).Encode(m)
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteManifestData(name, b.Bytes())
|
||||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
// WriteManifestData stores raw manifest bytes as a content-addressed blob and
|
||||
// updates the v2 named manifest path to reference that blob. Any legacy named
|
||||
// manifest for the same model is removed after the v2 write succeeds.
|
||||
func WriteManifestData(name model.Name, data []byte) error {
|
||||
if !name.IsFullyQualified() {
|
||||
return model.Unqualified(name)
|
||||
}
|
||||
|
||||
digest, err := writeManifestBlob(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := LinkManifest(name, digest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return removeLegacyManifestPaths(name)
|
||||
}
|
||||
|
||||
// LinkManifest updates the v2 named manifest path to reference an existing
|
||||
// manifest blob. It prefers symlinks, then hardlinks, then a byte-for-byte copy
|
||||
// for filesystems that do not support links.
|
||||
func LinkManifest(name model.Name, digest string) error {
|
||||
if !name.IsFullyQualified() {
|
||||
return model.Unqualified(name)
|
||||
}
|
||||
|
||||
manifestPath, err := V2PathForName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
blobPath, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := os.Stat(blobPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkBlobDigest(blobPath, digest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if rel, err := filepath.Rel(filepath.Dir(manifestPath), blobPath); err == nil {
|
||||
if err := os.Symlink(rel, manifestPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.Link(blobPath, manifestPath); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return copyManifestFile(blobPath, manifestPath)
|
||||
}
|
||||
|
||||
func writeManifestBlob(data []byte) (string, error) {
|
||||
sum := sha256.Sum256(data)
|
||||
digest := fmt.Sprintf("sha256:%x", sum)
|
||||
|
||||
blobPath, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if existing, err := os.ReadFile(blobPath); err == nil && bytes.Equal(existing, data) {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
blobs, err := BlobsPath("")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
temp, err := os.CreateTemp(blobs, "sha256-")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tempName := temp.Name()
|
||||
defer os.Remove(tempName)
|
||||
|
||||
if _, err := temp.Write(data); err != nil {
|
||||
temp.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := temp.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.Chmod(tempName, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.Rename(tempName, blobPath); err != nil {
|
||||
if err := os.Remove(blobPath); err != nil && !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if err := os.Rename(tempName, blobPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
func copyManifestFile(src, dst string) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
temp, err := os.CreateTemp(filepath.Dir(dst), ".manifest-*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tempName := temp.Name()
|
||||
defer os.Remove(tempName)
|
||||
|
||||
if _, err := io.Copy(temp, in); err != nil {
|
||||
temp.Close()
|
||||
return err
|
||||
}
|
||||
if err := temp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Chmod(tempName, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Rename(tempName, dst)
|
||||
}
|
||||
|
||||
// OpenVerifiedManifest opens a named manifest path rooted under root. Symlinks must resolve to a
|
||||
// blob whose basename is sha256-<hex> and whose bytes hash to that digest.
|
||||
// Regular-file manifests are treated as legacy/copy fallback manifests and are
|
||||
// opened without mutating the local store.
|
||||
func OpenVerifiedManifest(path, root string) (*os.File, string, error) {
|
||||
resolvedRoot, err := filepath.EvalSymlinks(root)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
info, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
target, err := evalAbs(path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if info.Mode()&os.ModeSymlink != 0 {
|
||||
base := filepath.Base(target)
|
||||
if !blobFilenamePattern.MatchString(base) {
|
||||
return nil, "", fmt.Errorf("manifest symlink target %q is not a sha256 blob", target)
|
||||
}
|
||||
|
||||
digest := strings.ToLower(strings.TrimPrefix(base, "sha256-"))
|
||||
blobPath, err := BlobsPath("sha256:" + digest)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if !sameFile(target, blobPath) {
|
||||
return nil, "", fmt.Errorf("manifest symlink target %q does not match blob %q", target, blobPath)
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err := checkBlobDigestReader(f, "sha256:"+digest); err != nil {
|
||||
f.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
f.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return f, digest, nil
|
||||
}
|
||||
|
||||
if !pathWithin(target, resolvedRoot) {
|
||||
return nil, "", fmt.Errorf("manifest path %q resolves outside manifest directory", path)
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
f.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
if _, err := f.Seek(0, io.SeekStart); err != nil {
|
||||
f.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
digest := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
return f, digest, nil
|
||||
}
|
||||
|
||||
// MigrateManifestLinks moves legacy named manifests into manifests-v2. This is currently unwired but
|
||||
// will be added in the future.
|
||||
func MigrateManifestLinks() (int, error) {
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// TODO(mxyng): use something less brittle
|
||||
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ms := make(map[model.Name]*Manifest)
|
||||
var migrated int
|
||||
for _, match := range matches {
|
||||
fi, err := os.Stat(match)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return migrated, err
|
||||
}
|
||||
if fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(manifests, match)
|
||||
if err != nil {
|
||||
return migrated, fmt.Errorf("%s %w", match, err)
|
||||
}
|
||||
|
||||
n := model.ParseNameFromFilepath(rel)
|
||||
if !n.IsFullyQualified() {
|
||||
slog.Warn("bad manifest name", "path", rel)
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := readManifestPath(match, manifests)
|
||||
if err != nil {
|
||||
return migrated, err
|
||||
}
|
||||
if err := WriteManifestData(normalizeLogicalName(n), data); err != nil {
|
||||
return migrated, err
|
||||
}
|
||||
migrated++
|
||||
}
|
||||
|
||||
return migrated, nil
|
||||
}
|
||||
|
||||
func readManifestPath(path, root string) ([]byte, error) {
|
||||
f, _, err := OpenVerifiedManifest(path, root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return io.ReadAll(f)
|
||||
}
|
||||
|
||||
func pathWithin(path, root string) bool {
|
||||
rel, err := filepath.Rel(root, path)
|
||||
return err == nil && rel != "." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) && rel != ".."
|
||||
}
|
||||
|
||||
func evalAbs(path string) (string, error) {
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.EvalSymlinks(abs)
|
||||
}
|
||||
|
||||
func sameFile(a, b string) bool {
|
||||
ai, err := os.Stat(a)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
bi, err := os.Stat(b)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return os.SameFile(ai, bi)
|
||||
}
|
||||
|
||||
func checkBlobDigest(path, digest string) error {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return checkBlobDigestReader(f, digest)
|
||||
}
|
||||
|
||||
func checkBlobDigestReader(r io.Reader, digest string) error {
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
got := fmt.Sprintf("sha256:%x", h.Sum(nil))
|
||||
if got != strings.ToLower(strings.Replace(digest, "-", ":", 1)) {
|
||||
return errors.New("digest mismatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
ms := make(map[model.Name]*Manifest)
|
||||
|
||||
manifestsV2, err := V2Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := collectManifests(ms, manifestsV2, continueOnError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manifests, err := Path()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := collectManifests(ms, manifests, continueOnError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
func collectManifests(ms map[model.Name]*Manifest, root string, continueOnError bool) error {
|
||||
// TODO(mxyng): use something less brittle
|
||||
matches, err := filepath.Glob(filepath.Join(root, "*", "*", "*", "*"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
fi, err := os.Lstat(match)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !fi.IsDir() {
|
||||
rel, err := filepath.Rel(manifests, match)
|
||||
rel, err := filepath.Rel(root, match)
|
||||
if err != nil {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", match, err)
|
||||
return fmt.Errorf("%s %w", match, err)
|
||||
}
|
||||
slog.Warn("bad filepath", "path", match, "error", err)
|
||||
continue
|
||||
@@ -204,16 +562,21 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
n := model.ParseNameFromFilepath(rel)
|
||||
if !n.IsValid() {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", rel, err)
|
||||
return fmt.Errorf("invalid manifest name: %s", rel)
|
||||
}
|
||||
slog.Warn("bad manifest name", "path", rel)
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
n = normalizeLogicalName(n)
|
||||
if _, ok := ms[n]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := parseManifestFile(n, match, root)
|
||||
if err != nil {
|
||||
if !continueOnError {
|
||||
return nil, fmt.Errorf("%s %w", n, err)
|
||||
return fmt.Errorf("%s %w", n, err)
|
||||
}
|
||||
slog.Warn("bad manifest", "name", n, "error", err)
|
||||
continue
|
||||
@@ -223,5 +586,5 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return ms, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func createManifest(t *testing.T, path, name string) {
|
||||
func createManifestAtRoot(t *testing.T, path, root, name string) {
|
||||
t.Helper()
|
||||
|
||||
p := filepath.Join(path, "manifests", name)
|
||||
p := filepath.Join(path, root, name)
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -29,6 +33,309 @@ func createManifest(t *testing.T, path, name string) {
|
||||
}
|
||||
}
|
||||
|
||||
func createManifest(t *testing.T, path, name string) {
|
||||
t.Helper()
|
||||
createManifestAtRoot(t, path, "manifests", name)
|
||||
}
|
||||
|
||||
func TestWriteManifestStoresManifestAsBlob(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
name := model.ParseName("example")
|
||||
config := Layer{
|
||||
MediaType: "application/vnd.docker.container.image.v1+json",
|
||||
Digest: "sha256:" + strings.Repeat("a", 64),
|
||||
Size: 12,
|
||||
}
|
||||
|
||||
if err := WriteManifest(name, config, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
manifestPath, err := V2PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
manifestData, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(manifestData)
|
||||
digest := fmt.Sprintf("sha256:%x", sum)
|
||||
blobPath, err := BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
blobData, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(blobData, manifestData) {
|
||||
t.Fatal("manifest path and blob content differ")
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := m.Digest(); got != fmt.Sprintf("%x", sum) {
|
||||
t.Fatalf("digest = %q, want %x", got, sum)
|
||||
}
|
||||
if got := m.BlobDigest(); got != digest {
|
||||
t.Fatalf("blob digest = %q, want %q", got, digest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNamedManifestLeavesLegacyManifestInPlace(t *testing.T) {
|
||||
models := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", models)
|
||||
|
||||
name := model.ParseName("example")
|
||||
createManifest(t, models, name.Filepath())
|
||||
|
||||
manifestPath, err := PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := ParseNamedManifest(name); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fi, err := os.Lstat(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fi.Mode()&os.ModeSymlink != 0 {
|
||||
t.Fatal("legacy manifest was converted to a symlink while reading")
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("legacy manifest read created blob: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateManifestLinks(t *testing.T) {
|
||||
models := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", models)
|
||||
|
||||
name := model.ParseName("example")
|
||||
createManifest(t, models, name.Filepath())
|
||||
|
||||
migrated, err := MigrateManifestLinks()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated != 1 {
|
||||
t.Fatalf("migrated = %d, want 1", migrated)
|
||||
}
|
||||
|
||||
manifestPath, err := V2PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
manifestData, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sum := sha256.Sum256(manifestData)
|
||||
blobPath, err := BlobsPath(fmt.Sprintf("sha256:%x", sum))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
blobData, err := os.ReadFile(blobPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(blobData, manifestData) {
|
||||
t.Fatal("migrated manifest path and blob content differ")
|
||||
}
|
||||
|
||||
legacyPath, err := PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("legacy manifest still exists: %v", err)
|
||||
}
|
||||
|
||||
migrated, err = MigrateManifestLinks()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if migrated != 0 {
|
||||
t.Fatalf("migrated on second run = %d, want 0", migrated)
|
||||
}
|
||||
|
||||
if _, err := MigrateManifestLinks(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
manifestDataAfter, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(manifestDataAfter, manifestData) {
|
||||
t.Fatal("second migration changed manifest content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveLayersRemovesUnreferencedManifestBlob(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
name := model.ParseName("example")
|
||||
if err := WriteManifest(name, Layer{}, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
blobPath, err := BlobsPath(m.BlobDigest())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(blobPath); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := m.Remove(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.RemoveLayers(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("manifest blob still exists: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNamedManifestRejectsUnsafeSymlinks(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
name := model.ParseName("example")
|
||||
manifestPath, err := PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("non blob basename", func(t *testing.T) {
|
||||
target := filepath.Join(t.TempDir(), "not-a-blob")
|
||||
if err := os.WriteFile(target, []byte(`{"schemaVersion":2}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Symlink(target, manifestPath); err != nil {
|
||||
t.Skipf("symlink unavailable: %v", err)
|
||||
}
|
||||
|
||||
_, err := ParseNamedManifest(name)
|
||||
if err == nil || !strings.Contains(err.Error(), "not a sha256 blob") {
|
||||
t.Fatalf("err = %v, want not a sha256 blob", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blob basename outside blob store", func(t *testing.T) {
|
||||
data := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`)
|
||||
sum := sha256.Sum256(data)
|
||||
target := filepath.Join(t.TempDir(), fmt.Sprintf("sha256-%x", sum))
|
||||
if err := os.WriteFile(target, data, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Remove(manifestPath); err != nil && !os.IsNotExist(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Symlink(target, manifestPath); err != nil {
|
||||
t.Skipf("symlink unavailable: %v", err)
|
||||
}
|
||||
|
||||
_, err := ParseNamedManifest(name)
|
||||
if err == nil || !strings.Contains(err.Error(), "does not match blob") {
|
||||
t.Fatalf("err = %v, want does not match blob", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseNamedManifestPrefersV2(t *testing.T) {
|
||||
models := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", models)
|
||||
|
||||
name := model.ParseName("example")
|
||||
|
||||
legacyPath, err := PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m.MediaType != "v2" {
|
||||
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifestsV2ShadowsLegacy(t *testing.T) {
|
||||
models := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", models)
|
||||
|
||||
name := model.ParseName("example")
|
||||
createManifest(t, models, name.Filepath())
|
||||
if err := WriteManifestData(name, []byte(`{"schemaVersion":2,"mediaType":"v2"}`)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ms, err := Manifests(true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(ms) != 1 {
|
||||
t.Fatalf("manifest count = %d, want 1", len(ms))
|
||||
}
|
||||
|
||||
var m *Manifest
|
||||
for gotName, gotManifest := range ms {
|
||||
if gotName.EqualFold(model.ParseName("example")) {
|
||||
m = gotManifest
|
||||
break
|
||||
}
|
||||
}
|
||||
if m == nil {
|
||||
t.Fatalf("missing v2 manifest for %s", name)
|
||||
}
|
||||
if m.MediaType != "v2" {
|
||||
t.Fatalf("media type = %q, want %q", m.MediaType, "v2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifests(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
ps []string
|
||||
|
||||
@@ -14,8 +14,23 @@ import (
|
||||
|
||||
var ErrInvalidDigestFormat = errors.New("invalid digest format")
|
||||
|
||||
const (
|
||||
legacyDirName = "manifests"
|
||||
v2DirName = "manifests-v2"
|
||||
defaultPublicHost = "registry.ollama.ai"
|
||||
v2CanonicalHost = "ollama.com"
|
||||
)
|
||||
|
||||
func Path() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
return manifestPath(legacyDirName)
|
||||
}
|
||||
|
||||
func V2Path() (string, error) {
|
||||
return manifestPath(v2DirName)
|
||||
}
|
||||
|
||||
func manifestPath(dir string) (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), dir)
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
@@ -25,6 +40,10 @@ func Path() (string, error) {
|
||||
|
||||
// PathForName returns the path to the manifest file for a specific model name.
|
||||
func PathForName(n model.Name) (string, error) {
|
||||
return LegacyPathForName(n)
|
||||
}
|
||||
|
||||
func LegacyPathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
@@ -37,6 +56,162 @@ func PathForName(n model.Name) (string, error) {
|
||||
return filepath.Join(manifests, n.Filepath()), nil
|
||||
}
|
||||
|
||||
func V2PathForName(n model.Name) (string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
|
||||
manifests, err := V2Path()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(manifests, canonicalV2Name(n).Filepath()), nil
|
||||
}
|
||||
|
||||
func ResolvePathForName(n model.Name) (string, error) {
|
||||
path, _, err := resolveManifestPath(n)
|
||||
return path, err
|
||||
}
|
||||
|
||||
func resolveManifestPath(n model.Name) (string, string, error) {
|
||||
if !n.IsValid() {
|
||||
return "", "", os.ErrNotExist
|
||||
}
|
||||
|
||||
v2Path, err := V2PathForName(n)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if _, err := os.Lstat(v2Path); err == nil {
|
||||
root, err := V2Path()
|
||||
return v2Path, root, err
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
legacyRoot, err := Path()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
for _, legacyName := range legacyNameCandidates(n) {
|
||||
legacyPath := filepath.Join(legacyRoot, legacyName.Filepath())
|
||||
if _, err := os.Lstat(legacyPath); err == nil {
|
||||
return legacyPath, legacyRoot, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", os.ErrNotExist
|
||||
}
|
||||
|
||||
func removeNamedManifestPaths(n model.Name) error {
|
||||
candidates := legacyNameCandidates(n)
|
||||
paths := make([]string, 0, 1+len(candidates))
|
||||
|
||||
v2Path, err := V2PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paths = append(paths, v2Path)
|
||||
|
||||
for _, legacyName := range candidates {
|
||||
legacyPath, err := LegacyPathForName(legacyName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paths = append(paths, legacyPath)
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return pruneManifestRoots()
|
||||
}
|
||||
|
||||
func removeLegacyManifestPaths(n model.Name) error {
|
||||
for _, legacyName := range legacyNameCandidates(n) {
|
||||
legacyPath, err := LegacyPathForName(legacyName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(legacyPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
legacyRoot, err := Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := PruneDirectory(legacyRoot); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pruneManifestRoots() error {
|
||||
roots := []func() (string, error){Path, V2Path}
|
||||
for _, rootFn := range roots {
|
||||
root, err := rootFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := PruneDirectory(root); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeLogicalName maps any public host to the legacy default
|
||||
// so that map keys use a single identity regardless of on-disk host.
|
||||
func normalizeLogicalName(n model.Name) model.Name {
|
||||
if isDefaultPublicHost(n.Host) {
|
||||
n.Host = defaultPublicHost
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// canonicalV2Name maps any public host to the v2 canonical host
|
||||
// for use in manifests-v2/ on-disk paths.
|
||||
func canonicalV2Name(n model.Name) model.Name {
|
||||
if isDefaultPublicHost(n.Host) {
|
||||
n.Host = v2CanonicalHost
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func legacyNameCandidates(n model.Name) []model.Name {
|
||||
names := []model.Name{n}
|
||||
if !isDefaultPublicHost(n.Host) {
|
||||
return names
|
||||
}
|
||||
|
||||
alt := n
|
||||
switch {
|
||||
case strings.EqualFold(n.Host, defaultPublicHost):
|
||||
alt.Host = v2CanonicalHost
|
||||
default:
|
||||
alt.Host = defaultPublicHost
|
||||
}
|
||||
|
||||
return append(names, alt)
|
||||
}
|
||||
|
||||
func isDefaultPublicHost(host string) bool {
|
||||
return strings.EqualFold(host, defaultPublicHost) || strings.EqualFold(host, v2CanonicalHost)
|
||||
}
|
||||
|
||||
func BlobsPath(digest string) (string, error) {
|
||||
// only accept actual sha256 digests
|
||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -33,6 +34,10 @@ import (
|
||||
"github.com/ollama/ollama/x/imagegen/transfer"
|
||||
)
|
||||
|
||||
// Blobs newer than this may belong to another process that has not written its
|
||||
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
|
||||
const layerPruneGracePeriod = time.Hour
|
||||
|
||||
var (
|
||||
errCapabilities = errors.New("does not support")
|
||||
errCapabilityCompletion = errors.New("completion")
|
||||
@@ -406,31 +411,12 @@ func CopyModel(src, dst model.Name) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
manifests, err := manifest.Path()
|
||||
data, err := manifest.ReadManifestData(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dstpath := filepath.Join(manifests, dst.Filepath())
|
||||
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srcpath := filepath.Join(manifests, src.Filepath())
|
||||
srcfile, err := os.Open(srcpath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcfile.Close()
|
||||
|
||||
dstfile, err := os.Create(dstpath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstfile.Close()
|
||||
|
||||
_, err = io.Copy(dstfile, srcfile)
|
||||
return err
|
||||
return manifest.WriteManifestData(dst, data)
|
||||
}
|
||||
|
||||
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
@@ -441,6 +427,10 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
|
||||
}
|
||||
|
||||
for _, manifest := range manifests {
|
||||
if manifest.BlobDigest() != "" {
|
||||
delete(deleteMap, manifest.BlobDigest())
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
@@ -478,10 +468,23 @@ func PruneLayers() error {
|
||||
}
|
||||
|
||||
for _, blob := range blobs {
|
||||
if blob.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := blob.Info()
|
||||
if err != nil {
|
||||
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
|
||||
continue
|
||||
}
|
||||
if time.Since(info.ModTime()) < layerPruneGracePeriod {
|
||||
continue
|
||||
}
|
||||
|
||||
name := blob.Name()
|
||||
name = strings.ReplaceAll(name, "-", ":")
|
||||
|
||||
_, err := manifest.BlobsPath(name)
|
||||
_, err = manifest.BlobsPath(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
|
||||
// remove invalid blobs (e.g. partial downloads)
|
||||
@@ -531,11 +534,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
// Use fast transfer for models with tensor layers (many small blobs)
|
||||
if hasTensorLayers(layers) {
|
||||
// Read raw manifest JSON to preserve tensor metadata fields
|
||||
manifestPath, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manifestJSON, err := os.ReadFile(manifestPath)
|
||||
manifestJSON, err := manifest.ReadManifestData(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -592,6 +591,14 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
if existingMf.Config.Digest != "" {
|
||||
deleteMap[existingMf.Config.Digest] = struct{}{}
|
||||
}
|
||||
if existingMf.BlobDigest() != "" {
|
||||
digest := existingMf.BlobDigest()
|
||||
if blob, err := manifest.BlobsPath(digest); err == nil {
|
||||
if _, err := os.Stat(blob); err == nil {
|
||||
deleteMap[digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if n.ProtocolScheme == "http" && !regOpts.Insecure {
|
||||
@@ -661,21 +668,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
if err := manifest.WriteManifestData(n, manifestData); err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't write manifest for %s", n.DisplayShortest()))
|
||||
return err
|
||||
}
|
||||
|
||||
err = os.WriteFile(fp, manifestData, 0o644)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
||||
slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
||||
|
||||
if !envconfig.NoPrune() && len(deleteMap) > 0 {
|
||||
fn(api.ProgressResponse{Status: "removing unused layers"})
|
||||
@@ -758,19 +756,11 @@ func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer
|
||||
// Write manifest
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
fp, err := manifest.PathForName(n)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
|
||||
if err := manifest.WriteManifestData(n, manifestData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.WriteFile(fp, manifestData, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("manifest written", "path", fp, "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
||||
slog.Debug("manifest written", "name", n.DisplayShortest(), "sha256", fmt.Sprintf("%x", sha256.Sum256(manifestData)), "size", len(manifestData))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,14 +5,58 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
|
||||
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
|
||||
|
||||
for _, digest := range []string{recentDigest, oldDigest} {
|
||||
p, err := manifest.BlobsPath(digest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(p, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
oldPath, err := manifest.BlobsPath(oldDigest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
|
||||
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := PruneLayers(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
recentPath, err := manifest.BlobsPath(recentDigest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(recentPath); err != nil {
|
||||
t.Fatalf("recent orphan was pruned: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("old orphan still exists: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create completion model (llama architecture without vision)
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
|
||||
@@ -116,6 +116,10 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
proxied, err := func() (bool, error) {
|
||||
switch r.URL.Path {
|
||||
case "/api/delete":
|
||||
if s.Fallback != nil {
|
||||
s.Fallback.ServeHTTP(rec, r)
|
||||
return true, nil
|
||||
}
|
||||
return false, s.handleDelete(rec, r)
|
||||
case "/api/pull":
|
||||
return false, s.handlePull(rec, r)
|
||||
|
||||
@@ -1770,13 +1770,15 @@ func Serve(ln net.Listener) error {
|
||||
return err
|
||||
}
|
||||
|
||||
manifestsPath, err := manifest.Path()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rootFn := range []func() (string, error){manifest.Path, manifest.V2Path} {
|
||||
manifestsPath, err := rootFn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil {
|
||||
return err
|
||||
if err := manifest.PruneDirectory(manifestsPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2408,7 +2410,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
// current approach uses the transition from parsed thinking content to
|
||||
// parsed non-thinking content as the signal to turn constraining on
|
||||
|
||||
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
||||
// TODO(parthsareen): temporary fix for https://github.com/ollama/ollama/issues/15260.
|
||||
// To revisit for other models and have a consistent pattern across models through parsers.
|
||||
forceImmediate := m.Config.Parser == "gemma4" && req.Think != nil && !req.Think.Bool()
|
||||
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && !forceImmediate && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
||||
currentFormat = nil
|
||||
}
|
||||
|
||||
|
||||
@@ -109,12 +109,44 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.HasSuffix(filepath.ToSlash(p), "/blobs/*") {
|
||||
actual = slices.DeleteFunc(actual, isManifestBlobForTest)
|
||||
}
|
||||
|
||||
if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("file exists mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func checkManifestFiles(t *testing.T, names ...string) {
|
||||
t.Helper()
|
||||
|
||||
expect := make([]string, len(names))
|
||||
for i, name := range names {
|
||||
p, err := manifest.V2PathForName(model.ParseName(name))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expect[i] = p
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(envconfig.Models(), "manifests-v2", "*", "*", "*", "*"), expect)
|
||||
}
|
||||
|
||||
func isManifestBlobForTest(path string) bool {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var m manifest.Manifest
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return m.SchemaVersion != 0 && m.MediaType != "" && (m.Config.Digest != "" || len(m.Layers) > 0)
|
||||
}
|
||||
|
||||
func TestCreateFromBin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -136,9 +168,7 @@ func TestCreateFromBin(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
@@ -196,9 +226,7 @@ func TestCreateFromModel(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
@@ -210,10 +238,7 @@ func TestCreateFromModel(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test", "test2")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
@@ -306,9 +331,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
@@ -327,9 +350,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
@@ -357,9 +378,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"),
|
||||
@@ -378,9 +397,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
@@ -411,9 +428,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
@@ -436,10 +451,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test", "test2")
|
||||
|
||||
// Display contents of each blob in the directory
|
||||
blobDir := filepath.Join(p, "blobs")
|
||||
@@ -495,10 +507,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test", "test2")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
|
||||
@@ -555,9 +564,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||
@@ -589,10 +596,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test", "test2")
|
||||
|
||||
// Old layers will not have been pruned
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
@@ -650,9 +654,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"),
|
||||
@@ -850,9 +852,7 @@ func TestCreateLicenses(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
|
||||
|
||||
@@ -42,10 +42,7 @@ func TestDelete(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test", "test2")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
@@ -60,9 +57,7 @@ func TestDelete(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
checkManifestFiles(t, "test2")
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
@@ -76,7 +71,7 @@ func TestDelete(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
checkManifestFiles(t)
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
||||
}
|
||||
|
||||
@@ -109,7 +104,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
t.Errorf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
checkManifestFiles(t)
|
||||
}
|
||||
|
||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
@@ -129,14 +124,12 @@ func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||
})
|
||||
checkManifestFiles(t, "gpt-oss:20b-cloud")
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
checkManifestFiles(t)
|
||||
}
|
||||
|
||||
@@ -2108,6 +2108,132 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestChatFormatWithThinkFalse verifies that when a model uses a builtin
|
||||
// parser that supports thinking (e.g. gemma4) and the request explicitly
|
||||
// disables thinking (think=false), the format constraint is passed to the
|
||||
// first and only completion call. Previously, format was deferred for all
|
||||
// thinking-capable parsers and only re-applied after an end-of-thinking
|
||||
// transition — a transition that never fires when thinking is off. See
|
||||
// https://github.com/ollama/ollama/issues/15260.
|
||||
func TestChatFormatWithThinkFalse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := &mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{llama: mock}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
// Use the gemma4 builtin parser — it reports HasThinkingSupport=true, which
|
||||
// adds CapabilityThinking to the model and previously triggered deferral of
|
||||
// the format even when the user passed think=false.
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-gemma4-parser",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Parser: "gemma4",
|
||||
Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }}{{ end }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("create: expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}},"required":["answer"]}`)
|
||||
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
)
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
requestsMu.Unlock()
|
||||
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
streamRequest := false
|
||||
think := false
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-gemma4-parser",
|
||||
Messages: []api.Message{{Role: "user", Content: "Respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("chat: expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if len(requests) != 1 {
|
||||
t.Fatalf("expected a single completion call, got %d", len(requests))
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[0].Format)) {
|
||||
t.Errorf("expected first completion format to match the request format, got %q", string(requests[0].Format))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateUnload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -658,11 +658,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
||||
checkManifestList := func() {
|
||||
t.Helper()
|
||||
|
||||
mandir := filepath.Join(os.Getenv("OLLAMA_MODELS"), "manifests/")
|
||||
mandir, err := manifest.V2Path()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to resolve v2 manifest path: %v", err)
|
||||
}
|
||||
var entries []string
|
||||
t.Logf("dir entries:")
|
||||
fsys := os.DirFS(mandir)
|
||||
err := fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
||||
err = fs.WalkDir(fsys, ".", func(path string, info fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -685,7 +688,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
||||
|
||||
g := entries[0] // raw path
|
||||
g = filepath.ToSlash(g)
|
||||
w := model.ParseName(wantStableName).Filepath()
|
||||
wp, err := manifest.V2PathForName(model.ParseName(wantStableName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to resolve expected manifest path: %v", err)
|
||||
}
|
||||
w, err := filepath.Rel(mandir, wp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make expected manifest path relative: %v", err)
|
||||
}
|
||||
w = filepath.ToSlash(w)
|
||||
if g != w {
|
||||
t.Errorf("\ngot: %s\nwant: %s", g, w)
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
rootmanifest "github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// ManifestLayer represents a layer in the manifest.
|
||||
@@ -49,9 +51,7 @@ func DefaultManifestDir() string {
|
||||
// LoadManifest loads a manifest for the given model name.
|
||||
// Model name format: "modelname" or "modelname:tag" or "host/namespace/name:tag"
|
||||
func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||
manifestPath := resolveManifestPath(modelName)
|
||||
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
data, err := rootmanifest.ReadManifestData(model.ParseName(modelName))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read manifest: %w", err)
|
||||
}
|
||||
@@ -67,36 +67,6 @@ func LoadManifest(modelName string) (*ModelManifest, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveManifestPath converts a model name to a manifest file path.
|
||||
func resolveManifestPath(modelName string) string {
|
||||
// Parse model name into components
|
||||
// Default: registry.ollama.ai/library/<name>/<tag>
|
||||
host := "registry.ollama.ai"
|
||||
namespace := "library"
|
||||
name := modelName
|
||||
tag := "latest"
|
||||
|
||||
// Handle explicit tag
|
||||
if idx := strings.LastIndex(name, ":"); idx != -1 {
|
||||
tag = name[idx+1:]
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
// Handle full path like "host/namespace/name"
|
||||
parts := strings.Split(name, "/")
|
||||
switch len(parts) {
|
||||
case 3:
|
||||
host = parts[0]
|
||||
namespace = parts[1]
|
||||
name = parts[2]
|
||||
case 2:
|
||||
namespace = parts[0]
|
||||
name = parts[1]
|
||||
}
|
||||
|
||||
return filepath.Join(DefaultManifestDir(), host, namespace, name, tag)
|
||||
}
|
||||
|
||||
// BlobPath returns the full path to a blob given its digest.
|
||||
func (m *ModelManifest) BlobPath(digest string) string {
|
||||
// Convert "sha256:abc123" to "sha256-abc123"
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package manifest
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
rootmanifest "github.com/ollama/ollama/manifest"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestTotalTensorSize(t *testing.T) {
|
||||
@@ -55,3 +59,39 @@ func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
|
||||
t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadManifestPrefersV2(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
name := model.ParseName("example")
|
||||
|
||||
legacyPath, err := rootmanifest.PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(legacyPath), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(legacyPath, []byte(`{"schemaVersion":2,"mediaType":"legacy"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
v2Path, err := rootmanifest.V2PathForName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(v2Path), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(v2Path, []byte(`{"schemaVersion":2,"mediaType":"v2"}`), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m, err := LoadManifest(name.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if m.Manifest.MediaType != "v2" {
|
||||
t.Fatalf("media type = %q, want %q", m.Manifest.MediaType, "v2")
|
||||
}
|
||||
}
|
||||
|
||||
5
x/mlxrunner/cache/cache.go
vendored
5
x/mlxrunner/cache/cache.go
vendored
@@ -337,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil
|
||||
}
|
||||
liveLen := min(c.offset, c.keys.Dim(2))
|
||||
return []*mlx.Array{
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,20 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
|
||||
type completionRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options *completionOpts `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type completionOpts struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
MinP float32 `json:"min_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
RepeatLastN int `json:"repeat_last_n,omitempty"`
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Options api.Options
|
||||
Logprobs bool
|
||||
TopLogprobs int
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
@@ -177,6 +168,8 @@ type CompletionResponse struct {
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
|
||||
Logprobs []llm.Logprob
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
@@ -201,19 +194,13 @@ func (c *Client) Close() error {
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
creq := completionRequest{
|
||||
Prompt: req.Prompt,
|
||||
creq := CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = &completionOpts{
|
||||
Temperature: req.Options.Temperature,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TopK: req.Options.TopK,
|
||||
RepeatLastN: req.Options.RepeatLastN,
|
||||
PresencePenalty: req.Options.PresencePenalty,
|
||||
NumPredict: req.Options.NumPredict,
|
||||
}
|
||||
creq.Options = *req.Options
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
@@ -262,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: raw.EvalDuration,
|
||||
Logprobs: raw.Logprobs,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
|
||||
@@ -62,3 +62,25 @@ var LogitSoftcap = Compile2(
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
|
||||
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
|
||||
// per-expert scores after top-k) and the post-bias negation (used as the
|
||||
// argpartition key for top-k) share a single kernel.
|
||||
var sigmoidRouterFused = Compile(
|
||||
"SigmoidRouter",
|
||||
func(in ...*Array) []*Array {
|
||||
gates, bias := in[0], in[1]
|
||||
orig := gates.Sigmoid()
|
||||
neg := orig.Add(bias).Negative()
|
||||
return []*Array{orig, neg}
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
|
||||
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
|
||||
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
|
||||
out := sigmoidRouterFused(gates, bias)
|
||||
return out[0], out[1]
|
||||
}
|
||||
|
||||
@@ -238,6 +238,9 @@ func (t Array) Float() float64 {
|
||||
}
|
||||
|
||||
func (t Array) Ints() []int {
|
||||
if dt := t.DType(); dt != DTypeInt32 {
|
||||
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
|
||||
}
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
@@ -246,6 +249,9 @@ func (t Array) Ints() []int {
|
||||
}
|
||||
|
||||
func (t Array) Floats() []float32 {
|
||||
if dt := t.DType(); dt != DTypeFloat32 {
|
||||
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
|
||||
}
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
|
||||
@@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS")
|
||||
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL")
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
@@ -169,6 +175,12 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("SCATTER_ADD_AXIS")
|
||||
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
|
||||
@@ -7,11 +7,15 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func prefillChunkSize() int {
|
||||
@@ -25,17 +29,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
mlx.ResetPeakMemory()
|
||||
ctx := request.Ctx
|
||||
var (
|
||||
sample, logprobs *mlx.Array
|
||||
nextSample, nextLogprobs *mlx.Array
|
||||
)
|
||||
var sample, nextSample sampler.Result
|
||||
|
||||
defer func() {
|
||||
if request.Sampler != nil {
|
||||
request.Sampler.Free()
|
||||
}
|
||||
mlx.Unpin(sample, logprobs)
|
||||
mlx.Unpin(nextSample, nextLogprobs)
|
||||
mlx.Unpin(sample.Arrays()...)
|
||||
mlx.Unpin(nextSample.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
@@ -60,10 +61,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(inputs)
|
||||
if request.Options.MaxTokens <= 0 {
|
||||
request.Options.MaxTokens = maxGenerate
|
||||
if request.Options.NumPredict <= 0 {
|
||||
request.Options.NumPredict = maxGenerate
|
||||
} else {
|
||||
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
|
||||
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
|
||||
}
|
||||
|
||||
request.Sampler.ResetHistory(inputs)
|
||||
@@ -135,41 +136,38 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
step := func(token *mlx.Array) sampler.Result {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
sample := request.Sampler.Sample(logprobs)
|
||||
|
||||
mlx.Pin(sample, logprobs)
|
||||
sample := request.Sampler.Sample(logits)
|
||||
mlx.Pin(sample.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(sample, logprobs)
|
||||
|
||||
return sample, logprobs
|
||||
mlx.AsyncEval(sample.Arrays()...)
|
||||
return sample
|
||||
}
|
||||
|
||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
|
||||
var b bytes.Buffer
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
|
||||
for i := range request.Options.MaxTokens {
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
|
||||
for i := range request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Sampler.AppendToken(sample)
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
request.Sampler.AppendToken(sample.Token)
|
||||
nextSample = step(sample.Token)
|
||||
|
||||
if i == 0 {
|
||||
mlx.Eval(sample)
|
||||
mlx.Eval(sample.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
output := int32(sample.Int())
|
||||
output := int32(sample.Token.Int())
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
@@ -178,17 +176,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- CompletionResponse{
|
||||
Content: r.Decode(output, &b),
|
||||
}:
|
||||
if resp, ok := dec.decode(sample); ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- resp:
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
nextSample, nextLogprobs = nil, nil
|
||||
mlx.Unpin(sample.Arrays()...)
|
||||
sample, nextSample = nextSample, sampler.Result{}
|
||||
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
@@ -204,13 +201,57 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
|
||||
token := r.Tokenizer.Decode([]int32{sample})
|
||||
// decoder serializes sampled tokens into response chunks, holding bytes
|
||||
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
|
||||
// with those bytes so Content and Logprobs stay aligned when a chunk does
|
||||
// flush.
|
||||
type decoder struct {
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
buf bytes.Buffer
|
||||
logprobs []llm.Logprob
|
||||
}
|
||||
|
||||
if _, err := b.WriteString(token); err != nil {
|
||||
slog.Error("Failed to write token to buffer", "error", err)
|
||||
return ""
|
||||
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
|
||||
output := int32(res.Token.Int())
|
||||
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
|
||||
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
|
||||
|
||||
content := flushValidUTF8Prefix(&d.buf)
|
||||
if content == "" {
|
||||
return CompletionResponse{}, false
|
||||
}
|
||||
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
|
||||
d.logprobs = nil
|
||||
return resp, true
|
||||
}
|
||||
|
||||
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
|
||||
if sample.Logprob == nil {
|
||||
return nil
|
||||
}
|
||||
tok := func(id int32) string { return decode([]int32{id}) }
|
||||
|
||||
out := llm.Logprob{
|
||||
TokenLogprob: llm.TokenLogprob{
|
||||
Token: tok(int32(sample.Token.Int())),
|
||||
Logprob: float64(sample.Logprob.Floats()[0]),
|
||||
},
|
||||
}
|
||||
|
||||
return flushValidUTF8Prefix(b)
|
||||
if sample.TopTokens != nil {
|
||||
ids := sample.TopTokens.Ints()
|
||||
vals := sample.TopLogprobs.Floats()
|
||||
pairs := make([]llm.TokenLogprob, len(ids))
|
||||
for i, id := range ids {
|
||||
pairs[i] = llm.TokenLogprob{
|
||||
Token: tok(int32(id)),
|
||||
Logprob: float64(vals[i]),
|
||||
}
|
||||
}
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
return pairs[i].Logprob > pairs[j].Logprob
|
||||
})
|
||||
out.TopLogprobs = pairs
|
||||
}
|
||||
return []llm.Logprob{out}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
TextCompletionsRequest
|
||||
CompletionRequest
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(Request) error
|
||||
|
||||
@@ -28,22 +28,6 @@ type Request struct {
|
||||
Sampler *sample.Sampler
|
||||
}
|
||||
|
||||
type TextCompletionsRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TopK int `json:"top_k"`
|
||||
RepeatLastN int `json:"repeat_last_n"`
|
||||
PresencePenalty float32 `json:"presence_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
|
||||
// Deprecated: use MaxTokens instead
|
||||
NumPredict int `json:"num_predict"`
|
||||
} `json:"options"`
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
|
||||
249
x/mlxrunner/sample/logprob_test.go
Normal file
249
x/mlxrunner/sample/logprob_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
//go:build mlx
|
||||
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// logprobEntry is the (token id, logprob) pair returned by the sampler's
|
||||
// top-K extraction, used after the test-side descending sort.
|
||||
type logprobEntry struct {
|
||||
id int
|
||||
logprob float64
|
||||
}
|
||||
|
||||
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
|
||||
// and returns the greedily-sampled token id, its logprob, and the top-K
|
||||
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
|
||||
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
|
||||
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
|
||||
t.Helper()
|
||||
|
||||
s := New(Options{Logprobs: true, TopLogprobs: topK})
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
tensor := mlx.FromValues(logits, 1, len(logits))
|
||||
res := s.Sample(tensor)
|
||||
|
||||
mlx.Pin(res.Arrays()...)
|
||||
defer mlx.Unpin(res.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.Eval(res.Arrays()...)
|
||||
|
||||
selected := res.Token.Int()
|
||||
selLP := float64(res.Logprob.Floats()[0])
|
||||
|
||||
var top []logprobEntry
|
||||
if topK > 0 && res.TopTokens != nil {
|
||||
ids := res.TopTokens.Ints()
|
||||
vals := res.TopLogprobs.Floats()
|
||||
top = make([]logprobEntry, len(ids))
|
||||
for i, id := range ids {
|
||||
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
|
||||
}
|
||||
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
|
||||
}
|
||||
return selected, selLP, top
|
||||
}
|
||||
|
||||
func TestSampleLogprobsBasic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []float32
|
||||
topK int
|
||||
wantSelectedID int
|
||||
wantTopLen int
|
||||
}{
|
||||
{
|
||||
name: "single token without top logprobs",
|
||||
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
||||
topK: 0,
|
||||
wantSelectedID: 0,
|
||||
wantTopLen: 0,
|
||||
},
|
||||
{
|
||||
name: "single token with top logprobs",
|
||||
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
||||
topK: 3,
|
||||
wantSelectedID: 0,
|
||||
wantTopLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
|
||||
if selected != tt.wantSelectedID {
|
||||
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
|
||||
}
|
||||
if len(top) != tt.wantTopLen {
|
||||
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsNumericalStability(t *testing.T) {
|
||||
logits := []float32{1000.0, 999.0, 998.0}
|
||||
_, selLP, top := runSampleLogprobs(t, logits, 3)
|
||||
|
||||
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
|
||||
t.Errorf("selected logprob is not finite: %f", selLP)
|
||||
}
|
||||
for i, e := range top {
|
||||
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
|
||||
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(top); i++ {
|
||||
if top[i].logprob > top[i-1].logprob {
|
||||
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []float32
|
||||
}{
|
||||
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
|
||||
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
|
||||
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
|
||||
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
|
||||
|
||||
if selLP > 0 {
|
||||
t.Errorf("selected logprob should be <= 0, got %f", selLP)
|
||||
}
|
||||
for i, e := range top {
|
||||
if e.logprob > 0 {
|
||||
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.name == "uniform" {
|
||||
want := 1.0 / float64(len(tt.logits))
|
||||
got := math.Exp(selLP)
|
||||
if math.Abs(got-want) > 1e-6 {
|
||||
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i < len(top); i++ {
|
||||
if top[i].logprob > top[i-1].logprob {
|
||||
t.Errorf("top logprobs not descending at %d: %f > %f",
|
||||
i, top[i].logprob, top[i-1].logprob)
|
||||
}
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, e := range top {
|
||||
if e.id == selected {
|
||||
found = true
|
||||
if math.Abs(e.logprob-selLP) > 1e-6 {
|
||||
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("selected token %d not present in top-K", selected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []float32
|
||||
}{
|
||||
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
|
||||
{"large differences", []float32{10.0, 0.0, -10.0}},
|
||||
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
|
||||
{"very large values", []float32{500.0, 499.0, 498.0}},
|
||||
{"very small values", []float32{-500.0, -499.0, -498.0}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
|
||||
if len(top) != len(tt.logits) {
|
||||
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
|
||||
}
|
||||
|
||||
var sum float64
|
||||
for _, e := range top {
|
||||
p := math.Exp(e.logprob)
|
||||
if p < 0 || p > 1 {
|
||||
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
|
||||
}
|
||||
sum += p
|
||||
}
|
||||
|
||||
if math.Abs(sum-1.0) > 1e-5 {
|
||||
t.Errorf("probabilities sum = %f, want 1.0", sum)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
|
||||
logits := []float32{3.0, 1.0, 2.0, 0.5}
|
||||
|
||||
maxIdx := 0
|
||||
for i, v := range logits[1:] {
|
||||
if v > logits[maxIdx] {
|
||||
maxIdx = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
|
||||
|
||||
if selected != maxIdx {
|
||||
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
|
||||
}
|
||||
|
||||
if top[0].id != maxIdx {
|
||||
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
|
||||
}
|
||||
if math.Abs(top[0].logprob-selLP) > 1e-6 {
|
||||
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsTopKOrdering(t *testing.T) {
|
||||
// Logits chosen so argmax order differs from index order.
|
||||
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
|
||||
wantOrder := []int{1, 3, 4, 0, 2}
|
||||
|
||||
_, _, top := runSampleLogprobs(t, logits, len(logits))
|
||||
|
||||
if len(top) != len(wantOrder) {
|
||||
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
|
||||
}
|
||||
for i, e := range top {
|
||||
if e.id != wantOrder[i] {
|
||||
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(top); i++ {
|
||||
if top[i].logprob > top[i-1].logprob {
|
||||
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
|
||||
i, top[i].logprob, i-1, top[i-1].logprob)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,47 +8,76 @@ import (
|
||||
|
||||
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
||||
|
||||
type Options struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
MinP float32
|
||||
TopK int
|
||||
RepeatLastN int
|
||||
RepeatPenalty float32
|
||||
PresencePenalty float32
|
||||
FrequencyPenalty float32
|
||||
|
||||
// Logprobs causes Sample to populate Result.Logprob with the selected
|
||||
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
||||
Logprobs bool
|
||||
TopLogprobs int
|
||||
}
|
||||
|
||||
type Sampler struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
MinP float32
|
||||
TopK int
|
||||
RepeatLastN int
|
||||
PresencePenalty float32
|
||||
Options
|
||||
|
||||
history *mlx.Array
|
||||
historyLen int
|
||||
transforms []Transform
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler {
|
||||
s := &Sampler{
|
||||
Temperature: temp,
|
||||
TopP: top_p,
|
||||
MinP: min_p,
|
||||
TopK: top_k,
|
||||
RepeatLastN: repeatLastN,
|
||||
PresencePenalty: presencePenalty,
|
||||
// Result bundles the outputs of one decode step. The logprob tensors are
|
||||
// populated only when the sampler is configured to report them.
|
||||
type Result struct {
|
||||
Token *mlx.Array // sampled token id, shape [B]
|
||||
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
|
||||
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
|
||||
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
|
||||
}
|
||||
|
||||
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
||||
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
||||
// fields stay nil; the mlx helpers skip them.
|
||||
func (r Result) Arrays() []*mlx.Array {
|
||||
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
||||
}
|
||||
|
||||
func New(opts Options) *Sampler {
|
||||
if opts.RepeatPenalty <= 0 {
|
||||
opts.RepeatPenalty = 1
|
||||
}
|
||||
|
||||
s := &Sampler{Options: opts}
|
||||
|
||||
var transforms []Transform
|
||||
if presencePenalty != 0 {
|
||||
if s.usesHistory() {
|
||||
transforms = append(transforms, penalty)
|
||||
}
|
||||
|
||||
if top_p > 0 && top_p < 1 {
|
||||
transforms = append(transforms, topP)
|
||||
}
|
||||
|
||||
if min_p != 0 {
|
||||
transforms = append(transforms, minP)
|
||||
}
|
||||
|
||||
if top_k > 0 {
|
||||
hasTopP := opts.TopP > 0 && opts.TopP < 1
|
||||
hasTopK := opts.TopK > 0
|
||||
switch {
|
||||
case hasTopP:
|
||||
// topKTopP always does a full descending sort for the top-P
|
||||
// cumulative mask and opportunistically masks top-K during the
|
||||
// same pass when it is also configured.
|
||||
transforms = append(transforms, topKTopP)
|
||||
case hasTopK:
|
||||
// Argpartition (partial sort) is cheaper than a full sort.
|
||||
transforms = append(transforms, topK)
|
||||
}
|
||||
|
||||
if temp == 0 {
|
||||
if opts.MinP != 0 {
|
||||
transforms = append(transforms, minP)
|
||||
}
|
||||
|
||||
if opts.Temperature == 0 {
|
||||
transforms = append(transforms, greedy)
|
||||
} else {
|
||||
transforms = append(transforms, temperature)
|
||||
@@ -59,7 +88,7 @@ func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty flo
|
||||
}
|
||||
|
||||
func (s *Sampler) usesHistory() bool {
|
||||
return s.PresencePenalty != 0
|
||||
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
|
||||
}
|
||||
|
||||
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
||||
@@ -115,75 +144,138 @@ func (s *Sampler) Free() {
|
||||
s.setHistory(nil, 0)
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
// Sample runs the configured transform chain on the raw per-token logits
|
||||
// and returns the sampled token id plus, when configured, the reported
|
||||
// log-probability tensors for the selected token and the top-K tokens.
|
||||
func (s *Sampler) Sample(logits *mlx.Array) Result {
|
||||
scores := logits
|
||||
for _, transform := range s.transforms {
|
||||
logits = transform(s, logits)
|
||||
scores = transform(s, scores)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
res := Result{Token: scores}
|
||||
|
||||
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false)
|
||||
}
|
||||
|
||||
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
|
||||
}
|
||||
|
||||
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||
if s.TopP <= 0 || s.TopP >= 1 {
|
||||
return logprobs
|
||||
if s.Logprobs {
|
||||
// Compute log_softmax in fp32 and subtract the max before
|
||||
// logsumexp so the final subtraction stays on small values.
|
||||
// Otherwise it cancels two large numbers and loses precision.
|
||||
lp := logits.AsType(mlx.DTypeFloat32)
|
||||
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
||||
lp = lp.Subtract(lp.Logsumexp(true))
|
||||
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
|
||||
if k := s.TopLogprobs; k > 0 {
|
||||
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
||||
k = vocab
|
||||
}
|
||||
// Argpartition on the negated values places the K largest
|
||||
// (unsorted) in positions [0:K].
|
||||
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
||||
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
||||
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
order := logprobs.Negative().ArgsortAxis(-1)
|
||||
sortedLogprobs := logprobs.TakeAlongAxis(order, -1)
|
||||
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true)
|
||||
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
||||
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
return scores.Argmax(-1, false)
|
||||
}
|
||||
|
||||
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
|
||||
}
|
||||
|
||||
// topKTopP applies top-P in a descending sort pass and, when top-K is also
|
||||
// configured, masks any surviving value below the K-th largest in the same
|
||||
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
|
||||
// case uses a cheaper partial sort via the topK transform.
|
||||
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
applyTopK := s.TopK > 0 && s.TopK < vocab
|
||||
|
||||
order := scores.Negative().ArgsortAxis(-1)
|
||||
sorted := scores.TakeAlongAxis(order, -1)
|
||||
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||
|
||||
// Top-P: in descending order, keep tokens whose exclusive cumulative
|
||||
// probability is still below s.TopP.
|
||||
probs := mlx.SoftmaxAxis(sorted, -1, true)
|
||||
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
|
||||
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
||||
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1))))
|
||||
return logprobs.PutAlongAxis(order, filtered, -1)
|
||||
}
|
||||
sorted = mlx.Where(keep, sorted, negInf)
|
||||
|
||||
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||
if s.MinP <= 0 || s.MinP > 1 {
|
||||
return logprobs
|
||||
out := scores.PutAlongAxis(order, sorted, -1)
|
||||
|
||||
// Top-K: sorted is already in descending order, so positions [K, V)
|
||||
// are the ones to drop. Scatter -inf through their original-layout
|
||||
// indices (order[K:]). Positional (not value-based) so exactly K
|
||||
// tokens survive — ties at the K-th logit get broken by the sort
|
||||
// order rather than promoted through the filter.
|
||||
if applyTopK {
|
||||
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
out = out.PutAlongAxis(dropOrder, negInf, -1)
|
||||
}
|
||||
|
||||
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1)
|
||||
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP))))
|
||||
return out
|
||||
}
|
||||
|
||||
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.MinP <= 0 || s.MinP > 1 {
|
||||
return scores
|
||||
}
|
||||
|
||||
maxScore := scores.MaxAxis(-1, true)
|
||||
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
|
||||
|
||||
return mlx.Where(
|
||||
logprobs.Less(minLogprobs),
|
||||
scores.Less(threshold),
|
||||
mlx.FromValue(float32(math.Inf(-1))),
|
||||
logprobs,
|
||||
scores,
|
||||
)
|
||||
}
|
||||
|
||||
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||
func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.TopK <= 0 {
|
||||
return logprobs
|
||||
return scores
|
||||
}
|
||||
|
||||
vocab := logprobs.Dim(logprobs.NumDims() - 1)
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
if s.TopK >= vocab {
|
||||
return logprobs
|
||||
return scores
|
||||
}
|
||||
|
||||
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
||||
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array {
|
||||
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 {
|
||||
return logprobs
|
||||
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
|
||||
if s.historyLen == 0 {
|
||||
return scores
|
||||
}
|
||||
|
||||
tokenIndices := s.history
|
||||
if logprobs.NumDims() > 1 {
|
||||
if scores.NumDims() > 1 {
|
||||
tokenIndices = tokenIndices.ExpandDims(0)
|
||||
}
|
||||
|
||||
selected := logprobs.TakeAlongAxis(tokenIndices, -1)
|
||||
adjusted := mlx.AddScalar(selected, -s.PresencePenalty)
|
||||
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1)
|
||||
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
|
||||
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
||||
if s.RepeatPenalty != 1 {
|
||||
factor := mlx.Where(
|
||||
adjusted.Less(mlx.FromValue(float32(0))),
|
||||
mlx.FromValue(s.RepeatPenalty),
|
||||
mlx.FromValue(1/s.RepeatPenalty),
|
||||
)
|
||||
adjusted = adjusted.Multiply(factor)
|
||||
}
|
||||
if s.PresencePenalty != 0 {
|
||||
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
|
||||
}
|
||||
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
||||
}
|
||||
|
||||
if s.FrequencyPenalty != 0 {
|
||||
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
|
||||
}
|
||||
|
||||
return scores
|
||||
}
|
||||
|
||||
@@ -10,8 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
||||
// RepeatLastN = 1, PresencePenalty = 6
|
||||
s := New(0, 0, 0, 0, 1, 6)
|
||||
s := New(Options{RepeatLastN: 1, PresencePenalty: 6})
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
@@ -20,11 +19,11 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
||||
s.ResetHistory([]int32{0})
|
||||
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
|
||||
|
||||
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logprobs)
|
||||
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logits).Token
|
||||
mlx.Eval(got)
|
||||
|
||||
// logprobs will be [0, -1, 4] after the penalty
|
||||
// logits will be [0, -1, 4] after the penalty
|
||||
// and then (index) 2 after the greedy sampler
|
||||
gotInt := got.Int()
|
||||
if gotInt != 2 {
|
||||
@@ -32,19 +31,59 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
||||
s := New(0, 0, 0.5, 0, 0, 0)
|
||||
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
|
||||
s := New(Options{RepeatLastN: 1, RepeatPenalty: 2})
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
logprobs := mlx.FromValues([]float32{
|
||||
s.ResetHistory([]int32{1})
|
||||
|
||||
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logits).Token
|
||||
mlx.Eval(got)
|
||||
|
||||
// token 1 is repeated and positive, so 5 / 2 falls below token 2.
|
||||
gotInt := got.Int()
|
||||
if gotInt != 2 {
|
||||
t.Fatalf("got %d, want 2", gotInt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
|
||||
s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2})
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
s.ResetHistory([]int32{1, 1})
|
||||
|
||||
logits := mlx.FromValues([]float32{0, 5, 4}, 3)
|
||||
got := s.Sample(logits).Token
|
||||
mlx.Eval(got)
|
||||
|
||||
// token 1 appears twice, so 5 - (2 * 2) falls below token 2.
|
||||
gotInt := got.Int()
|
||||
if gotInt != 2 {
|
||||
t.Fatalf("got %d, want 2", gotInt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
|
||||
s := New(Options{MinP: 0.5})
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
|
||||
logits := mlx.FromValues([]float32{
|
||||
float32(math.Log(0.5)),
|
||||
float32(math.Log(0.3)),
|
||||
float32(math.Log(0.2)),
|
||||
}, 3)
|
||||
got := minP(s, logprobs)
|
||||
got := minP(s, logits)
|
||||
mlx.Eval(got)
|
||||
|
||||
gotFloats := got.Floats()
|
||||
|
||||
@@ -2,7 +2,6 @@ package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
@@ -87,23 +86,25 @@ func Execute(args []string) error {
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
request.Options.Temperature,
|
||||
request.Options.TopP,
|
||||
request.Options.MinP,
|
||||
request.Options.TopK,
|
||||
request.Options.RepeatLastN,
|
||||
request.Options.PresencePenalty,
|
||||
)
|
||||
request.Sampler = sample.New(sample.Options{
|
||||
Temperature: request.Options.Temperature,
|
||||
TopP: request.Options.TopP,
|
||||
MinP: request.Options.MinP,
|
||||
TopK: request.Options.TopK,
|
||||
RepeatLastN: request.Options.RepeatLastN,
|
||||
RepeatPenalty: request.Options.RepeatPenalty,
|
||||
PresencePenalty: request.Options.PresencePenalty,
|
||||
FrequencyPenalty: request.Options.FrequencyPenalty,
|
||||
Logprobs: request.Logprobs,
|
||||
TopLogprobs: request.TopLogprobs,
|
||||
})
|
||||
|
||||
var cancel context.CancelFunc
|
||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||
|
||||
@@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) {
|
||||
|
||||
gotScores, gotInds := r.Forward(x, cfg)
|
||||
wantScores, wantInds := legacyRouterForward(r, x, cfg)
|
||||
gotInds = gotInds.AsType(mlx.DTypeInt32)
|
||||
wantInds = wantInds.AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
|
||||
|
||||
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
|
||||
|
||||
@@ -161,21 +161,21 @@ type MoEGate struct {
|
||||
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
|
||||
gates := g.Gate.Forward(x)
|
||||
|
||||
scores := mlx.Sigmoid(gates)
|
||||
origScores := scores
|
||||
|
||||
var origScores, negScores *mlx.Array
|
||||
if g.EScoreCorrectionBias != nil {
|
||||
scores = mlx.Add(scores, g.EScoreCorrectionBias)
|
||||
origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
|
||||
} else {
|
||||
origScores = mlx.Sigmoid(gates)
|
||||
negScores = mlx.Neg(origScores)
|
||||
}
|
||||
|
||||
topK := cfg.NumExpertsPerTok
|
||||
negScores := mlx.Neg(scores)
|
||||
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
|
||||
|
||||
dims := inds.Dims()
|
||||
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
|
||||
|
||||
scores = mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
scores := mlx.TakeAlongAxis(origScores, inds, -1)
|
||||
|
||||
if topK > 1 && cfg.NormTopKProb {
|
||||
sumScores := mlx.Sum(scores, -1, true)
|
||||
|
||||
@@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
|
||||
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
|
||||
mlx.Eval(dequantizedWeight)
|
||||
|
||||
qOut := ql.Forward(input)
|
||||
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
|
||||
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
|
||||
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
|
||||
mlx.Eval(qOut, dOut)
|
||||
|
||||
got := qOut.Floats()
|
||||
|
||||
Reference in New Issue
Block a user