Compare commits

...

18 Commits

Author SHA1 Message Date
Jesse Gross
a50199cd70 mlxrunner: batch the sampler across multiple sequences
Register sequences with Add/Remove; each Sample call takes any subset of
registered slots and samples one token per row, appending to each slot's
ring-buffer history. When all slots share Options and penalty rings are
full, one fused transform pass runs over the whole batch via a persistent
pooled history tensor; otherwise calls fall back to per-slot serial
processing indexed against the same pool.

Performance is unchanged for a single sequence, which is all that is
exposed for now.
2026-04-21 15:09:19 -07:00
Jesse Gross
5264ba9194 mlxrunner: track sampler history in a fixed-size ring buffer
AppendToken used to concatenate the new token onto the history tensor
and slice it back to RepeatLastN every decode step, churning the graph
shape and reallocating a fresh tensor each call. The stateful penalties
don't care about order within the window, so a fixed-capacity ring with
one SliceUpdate per append keeps the tensor shape constant across
steps.
2026-04-21 14:40:19 -07:00
Jesse Gross
ce99f24731 mlxrunner: tokenize prompts in request handler goroutines
Move tokenization out of the single GPU processing goroutine and
into each request's HTTP handler goroutine. This allows the next
request's prompt to be tokenized on the CPU while the current
request is executing on the GPU.
2026-04-21 14:38:49 -07:00
Jesse Gross
04f5f0cdb4 mlx: improve thread safety of array management
Use atomic.Int32 for Array.pinned and a sync.Mutex for the global
arrays slice so MLX arrays can be created and pinned from multiple
goroutines without racing on those structures. Convert Array value
receivers to pointer receivers and struct fields from Array to
*Array to avoid copying the atomic.

This does not fully achieve thread safety even when building
completely independent graphs. The tracing flag and traceScratch
slice in compile.go are unprotected, so concurrent Compile calls
will race. MLX itself is not fully thread-safe either although
it is working to improve.
2026-04-21 14:38:49 -07:00
Matteo Celani
fb36a01ffe app/ui: fix model picker showing stale model after switching chats (#15280)
* app/ui: fix model picker showing stale model after switching chats

Optimistic messages created during streaming were storing the full
Model object instead of the model name string. When switching back
to a chat with cached streaming data, the restore effect read an
object where it expected a string, causing the model picker to fail
matching and remain stuck on the previous chat's model.

* app/ui: fix two more instances of Model object passed as model name

Fix the same bug at lines 523 and 536 in the assistant_with_tools
event handler, where selectedModel (object) was used instead of
selectedModel.model (string).
2026-04-21 15:08:06 -04:00
Michael Verrilli
0c65ed33bc cmd: populate model capabilities in launchInteractiveModel (#15712)
launchInteractiveModel was introduced in PR #14609 without the
client.Show() capability-detection block that RunHandler uses.
This left opts.MultiModal always false in the TUI path, causing
image/audio file paths to always be treated as unknown commands
instead of being loaded as multimodal attachments.

Mirror the Show() call, pull-on-404 fallback, cloud auth handling,
and MultiModal/Think population from RunHandler into
launchInteractiveModel.

Fixes #15711
2026-04-21 14:37:36 -04:00
Jesse Gross
22d6c817f8 mlxrunner: fuse top-P and top-K into a single sort pass
When both filters are active, avoid paying for a full sort in top-P
and a partial sort in top-K. Single-filter paths are unchanged.
Improves generation throughput on gemma4:e4b by 1.5%.
2026-04-20 17:43:00 -07:00
Jesse Gross
ca01373b28 mlxrunner: use MaxAxis in the min-P sampler
One reduction op instead of Argmax + TakeAlongAxis.
2026-04-20 17:43:00 -07:00
Jesse Gross
24e038d56a mlxrunner: add logprobs support
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax
with the top-K ranked by probability. Skipped on the GPU when the request
doesn't ask for logprobs so decode doesn't pay for it otherwise.
2026-04-20 17:43:00 -07:00
Parth Sareen
5d1021603a server: apply format when think=false for gemma4 (#15678) 2026-04-20 17:42:29 -07:00
Parth Sareen
8e05d734b9 launch: add kimi cli integration with installer flow (#15723) 2026-04-20 15:33:32 -07:00
Jesse Gross
05e0f21bec mlx: fuse sigmoid router head in glm4_moe_lite
DeepSeek-V2-style aux-loss-free routing computes sigmoid(gates) once but
needs it twice: the raw sigmoid output is gathered after top-k, while the
post-bias negation is the argpartition key. Fuse into a single multi-output
Compiled kernel returning both, saving two launches on the routing path
per token. Exposed as a general SigmoidRouter since the same pattern is
shared across DeepSeek-V2 descendants.

Improves glm4.7 generation performance by approximately 1%.
2026-04-20 15:02:14 -07:00
Daniel Hiltgen
ff23dd343f mlx: apply repeat penalties in sampler (#15631) 2026-04-18 07:49:38 -07:00
Parth Sareen
123b300af6 docs: update hermes (#15655) 2026-04-17 14:20:59 -07:00
Parth Sareen
57653b8e42 cmd/launch: show WSL guidance on Windows instead of handing off (#15637) 2026-04-16 17:18:04 -07:00
Parth Sareen
a50ce61c54 launch: skip unchanged managed-single rewrite (#15633) 2026-04-16 16:20:42 -07:00
Daniel Hiltgen
2bb7ea00d2 create: avoid gc race with create (#15628)
If you have a long running create, and start another ollama server with the
same model dir, the GC algorithm deletes the pending blobs and breaks the
create.  This adds a 1h grace period to avoid deleting in-flight creation
operations.
2026-04-16 13:29:16 -07:00
Daniel Hiltgen
55fa80d07a mlx: additional gemma4 cache fixes (#15607)
Harden additional corner cases
2026-04-16 13:07:19 -07:00
38 changed files with 2708 additions and 859 deletions

View File

@@ -381,7 +381,7 @@ export const useSendMessage = (chatId: string) => {
role: "assistant",
content: "",
thinking: "",
model: effectiveModel,
model: effectiveModel.model,
}),
);
lastMessage = newMessages[newMessages.length - 1];
@@ -433,7 +433,7 @@ export const useSendMessage = (chatId: string) => {
role: "assistant",
content: "",
thinking: "",
model: effectiveModel,
model: effectiveModel.model,
}),
);
lastMessage = newMessages[newMessages.length - 1];
@@ -520,7 +520,7 @@ export const useSendMessage = (chatId: string) => {
thinkingTimeStart:
lastMessage.thinkingTimeStart || event.thinkingTimeStart,
thinkingTimeEnd: event.thinkingTimeEnd,
model: selectedModel,
model: selectedModel.model,
});
newMessages[newMessages.length - 1] = updatedMessage;
} else {
@@ -533,7 +533,7 @@ export const useSendMessage = (chatId: string) => {
tool_calls: event.toolCalls,
thinkingTimeStart: event.thinkingTimeStart,
thinkingTimeEnd: event.thinkingTimeEnd,
model: selectedModel,
model: selectedModel.model,
}),
);
}
@@ -699,7 +699,7 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["chat", newId], {
chat: new Chat({
id: newId,
model: effectiveModel,
model: effectiveModel.model,
messages: [
new Message({
role: "user",

View File

@@ -1975,8 +1975,61 @@ func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
Options: map[string]any{},
ShowConnect: true,
}
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
// and only validate auth/connectivity before interactive chat starts.
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
requestedCloud := modelref.HasExplicitCloudSource(modelName)
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: modelName}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{modelName}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: modelName})
}
return info, err
}()
if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
ensureCloudStub(cmd.Context(), client, modelName)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, false)
if err != nil {
return err
}
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
applyShowResponseToRunOptions(&opts, info)
if err := loadOrUnloadModel(cmd, &opts); err != nil {
return fmt.Errorf("error loading model: %w", err)
}

View File

@@ -61,6 +61,9 @@ func TestLaunchCmd(t *testing.T) {
if !strings.Contains(cmd.Long, "hermes") {
t.Error("Long description should mention hermes")
}
if !strings.Contains(cmd.Long, "kimi") {
t.Error("Long description should mention kimi")
}
})
t.Run("flags exist", func(t *testing.T) {

View File

@@ -4,18 +4,15 @@ import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
pathpkg "path"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"time"
"gopkg.in/yaml.v3"
@@ -66,23 +63,13 @@ var hermesMessagingEnvGroups = [][]string{
// switching UX after startup.
type Hermes struct{}
type hermesConfigBackend struct {
displayPath string
read func() ([]byte, error)
write func([]byte) error
}
func (h *Hermes) String() string { return "Hermes Agent" }
func (h *Hermes) Run(_ string, args []string) error {
// Hermes reads its primary model from config.yaml. launch configures that
// default model ahead of time so we can keep runtime invocation simple and
// still let Hermes discover additional models later via its own UX.
if hermesGOOS == "windows" {
return h.runWindows(args)
}
bin, err := h.findUnixBinary()
bin, err := h.binary()
if err != nil {
return err
}
@@ -95,21 +82,21 @@ func (h *Hermes) Run(_ string, args []string) error {
}
func (h *Hermes) Paths() []string {
backend, err := h.configBackend()
configPath, err := hermesConfigPath()
if err != nil {
return nil
}
return []string{backend.displayPath}
return []string{configPath}
}
func (h *Hermes) Configure(model string) error {
backend, err := h.configBackend()
configPath, err := hermesConfigPath()
if err != nil {
return err
}
cfg := map[string]any{}
if data, err := backend.read(); err == nil {
if data, err := os.ReadFile(configPath); err == nil {
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("parse hermes config: %w", err)
}
@@ -142,15 +129,18 @@ func (h *Hermes) Configure(model string) error {
if err != nil {
return err
}
return backend.write(data)
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
}
func (h *Hermes) CurrentModel() string {
backend, err := h.configBackend()
configPath, err := hermesConfigPath()
if err != nil {
return ""
}
data, err := backend.read()
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
@@ -188,14 +178,7 @@ func (h *Hermes) RefreshRuntimeAfterConfigure() error {
}
func (h *Hermes) installed() bool {
if hermesGOOS == "windows" {
if _, err := hermesLookPath("hermes"); err == nil {
return true
}
return h.wslHasHermes()
}
_, err := h.findUnixBinary()
_, err := h.binary()
return err == nil
}
@@ -205,7 +188,7 @@ func (h *Hermes) ensureInstalled() error {
}
if hermesGOOS == "windows" {
return h.ensureInstalledWindows()
return hermesWindowsHint()
}
var missing []string
@@ -239,42 +222,6 @@ func (h *Hermes) ensureInstalled() error {
return nil
}
func (h *Hermes) ensureInstalledWindows() error {
// Hermes upstream support is WSL-oriented, so Windows launch uses a hybrid
// WSL handoff that stays on the same install path as upstream Hermes.
if _, err := hermesLookPath("hermes"); err == nil {
return nil
}
if !h.wslAvailable() {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
if h.wslHasHermes() {
return nil
}
ok, err := ConfirmPromptWithOptions("Hermes runs through WSL2 on Windows. Install it in WSL now?", ConfirmOptions{
YesLabel: "Use WSL",
NoLabel: "Show manual steps",
})
if err != nil {
return err
}
if !ok {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
fmt.Fprintf(os.Stderr, "\nInstalling Hermes in WSL...\n")
if err := h.runWSL("bash", "-lc", hermesInstallScript); err != nil {
return hermesWindowsHint(fmt.Errorf("failed to install hermes in WSL: %w", err))
}
if !h.wslHasHermes() {
return hermesWindowsHint(fmt.Errorf("hermes install finished but the WSL binary was not found"))
}
fmt.Fprintf(os.Stderr, "%sHermes installed successfully in WSL%s\n\n", ansiGreen, ansiReset)
return nil
}
func (h *Hermes) listModels(defaultModel string) []string {
client := hermesOllamaClient()
resp, err := client.List(context.Background())
@@ -306,11 +253,15 @@ func (h *Hermes) listModels(defaultModel string) []string {
return models
}
func (h *Hermes) findUnixBinary() (string, error) {
func (h *Hermes) binary() (string, error) {
if path, err := hermesLookPath("hermes"); err == nil {
return path, nil
}
if hermesGOOS == "windows" {
return "", hermesWindowsHint()
}
home, err := hermesUserHome()
if err != nil {
return "", err
@@ -323,70 +274,6 @@ func (h *Hermes) findUnixBinary() (string, error) {
return "", fmt.Errorf("hermes is not installed")
}
func (h *Hermes) runWindows(args []string) error {
if path, err := hermesLookPath("hermes"); err == nil {
if err := h.runGatewaySetupPreflight(args, func() error {
return hermesAttachedCommand(path, "gateway", "setup").Run()
}); err != nil {
return err
}
return hermesAttachedCommand(path, args...).Run()
}
if !h.wslAvailable() {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
if err := h.runGatewaySetupPreflight(args, func() error {
return h.runWSL("hermes", "gateway", "setup")
}); err != nil {
return err
}
if err := h.runWSL(append([]string{"hermes"}, args...)...); err != nil {
return hermesWindowsHint(err)
}
return nil
}
func (h *Hermes) runWSL(args ...string) error {
if !h.wslAvailable() {
return fmt.Errorf("wsl.exe is not available")
}
return hermesAttachedCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).Run()
}
func (h *Hermes) runWSLCombinedOutput(args ...string) ([]byte, error) {
if !h.wslAvailable() {
return nil, fmt.Errorf("wsl.exe is not available")
}
return hermesCommand("wsl.exe", "bash", "-lc", shellQuoteArgs(args)).CombinedOutput()
}
func (h *Hermes) wslAvailable() bool {
_, err := hermesLookPath("wsl.exe")
return err == nil
}
func (h *Hermes) wslHasHermes() bool {
if !h.wslAvailable() {
return false
}
cmd := hermesCommand("wsl.exe", "bash", "-lc", "command -v hermes >/dev/null 2>&1")
return cmd.Run() == nil
}
func (h *Hermes) configBackend() (*hermesConfigBackend, error) {
if hermesGOOS == "windows" {
if _, err := hermesLookPath("hermes"); err == nil {
return hermesLocalConfigBackend()
}
if h.wslAvailable() {
return h.wslConfigBackend()
}
}
return hermesLocalConfigBackend()
}
func hermesConfigPath() (string, error) {
home, err := hermesUserHome()
if err != nil {
@@ -395,110 +282,6 @@ func hermesConfigPath() (string, error) {
return filepath.Join(home, ".hermes", "config.yaml"), nil
}
func hermesLocalConfigBackend() (*hermesConfigBackend, error) {
configPath, err := hermesConfigPath()
if err != nil {
return nil, err
}
return &hermesConfigBackend{
displayPath: configPath,
read: func() ([]byte, error) {
return os.ReadFile(configPath)
},
write: func(data []byte) error {
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
return err
}
return fileutil.WriteWithBackup(configPath, data)
},
}, nil
}
func (h *Hermes) wslConfigBackend() (*hermesConfigBackend, error) {
home, err := h.wslHome()
if err != nil {
return nil, err
}
configPath := pathpkg.Join(home, ".hermes", "config.yaml")
return &hermesConfigBackend{
displayPath: configPath,
read: func() ([]byte, error) {
return h.readWSLFile(configPath)
},
write: func(data []byte) error {
return h.writeWSLConfig(configPath, data)
},
}, nil
}
func (h *Hermes) wslHome() (string, error) {
if !h.wslAvailable() {
return "", fmt.Errorf("wsl.exe is not available")
}
cmd := hermesCommand("wsl.exe", "bash", "-lc", `printf %s "$HOME"`)
out, err := cmd.Output()
if err != nil {
return "", err
}
home := strings.TrimSpace(string(out))
if home == "" {
return "", fmt.Errorf("could not resolve WSL home directory")
}
return home, nil
}
func (h *Hermes) readWSLFile(path string) ([]byte, error) {
pathArg := shellQuoteArgs([]string{path})
cmd := hermesCommand("wsl.exe", "bash", "-lc", fmt.Sprintf("if [ -f %s ]; then cat %s; else exit 42; fi", pathArg, pathArg))
out, err := cmd.Output()
if err == nil {
return out, nil
}
var exitErr *exec.ExitError
if errors.As(err, &exitErr) && exitErr.ExitCode() == 42 {
return nil, os.ErrNotExist
}
return nil, err
}
func (h *Hermes) writeWSLConfig(path string, data []byte) error {
if existing, err := h.readWSLFile(path); err == nil {
if !bytes.Equal(existing, data) {
if err := hermesBackupData(path, existing); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
}
} else if !os.IsNotExist(err) {
return fmt.Errorf("read existing file: %w", err)
}
dir := pathpkg.Dir(path)
dirArg := shellQuoteArgs([]string{dir})
pathArg := shellQuoteArgs([]string{path})
script := fmt.Sprintf(
"dir=%s; path=%s; mkdir -p \"$dir\" && tmp=$(mktemp \"$dir/.tmp-XXXXXX\") && cat > \"$tmp\" && mv \"$tmp\" \"$path\"",
dirArg,
pathArg,
)
cmd := hermesCommand("wsl.exe", "bash", "-lc", script)
cmd.Stdin = bytes.NewReader(data)
if out, err := cmd.CombinedOutput(); err != nil {
if msg := strings.TrimSpace(string(out)); msg != "" {
return fmt.Errorf("%w: %s", err, msg)
}
return err
}
return nil
}
func hermesBackupData(path string, data []byte) error {
if err := os.MkdirAll(fileutil.BackupDir(), 0o755); err != nil {
return err
}
backupPath := filepath.Join(fileutil.BackupDir(), fmt.Sprintf("%s.%d", filepath.Base(path), time.Now().Unix()))
return os.WriteFile(backupPath, data, 0o644)
}
func hermesBaseURL() string {
return strings.TrimRight(hermesOllamaURL().String(), "/") + "/v1"
}
@@ -554,8 +337,11 @@ func (h *Hermes) messagingConfigured() bool {
func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
envVars := make(map[string]string)
data, err := h.readGatewayEnvFile()
switch {
envFilePath, err := hermesEnvPath()
if err != nil {
return nil, err
}
switch data, err := os.ReadFile(envFilePath); {
case err == nil:
for key, value := range hermesParseEnvFile(data) {
envVars[key] = value
@@ -566,12 +352,10 @@ func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
return nil, err
}
if h.usesLocalRuntimeEnv() {
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
for _, group := range hermesMessagingEnvGroups {
for _, key := range group {
if value, ok := os.LookupEnv(key); ok {
envVars[key] = value
}
}
}
@@ -579,39 +363,6 @@ func (h *Hermes) gatewayEnvVars() (map[string]string, error) {
return envVars, nil
}
func (h *Hermes) readGatewayEnvFile() ([]byte, error) {
if hermesGOOS == "windows" {
if _, err := hermesLookPath("hermes"); err == nil {
path, err := hermesEnvPath()
if err != nil {
return nil, err
}
return os.ReadFile(path)
}
if h.wslAvailable() {
home, err := h.wslHome()
if err != nil {
return nil, err
}
return h.readWSLFile(pathpkg.Join(home, ".hermes", ".env"))
}
}
path, err := hermesEnvPath()
if err != nil {
return nil, err
}
return os.ReadFile(path)
}
func (h *Hermes) usesLocalRuntimeEnv() bool {
if hermesGOOS != "windows" {
return true
}
_, err := hermesLookPath("hermes")
return err == nil
}
func (h *Hermes) gatewayRunning() (bool, error) {
status, err := h.gatewayStatusOutput()
if err != nil {
@@ -621,19 +372,7 @@ func (h *Hermes) gatewayRunning() (bool, error) {
}
func (h *Hermes) gatewayStatusOutput() (string, error) {
if hermesGOOS == "windows" {
if path, err := hermesLookPath("hermes"); err == nil {
out, err := hermesCommand(path, "gateway", "status").CombinedOutput()
return string(out), err
}
if !h.wslAvailable() {
return "", hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
out, err := h.runWSLCombinedOutput("hermes", "gateway", "status")
return string(out), err
}
bin, err := h.findUnixBinary()
bin, err := h.binary()
if err != nil {
return "", err
}
@@ -642,20 +381,7 @@ func (h *Hermes) gatewayStatusOutput() (string, error) {
}
func (h *Hermes) restartGateway() error {
if hermesGOOS == "windows" {
if path, err := hermesLookPath("hermes"); err == nil {
return hermesAttachedCommand(path, "gateway", "restart").Run()
}
if !h.wslAvailable() {
return hermesWindowsHint(fmt.Errorf("hermes is not installed"))
}
if err := h.runWSL("hermes", "gateway", "restart"); err != nil {
return hermesWindowsHint(err)
}
return nil
}
bin, err := h.findUnixBinary()
bin, err := h.binary()
if err != nil {
return err
}
@@ -938,14 +664,6 @@ func mergeHermesToolsets(current any) any {
}
}
func shellQuoteArgs(args []string) string {
quoted := make([]string, 0, len(args))
for _, arg := range args {
quoted = append(quoted, "'"+strings.ReplaceAll(arg, "'", `'\''`)+"'")
}
return strings.Join(quoted, " ")
}
func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
cmd := hermesCommand(name, args...)
cmd.Stdin = os.Stdin
@@ -954,9 +672,8 @@ func hermesAttachedCommand(name string, args ...string) *exec.Cmd {
return cmd
}
func hermesWindowsHint(err error) error {
if hermesGOOS != "windows" {
return err
}
return fmt.Errorf("%w\n\nHermes runs on Windows through WSL2.\nQuick setup: wsl --install\nInstaller docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/", err)
func hermesWindowsHint() error {
return fmt.Errorf("Hermes on Windows requires WSL2. Install WSL with: wsl --install\n" +
"Then run 'ollama launch hermes' from inside your WSL shell.\n" +
"Docs: https://hermes-agent.nousresearch.com/docs/getting-started/installation/")
}

View File

@@ -896,64 +896,6 @@ fi
}
}
func TestHermesRefreshRuntimeAfterConfigure_WindowsWSLRestartsRunningGateway(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell test binaries to simulate WSL")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withHermesPlatform(t, "windows")
t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH"))
wslPath := filepath.Join(tmpDir, "wsl.exe")
wslScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/wsl-invocations.log"
exec /bin/sh -lc "$3"
`
if err := os.WriteFile(wslPath, []byte(wslScript), 0o755); err != nil {
t.Fatal(err)
}
hermesBin := filepath.Join(tmpDir, "hermes")
hermesScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/hermes-invocations.log"
if [ "$1" = "gateway" ] && [ "$2" = "status" ]; then
printf '✓ Gateway is running (PID: 321)\n'
fi
`
if err := os.WriteFile(hermesBin, []byte(hermesScript), 0o755); err != nil {
t.Fatal(err)
}
withHermesLookPath(t, func(file string) (string, error) {
if file == "wsl.exe" {
return wslPath, nil
}
return "", os.ErrNotExist
})
h := &Hermes{}
if err := h.RefreshRuntimeAfterConfigure(); err != nil {
t.Fatalf("RefreshRuntimeAfterConfigure returned error: %v", err)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "hermes-invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 2 {
t.Fatalf("expected WSL status then restart invocations, got %v", lines)
}
if lines[0] != "[gateway status]" {
t.Fatalf("expected WSL gateway status first, got %q", lines[0])
}
if lines[1] != "[gateway restart]" {
t.Fatalf("expected WSL gateway restart second, got %q", lines[1])
}
}
func TestHermesMessagingConfiguredRecognizesSupportedGatewayVars(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
@@ -1002,82 +944,7 @@ func TestHermesMessagingConfiguredRecognizesSupportedGatewayVars(t *testing.T) {
}
}
func TestHermesRunWindowsWSL_UsesGatewaySetupPreflight(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell test binaries to simulate WSL")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withLauncherHooks(t)
withInteractiveSession(t, true)
withHermesPlatform(t, "windows")
clearHermesMessagingEnvVars(t)
t.Setenv("PATH", tmpDir+string(os.PathListSeparator)+os.Getenv("PATH"))
wslPath := filepath.Join(tmpDir, "wsl.exe")
wslScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/wsl-invocations.log"
exec /bin/sh -lc "$3"
`
if err := os.WriteFile(wslPath, []byte(wslScript), 0o755); err != nil {
t.Fatal(err)
}
hermesBin := filepath.Join(tmpDir, "hermes")
hermesScript := `#!/bin/sh
printf '[%s]\n' "$*" >> "$HOME/hermes-invocations.log"
if [ "$1" = "gateway" ] && [ "$2" = "setup" ]; then
/bin/mkdir -p "$HOME/.hermes"
printf 'TELEGRAM_BOT_TOKEN=configured\n' > "$HOME/.hermes/.env"
fi
`
if err := os.WriteFile(hermesBin, []byte(hermesScript), 0o755); err != nil {
t.Fatal(err)
}
withHermesLookPath(t, func(file string) (string, error) {
if file == "wsl.exe" {
return wslPath, nil
}
return "", os.ErrNotExist
})
promptCount := 0
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
promptCount++
if prompt != hermesGatewaySetupTitle {
t.Fatalf("unexpected prompt %q", prompt)
}
return true, nil
}
h := &Hermes{}
if err := h.Run("", nil); err != nil {
t.Fatalf("Run returned error: %v", err)
}
if promptCount != 1 {
t.Fatalf("expected one messaging prompt, got %d", promptCount)
}
data, err := os.ReadFile(filepath.Join(tmpDir, "hermes-invocations.log"))
if err != nil {
t.Fatal(err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) != 2 {
t.Fatalf("expected WSL hermes to run setup then launch, got %v", lines)
}
if lines[0] != "[gateway setup]" {
t.Fatalf("expected WSL gateway setup first, got %q", lines[0])
}
if lines[1] != "[]" {
t.Fatalf("expected WSL default hermes launch second, got %q", lines[1])
}
}
func TestHermesEnsureInstalledWindowsWithoutWSLGivesGuidance(t *testing.T) {
func TestHermesEnsureInstalledWindowsShowsWSLGuidance(t *testing.T) {
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
withHermesPlatform(t, "windows")
@@ -1086,10 +953,17 @@ func TestHermesEnsureInstalledWindowsWithoutWSLGivesGuidance(t *testing.T) {
h := &Hermes{}
err := h.ensureInstalled()
if err == nil {
t.Fatal("expected missing WSL guidance error")
t.Fatal("expected WSL guidance error")
}
if !strings.Contains(err.Error(), "wsl --install") {
t.Fatalf("expected WSL guidance, got %v", err)
msg := err.Error()
if !strings.Contains(msg, "wsl --install") {
t.Fatalf("expected install command in guidance, got %v", err)
}
if !strings.Contains(msg, "hermes-agent.nousresearch.com") {
t.Fatalf("expected docs link in guidance, got %v", err)
}
if strings.Contains(msg, "hermes is not installed") {
t.Fatalf("guidance should not lead with 'hermes is not installed', got %v", err)
}
}

View File

@@ -54,6 +54,7 @@ func TestIntegrationLookup(t *testing.T) {
{"claude uppercase", "CLAUDE", true, "Claude Code"},
{"claude mixed case", "Claude", true, "Claude Code"},
{"codex", "codex", true, "Codex"},
{"kimi", "kimi", true, "Kimi Code CLI"},
{"droid", "droid", true, "Droid"},
{"opencode", "opencode", true, "OpenCode"},
{"unknown integration", "unknown", false, ""},
@@ -74,7 +75,7 @@ func TestIntegrationLookup(t *testing.T) {
}
func TestIntegrationRegistry(t *testing.T) {
expectedIntegrations := []string{"claude", "codex", "droid", "opencode", "hermes"}
expectedIntegrations := []string{"claude", "codex", "kimi", "droid", "opencode", "hermes"}
for _, name := range expectedIntegrations {
t.Run(name, func(t *testing.T) {
@@ -89,6 +90,15 @@ func TestIntegrationRegistry(t *testing.T) {
}
}
func TestHiddenIntegrationsExcludedFromVisibleLists(t *testing.T) {
for _, info := range ListIntegrationInfos() {
switch info.Name {
case "cline", "vscode", "kimi":
t.Fatalf("hidden integration %q should not appear in ListIntegrationInfos", info.Name)
}
}
}
func TestHasLocalModel(t *testing.T) {
tests := []struct {
name string

315
cmd/launch/kimi.go Normal file
View File

@@ -0,0 +1,315 @@
package launch
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
)
// Kimi implements Runner for Kimi Code CLI integration.
type Kimi struct{}
const (
kimiDefaultModelAlias = "ollama"
kimiDefaultMaxContextSize = 32768
)
var (
kimiGOOS = runtime.GOOS
kimiModelShowTimeout = 5 * time.Second
)
func (k *Kimi) String() string { return "Kimi Code CLI" }
func (k *Kimi) args(config string, extra []string) []string {
args := []string{"--config", config}
args = append(args, extra...)
return args
}
func (k *Kimi) Run(model string, args []string) error {
if strings.TrimSpace(model) == "" {
return fmt.Errorf("model is required")
}
if err := validateKimiPassthroughArgs(args); err != nil {
return err
}
config, err := buildKimiInlineConfig(model, resolveKimiMaxContextSize(model))
if err != nil {
return fmt.Errorf("failed to build kimi config: %w", err)
}
bin, err := ensureKimiInstalled()
if err != nil {
return err
}
cmd := exec.Command(bin, k.args(config, args)...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func findKimiBinary() (string, error) {
if path, err := exec.LookPath("kimi"); err == nil {
return path, nil
}
home, _ := os.UserHomeDir()
var candidates []string
switch kimiGOOS {
case "windows":
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, ".local", "bin"))
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(home, "bin"))
if appData := strings.TrimSpace(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := strings.TrimSpace(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
default:
candidates = append(candidates,
filepath.Join(home, ".local", "bin", "kimi"),
filepath.Join(home, "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(home, ".local", "share", "uv", "tools", "kimi", "bin", "kimi"),
)
if xdgDataHome := strings.TrimSpace(os.Getenv("XDG_DATA_HOME")); xdgDataHome != "" {
candidates = append(candidates,
filepath.Join(xdgDataHome, "uv", "tools", "kimi-cli", "bin", "kimi"),
filepath.Join(xdgDataHome, "uv", "tools", "kimi", "bin", "kimi"),
)
}
// WSL users can inherit Windows env vars while launching from Linux shells.
if profile := windowsPathToWSL(os.Getenv("USERPROFILE")); profile != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(profile, ".local", "bin"))
}
if appData := windowsPathToWSL(os.Getenv("APPDATA")); appData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(appData, "uv", "bin"))
}
if localAppData := windowsPathToWSL(os.Getenv("LOCALAPPDATA")); localAppData != "" {
candidates = appendWindowsKimiCandidates(candidates, filepath.Join(localAppData, "uv", "bin"))
}
}
for _, candidate := range candidates {
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate, nil
}
}
return "", fmt.Errorf("kimi binary not found")
}
func appendWindowsKimiCandidates(candidates []string, dir string) []string {
if strings.TrimSpace(dir) == "" {
return candidates
}
return append(candidates,
filepath.Join(dir, "kimi.exe"),
filepath.Join(dir, "kimi.cmd"),
filepath.Join(dir, "kimi.bat"),
)
}
func windowsPathToWSL(path string) string {
trimmed := strings.TrimSpace(path)
if len(trimmed) < 3 || trimmed[1] != ':' {
return ""
}
drive := strings.ToLower(string(trimmed[0]))
rest := strings.ReplaceAll(trimmed[2:], "\\", "/")
rest = strings.TrimPrefix(rest, "/")
if rest == "" {
return filepath.Join("/mnt", drive)
}
return filepath.Join("/mnt", drive, rest)
}
func validateKimiPassthroughArgs(args []string) error {
for _, arg := range args {
switch {
case arg == "--config", strings.HasPrefix(arg, "--config="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config", arg)
case arg == "--config-file", strings.HasPrefix(arg, "--config-file="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --config-file", arg)
case arg == "--model", strings.HasPrefix(arg, "--model="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages --model", arg)
case arg == "-m", strings.HasPrefix(arg, "-m="):
return fmt.Errorf("conflicting extra argument %q: ollama launch kimi manages -m/--model", arg)
}
}
return nil
}
func buildKimiInlineConfig(model string, maxContextSize int) (string, error) {
cfg := map[string]any{
"default_model": kimiDefaultModelAlias,
"providers": map[string]any{
kimiDefaultModelAlias: map[string]any{
"type": "openai_legacy",
"base_url": envconfig.ConnectableHost().String() + "/v1",
"api_key": "ollama",
},
},
"models": map[string]any{
kimiDefaultModelAlias: map[string]any{
"provider": kimiDefaultModelAlias,
"model": model,
"max_context_size": maxContextSize,
},
},
}
data, err := json.Marshal(cfg)
if err != nil {
return "", err
}
return string(data), nil
}
func resolveKimiMaxContextSize(model string) int {
if l, ok := lookupCloudModelLimit(model); ok {
return l.Context
}
client, err := api.ClientFromEnvironment()
if err != nil {
return kimiDefaultMaxContextSize
}
ctx, cancel := context.WithTimeout(context.Background(), kimiModelShowTimeout)
defer cancel()
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
return kimiDefaultMaxContextSize
}
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
return n
}
return kimiDefaultMaxContextSize
}
func modelInfoContextLength(modelInfo map[string]any) (int, bool) {
for key, val := range modelInfo {
if !strings.HasSuffix(key, ".context_length") {
continue
}
switch v := val.(type) {
case float64:
if v > 0 {
return int(v), true
}
case int:
if v > 0 {
return v, true
}
case int64:
if v > 0 {
return int(v), true
}
}
}
return 0, false
}
func ensureKimiInstalled() (string, error) {
if path, err := findKimiBinary(); err == nil {
return path, nil
}
if err := checkKimiInstallerDependencies(); err != nil {
return "", err
}
ok, err := ConfirmPrompt("Kimi is not installed. Install now?")
if err != nil {
return "", err
}
if !ok {
return "", fmt.Errorf("kimi installation cancelled")
}
bin, args, err := kimiInstallerCommand(kimiGOOS)
if err != nil {
return "", err
}
fmt.Fprintf(os.Stderr, "\nInstalling Kimi...\n")
cmd := exec.Command(bin, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("failed to install kimi: %w", err)
}
path, err := findKimiBinary()
if err != nil {
return "", fmt.Errorf("kimi was installed but the binary was not found on PATH\n\nYou may need to restart your shell")
}
fmt.Fprintf(os.Stderr, "%sKimi installed successfully%s\n\n", ansiGreen, ansiReset)
return path, nil
}
func checkKimiInstallerDependencies() error {
switch kimiGOOS {
case "windows":
if _, err := exec.LookPath("powershell"); err != nil {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n PowerShell: https://learn.microsoft.com/powershell/\n\nThen re-run:\n ollama launch kimi")
}
default:
var missing []string
if _, err := exec.LookPath("curl"); err != nil {
missing = append(missing, "curl: https://curl.se/")
}
if _, err := exec.LookPath("bash"); err != nil {
missing = append(missing, "bash: https://www.gnu.org/software/bash/")
}
if len(missing) > 0 {
return fmt.Errorf("kimi is not installed and required dependencies are missing\n\nInstall the following first:\n %s\n\nThen re-run:\n ollama launch kimi", strings.Join(missing, "\n "))
}
}
return nil
}
func kimiInstallerCommand(goos string) (string, []string, error) {
switch goos {
case "windows":
return "powershell", []string{
"-NoProfile",
"-ExecutionPolicy",
"Bypass",
"-Command",
"Invoke-RestMethod https://code.kimi.com/install.ps1 | Invoke-Expression",
}, nil
case "darwin", "linux":
return "bash", []string{
"-c",
"curl -LsSf https://code.kimi.com/install.sh | bash",
}, nil
default:
return "", nil, fmt.Errorf("unsupported platform for kimi install: %s", goos)
}
}

636
cmd/launch/kimi_test.go Normal file
View File

@@ -0,0 +1,636 @@
package launch
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func assertKimiBinPath(t *testing.T, bin string) {
t.Helper()
base := strings.ToLower(filepath.Base(bin))
if !strings.HasPrefix(base, "kimi") {
t.Fatalf("bin = %q, want path to kimi executable", bin)
}
}
func TestKimiIntegration(t *testing.T) {
k := &Kimi{}
t.Run("String", func(t *testing.T) {
if got := k.String(); got != "Kimi Code CLI" {
t.Errorf("String() = %q, want %q", got, "Kimi Code CLI")
}
})
t.Run("implements Runner", func(t *testing.T) {
var _ Runner = k
})
}
func TestKimiArgs(t *testing.T) {
k := &Kimi{}
got := k.args(`{"foo":"bar"}`, []string{"--quiet", "--print"})
want := []string{"--config", `{"foo":"bar"}`, "--quiet", "--print"}
if !slices.Equal(got, want) {
t.Fatalf("args() = %v, want %v", got, want)
}
}
func TestWindowsPathToWSL(t *testing.T) {
tests := []struct {
name string
in string
want string
valid bool
}{
{
name: "user profile path",
in: `C:\Users\parth`,
want: filepath.Join("/mnt", "c", "Users", "parth"),
valid: true,
},
{
name: "path with trailing slash",
in: `D:\tools\bin\`,
want: filepath.Join("/mnt", "d", "tools", "bin"),
valid: true,
},
{
name: "non windows path",
in: "/home/parth",
valid: false,
},
{
name: "empty",
in: "",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := windowsPathToWSL(tt.in)
if !tt.valid {
if got != "" {
t.Fatalf("windowsPathToWSL(%q) = %q, want empty", tt.in, got)
}
return
}
if got != tt.want {
t.Fatalf("windowsPathToWSL(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestFindKimiBinaryFallbacks(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
t.Run("linux/ubuntu uv tool path", func(t *testing.T) {
homeDir := t.TempDir()
setTestHome(t, homeDir)
t.Setenv("PATH", t.TempDir())
kimiGOOS = "linux"
target := filepath.Join(homeDir, ".local", "share", "uv", "tools", "kimi-cli", "bin", "kimi")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
t.Run("windows appdata uv bin", func(t *testing.T) {
setTestHome(t, t.TempDir())
t.Setenv("PATH", t.TempDir())
kimiGOOS = "windows"
appDataDir := t.TempDir()
t.Setenv("APPDATA", appDataDir)
t.Setenv("LOCALAPPDATA", "")
target := filepath.Join(appDataDir, "uv", "bin", "kimi.cmd")
if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil {
t.Fatalf("failed to create candidate dir: %v", err)
}
if err := os.WriteFile(target, []byte("@echo off\r\nexit /b 0\r\n"), 0o755); err != nil {
t.Fatalf("failed to write kimi candidate: %v", err)
}
got, err := findKimiBinary()
if err != nil {
t.Fatalf("findKimiBinary() error = %v", err)
}
if got != target {
t.Fatalf("findKimiBinary() = %q, want %q", got, target)
}
})
}
func TestValidateKimiPassthroughArgs_RejectsConflicts(t *testing.T) {
tests := []struct {
name string
args []string
want string
}{
{name: "--config", args: []string{"--config", "{}"}, want: "--config"},
{name: "--config=", args: []string{"--config={}"}, want: "--config={"},
{name: "--config-file", args: []string{"--config-file", "x.toml"}, want: "--config-file"},
{name: "--config-file=", args: []string{"--config-file=x.toml"}, want: "--config-file=x.toml"},
{name: "--model", args: []string{"--model", "foo"}, want: "--model"},
{name: "--model=", args: []string{"--model=foo"}, want: "--model=foo"},
{name: "-m", args: []string{"-m", "foo"}, want: "-m"},
{name: "-m=", args: []string{"-m=foo"}, want: "-m=foo"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateKimiPassthroughArgs(tt.args)
if err == nil {
t.Fatalf("expected error for args %v", tt.args)
}
if !strings.Contains(err.Error(), tt.want) {
t.Fatalf("error %q does not contain %q", err.Error(), tt.want)
}
})
}
}
func TestBuildKimiInlineConfig(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://127.0.0.1:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
if parsed["default_model"] != "ollama" {
t.Fatalf("default_model = %v, want ollama", parsed["default_model"])
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if ollamaProvider["base_url"] != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %v, want http://127.0.0.1:11434/v1", ollamaProvider["base_url"])
}
if ollamaProvider["api_key"] != "ollama" {
t.Fatalf("provider api_key = %v, want ollama", ollamaProvider["api_key"])
}
models, ok := parsed["models"].(map[string]any)
if !ok {
t.Fatalf("models missing or wrong type: %T", parsed["models"])
}
ollamaModel, ok := models["ollama"].(map[string]any)
if !ok {
t.Fatalf("models.ollama missing or wrong type: %T", models["ollama"])
}
if ollamaModel["provider"] != "ollama" {
t.Fatalf("model provider = %v, want ollama", ollamaModel["provider"])
}
if ollamaModel["model"] != "llama3.2" {
t.Fatalf("model model = %v, want llama3.2", ollamaModel["model"])
}
if ollamaModel["max_context_size"] != float64(65536) {
t.Fatalf("model max_context_size = %v, want 65536", ollamaModel["max_context_size"])
}
}
func TestBuildKimiInlineConfig_UsesConnectableHostForUnspecifiedBind(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://0.0.0.0:11434")
cfg, err := buildKimiInlineConfig("llama3.2", 65536)
if err != nil {
t.Fatalf("buildKimiInlineConfig() error = %v", err)
}
var parsed map[string]any
if err := json.Unmarshal([]byte(cfg), &parsed); err != nil {
t.Fatalf("config is not valid JSON: %v", err)
}
providers, ok := parsed["providers"].(map[string]any)
if !ok {
t.Fatalf("providers missing or wrong type: %T", parsed["providers"])
}
ollamaProvider, ok := providers["ollama"].(map[string]any)
if !ok {
t.Fatalf("providers.ollama missing or wrong type: %T", providers["ollama"])
}
if got, _ := ollamaProvider["base_url"].(string); got != "http://127.0.0.1:11434/v1" {
t.Fatalf("provider base_url = %q, want %q", got, "http://127.0.0.1:11434/v1")
}
}
func TestResolveKimiMaxContextSize(t *testing.T) {
t.Run("uses cloud limit when known", func(t *testing.T) {
got := resolveKimiMaxContextSize("kimi-k2.5:cloud")
if got != 262_144 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 262144", got)
}
})
t.Run("uses model show context length for local models", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
http.NotFound(w, r)
return
}
fmt.Fprint(w, `{"model_info":{"llama.context_length":131072}}`)
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
got := resolveKimiMaxContextSize("llama3.2")
if got != 131_072 {
t.Fatalf("resolveKimiMaxContextSize() = %d, want 131072", got)
}
})
t.Run("falls back to default when show fails", func(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
oldTimeout := kimiModelShowTimeout
kimiModelShowTimeout = 100 * 1000 * 1000 // 100ms
t.Cleanup(func() { kimiModelShowTimeout = oldTimeout })
got := resolveKimiMaxContextSize("llama3.2")
if got != kimiDefaultMaxContextSize {
t.Fatalf("resolveKimiMaxContextSize() = %d, want %d", got, kimiDefaultMaxContextSize)
}
})
}
func TestKimiRun_RejectsConflictingArgsBeforeInstall(t *testing.T) {
k := &Kimi{}
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
t.Fatalf("did not expect install prompt, got %q", prompt)
return false, nil
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
err := k.Run("llama3.2", []string{"--model", "other"})
if err == nil || !strings.Contains(err.Error(), "--model") {
t.Fatalf("expected conflict error mentioning --model, got %v", err)
}
}
func TestKimiRun_PassesInlineConfigAndExtraArgs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binary")
}
tmpDir := t.TempDir()
setTestHome(t, tmpDir)
logPath := filepath.Join(tmpDir, "kimi-args.log")
script := fmt.Sprintf(`#!/bin/sh
for arg in "$@"; do
printf "%%s\n" "$arg" >> %q
done
exit 0
`, logPath)
if err := os.WriteFile(filepath.Join(tmpDir, "kimi"), []byte(script), 0o755); err != nil {
t.Fatalf("failed to write fake kimi: %v", err)
}
t.Setenv("PATH", tmpDir)
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
k := &Kimi{}
if err := k.Run("llama3.2", []string{"--quiet", "--print"}); err != nil {
t.Fatalf("Run() error = %v", err)
}
data, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("failed to read args log: %v", err)
}
lines := strings.Split(strings.TrimSpace(string(data)), "\n")
if len(lines) < 4 {
t.Fatalf("expected at least 4 args, got %v", lines)
}
if lines[0] != "--config" {
t.Fatalf("first arg = %q, want --config", lines[0])
}
var cfg map[string]any
if err := json.Unmarshal([]byte(lines[1]), &cfg); err != nil {
t.Fatalf("config arg is not valid JSON: %v", err)
}
providers := cfg["providers"].(map[string]any)
ollamaProvider := providers["ollama"].(map[string]any)
if ollamaProvider["type"] != "openai_legacy" {
t.Fatalf("provider type = %v, want openai_legacy", ollamaProvider["type"])
}
if lines[2] != "--quiet" || lines[3] != "--print" {
t.Fatalf("extra args = %v, want [--quiet --print]", lines[2:])
}
}
func TestEnsureKimiInstalled(t *testing.T) {
oldGOOS := kimiGOOS
t.Cleanup(func() { kimiGOOS = oldGOOS })
withConfirm := func(t *testing.T, fn func(prompt string) (bool, error)) {
t.Helper()
oldConfirm := DefaultConfirmPrompt
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return fn(prompt)
}
t.Cleanup(func() { DefaultConfirmPrompt = oldConfirm })
}
t.Run("already installed", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "kimi")
kimiGOOS = runtime.GOOS
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
})
t.Run("missing dependencies", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
t.Fatalf("did not expect prompt, got %q", prompt)
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "required dependencies are missing") {
t.Fatalf("expected missing dependency error, got %v", err)
}
})
t.Run("missing and user declines install", func(t *testing.T) {
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
writeFakeBinary(t, tmpDir, "curl")
writeFakeBinary(t, tmpDir, "bash")
kimiGOOS = "linux"
withConfirm(t, func(prompt string) (bool, error) {
if !strings.Contains(prompt, "Kimi is not installed.") {
t.Fatalf("unexpected prompt: %q", prompt)
}
return false, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "installation cancelled") {
t.Fatalf("expected cancellation error, got %v", err)
}
})
t.Run("missing and user confirms install succeeds", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
installLog := filepath.Join(tmpDir, "bash.log")
kimiPath := filepath.Join(tmpDir, "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
echo "$@" >> %q
if [ "$1" = "-c" ]; then
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, installLog, kimiPath, kimiPath)
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
assertKimiBinPath(t, bin)
logData, err := os.ReadFile(installLog)
if err != nil {
t.Fatalf("failed to read install log: %v", err)
}
if !strings.Contains(string(logData), "https://code.kimi.com/install.sh") {
t.Fatalf("expected install.sh command in log, got:\n%s", string(logData))
}
})
t.Run("install succeeds and kimi is in home local bin without PATH update", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
homeDir := t.TempDir()
setTestHome(t, homeDir)
tmpBin := t.TempDir()
t.Setenv("PATH", tmpBin)
kimiGOOS = "linux"
writeFakeBinary(t, tmpBin, "curl")
installedKimi := filepath.Join(homeDir, ".local", "bin", "kimi")
bashScript := fmt.Sprintf(`#!/bin/sh
if [ "$1" = "-c" ]; then
/bin/mkdir -p %q
/bin/cat > %q <<'EOS'
#!/bin/sh
exit 0
EOS
/bin/chmod +x %q
fi
exit 0
`, filepath.Dir(installedKimi), installedKimi, installedKimi)
if err := os.WriteFile(filepath.Join(tmpBin, "bash"), []byte(bashScript), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
bin, err := ensureKimiInstalled()
if err != nil {
t.Fatalf("ensureKimiInstalled() error = %v", err)
}
if bin != installedKimi {
t.Fatalf("bin = %q, want %q", bin, installedKimi)
}
})
t.Run("install command fails", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 1\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "failed to install kimi") {
t.Fatalf("expected install failure error, got %v", err)
}
})
t.Run("install succeeds but binary missing on PATH", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("uses POSIX shell fake binaries")
}
setTestHome(t, t.TempDir())
tmpDir := t.TempDir()
t.Setenv("PATH", tmpDir)
kimiGOOS = "linux"
writeFakeBinary(t, tmpDir, "curl")
if err := os.WriteFile(filepath.Join(tmpDir, "bash"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("failed to write fake bash: %v", err)
}
withConfirm(t, func(prompt string) (bool, error) {
return true, nil
})
_, err := ensureKimiInstalled()
if err == nil || !strings.Contains(err.Error(), "binary was not found on PATH") {
t.Fatalf("expected PATH guidance error, got %v", err)
}
})
}
func TestKimiInstallerCommand(t *testing.T) {
tests := []struct {
name string
goos string
wantBin string
wantParts []string
wantErr bool
}{
{
name: "linux",
goos: "linux",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "darwin",
goos: "darwin",
wantBin: "bash",
wantParts: []string{"-c", "install.sh"},
},
{
name: "windows",
goos: "windows",
wantBin: "powershell",
wantParts: []string{"-Command", "install.ps1"},
},
{
name: "unsupported",
goos: "freebsd",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bin, args, err := kimiInstallerCommand(tt.goos)
if tt.wantErr {
if err == nil {
t.Fatal("expected error")
}
return
}
if err != nil {
t.Fatalf("kimiInstallerCommand() error = %v", err)
}
if bin != tt.wantBin {
t.Fatalf("bin = %q, want %q", bin, tt.wantBin)
}
joined := strings.Join(args, " ")
for _, part := range tt.wantParts {
if !strings.Contains(joined, part) {
t.Fatalf("args %q missing %q", joined, part)
}
}
})
}
}

View File

@@ -209,6 +209,7 @@ Supported integrations:
copilot Copilot CLI (aliases: copilot-cli)
droid Droid
hermes Hermes Agent
kimi Kimi Code CLI
opencode OpenCode
openclaw OpenClaw (aliases: clawdbot, moltbot)
pi Pi
@@ -587,7 +588,7 @@ func (c *launcherClient) launchManagedSingleIntegration(ctx context.Context, nam
return nil
}
if current == "" || needsConfigure || req.ModelOverride != "" || target != current {
if (current == "" || needsConfigure || req.ModelOverride != "" || target != current) && !savedMatchesModels(saved, []string{target}) {
if err := prepareManagedSingleIntegration(name, runner, managed, target); err != nil {
return err
}

View File

@@ -409,7 +409,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigOnlySkipsFinalRun(t *te
}
}
func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsingSavedModel(t *testing.T) {
func TestLaunchIntegration_ManagedSingleIntegrationSkipsRewriteWhenSavedMatches(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
@@ -436,29 +436,30 @@ func TestLaunchIntegration_ManagedSingleIntegrationRepairsMissingLiveConfigUsing
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model is reused for repair")
t.Fatal("selector should not be called when saved model matches target")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
return true, nil
t.Fatal("confirm prompt should not run when saved model matches target")
return false, nil
}
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{Name: "stubmanaged"}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("expected missing live config to be rewritten from saved model: %s", diff)
if len(runner.configured) != 0 {
t.Fatalf("expected Configure to be skipped when saved matches, got %v", runner.configured)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected repaired config to refresh runtime once, got %d", runner.refreshCalls)
if runner.refreshCalls != 0 {
t.Fatalf("expected no runtime refresh when config is unchanged, got %d", runner.refreshCalls)
}
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to use repaired saved model, got %q", runner.ranModel)
t.Fatalf("expected launch to run saved model, got %q", runner.ranModel)
}
}
func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLiveConfig(t *testing.T) {
func TestLaunchIntegration_ManagedSingleIntegrationRewritesWhenSavedDiffers(t *testing.T) {
tmpDir := t.TempDir()
setLaunchTestHome(t, tmpDir)
withInteractiveSession(t, true)
@@ -466,6 +467,8 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/tags":
fmt.Fprint(w, `{"models":[{"name":"gemma4"}]}`)
case "/api/show":
fmt.Fprint(w, `{"model_info":{"general.context_length":131072}}`)
default:
@@ -475,7 +478,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if err := config.SaveIntegration("stubmanaged", []string{"gemma4"}); err != nil {
if err := config.SaveIntegration("stubmanaged", []string{"old-model"}); err != nil {
t.Fatalf("failed to save managed integration config: %v", err)
}
@@ -483,7 +486,7 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
withIntegrationOverride(t, "stubmanaged", runner)
DefaultSingleSelector = func(title string, items []ModelItem, current string) (string, error) {
t.Fatal("selector should not be called when saved model is reused for repair")
t.Fatal("selector should not be called when model override is provided")
return "", nil
}
DefaultConfirmPrompt = func(prompt string, options ConfirmOptions) (bool, error) {
@@ -492,19 +495,19 @@ func TestLaunchIntegration_ManagedSingleIntegrationConfigureOnlyRepairsMissingLi
if err := LaunchIntegration(context.Background(), IntegrationLaunchRequest{
Name: "stubmanaged",
ConfigureOnly: true,
ModelOverride: "gemma4",
}); err != nil {
t.Fatalf("LaunchIntegration returned error: %v", err)
}
if diff := compareStrings(runner.configured, []string{"gemma4"}); diff != "" {
t.Fatalf("expected configure-only flow to rewrite missing live config: %s", diff)
t.Fatalf("expected Configure to run when saved differs from target: %s", diff)
}
if runner.refreshCalls != 1 {
t.Fatalf("expected configure-only repair to refresh runtime once, got %d", runner.refreshCalls)
t.Fatalf("expected runtime refresh once after configure, got %d", runner.refreshCalls)
}
if runner.ranModel != "" {
t.Fatalf("expected configure-only flow to skip final launch, got %q", runner.ranModel)
if runner.ranModel != "gemma4" {
t.Fatalf("expected launch to run configured model, got %q", runner.ranModel)
}
}

View File

@@ -74,6 +74,23 @@ var integrationSpecs = []*IntegrationSpec{
Command: []string{"npm", "install", "-g", "@openai/codex"},
},
},
{
Name: "kimi",
Runner: &Kimi{},
Description: "Moonshot's coding agent for terminal and IDEs",
Hidden: true,
Install: IntegrationInstallSpec{
CheckInstalled: func() bool {
_, err := exec.LookPath("kimi")
return err == nil
},
EnsureInstalled: func() error {
_, err := ensureKimiInstalled()
return err
},
URL: "https://moonshotai.github.io/kimi-cli/en/guides/getting-started.html",
},
},
{
Name: "copilot",
Runner: &Copilot{},

View File

@@ -45,6 +45,14 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
return filepath.Join(home, ".pi", "agent", "models.json")
},
},
{
name: "kimi",
binary: "kimi",
runner: &Kimi{},
checkPath: func(home string) string {
return filepath.Join(home, ".kimi", "config.toml")
},
},
}
for _, tt := range tests {
@@ -57,6 +65,10 @@ func TestEditorRunsDoNotRewriteConfig(t *testing.T) {
if tt.name == "pi" {
writeFakeBinary(t, binDir, "npm")
}
if tt.name == "kimi" {
writeFakeBinary(t, binDir, "curl")
writeFakeBinary(t, binDir, "bash")
}
t.Setenv("PATH", binDir)
configPath := tt.checkPath(home)

BIN
docs/images/hermes.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@@ -2,7 +2,9 @@
title: Hermes Agent
---
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and connects messaging platforms (Telegram, Discord, Slack, WhatsApp, Signal, Email) to models through a unified gateway.
Hermes Agent is a self-improving AI agent built by Nous Research. It features automatic skill creation, cross-session memory, and 70+ skills that it ships with by default.
![Hermes Agent with Ollama](/images/hermes.png)
## Quick start
@@ -10,25 +12,56 @@ Hermes Agent is a self-improving AI agent built by Nous Research. It features au
ollama launch hermes
```
### Pull a model
Ollama handles everything automatically:
Before running the setup wizard, make sure you have a model available. Hermes will auto-detect models downloaded through Ollama.
1. **Install** — If Hermes isn't installed, Ollama prompts to install it via the Nous Research install script
2. **Model** — Pick a model from the selector (local or cloud)
3. **Onboarding** — Ollama configures the Ollama provider, points Hermes at `http://127.0.0.1:11434/v1`, and sets your model as the primary
4. **Gateway** — Optionally connects a messaging platform (Telegram, Discord, Slack, WhatsApp, Signal, Email) and launches the Hermes chat
<Note>Hermes on Windows requires WSL2. Install it with `wsl --install` and re-run from inside the WSL shell.</Note>
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `glm-5.1:cloud` — Reasoning and code generation
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.6` — Reasoning, coding, and visual understanding locally (~24 GB VRAM)
More models at [ollama.com/search](https://ollama.com/search?c=cloud).
## Connect messaging apps
Link Telegram, Discord, Slack, WhatsApp, Signal, or Email to chat with your models from anywhere:
```bash
ollama pull kimi-k2.5:cloud
hermes gateway setup
```
See [Recommended models](#recommended-models) for more options.
## Reconfigure
### Install
Re-run the full setup wizard at any time:
```bash
hermes setup
```
## Manual setup
If you'd rather drive Hermes's own wizard instead of `ollama launch hermes`, install it directly:
```bash
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
```
### Set up
After installation, Hermes launches the setup wizard automatically. Choose **Quick setup**:
Hermes launches the setup wizard automatically. Choose **Quick setup**:
```
How would you like to set up Hermes?
@@ -84,32 +117,3 @@ Connect a messaging platform? (Telegram, Discord, etc.)
Launch hermes chat now? [Y/n]: Y
```
## Recommended models
**Cloud models**:
- `kimi-k2.5:cloud` — Multimodal reasoning with subagents
- `qwen3.5:cloud` — Reasoning, coding, and agentic tool use with vision
- `glm-5.1:cloud` — Reasoning and code generation
- `minimax-m2.7:cloud` — Fast, efficient coding and real-world productivity
**Local models:**
- `gemma4` — Reasoning and code generation locally (~16 GB VRAM)
- `qwen3.5` — Reasoning, coding, and visual understanding locally (~11 GB VRAM)
More models at [ollama.com/search](https://ollama.com/models).
## Configure later
Re-run the setup wizard at any time:
```bash
hermes setup
```
To configure just messaging:
```bash
hermes setup gateway
```

View File

@@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) {
}
func TestAPIGenerateLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
@@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
}
func TestAPIChatLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"os"
"time"
)
type Layer struct {
@@ -60,6 +61,9 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
return Layer{}, err
}
}
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{
MediaType: mediatype,
@@ -83,6 +87,9 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
if err != nil {
return Layer{}, err
}
if err := touchLayer(blob); err != nil {
return Layer{}, err
}
return Layer{
MediaType: mediatype,
@@ -93,6 +100,11 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
}, nil
}
func touchLayer(path string) error {
now := time.Now()
return os.Chtimes(path, now, now)
}
func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" {
return nil, errors.New("opening layer with empty digest")

View File

@@ -19,6 +19,7 @@ import (
"slices"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@@ -33,6 +34,10 @@ import (
"github.com/ollama/ollama/x/imagegen/transfer"
)
// Blobs newer than this may belong to another process that has not written its
// manifest yet. They become eligible for the normal mark-and-sweep pass later.
const layerPruneGracePeriod = time.Hour
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
@@ -478,10 +483,23 @@ func PruneLayers() error {
}
for _, blob := range blobs {
if blob.IsDir() {
continue
}
info, err := blob.Info()
if err != nil {
slog.Error("couldn't stat blob", "blob", blob.Name(), "error", err)
continue
}
if time.Since(info.ModTime()) < layerPruneGracePeriod {
continue
}
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := manifest.BlobsPath(name)
_, err = manifest.BlobsPath(name)
if err != nil {
if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)

View File

@@ -5,14 +5,58 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
for _, digest := range []string{recentDigest, oldDigest} {
p, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(p, nil, 0o644); err != nil {
t.Fatal(err)
}
}
oldPath, err := manifest.BlobsPath(oldDigest)
if err != nil {
t.Fatal(err)
}
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
if err := PruneLayers(); err != nil {
t.Fatal(err)
}
recentPath, err := manifest.BlobsPath(recentDigest)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(recentPath); err != nil {
t.Fatalf("recent orphan was pruned: %v", err)
}
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
t.Fatalf("old orphan still exists: %v", err)
}
}
func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{

View File

@@ -2408,7 +2408,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
// TODO(parthsareen): temporary fix for https://github.com/ollama/ollama/issues/15260.
// To revisit for other models and have a consistent pattern across models through parsers.
forceImmediate := m.Config.Parser == "gemma4" && req.Think != nil && !req.Think.Bool()
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && !forceImmediate && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil
}

View File

@@ -2108,6 +2108,132 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
})
}
// TestChatFormatWithThinkFalse verifies that when a model uses a builtin
// parser that supports thinking (e.g. gemma4) and the request explicitly
// disables thinking (think=false), the format constraint is passed to the
// first and only completion call. Previously, format was deferred for all
// thinking-capable parsers and only re-applied after an end-of-thinking
// transition — a transition that never fires when thinking is off. See
// https://github.com/ollama/ollama/issues/15260.
func TestChatFormatWithThinkFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := &mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Use the gemma4 builtin parser — it reports HasThinkingSupport=true, which
// adds CapabilityThinking to the model and previously triggered deferral of
// the format even when the user passed think=false.
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-gemma4-parser",
Files: map[string]string{"file.gguf": digest},
Parser: "gemma4",
Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }}{{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("create: expected status 200, got %d: %s", w.Code, w.Body.String())
}
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}},"required":["answer"]}`)
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
requestsMu.Lock()
requests = append(requests, r)
requestsMu.Unlock()
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
}
streamRequest := false
think := false
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-gemma4-parser",
Messages: []api.Message{{Role: "user", Content: "Respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
if w.Code != http.StatusOK {
t.Fatalf("chat: expected status 200, got %d: %s", w.Code, w.Body.String())
}
if len(requests) != 1 {
t.Fatalf("expected a single completion call, got %d", len(requests))
}
if !bytes.Equal([]byte(format), []byte(requests[0].Format)) {
t.Errorf("expected first completion format to match the request format, got %q", string(requests[0].Format))
}
}
func TestGenerateUnload(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -337,9 +337,10 @@ func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
liveLen := min(c.offset, c.keys.Dim(2))
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
}
}

View File

@@ -151,20 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
}
}
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
type completionRequest struct {
Prompt string `json:"prompt"`
Options *completionOpts `json:"options,omitempty"`
}
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
type CompletionRequest struct {
Prompt string
Options api.Options
Logprobs bool
TopLogprobs int
}
type CompletionResponse struct {
@@ -177,6 +168,8 @@ type CompletionResponse struct {
EvalCount int
EvalDuration time.Duration
Logprobs []llm.Logprob
Error *api.StatusError
}
@@ -201,19 +194,13 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := completionRequest{
Prompt: req.Prompt,
creq := CompletionRequest{
Prompt: req.Prompt,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}
if req.Options != nil {
creq.Options = &completionOpts{
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
RepeatLastN: req.Options.RepeatLastN,
PresencePenalty: req.Options.PresencePenalty,
NumPredict: req.Options.NumPredict,
}
creq.Options = *req.Options
}
body, err := json.Marshal(creq)
@@ -239,7 +226,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))}
}
scanner := bufio.NewScanner(resp.Body)
@@ -262,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration,
Logprobs: raw.Logprobs,
}
fn(cresp)

View File

@@ -62,3 +62,25 @@ var LogitSoftcap = Compile2(
},
Shapeless(),
)
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
// per-expert scores after top-k) and the post-bias negation (used as the
// argpartition key for top-k) share a single kernel.
var sigmoidRouterFused = Compile(
"SigmoidRouter",
func(in ...*Array) []*Array {
gates, bias := in[0], in[1]
orig := gates.Sigmoid()
neg := orig.Add(bias).Negative()
return []*Array{orig, neg}
},
Shapeless(),
)
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
out := sigmoidRouterFused(gates, bias)
return out[0], out[1]
}

View File

@@ -10,6 +10,8 @@ import (
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
"unsafe"
"github.com/ollama/ollama/logutil"
@@ -18,20 +20,28 @@ import (
type Array struct {
ctx C.mlx_array
name string
pinned int
pinned atomic.Int32
}
var arrays []*Array
var (
arrays []*Array
arraysMu sync.Mutex
)
// constructor utilities
func New(name string) *Array {
t := &Array{name: name}
if tracing {
traceScratch = append(traceScratch, t)
} else {
arraysMu.Lock()
defer arraysMu.Unlock()
arrays = append(arrays, t)
}
return t
}
@@ -131,7 +141,7 @@ func (t *Array) Clone() *Array {
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned++
t.pinned.Add(1)
}
}
}
@@ -140,8 +150,7 @@ func Pin(s ...*Array) {
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned--
if t.pinned < 0 {
if t.pinned.Add(-1) < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
}
@@ -151,9 +160,11 @@ func Unpin(s ...*Array) {
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
arraysMu.Lock()
defer arraysMu.Unlock()
n := 0
for _, t := range arrays {
if t.pinned > 0 && t.Valid() {
if t.pinned.Load() > 0 && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
@@ -180,7 +191,7 @@ func (t *Array) String() string {
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Int("pinned", t.pinned),
slog.Int("pinned", int(t.pinned.Load())),
}
if t.Valid() {
attrs = append(attrs,
@@ -194,19 +205,19 @@ func (t *Array) LogValue() slog.Value {
// shape utilities
func (t Array) Size() int {
func (t *Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t Array) NumBytes() int {
func (t *Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t Array) NumDims() int {
func (t *Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t Array) Dims() []int {
func (t *Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
@@ -215,29 +226,32 @@ func (t Array) Dims() []int {
return dims
}
func (t Array) Dim(dim int) int {
func (t *Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t Array) DType() DType {
func (t *Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t Array) Int() int {
func (t *Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t Array) Float() float64 {
func (t *Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t Array) Ints() []int {
func (t *Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
@@ -245,7 +259,10 @@ func (t Array) Ints() []int {
return ints
}
func (t Array) Floats() []float32 {
func (t *Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
@@ -253,7 +270,7 @@ func (t Array) Floats() []float32 {
return floats
}
func (t Array) Save(name string) error {
func (t *Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
@@ -262,6 +279,8 @@ func (t Array) Save(name string) error {
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
arraysMu.Lock()
defer arraysMu.Unlock()
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
@@ -270,7 +289,7 @@ func LogArrays() {
for _, t := range arrays {
nb := t.NumBytes()
total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims()))
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
}

View File

@@ -150,7 +150,7 @@ func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned > 0 {
if a.pinned.Load() > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {

View File

@@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
}
type LayerNorm struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
@@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
}
type RMSNorm struct {
Weight Array `weight:"weight"`
Weight *Array `weight:"weight"`
}
func (r RMSNorm) Forward(x *Array, eps float32) *Array {
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out

View File

@@ -1,12 +1,12 @@
package mlx
type Linear struct {
Weight Array `weight:"weight"`
Bias Array `weight:"bias"`
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array {
func (m *Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
@@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array {
return x.Matmul(w)
}
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1)
// TODO: bias
return x.GatherMM(w, lhs, rhs, sorted)
}
type Embedding struct {
Weight Array `weight:"weight"`
Weight *Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Array) *Array {

View File

@@ -72,6 +72,10 @@ func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
if len(others) == 0 {
return t
}
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
@@ -127,9 +131,9 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
return out
}
func (t *Array) Logsumexp(keepDims bool) *Array {
out := New("LOGSUMEXP")
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
out := New("LOGSUMEXP_AXIS")
C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
@@ -139,6 +143,12 @@ func (t *Array) Less(other *Array) *Array {
return out
}
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
@@ -169,6 +179,12 @@ func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
return out
}
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
out := New("SCATTER_ADD_AXIS")
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {

View File

@@ -376,6 +376,9 @@ func Concatenate(arrays []*Array, axis int) *Array {
if len(arrays) == 0 {
return nil
}
if len(arrays) == 1 {
return arrays[0]
}
return arrays[0].Concatenate(axis, arrays[1:]...)
}

View File

@@ -6,36 +6,60 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"sort"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error {
// Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.NumPredict.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
request.Tokens = tokens
return nil
}
// The runner serializes requests today so we just use a fixed slot ID.
const pipelineSlot = 0
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory()
ctx := request.Ctx
var (
sample, logprobs *mlx.Array
nextSample, nextLogprobs *mlx.Array
)
var sample, nextSample sampler.Result
defer func() {
if request.Sampler != nil {
request.Sampler.Free()
}
mlx.Unpin(sample, logprobs)
mlx.Unpin(nextSample, nextLogprobs)
r.Sampler.Remove(pipelineSlot)
mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep()
mlx.ClearCache()
@@ -46,27 +70,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
request.Sampler.ResetHistory(inputs)
inputs := request.Tokens
session := r.cache.begin(r.Model, inputs)
defer session.close()
@@ -118,7 +122,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], 1, n), caches)
mlx.Sweep()
materializeCaches()
processed += n
@@ -135,41 +139,44 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ClearCache()
}
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
// Register the sampler after prefill completes.
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token, caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
logprobs := logits.Subtract(logits.Logsumexp(true))
sample := request.Sampler.Sample(logprobs)
mlx.Pin(sample, logprobs)
sample := r.Sampler.Sample([]int{pipelineSlot}, logits)
mlx.Pin(sample.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(sample, logprobs)
return sample, logprobs
mlx.AsyncEval(sample.Arrays()...)
return sample
}
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
sample = step(mlx.FromValues(tokens[processed:], 1, total-processed))
var b bytes.Buffer
dec := decoder{
tokenizer: r.Tokenizer,
wantLogprobs: request.SamplerOpts.Logprobs,
wantTopLogprobs: request.SamplerOpts.TopLogprobs,
}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1}
for i := range request.Options.MaxTokens {
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
request.Sampler.AppendToken(sample)
nextSample, nextLogprobs = step(sample)
nextSample = step(sample.Token.ExpandDims(-1))
if i == 0 {
mlx.Eval(sample)
mlx.Eval(sample.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Int())
output := int32(sample.Token.Int())
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
@@ -178,17 +185,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
break
}
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- CompletionResponse{
Content: r.Decode(output, &b),
}:
if resp, ok := dec.decode(sample); ok {
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- resp:
}
}
mlx.Unpin(sample, logprobs)
sample, logprobs = nextSample, nextLogprobs
nextSample, nextLogprobs = nil, nil
mlx.Unpin(sample.Arrays()...)
sample, nextSample = nextSample, sampler.Result{}
if i%256 == 0 {
mlx.ClearCache()
@@ -204,13 +210,69 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
}
}
func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
token := r.Tokenizer.Decode([]int32{sample})
// decoder serializes sampled tokens into response chunks, holding bytes
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
// with those bytes so Content and Logprobs stay aligned when a chunk does
// flush.
type decoder struct {
tokenizer *tokenizer.Tokenizer
buf bytes.Buffer
logprobs []llm.Logprob
wantLogprobs bool
wantTopLogprobs int
}
if _, err := b.WriteString(token); err != nil {
slog.Error("Failed to write token to buffer", "error", err)
return ""
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
output := int32(res.Token.Int())
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.wantLogprobs, d.wantTopLogprobs, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf)
if content == "" {
return CompletionResponse{}, false
}
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
d.logprobs = nil
return resp, true
}
// buildLogprob converts the sampler's logprob tensors into the wire-format
// llm.Logprob entries the caller wants. The sampler populates its logprob
// tensors whenever any registered slot requested them, so the caller must
// gate emission on its own request config (wantLogprobs / wantTopLogprobs)
// rather than on whether the tensors happen to be non-nil.
func buildLogprob(sample sampler.Result, wantLogprobs bool, wantTopLogprobs int, decode func([]int32) string) []llm.Logprob {
if !wantLogprobs || sample.Logprob == nil {
return nil
}
tok := func(id int32) string { return decode([]int32{id}) }
out := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: tok(int32(sample.Token.Int())),
Logprob: float64(sample.Logprob.Floats()[0]),
},
}
return flushValidUTF8Prefix(b)
if wantTopLogprobs > 0 && sample.TopTokens != nil {
ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids))
for i, id := range ids {
pairs[i] = llm.TokenLogprob{
Token: tok(int32(id)),
Logprob: float64(vals[i]),
}
}
// The sampler emits the top maxK across registered slots via
// Argpartition, which leaves entries unsorted.
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob
})
if wantTopLogprobs < len(pairs) {
pairs = pairs[:wantTopLogprobs]
}
out.TopLogprobs = pairs
}
return []llm.Logprob{out}
}

View File

@@ -18,36 +18,25 @@ import (
"github.com/ollama/ollama/x/tokenizer"
)
// Request is a short-lived struct that carries a completion request through
// a channel from the HTTP handler to the runner goroutine. The ctx field
// must travel with the request so that cancellation propagates across the
// channel boundary.
type Request struct {
TextCompletionsRequest
CompletionRequest
Responses chan CompletionResponse
Pipeline func(Request) error
Pipeline func(context.Context, Request) error
Ctx context.Context
Sampler *sample.Sampler
}
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n"`
PresencePenalty float32 `json:"presence_penalty"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
Ctx context.Context //nolint:containedctx
Tokens []int32
SamplerOpts sample.Options
}
type Runner struct {
Model base.Model
Tokenizer *tokenizer.Tokenizer
Requests chan Request
Sampler *sample.Sampler
cache kvCache
contextLength int
}
@@ -79,6 +68,7 @@ func (r *Runner) Load(modelName string) error {
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
r.Sampler = sample.New(r.contextLength)
mlx.EnableCompile()
return nil
@@ -147,7 +137,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
if err := request.Pipeline(request.Ctx, request); err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {

View File

@@ -0,0 +1,286 @@
//go:build mlx
package sample
import (
"math"
"sort"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// logprobEntry is the (token id, logprob) pair returned by the sampler's
// top-K extraction, used after the test-side descending sort.
type logprobEntry struct {
id int
logprob float64
}
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
// and returns the greedily-sampled token id, its logprob, and the top-K
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
t.Helper()
s := New(128)
defer func() {
s.Free()
mlx.Sweep()
}()
s.Add(0, Options{Logprobs: true, TopLogprobs: topK}, nil)
tensor := mlx.FromValues(logits, 1, len(logits))
res := s.Sample([]int{0}, tensor)
mlx.Pin(res.Arrays()...)
defer mlx.Unpin(res.Arrays()...)
mlx.Sweep()
mlx.Eval(res.Arrays()...)
selected := res.Token.Int()
selLP := float64(res.Logprob.Floats()[0])
var top []logprobEntry
if topK > 0 && res.TopTokens != nil {
ids := res.TopTokens.Ints()
vals := res.TopLogprobs.Floats()
top = make([]logprobEntry, len(ids))
for i, id := range ids {
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
}
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
}
return selected, selLP, top
}
func TestSampleLogprobsBasic(t *testing.T) {
tests := []struct {
name string
logits []float32
topK int
wantSelectedID int
wantTopLen int
}{
{
name: "single token without top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 0,
wantSelectedID: 0,
wantTopLen: 0,
},
{
name: "single token with top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 3,
wantSelectedID: 0,
wantTopLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
if selected != tt.wantSelectedID {
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
}
if len(top) != tt.wantTopLen {
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
}
})
}
}
func TestSampleLogprobsNumericalStability(t *testing.T) {
logits := []float32{1000.0, 999.0, 998.0}
_, selLP, top := runSampleLogprobs(t, logits, 3)
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
t.Errorf("selected logprob is not finite: %f", selLP)
}
for i, e := range top {
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
}
}
}
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if selLP > 0 {
t.Errorf("selected logprob should be <= 0, got %f", selLP)
}
for i, e := range top {
if e.logprob > 0 {
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
}
}
if tt.name == "uniform" {
want := 1.0 / float64(len(tt.logits))
got := math.Exp(selLP)
if math.Abs(got-want) > 1e-6 {
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending at %d: %f > %f",
i, top[i].logprob, top[i-1].logprob)
}
}
found := false
for _, e := range top {
if e.id == selected {
found = true
if math.Abs(e.logprob-selLP) > 1e-6 {
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
}
break
}
}
if !found {
t.Errorf("selected token %d not present in top-K", selected)
}
})
}
}
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
{"large differences", []float32{10.0, 0.0, -10.0}},
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
{"very large values", []float32{500.0, 499.0, 498.0}},
{"very small values", []float32{-500.0, -499.0, -498.0}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if len(top) != len(tt.logits) {
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
}
var sum float64
for _, e := range top {
p := math.Exp(e.logprob)
if p < 0 || p > 1 {
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
}
sum += p
}
if math.Abs(sum-1.0) > 1e-5 {
t.Errorf("probabilities sum = %f, want 1.0", sum)
}
})
}
}
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
logits := []float32{3.0, 1.0, 2.0, 0.5}
maxIdx := 0
for i, v := range logits[1:] {
if v > logits[maxIdx] {
maxIdx = i + 1
}
}
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
if selected != maxIdx {
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
}
if top[0].id != maxIdx {
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
}
if math.Abs(top[0].logprob-selLP) > 1e-6 {
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
}
}
// TestBatchedLogprobsPerRow verifies that per-row logprobs in a batched
// sample call match the per-slot reference. The numerically-stable softmax
// must reduce along the last axis only, not over the whole batch.
func TestBatchedLogprobsPerRow(t *testing.T) {
rowA := []float32{2, 1, 0}
rowB := []float32{0, 5, 0}
_, wantA, _ := runSampleLogprobs(t, rowA, 0)
_, wantB, _ := runSampleLogprobs(t, rowB, 0)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(1, Options{Logprobs: true}, nil)
s.Add(2, Options{Logprobs: true}, nil)
logits := mlx.FromValues(append(append([]float32{}, rowA...), rowB...), 2, 3)
res := s.Sample([]int{1, 2}, logits)
mlx.Pin(res.Arrays()...)
t.Cleanup(func() { mlx.Unpin(res.Arrays()...) })
mlx.Eval(res.Arrays()...)
got := res.Logprob.Floats()
if len(got) != 2 {
t.Fatalf("Logprob length = %d, want 2", len(got))
}
if math.Abs(float64(got[0])-wantA) > 1e-5 {
t.Errorf("row 0 logprob = %f, want %f (per-slot reference)", got[0], wantA)
}
if math.Abs(float64(got[1])-wantB) > 1e-5 {
t.Errorf("row 1 logprob = %f, want %f (per-slot reference)", got[1], wantB)
}
}
func TestSampleLogprobsTopKOrdering(t *testing.T) {
// Logits chosen so argmax order differs from index order.
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
wantOrder := []int{1, 3, 4, 0, 2}
_, _, top := runSampleLogprobs(t, logits, len(logits))
if len(top) != len(wantOrder) {
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
}
for i, e := range top {
if e.id != wantOrder[i] {
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
i, top[i].logprob, i-1, top[i-1].logprob)
}
}
}

View File

@@ -1,189 +1,574 @@
package sample
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Options struct {
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
// Logprobs causes Sample to populate Result.Logprob with the selected
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
Logprobs bool
TopLogprobs int
}
// Result bundles the outputs of one decode step. Logprob/TopTokens/
// TopLogprobs are populated whenever any registered slot has Logprobs
// (respectively TopLogprobs>0). Consumers need to filter by their
// per-slot Options.
type Result struct {
Token *mlx.Array // sampled token ids, shape [B]
Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
}
// Arrays returns the tensor fields as a slice so callers can drive the mlx
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
// fields stay nil; the mlx helpers skip them.
func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
}
// Sampler is a batched, slot-based sampler. Sequences are registered with
// Add and released with Remove. Each Sample call takes a subset of
// registered slots (in any order) with their [B,V] logits, samples one
// token per row, and appends it to that slot's ring-buffer history. Slots
// not named in a given call are untouched.
type Sampler struct {
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
PresencePenalty float32
slots []*slotState
byID map[int]*slotState
history *mlx.Array
// history is the pooled ring-buffer storage, [B, W] int32. Row i
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
history *mlx.Array
// allSameOpts: every registered slot shares Options. When true the
// canonical shared value is s.slots[0].opts.
allSameOpts bool
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
// any registered slot requests them, even if that slot isn't in the
// current call.
anyLogprobs bool
maxTopLogprobs int
// numCtx is the runner's context window; normalize uses it to
// resolve the repeat_last_n == -1 sentinel.
numCtx int
}
type slotState struct {
opts Options
transforms []transform
historyLen int
transforms []Transform
}
func New(temp, top_p, min_p float32, top_k, repeatLastN int, presencePenalty float32) *Sampler {
s := &Sampler{
Temperature: temp,
TopP: top_p,
MinP: min_p,
TopK: top_k,
RepeatLastN: repeatLastN,
PresencePenalty: presencePenalty,
type slotCtx struct {
opts Options
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
}
type transform func(*slotCtx, *mlx.Array) *mlx.Array
// New constructs an empty sampler with no registered slots. numCtx is
// the runner's context window and must be positive.
func New(numCtx int) *Sampler {
return &Sampler{
byID: make(map[int]*slotState),
allSameOpts: true,
numCtx: numCtx,
}
}
// historyWidth returns the column count of the pooled history tensor,
// or 0 when no penalty slot has forced it to be allocated.
func (s *Sampler) historyWidth() int {
if s.history == nil {
return 0
}
return s.history.Dim(1)
}
func (o Options) usesHistory() bool {
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
// contract (0 = disabled), overriding any penalty coefficients.
if o.RepeatLastN == 0 {
return false
}
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
}
func (o Options) normalize(numCtx int) Options {
if o.RepeatPenalty <= 0 {
o.RepeatPenalty = 1
}
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
// the caller's context window.
if o.RepeatLastN < 0 {
o.RepeatLastN = numCtx
}
if !o.usesHistory() {
// Zero the ring capacity so slots that differ only in a spurious
// RepeatLastN still batch together and don't inflate pool width.
o.RepeatLastN = 0
}
return o
}
func (o Options) buildTransforms() []transform {
var ts []transform
if o.usesHistory() {
ts = append(ts, penalty)
}
var transforms []Transform
if presencePenalty != 0 {
transforms = append(transforms, penalty)
hasTopP := o.TopP > 0 && o.TopP < 1
hasTopK := o.TopK > 0
switch {
case hasTopP:
// topKTopP always does a full descending sort for the top-P
// cumulative mask and opportunistically masks top-K during the
// same pass when it is also configured.
ts = append(ts, topKTopP)
case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort.
ts = append(ts, topK)
}
if top_p > 0 && top_p < 1 {
transforms = append(transforms, topP)
if o.MinP != 0 {
ts = append(ts, minP)
}
if min_p != 0 {
transforms = append(transforms, minP)
}
if top_k > 0 {
transforms = append(transforms, topK)
}
if temp == 0 {
transforms = append(transforms, greedy)
if o.Temperature == 0 {
ts = append(ts, greedy)
} else {
transforms = append(transforms, temperature)
ts = append(ts, temperature)
}
s.transforms = transforms
return s
return ts
}
func (s *Sampler) usesHistory() bool {
return s.PresencePenalty != 0
}
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
if history != nil {
mlx.Pin(history)
// Add registers a sequence under seqID. The last RepeatLastN entries of
// priorTokens seed the ring buffer.
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
if _, dup := s.byID[seqID]; dup {
panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
}
if s.history != nil {
opts = opts.normalize(s.numCtx)
slot := &slotState{
opts: opts,
transforms: opts.buildTransforms(),
}
// Grow the pool to hold this slot's row. The pool is lazy — the first
// penalty slot allocates it — and thereafter every registered slot
// gets a row (rows for non-penalty slots are zero and never read).
// Invariant: s.history is pinned whenever non-nil.
if s.history != nil || opts.usesHistory() {
targetWidth := max(opts.RepeatLastN, s.historyWidth())
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
var pool *mlx.Array
switch {
case s.history == nil && len(s.slots) == 0:
pool = newRow
case s.history == nil:
// First penalty slot with non-penalty slots already registered;
// seed zero rows so s.slots and pool row indices stay aligned.
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
pool = zeros.Concatenate(0, newRow)
case targetWidth > s.historyWidth():
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
default:
pool = s.history.Concatenate(0, newRow)
}
mlx.Pin(pool)
mlx.Unpin(s.history)
s.history = pool
if opts.usesHistory() {
// Cap on seed so the next write's ring position
// (historyLen % RepeatLastN) lands at 0, overwriting the
// oldest entry when the ring was filled from priors.
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
}
}
s.history = history
s.historyLen = historyLen
s.slots = append(s.slots, slot)
s.byID[seqID] = slot
s.recomputeInvariants()
}
func (s *Sampler) ResetHistory(history []int32) {
if !s.usesHistory() {
// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
// elsewhere.
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
take := min(len(priorTokens), repeatLastN)
if take <= 0 {
return mlx.Zeros(mlx.DTypeInt32, 1, width)
}
row := make([]int32, width)
copy(row, priorTokens[len(priorTokens)-take:])
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
}
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
// from s.slots. Called at the end of Add and Remove.
func (s *Sampler) recomputeInvariants() {
if len(s.slots) == 0 {
s.allSameOpts = true
s.anyLogprobs = false
s.maxTopLogprobs = 0
return
}
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN {
history = history[len(history)-s.RepeatLastN:]
first := s.slots[0].opts
s.allSameOpts = true
s.anyLogprobs = false
s.maxTopLogprobs = 0
for _, slot := range s.slots {
if slot.opts != first {
s.allSameOpts = false
}
if slot.opts.Logprobs {
s.anyLogprobs = true
if slot.opts.TopLogprobs > s.maxTopLogprobs {
s.maxTopLogprobs = slot.opts.TopLogprobs
}
}
}
if len(history) == 0 {
s.setHistory(nil, 0)
}
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
func (s *Sampler) Remove(seqID int) {
slot, ok := s.byID[seqID]
if !ok {
return
}
delete(s.byID, seqID)
row := slices.Index(s.slots, slot)
s.slots = slices.Delete(s.slots, row, row+1)
s.recomputeInvariants()
if s.history == nil {
return
}
tokens := append([]int32(nil), history...)
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens))
}
func (s *Sampler) AppendToken(token *mlx.Array) {
if !s.usesHistory() || token == nil {
return
}
next := token.AsType(mlx.DTypeInt32)
nextLen := next.Size()
if s.history != nil && s.historyLen > 0 {
next = s.history.Concatenate(0, next)
nextLen += s.historyLen
}
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN {
trim := nextLen - s.RepeatLastN
next = next.Slice(mlx.Slice(trim, nextLen))
nextLen = s.RepeatLastN
}
s.setHistory(next, nextLen)
n := s.history.Dim(0)
var newHistory *mlx.Array
switch {
case n == 1:
newHistory = nil
case row == 0:
newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
case row == n-1:
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
default:
before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
newHistory = before.Concatenate(0, after)
}
mlx.Pin(newHistory)
mlx.Unpin(s.history)
s.history = newHistory
}
// Free releases the pooled history tensor and resets the sampler to the
// New-equivalent state so it may be reused.
func (s *Sampler) Free() {
s.setHistory(nil, 0)
}
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array {
for _, transform := range s.transforms {
logits = transform(s, logits)
mlx.Unpin(s.history)
*s = Sampler{
byID: make(map[int]*slotState),
allSameOpts: true,
numCtx: s.numCtx,
}
return logits
}
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false)
}
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array {
return mlx.DivScalar(logits, s.Temperature).Categorical(-1)
}
func topP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.TopP <= 0 || s.TopP >= 1 {
return logprobs
// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
// slot whose logits live at row i. Each sampled token is appended to its
// slot's ring. Slots not named in seqIDs are untouched.
func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
if len(seqIDs) == 0 {
return Result{}
}
order := logprobs.Negative().ArgsortAxis(-1)
sortedLogprobs := logprobs.TakeAlongAxis(order, -1)
sortedProbs := mlx.SoftmaxAxis(sortedLogprobs, -1, true)
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedLogprobs, mlx.FromValue(float32(math.Inf(-1))))
return logprobs.PutAlongAxis(order, filtered, -1)
}
func minP(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return logprobs
slots := make([]*slotState, len(seqIDs))
for i, id := range seqIDs {
slot, ok := s.byID[id]
if !ok {
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
}
slots[i] = slot
}
maxLogprobs := logprobs.TakeAlongAxis(logprobs.Argmax(-1, true), -1)
minLogprobs := mlx.AddScalar(maxLogprobs, float32(math.Log(float64(s.MinP))))
var token *mlx.Array
if opts0, ok := s.canBatch(slots); ok {
token = s.sampleTokensUniform(slots, opts0, logits)
} else {
token = s.sampleTokensSerial(slots, logits)
}
res := Result{Token: token}
if s.anyLogprobs {
// Log-softmax over original logits so every row holds a truthful
// value (compute-for-all; consumers filter per-slot). Subtract
// max first for numerical stability in the logsumexp.
lp := logits.AsType(mlx.DTypeFloat32)
lp = lp.Subtract(lp.MaxAxis(-1, true))
lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
if s.maxTopLogprobs > 0 {
k := s.maxTopLogprobs
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
k = vocab
}
// Argpartition on the negated values places the K largest
// (unsorted) in positions [0:K].
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
res.TopTokens = idx.AsType(mlx.DTypeInt32)
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
}
}
return res
}
// canBatch reports whether the call can take the uniform batched path.
// All slots must share Options; when penalties are active the call must
// additionally cover every registered slot in registration order with a
// full ring, because the uniform path indexes the pool positionally.
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
if !s.allSameOpts {
return Options{}, false
}
// slots is non-empty (Sample guards) and every slot is registered,
// so s.slots[0].opts is the canonical shared value.
shared := s.slots[0].opts
if !shared.usesHistory() {
return shared, true
}
if len(slots) != len(s.slots) {
return Options{}, false
}
for i, slot := range slots {
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
return Options{}, false
}
}
return shared, true
}
// sampleTokensUniform runs one fused transform pass over the whole batch.
// Reached only when canBatch is true, which lets the pool be used in place
// with a single PutAlongAxis write-back and no gather.
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
B := len(slots)
var hist *mlx.Array
if opts.usesHistory() {
hist = s.history
if s.historyWidth() > opts.RepeatLastN {
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
}
}
ctx := &slotCtx{opts: opts, history: hist}
scores := logits
for _, t := range slots[0].transforms {
scores = t(ctx, scores)
}
token := scores
if !opts.usesHistory() {
return token
}
writeIdxData := make([]int32, B)
for i, slot := range slots {
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
slot.historyLen++
}
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
return token
}
// sampleTokensSerial runs each slot's transforms against its own row of
// logits.
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
perSlotTokens := make([]*mlx.Array, len(slots))
rowOf := make(map[*slotState]int, len(s.slots))
for i, slot := range s.slots {
rowOf[slot] = i
}
for i, slot := range slots {
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
var hist *mlx.Array
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
poolRow := rowOf[slot]
fill := min(slot.historyLen, slot.opts.RepeatLastN)
hist = s.history.Slice(
mlx.Slice(poolRow, poolRow+1),
mlx.Slice(0, fill),
)
}
ctx := &slotCtx{opts: slot.opts, history: hist}
scores := row
for _, t := range slot.transforms {
scores = t(ctx, scores)
}
perSlotTokens[i] = scores
}
token := mlx.Concatenate(perSlotTokens, 0)
if s.history != nil {
// For each writing slot collect its flat (row-major) pool offset
// and the call-order position of its token. One PutAlongAxis on a
// flat view of the pool scatters all writes in a single op.
flatOffsets := make([]int32, 0, len(slots))
tokenPos := make([]int32, 0, len(slots))
for i, slot := range slots {
if !slot.opts.usesHistory() {
continue
}
ringPos := slot.historyLen % slot.opts.RepeatLastN
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
tokenPos = append(tokenPos, int32(i))
slot.historyLen++
}
if len(flatOffsets) > 0 {
m := len(flatOffsets)
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
writingTokens := token
if m != len(slots) {
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
writingTokens = token.TakeAxis(tokenPosIdx, 0)
}
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
}
}
return token
}
func greedy(_ *slotCtx, scores *mlx.Array) *mlx.Array {
return scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
}
func temperature(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, ctx.opts.Temperature).Categorical(-1).AsType(mlx.DTypeInt32)
}
// topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only case
// uses a cheaper partial sort via the topK transform.
func topKTopP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := ctx.opts.TopK > 0 && ctx.opts.TopK < vocab
order := scores.Negative().ArgsortAxis(-1)
sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(ctx.opts.TopP))
sorted = mlx.Where(keep, sorted, negInf)
out := scores.PutAlongAxis(order, sorted, -1)
// Top-K: sorted is already in descending order, so positions [K, V) are
// the ones to drop. Scatter -inf through their original-layout indices
// (order[K:]). Positional (not value-based) so exactly K tokens survive —
// ties at the K-th logit get broken by the sort order rather than
// promoted through the filter.
if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1)
}
return out
}
func minP(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if ctx.opts.MinP <= 0 || ctx.opts.MinP > 1 {
return scores
}
maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(ctx.opts.MinP))))
return mlx.Where(
logprobs.Less(minLogprobs),
scores.Less(threshold),
mlx.FromValue(float32(math.Inf(-1))),
logprobs,
scores,
)
}
func topK(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.TopK <= 0 {
return logprobs
func topK(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
if ctx.opts.TopK <= 0 {
return scores
}
vocab := scores.Dim(scores.NumDims() - 1)
if ctx.opts.TopK >= vocab {
return scores
}
vocab := logprobs.Dim(logprobs.NumDims() - 1)
if s.TopK >= vocab {
return logprobs
}
mask := logprobs.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
mask := scores.Negative().ArgpartitionAxis(ctx.opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(ctx.opts.TopK, mlx.End))
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
func penalty(s *Sampler, logprobs *mlx.Array) *mlx.Array {
if s.history == nil || s.historyLen == 0 || s.PresencePenalty == 0 {
return logprobs
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
tokenIndices := ctx.history
if tokenIndices == nil {
return scores
}
tokenIndices := s.history
if logprobs.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0)
if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
if ctx.opts.RepeatPenalty != 1 {
factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))),
mlx.FromValue(ctx.opts.RepeatPenalty),
mlx.FromValue(1/ctx.opts.RepeatPenalty),
)
adjusted = adjusted.Multiply(factor)
}
if ctx.opts.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
}
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
}
selected := logprobs.TakeAlongAxis(tokenIndices, -1)
adjusted := mlx.AddScalar(selected, -s.PresencePenalty)
return logprobs.PutAlongAxis(tokenIndices, adjusted, -1)
if ctx.opts.FrequencyPenalty != 0 {
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
}
return scores
}

View File

@@ -9,54 +9,283 @@ import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
// RepeatLastN = 1, PresencePenalty = 6
s := New(0, 0, 0, 0, 1, 6)
defer func() {
// slotLogits builds a [1, V] logits tensor for a single-slot Sample call.
func slotLogits(values []float32) *mlx.Array {
return mlx.FromValues(values, 1, len(values))
}
// batchLogits stacks per-row float32 slices of equal length into a [B, V]
// logits tensor.
func batchLogits(rows ...[]float32) *mlx.Array {
v := len(rows[0])
flat := make([]float32, 0, len(rows)*v)
for _, r := range rows {
if len(r) != v {
panic("batchLogits: rows must share vocab size")
}
flat = append(flat, r...)
}
return mlx.FromValues(flat, len(rows), v)
}
// sampleOne runs Sample on a freshly-added single slot and returns the
// sampled token id. Used both for the single-slot options table and as the
// reference oracle for the batched-equivalence test.
func sampleOne(t *testing.T, opts Options, priorTokens []int32, values []float32) int {
t.Helper()
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
}()
})
s.Add(0, opts, priorTokens)
s.ResetHistory([]int32{0})
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
logprobs := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logprobs)
got := s.Sample([]int{0}, slotLogits(values)).Token
mlx.Eval(got)
return got.Int()
}
// logprobs will be [0, -1, 4] after the penalty
// and then (index) 2 after the greedy sampler
gotInt := got.Int()
if gotInt != 2 {
t.Fatalf("got %d, want 2", gotInt)
// logOf returns log(p) as a float32 so tests can build logits that softmax to
// a chosen probability distribution.
func logOf(p float64) float32 { return float32(math.Log(p)) }
// TestSampleSingleSlotOptions pins the per-slot behavior of each Options
// knob against a concrete expected token. Expected values are worked out by
// hand from the math of each transform, not from a second call into the
// sampler — so a regression in any single transform shows up here.
func TestSampleSingleSlotOptions(t *testing.T) {
cases := []struct {
name string
opts Options
priors []int32
logits []float32
want int
}{
{
name: "presence penalty",
opts: Options{RepeatLastN: 1, PresencePenalty: 6},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // token 1: 5 - 6 = -1, argmax shifts to 2
},
{
name: "repeat penalty on positive logits",
opts: Options{RepeatLastN: 1, RepeatPenalty: 2},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // token 1 positive → divided: 5/2 = 2.5, argmax shifts to 2
},
{
name: "repeat penalty on negative logits",
opts: Options{RepeatLastN: 1, RepeatPenalty: 4},
priors: []int32{1},
logits: []float32{-5, -1, -3},
want: 2, // token 1 negative → multiplied: -1*4 = -4, argmax shifts to 2
},
{
name: "frequency penalty",
opts: Options{RepeatLastN: 4, FrequencyPenalty: 2},
priors: []int32{1, 1},
logits: []float32{0, 5, 4},
want: 2, // 5 - 2*count(1)=2*2=4 → 1, argmax shifts to 2
},
{
name: "top-k",
opts: Options{Temperature: 1, TopK: 1},
logits: []float32{1, 5, 4},
want: 1, // only argmax survives → deterministic even with temperature
},
{
name: "top-p",
opts: Options{Temperature: 1, TopP: 0.4},
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
want: 0, // exclusive cumsum below 0.4 keeps only token 0
},
{
name: "min-p",
opts: Options{Temperature: 1, MinP: 0.7},
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
want: 0, // threshold 0.5*0.7=0.35 drops all but the top token
},
{
name: "RepeatLastN=0 disables penalties",
opts: Options{RepeatLastN: 0, RepeatPenalty: 2, PresencePenalty: 10},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 1, // 0 = disabled per API contract, argmax unchanged
},
{
name: "RepeatLastN=-1 resolves to num_ctx",
opts: Options{RepeatLastN: -1, PresencePenalty: 6},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // -1 → num_ctx (128); penalty applies, argmax shifts
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := sampleOne(t, tc.opts, tc.priors, tc.logits); got != tc.want {
t.Errorf("got %d, want %d", got, tc.want)
}
})
}
}
func TestMinPMasksTokensBelowThreshold(t *testing.T) {
s := New(0, 0, 0.5, 0, 0, 0)
defer func() {
// TestSampleHistoryWindow verifies that penalty history respects the
// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
// and once the ring wraps, tokens that rotate out no longer contribute
// to penalties.
func TestSampleHistoryWindow(t *testing.T) {
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
}()
})
logprobs := mlx.FromValues([]float32{
float32(math.Log(0.5)),
float32(math.Log(0.3)),
float32(math.Log(0.2)),
}, 3)
got := minP(s, logprobs)
mlx.Eval(got)
// RepeatLastN=2 with priors {1, 2, 3}: makeHistoryRow keeps only
// {2, 3}. Token 1 was trimmed — its penalty is NOT active.
s.Add(0, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 2, 3})
gotFloats := got.Floats()
if len(gotFloats) != 3 {
t.Fatalf("got %d scores, want 3", len(gotFloats))
// Step 1: logits favor token 1 (trimmed). If the trim were broken it
// would be penalized and the argmax would move.
step1 := s.Sample([]int{0}, slotLogits([]float32{0, 5, 0, 0, 0})).Token
mlx.Eval(step1)
if got := step1.Int(); got != 1 {
t.Fatalf("step 1 = %d, want 1 (token 1 trimmed from priors)", got)
}
// After step 1 the ring holds {1, 3}; token 2 has rotated out.
if math.IsInf(float64(gotFloats[0]), -1) || math.IsInf(float64(gotFloats[1]), -1) {
t.Fatalf("kept tokens were masked: %v", gotFloats)
}
if !math.IsInf(float64(gotFloats[2]), -1) {
t.Fatalf("lowest-probability token should be masked, got %v", gotFloats)
// Step 2: logits favor token 2 (rotated out). If the ring wrap were
// wrong, token 2 would still be penalized.
step2 := s.Sample([]int{0}, slotLogits([]float32{0, 0, 5, 0, 0})).Token
mlx.Eval(step2)
if got := step2.Int(); got != 2 {
t.Fatalf("step 2 = %d, want 2 (token 2 rotated out of ring)", got)
}
}
// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
// for every representative dispatch branch (uniform, serial on mixed opts,
// serial on partial ring, subset/out-of-order), a batched Sample call must
// produce the same token per row as running the same slot alone.
func TestBatchSamplingPreservesPerSlotBehavior(t *testing.T) {
type slot struct {
id int
opts Options
priors []int32
}
cases := []struct {
name string
slots []slot
sample []int
rows [][]float32
}{
{
name: "uniform",
slots: []slot{
{10, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{1, 2}},
{20, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{0, 2}},
},
sample: []int{10, 20},
rows: [][]float32{{0, 5, 4}, {3, 0, 0}},
},
{
name: "serial — mixed opts",
slots: []slot{
{1, Options{RepeatLastN: 1, RepeatPenalty: 2}, []int32{1}},
{2, Options{Temperature: 1, TopK: 1}, nil},
},
sample: []int{1, 2},
rows: [][]float32{{0, 5, 4, 1}, {2, 1, 5, 3}},
},
{
name: "serial — partial ring",
slots: []slot{
{1, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{1, 1, 1, 1}},
{2, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{2}},
},
sample: []int{1, 2},
rows: [][]float32{{0, 5, 4}, {0, 4, 5}},
},
{
name: "subset out-of-order",
slots: []slot{
{10, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 1}},
{20, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{2, 2}},
{30, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{3, 3}},
},
sample: []int{30, 10},
rows: [][]float32{{5, 5, 5, 0, 5, 5}, {5, 0, 5, 5, 0, 5}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Per-slot reference for each sampled seq.
want := make([]int, len(tc.sample))
for i, id := range tc.sample {
var spec slot
for _, s := range tc.slots {
if s.id == id {
spec = s
break
}
}
want[i] = sampleOne(t, spec.opts, spec.priors, tc.rows[i])
}
// Batched call.
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
for _, spec := range tc.slots {
s.Add(spec.id, spec.opts, spec.priors)
}
res := s.Sample(tc.sample, batchLogits(tc.rows...))
mlx.Eval(res.Token)
got := res.Token.Ints()
for i, id := range tc.sample {
if got[i] != want[i] {
t.Errorf("seq %d: batched = %d, per-slot = %d", id, got[i], want[i])
}
}
})
}
}
// TestRemoveDoesNotLeakHistory: after Remove, a newly-added slot at the
// recycled row must start from its own priors only — no carryover from
// the removed slot's history.
func TestRemoveDoesNotLeakHistory(t *testing.T) {
opts := Options{RepeatLastN: 1, PresencePenalty: 10}
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(1, opts, []int32{1})
s.Add(2, opts, []int32{2})
s.Remove(1)
s.Add(3, opts, []int32{0})
// Slot 2 retains history {2}; slot 3 retains history {0}. With
// equal logits and PresencePenalty=10 the argmax drops to the first
// unpenalized token.
res := s.Sample([]int{2, 3}, batchLogits(
[]float32{3, 3, 0},
[]float32{3, 3, 0},
))
mlx.Eval(res.Token)
tokens := res.Token.Ints()
if tokens[0] != 0 {
t.Errorf("slot 2 = %d, want 0 (token 2 penalized)", tokens[0])
}
if tokens[1] != 1 {
t.Errorf("slot 3 = %d, want 1 (token 0 penalized, no slot-1 carryover)", tokens[1])
}
}

View File

@@ -2,7 +2,6 @@ package mlxrunner
import (
"bytes"
"cmp"
"context"
"encoding/json"
"flag"
@@ -87,23 +86,30 @@ func Execute(args []string) error {
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New(
request.Options.Temperature,
request.Options.TopP,
request.Options.MinP,
request.Options.TopK,
request.Options.RepeatLastN,
request.Options.PresencePenalty,
)
request.SamplerOpts = sample.Options{
Temperature: request.Options.Temperature,
TopP: request.Options.TopP,
MinP: request.Options.MinP,
TopK: request.Options.TopK,
RepeatLastN: request.Options.RepeatLastN,
RepeatPenalty: request.Options.RepeatPenalty,
PresencePenalty: request.Options.PresencePenalty,
FrequencyPenalty: request.Options.FrequencyPenalty,
Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs,
}
if err := runner.Prepare(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())

View File

@@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) {
gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg)
gotInds = gotInds.AsType(mlx.DTypeInt32)
wantInds = wantInds.AsType(mlx.DTypeInt32)
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {

View File

@@ -161,21 +161,21 @@ type MoEGate struct {
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
gates := g.Gate.Forward(x)
scores := mlx.Sigmoid(gates)
origScores := scores
var origScores, negScores *mlx.Array
if g.EScoreCorrectionBias != nil {
scores = mlx.Add(scores, g.EScoreCorrectionBias)
origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
} else {
origScores = mlx.Sigmoid(gates)
negScores = mlx.Neg(origScores)
}
topK := cfg.NumExpertsPerTok
negScores := mlx.Neg(scores)
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
dims := inds.Dims()
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
scores = mlx.TakeAlongAxis(origScores, inds, -1)
scores := mlx.TakeAlongAxis(origScores, inds, -1)
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)

View File

@@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input)
dOut := NewLinear(dequantizedWeight, nil).Forward(input)
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
mlx.Eval(qOut, dOut)
got := qOut.Floats()