Compare commits

...

12 Commits

Author SHA1 Message Date
Eva Ho
0cd8a0a442 launch: add codex model metadata catalog 2026-04-23 17:09:41 -04:00
Eva H
85ff8e4a21 launch: keep launch recommended models in a fixed canonical order (#15750) 2026-04-23 16:33:00 -04:00
Parth Sareen
160660e572 launch: use bundled OpenClaw ollama web search (#15757) 2026-04-22 16:34:19 -07:00
madflow
3b43b9bc4b docs: update structured outputs doc for cloud (#15733)
---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2026-04-22 00:42:39 -07:00
Parth Sareen
21883571b7 launch: replace kimi-k2.5 with k2.6 as top recommended model (#15737) 2026-04-21 15:13:20 -07:00
Jesse Gross
ce99f24731 mlxrunner: tokenize prompts in request handler goroutines
Move tokenization out of the single GPU processing goroutine and
into each request's HTTP handler goroutine. This allows the next
request's prompt to be tokenized on the CPU while the current
request is executing on the GPU.
2026-04-21 14:38:49 -07:00
Jesse Gross
04f5f0cdb4 mlx: improve thread safety of array management
Use atomic.Int32 for Array.pinned and a sync.Mutex for the global
arrays slice so MLX arrays can be created and pinned from multiple
goroutines without racing on those structures. Convert Array value
receivers to pointer receivers and struct fields from Array to
*Array to avoid copying the atomic.

This does not fully achieve thread safety even when building
completely independent graphs. The tracing flag and traceScratch
slice in compile.go are unprotected, so concurrent Compile calls
will race. MLX itself is not fully thread-safe either although
it is working to improve.
2026-04-21 14:38:49 -07:00
Matteo Celani
fb36a01ffe app/ui: fix model picker showing stale model after switching chats (#15280)
* app/ui: fix model picker showing stale model after switching chats

Optimistic messages created during streaming were storing the full
Model object instead of the model name string. When switching back
to a chat with cached streaming data, the restore effect read an
object where it expected a string, causing the model picker to fail
matching and remain stuck on the previous chat's model.

* app/ui: fix two more instances of Model object passed as model name

Fix the same bug at lines 523 and 536 in the assistant_with_tools
event handler, where selectedModel (object) was used instead of
selectedModel.model (string).
2026-04-21 15:08:06 -04:00
Michael Verrilli
0c65ed33bc cmd: populate model capabilities in launchInteractiveModel (#15712)
launchInteractiveModel was introduced in PR #14609 without the
client.Show() capability-detection block that RunHandler uses.
This left opts.MultiModal always false in the TUI path, causing
image/audio file paths to always be treated as unknown commands
instead of being loaded as multimodal attachments.

Mirror the Show() call, pull-on-404 fallback, cloud auth handling,
and MultiModal/Think population from RunHandler into
launchInteractiveModel.

Fixes #15711
2026-04-21 14:37:36 -04:00
Jesse Gross
22d6c817f8 mlxrunner: fuse top-P and top-K into a single sort pass
When both filters are active, avoid paying for a full sort in top-P
and a partial sort in top-K. Single-filter paths are unchanged.
Improves generation throughput on gemma4:e4b by 1.5%.
2026-04-20 17:43:00 -07:00
Jesse Gross
ca01373b28 mlxrunner: use MaxAxis in the min-P sampler
One reduction op instead of Argmax + TakeAlongAxis.
2026-04-20 17:43:00 -07:00
Jesse Gross
24e038d56a mlxrunner: add logprobs support
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax
with the top-K ranked by probability. Skipped on the GPU when the request
doesn't ask for logprobs so decode doesn't pay for it otherwise.
2026-04-20 17:43:00 -07:00
26 changed files with 1246 additions and 579 deletions

View File

@@ -381,7 +381,7 @@ export const useSendMessage = (chatId: string) => {
role: "assistant", role: "assistant",
content: "", content: "",
thinking: "", thinking: "",
model: effectiveModel, model: effectiveModel.model,
}), }),
); );
lastMessage = newMessages[newMessages.length - 1]; lastMessage = newMessages[newMessages.length - 1];
@@ -433,7 +433,7 @@ export const useSendMessage = (chatId: string) => {
role: "assistant", role: "assistant",
content: "", content: "",
thinking: "", thinking: "",
model: effectiveModel, model: effectiveModel.model,
}), }),
); );
lastMessage = newMessages[newMessages.length - 1]; lastMessage = newMessages[newMessages.length - 1];
@@ -520,7 +520,7 @@ export const useSendMessage = (chatId: string) => {
thinkingTimeStart: thinkingTimeStart:
lastMessage.thinkingTimeStart || event.thinkingTimeStart, lastMessage.thinkingTimeStart || event.thinkingTimeStart,
thinkingTimeEnd: event.thinkingTimeEnd, thinkingTimeEnd: event.thinkingTimeEnd,
model: selectedModel, model: selectedModel.model,
}); });
newMessages[newMessages.length - 1] = updatedMessage; newMessages[newMessages.length - 1] = updatedMessage;
} else { } else {
@@ -533,7 +533,7 @@ export const useSendMessage = (chatId: string) => {
tool_calls: event.toolCalls, tool_calls: event.toolCalls,
thinkingTimeStart: event.thinkingTimeStart, thinkingTimeStart: event.thinkingTimeStart,
thinkingTimeEnd: event.thinkingTimeEnd, thinkingTimeEnd: event.thinkingTimeEnd,
model: selectedModel, model: selectedModel.model,
}), }),
); );
} }
@@ -699,7 +699,7 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["chat", newId], { queryClient.setQueryData(["chat", newId], {
chat: new Chat({ chat: new Chat({
id: newId, id: newId,
model: effectiveModel, model: effectiveModel.model,
messages: [ messages: [
new Message({ new Message({
role: "user", role: "user",

View File

@@ -1975,8 +1975,61 @@ func launchInteractiveModel(cmd *cobra.Command, modelName string) error {
Options: map[string]any{}, Options: map[string]any{},
ShowConnect: true, ShowConnect: true,
} }
// loadOrUnloadModel is cloud-safe here: remote/cloud models skip local preload
// and only validate auth/connectivity before interactive chat starts. client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
requestedCloud := modelref.HasExplicitCloudSource(modelName)
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: modelName}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if requestedCloud {
return nil, err
}
if err := PullHandler(cmd, []string{modelName}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: modelName})
}
return info, err
}()
if err != nil {
if handleCloudAuthorizationError(err) {
return nil
}
return err
}
ensureCloudStub(cmd.Context(), client, modelName)
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, false)
if err != nil {
return err
}
audioCapable := slices.Contains(info.Capabilities, model.CapabilityAudio)
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) || audioCapable
// TODO: remove the projector info and vision info checks below,
// these are left in for backwards compatibility with older servers
// that don't have the capabilities field in the model info
if len(info.ProjectorInfo) != 0 {
opts.MultiModal = true
}
for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
applyShowResponseToRunOptions(&opts, info)
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return fmt.Errorf("error loading model: %w", err) return fmt.Errorf("error loading model: %w", err)
} }

View File

@@ -1,13 +1,20 @@
package launch package launch
import ( import (
"context"
"encoding/json"
"fmt" "fmt"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"slices"
"strconv"
"strings" "strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
) )
@@ -32,7 +39,7 @@ func (c *Codex) Run(model string, args []string) error {
return err return err
} }
if err := ensureCodexConfig(); err != nil { if err := ensureCodexConfig(model); err != nil {
return fmt.Errorf("failed to configure codex: %w", err) return fmt.Errorf("failed to configure codex: %w", err)
} }
@@ -46,9 +53,9 @@ func (c *Codex) Run(model string, args []string) error {
return cmd.Run() return cmd.Run()
} }
// ensureCodexConfig writes a [profiles.ollama-launch] section to ~/.codex/config.toml // ensureCodexConfig writes a Codex profile and model catalog so Codex uses the
// with openai_base_url pointing to the local Ollama server. // local Ollama server and has model metadata available.
func ensureCodexConfig() error { func ensureCodexConfig(modelName string) error {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return err return err
@@ -59,13 +66,18 @@ func ensureCodexConfig() error {
return err return err
} }
catalogPath := filepath.Join(codexDir, "model.json")
if err := writeCodexModelCatalog(catalogPath, modelName); err != nil {
return err
}
configPath := filepath.Join(codexDir, "config.toml") configPath := filepath.Join(codexDir, "config.toml")
return writeCodexProfile(configPath) return writeCodexProfile(configPath, catalogPath)
} }
// writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile // writeCodexProfile ensures ~/.codex/config.toml has the ollama-launch profile
// and model provider sections with the correct base URL. // and model provider sections with the correct base URL.
func writeCodexProfile(configPath string) error { func writeCodexProfile(configPath, catalogPath string) error {
baseURL := envconfig.Host().String() + "/v1/" baseURL := envconfig.Host().String() + "/v1/"
sections := []struct { sections := []struct {
@@ -78,6 +90,7 @@ func writeCodexProfile(configPath string) error {
fmt.Sprintf("openai_base_url = %q", baseURL), fmt.Sprintf("openai_base_url = %q", baseURL),
`forced_login_method = "api"`, `forced_login_method = "api"`,
fmt.Sprintf("model_provider = %q", codexProfileName), fmt.Sprintf("model_provider = %q", codexProfileName),
fmt.Sprintf("model_catalog_json = %q", catalogPath),
}, },
}, },
{ {
@@ -121,6 +134,110 @@ func writeCodexProfile(configPath string) error {
return os.WriteFile(configPath, []byte(text), 0o644) return os.WriteFile(configPath, []byte(text), 0o644)
} }
func writeCodexModelCatalog(catalogPath, modelName string) error {
entry := buildCodexModelEntry(modelName)
catalog := map[string]any{
"models": []any{entry},
}
data, err := json.MarshalIndent(catalog, "", " ")
if err != nil {
return err
}
return os.WriteFile(catalogPath, data, 0o644)
}
func buildCodexModelEntry(modelName string) map[string]any {
contextWindow := 0
hasVision := false
hasThinking := false
systemPrompt := ""
if l, ok := lookupCloudModelLimit(modelName); ok {
contextWindow = l.Context
}
client := api.NewClient(envconfig.Host(), http.DefaultClient)
resp, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName})
if err == nil {
systemPrompt = resp.System
if slices.Contains(resp.Capabilities, model.CapabilityVision) {
hasVision = true
}
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
hasThinking = true
}
if !isCloudModelName(modelName) {
if n, ok := modelInfoContextLength(resp.ModelInfo); ok {
contextWindow = n
}
if resp.Details.Format != "safetensors" {
if ctxLen := envconfig.ContextLength(); ctxLen > 0 {
contextWindow = int(ctxLen)
}
if numCtx := parseNumCtx(resp.Parameters); numCtx > 0 {
contextWindow = numCtx
}
}
}
}
modalities := []string{"text"}
if hasVision {
modalities = append(modalities, "image")
}
reasoningLevels := []any{}
if hasThinking {
reasoningLevels = []any{
map[string]any{"effort": "low", "description": "Fast responses with lighter reasoning"},
map[string]any{"effort": "medium", "description": "Balances speed and reasoning depth"},
map[string]any{"effort": "high", "description": "Greater reasoning depth for complex problems"},
}
}
truncationMode := "bytes"
if isCloudModelName(modelName) {
truncationMode = "tokens"
}
return map[string]any{
"slug": modelName,
"display_name": modelName,
"context_window": contextWindow,
"apply_patch_tool_type": "function",
"shell_type": "default",
"visibility": "list",
"supported_in_api": true,
"priority": 0,
"truncation_policy": map[string]any{"mode": truncationMode, "limit": 10000},
"input_modalities": modalities,
"base_instructions": systemPrompt,
"support_verbosity": true,
"default_verbosity": "low",
"supports_parallel_tool_calls": false,
"supports_reasoning_summaries": hasThinking,
"supported_reasoning_levels": reasoningLevels,
"experimental_supported_tools": []any{},
}
}
func parseNumCtx(parameters string) int {
for _, line := range strings.Split(parameters, "\n") {
fields := strings.Fields(line)
if len(fields) == 2 && fields[0] == "num_ctx" {
if v, err := strconv.ParseFloat(fields[1], 64); err == nil {
return int(v)
}
}
}
return 0
}
func checkCodexVersion() error { func checkCodexVersion() error {
if _, err := exec.LookPath("codex"); err != nil { if _, err := exec.LookPath("codex"); err != nil {
return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex") return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")

View File

@@ -1,6 +1,10 @@
package launch package launch
import ( import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
@@ -37,8 +41,9 @@ func TestWriteCodexProfile(t *testing.T) {
t.Run("creates new file when none exists", func(t *testing.T) { t.Run("creates new file when none exists", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml") configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
if err := writeCodexProfile(configPath); err != nil { if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -63,6 +68,9 @@ func TestWriteCodexProfile(t *testing.T) {
if !strings.Contains(content, `model_provider = "ollama-launch"`) { if !strings.Contains(content, `model_provider = "ollama-launch"`) {
t.Error("missing model_provider key") t.Error("missing model_provider key")
} }
if !strings.Contains(content, fmt.Sprintf("model_catalog_json = %q", catalogPath)) {
t.Error("missing model_catalog_json key")
}
if !strings.Contains(content, "[model_providers.ollama-launch]") { if !strings.Contains(content, "[model_providers.ollama-launch]") {
t.Error("missing [model_providers.ollama-launch] section") t.Error("missing [model_providers.ollama-launch] section")
} }
@@ -74,10 +82,11 @@ func TestWriteCodexProfile(t *testing.T) {
t.Run("appends profile to existing file without profile", func(t *testing.T) { t.Run("appends profile to existing file without profile", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml") configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[some_other_section]\nkey = \"value\"\n" existing := "[some_other_section]\nkey = \"value\"\n"
os.WriteFile(configPath, []byte(existing), 0o644) os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil { if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -95,10 +104,11 @@ func TestWriteCodexProfile(t *testing.T) {
t.Run("replaces existing profile section", func(t *testing.T) { t.Run("replaces existing profile section", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml") configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n" existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n\n[model_providers.ollama-launch]\nname = \"Ollama\"\nbase_url = \"http://old:1234/v1/\"\n"
os.WriteFile(configPath, []byte(existing), 0o644) os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil { if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -119,10 +129,11 @@ func TestWriteCodexProfile(t *testing.T) {
t.Run("replaces profile while preserving following sections", func(t *testing.T) { t.Run("replaces profile while preserving following sections", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml") configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n" existing := "[profiles.ollama-launch]\nopenai_base_url = \"http://old:1234/v1/\"\n[another_section]\nfoo = \"bar\"\n"
os.WriteFile(configPath, []byte(existing), 0o644) os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil { if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -143,10 +154,11 @@ func TestWriteCodexProfile(t *testing.T) {
t.Run("appends newline to file not ending with newline", func(t *testing.T) { t.Run("appends newline to file not ending with newline", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml") configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
existing := "[other]\nkey = \"val\"" existing := "[other]\nkey = \"val\""
os.WriteFile(configPath, []byte(existing), 0o644) os.WriteFile(configPath, []byte(existing), 0o644)
if err := writeCodexProfile(configPath); err != nil { if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -166,8 +178,9 @@ func TestWriteCodexProfile(t *testing.T) {
t.Setenv("OLLAMA_HOST", "http://myhost:9999") t.Setenv("OLLAMA_HOST", "http://myhost:9999")
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.toml") configPath := filepath.Join(tmpDir, "config.toml")
catalogPath := filepath.Join(tmpDir, "model.json")
if err := writeCodexProfile(configPath); err != nil { if err := writeCodexProfile(configPath, catalogPath); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -185,7 +198,7 @@ func TestEnsureCodexConfig(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
if err := ensureCodexConfig(); err != nil { if err := ensureCodexConfig("llama3.2"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -202,16 +215,25 @@ func TestEnsureCodexConfig(t *testing.T) {
if !strings.Contains(content, "openai_base_url") { if !strings.Contains(content, "openai_base_url") {
t.Error("missing openai_base_url key") t.Error("missing openai_base_url key")
} }
catalogPath := filepath.Join(tmpDir, ".codex", "model.json")
data, err = os.ReadFile(catalogPath)
if err != nil {
t.Fatalf("model.json not created: %v", err)
}
if !strings.Contains(string(data), `"slug": "llama3.2"`) {
t.Error("missing model catalog entry for selected model")
}
}) })
t.Run("is idempotent", func(t *testing.T) { t.Run("is idempotent", func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
setTestHome(t, tmpDir) setTestHome(t, tmpDir)
if err := ensureCodexConfig(); err != nil { if err := ensureCodexConfig("llama3.2"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := ensureCodexConfig(); err != nil { if err := ensureCodexConfig("llama3.2"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -227,3 +249,204 @@ func TestEnsureCodexConfig(t *testing.T) {
} }
}) })
} }
func TestParseNumCtx(t *testing.T) {
tests := []struct {
name string
parameters string
want int
}{
{"num_ctx set", "num_ctx 8192", 8192},
{"num_ctx with other params", "temperature 0.7\nnum_ctx 4096\ntop_p 0.9", 4096},
{"no num_ctx", "temperature 0.7\ntop_p 0.9", 0},
{"empty string", "", 0},
{"malformed value", "num_ctx abc", 0},
{"float value", "num_ctx 8192.0", 8192},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNumCtx(tt.parameters); got != tt.want {
t.Errorf("parseNumCtx(%q) = %d, want %d", tt.parameters, got, tt.want)
}
})
}
}
func TestModelInfoContextLength(t *testing.T) {
tests := []struct {
name string
modelInfo map[string]any
want int
}{
{"float64 value", map[string]any{"qwen3_5_moe.context_length": float64(262144)}, 262144},
{"int value", map[string]any{"llama.context_length": 131072}, 131072},
{"no context_length key", map[string]any{"llama.embedding_length": float64(4096)}, 0},
{"empty map", map[string]any{}, 0},
{"nil map", nil, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := modelInfoContextLength(tt.modelInfo)
if got != tt.want {
t.Errorf("modelInfoContextLength() = %d, want %d", got, tt.want)
}
})
}
}
func TestBuildCodexModelEntryContextWindow(t *testing.T) {
tests := []struct {
name string
modelName string
showResponse string
envContextLen string
wantContext int
}{
{
name: "architectural context length as fallback",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"}
}`,
wantContext: 131072,
},
{
name: "OLLAMA_CONTEXT_LENGTH overrides architectural",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"}
}`,
envContextLen: "64000",
wantContext: 64000,
},
{
name: "num_ctx overrides OLLAMA_CONTEXT_LENGTH",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"parameters": "num_ctx 8192",
"details": {"format": "gguf"}
}`,
envContextLen: "64000",
wantContext: 8192,
},
{
name: "num_ctx overrides architectural",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"parameters": "num_ctx 32768",
"details": {"format": "gguf"}
}`,
wantContext: 32768,
},
{
name: "safetensors uses architectural context only",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"parameters": "num_ctx 8192",
"details": {"format": "safetensors"}
}`,
envContextLen: "64000",
wantContext: 131072,
},
{
name: "cloud model uses hardcoded limits",
modelName: "qwen3.5:cloud",
showResponse: `{
"model_info": {"qwen3_5_moe.context_length": 131072},
"details": {"format": "gguf"}
}`,
envContextLen: "64000",
wantContext: 262144,
},
{
name: "vision and thinking capabilities",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"},
"capabilities": ["vision", "thinking"]
}`,
wantContext: 131072,
},
{
name: "system prompt passed through",
modelName: "llama3.2",
showResponse: `{
"model_info": {"llama.context_length": 131072},
"details": {"format": "gguf"},
"system": "You are a helpful assistant."
}`,
wantContext: 131072,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/show":
fmt.Fprint(w, tt.showResponse)
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
if tt.envContextLen != "" {
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
} else {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "")
}
entry := buildCodexModelEntry(tt.modelName)
gotContext, _ := entry["context_window"].(int)
if gotContext != tt.wantContext {
t.Errorf("context_window = %d, want %d", gotContext, tt.wantContext)
}
if tt.name == "vision and thinking capabilities" {
modalities, _ := entry["input_modalities"].([]string)
if !slices.Contains(modalities, "image") {
t.Error("expected image in input_modalities")
}
levels, _ := entry["supported_reasoning_levels"].([]any)
if len(levels) == 0 {
t.Error("expected non-empty supported_reasoning_levels")
}
}
if tt.name == "system prompt passed through" {
if got, _ := entry["base_instructions"].(string); got != "You are a helpful assistant." {
t.Errorf("base_instructions = %q, want %q", got, "You are a helpful assistant.")
}
}
if tt.name == "cloud model uses hardcoded limits" {
truncationPolicy, _ := entry["truncation_policy"].(map[string]any)
if mode, _ := truncationPolicy["mode"].(string); mode != "tokens" {
t.Errorf("truncation_policy mode = %q, want %q", mode, "tokens")
}
}
requiredKeys := []string{"slug", "display_name", "apply_patch_tool_type", "shell_type"}
for _, key := range requiredKeys {
if _, ok := entry[key]; !ok {
t.Errorf("missing required key %q", key)
}
}
if _, err := json.Marshal(entry); err != nil {
t.Errorf("entry is not JSON serializable: %v", err)
}
})
}
}

View File

@@ -301,7 +301,7 @@ func TestParseArgs(t *testing.T) {
func TestIsCloudModel(t *testing.T) { func TestIsCloudModel(t *testing.T) {
// isCloudModel now only uses Show API, so nil client always returns false // isCloudModel now only uses Show API, so nil client always returns false
t.Run("nil client returns false", func(t *testing.T) { t.Run("nil client returns false", func(t *testing.T) {
models := []string{"glm-5.1:cloud", "kimi-k2.5:cloud", "local-model"} models := []string{"glm-5.1:cloud", "kimi-k2.6:cloud", "local-model"}
for _, model := range models { for _, model := range models {
if isCloudModel(context.Background(), nil, model) { if isCloudModel(context.Background(), nil, model) {
t.Errorf("isCloudModel(%q) with nil client should return false", model) t.Errorf("isCloudModel(%q) with nil client should return false", model)
@@ -318,10 +318,18 @@ func names(items []ModelItem) []string {
return out return out
} }
func recommendedNames(extra ...string) []string {
out := make([]string, 0, len(recommendedModels)+len(extra))
for _, item := range recommendedModels {
out = append(out, item.Name)
}
return append(out, extra...)
}
func TestBuildModelList_NoExistingModels(t *testing.T) { func TestBuildModelList_NoExistingModels(t *testing.T) {
items, _, _, _ := buildModelList(nil, nil, "") items, _, _, _ := buildModelList(nil, nil, "")
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5"} want := recommendedNames()
if diff := cmp.Diff(want, names(items)); diff != "" { if diff := cmp.Diff(want, names(items)); diff != "" {
t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff) t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
} }
@@ -350,7 +358,7 @@ func TestBuildModelList_OnlyLocalModels_CloudRecsStillFirst(t *testing.T) {
// Cloud recs always come first among recommended, regardless of installed inventory. // Cloud recs always come first among recommended, regardless of installed inventory.
// Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems. // Cloud disablement is handled upstream in loadSelectableModels via filterCloudItems.
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2", "qwen2.5"} want := recommendedNames("llama3.2", "qwen2.5")
if diff := cmp.Diff(want, got); diff != "" { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff) t.Errorf("cloud recs pinned first even when no cloud models installed (-want +got):\n%s", diff)
} }
@@ -366,13 +374,13 @@ func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
got := names(items) got := names(items)
// All recs pinned at top (cloud before local in mixed case), then non-recs // All recs pinned at top (cloud before local in mixed case), then non-recs
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2"} want := recommendedNames("llama3.2")
if diff := cmp.Diff(want, got); diff != "" { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff) t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
} }
} }
func TestBuildModelList_PreCheckedFirst(t *testing.T) { func TestBuildModelList_PreCheckedNonRecommendedFirstInMore(t *testing.T) {
existing := []modelInfo{ existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false}, {Name: "llama3.2:latest", Remote: false},
{Name: "glm-5.1:cloud", Remote: true}, {Name: "glm-5.1:cloud", Remote: true},
@@ -381,8 +389,9 @@ func TestBuildModelList_PreCheckedFirst(t *testing.T) {
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "") items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
got := names(items) got := names(items)
if got[0] != "llama3.2" { want := recommendedNames("llama3.2")
t.Errorf("pre-checked model should be first, got %v", got) if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recommended block should stay fixed while checked non-recommended models lead More (-want +got):\n%s", diff)
} }
} }
@@ -437,7 +446,7 @@ func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
if !strings.HasSuffix(item.Description, "(not downloaded)") { if !strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description) t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
} }
case "minimax-m2.7:cloud", "kimi-k2.5:cloud", "qwen3.5:cloud": case "minimax-m2.7:cloud", "kimi-k2.6:cloud", "qwen3.5:cloud":
if strings.HasSuffix(item.Description, "(not downloaded)") { if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) t.Errorf("cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
} }
@@ -455,9 +464,9 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
got := names(items) got := names(items)
// gemma4 and glm-5.1:cloud are installed so they sort normally; // gemma4 and glm-5.1:cloud are installed so they sort normally;
// kimi-k2.5:cloud, qwen3.5:cloud, and qwen3.5 are not installed so they go to the bottom // qwen3.5:cloud and qwen3.5 are not installed so they go to the bottom
// All recs: cloud first in mixed case, then local, in rec order within each // All recs: cloud first in mixed case, then local, in rec order within each
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5"} want := recommendedNames()
if diff := cmp.Diff(want, got); diff != "" { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff) t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
} }
@@ -466,23 +475,23 @@ func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) { func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
existing := []modelInfo{ existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false}, {Name: "llama3.2:latest", Remote: false},
{Name: "kimi-k2.5:cloud", Remote: true}, {Name: "kimi-k2.6:cloud", Remote: true},
} }
items, _, _, _ := buildModelList(existing, nil, "") items, _, _, _ := buildModelList(existing, nil, "")
got := names(items) got := names(items)
// kimi-k2.5:cloud is installed so it sorts normally; // kimi-k2.6:cloud is installed so it sorts normally;
// the rest of the recommendations are not installed so they go to the bottom // the rest of the recommendations are not installed so they go to the bottom
// All recs pinned at top (cloud first in mixed case), then non-recs // All recs pinned at top (cloud first in mixed case), then non-recs
want := []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud", "gemma4", "qwen3.5", "llama3.2"} want := recommendedNames("llama3.2")
if diff := cmp.Diff(want, got); diff != "" { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff) t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
} }
for _, item := range items { for _, item := range items {
isCloud := strings.HasSuffix(item.Name, ":cloud") isCloud := strings.HasSuffix(item.Name, ":cloud")
isInstalled := slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) isInstalled := slices.Contains([]string{"kimi-k2.6:cloud", "llama3.2"}, item.Name)
if isInstalled || isCloud { if isInstalled || isCloud {
if strings.HasSuffix(item.Description, "(not downloaded)") { if strings.HasSuffix(item.Description, "(not downloaded)") {
t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description) t.Errorf("installed or cloud model %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
@@ -549,8 +558,8 @@ func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
if !cloudModels["glm-5.1:cloud"] { if !cloudModels["glm-5.1:cloud"] {
t.Error("glm-5.1:cloud should be in cloudModels") t.Error("glm-5.1:cloud should be in cloudModels")
} }
if !cloudModels["kimi-k2.5:cloud"] { if !cloudModels["kimi-k2.6:cloud"] {
t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)") t.Error("kimi-k2.6:cloud should be in cloudModels (recommended cloud)")
} }
if !cloudModels["qwen3.5:cloud"] { if !cloudModels["qwen3.5:cloud"] {
t.Error("qwen3.5:cloud should be in cloudModels (recommended cloud)") t.Error("qwen3.5:cloud should be in cloudModels (recommended cloud)")
@@ -570,7 +579,7 @@ func TestBuildModelList_RecommendedFieldSet(t *testing.T) {
for _, item := range items { for _, item := range items {
switch item.Name { switch item.Name {
case "gemma4", "qwen3.5", "glm-5.1:cloud", "kimi-k2.5:cloud", "qwen3.5:cloud": case "gemma4", "qwen3.5", "glm-5.1:cloud", "kimi-k2.6:cloud", "qwen3.5:cloud":
if !item.Recommended { if !item.Recommended {
t.Errorf("%q should have Recommended=true", item.Name) t.Errorf("%q should have Recommended=true", item.Name)
} }
@@ -628,7 +637,7 @@ func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
lastRecIdx := -1 lastRecIdx := -1
firstNonRecIdx := len(got) firstNonRecIdx := len(got)
for i, name := range got { for i, name := range got {
isRec := name == "gemma4" || name == "qwen3.5" || name == "minimax-m2.7:cloud" || name == "glm-5.1:cloud" || name == "kimi-k2.5:cloud" || name == "qwen3.5:cloud" isRec := name == "gemma4" || name == "qwen3.5" || name == "minimax-m2.7:cloud" || name == "glm-5.1:cloud" || name == "kimi-k2.6:cloud" || name == "qwen3.5:cloud"
if isRec && i > lastRecIdx { if isRec && i > lastRecIdx {
lastRecIdx = i lastRecIdx = i
} }
@@ -641,17 +650,32 @@ func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
} }
} }
func TestBuildModelList_CheckedBeforeRecs(t *testing.T) { func TestBuildModelList_CheckedRecommendedDoesNotReshuffleRecommendedOrder(t *testing.T) {
existing := []modelInfo{ existing := []modelInfo{
{Name: "llama3.2:latest", Remote: false}, {Name: "llama3.2:latest", Remote: false},
{Name: "glm-5.1:cloud", Remote: true}, {Name: "glm-5.1:cloud", Remote: true},
} }
items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "") items, _, _, _ := buildModelList(existing, []string{"qwen3.5:cloud", "glm-5.1:cloud"}, "")
got := names(items) got := names(items)
if got[0] != "llama3.2" { want := recommendedNames("llama3.2")
t.Errorf("checked model should be first even before recs, got %v", got) if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("checked recommended models should not reshuffle the fixed recommended order (-want +got):\n%s", diff)
}
}
func TestBuildModelList_StaleSavedKimiK25DoesNotReshuffleRecommendedOrder(t *testing.T) {
existing := []modelInfo{
{Name: "kimi-k2.5:cloud", Remote: true},
}
items, _, _, _ := buildModelList(existing, []string{"kimi-k2.5:cloud", "qwen3.5:cloud", "glm-5.1:cloud", "minimax-m2.7:cloud"}, "kimi-k2.5:cloud")
got := names(items)
want := recommendedNames("kimi-k2.5:cloud")
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("stale saved kimi-k2.5 should stay in More without reshuffling the fixed recommended order (-want +got):\n%s", diff)
} }
} }

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/cmd/config" "github.com/ollama/ollama/cmd/config"
) )
@@ -1219,8 +1220,9 @@ func TestLaunchIntegration_EditorForceConfigure_FloatsCheckedModelsInPicker(t *t
if len(gotItems) == 0 { if len(gotItems) == 0 {
t.Fatal("expected multi selector to receive items") t.Fatal("expected multi selector to receive items")
} }
if gotItems[0] != "qwen3.5:cloud" { wantItems := recommendedNames()
t.Fatalf("expected checked models floated to top with qwen3.5:cloud first, got %v", gotItems) if diff := cmp.Diff(wantItems, gotItems); diff != "" {
t.Fatalf("expected fixed recommended order in selector items (-want +got):\n%s", diff)
} }
if len(gotPreChecked) < 2 { if len(gotPreChecked) < 2 {
t.Fatalf("expected prechecked models to be preserved, got %v", gotPreChecked) t.Fatalf("expected prechecked models to be preserved, got %v", gotPreChecked)

View File

@@ -21,7 +21,7 @@ import (
) )
var recommendedModels = []ModelItem{ var recommendedModels = []ModelItem{
{Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true}, {Name: "kimi-k2.6:cloud", Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability", Recommended: true},
{Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true}, {Name: "qwen3.5:cloud", Description: "Reasoning, coding, and agentic tool use with vision", Recommended: true},
{Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true}, {Name: "glm-5.1:cloud", Description: "Reasoning and code generation", Recommended: true},
{Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true}, {Name: "minimax-m2.7:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
@@ -56,6 +56,7 @@ var cloudModelLimits = map[string]cloudModelLimit{
"gpt-oss:20b": {Context: 131_072, Output: 131_072}, "gpt-oss:20b": {Context: 131_072, Output: 131_072},
"kimi-k2:1t": {Context: 262_144, Output: 262_144}, "kimi-k2:1t": {Context: 262_144, Output: 262_144},
"kimi-k2.5": {Context: 262_144, Output: 262_144}, "kimi-k2.5": {Context: 262_144, Output: 262_144},
"kimi-k2.6": {Context: 262_144, Output: 262_144},
"kimi-k2-thinking": {Context: 262_144, Output: 262_144}, "kimi-k2-thinking": {Context: 262_144, Output: 262_144},
"nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072}, "nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
"qwen3-coder:480b": {Context: 262_144, Output: 65_536}, "qwen3-coder:480b": {Context: 262_144, Output: 65_536},
@@ -360,18 +361,12 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
} }
if hasLocalModel || hasCloudModel { if hasLocalModel || hasCloudModel {
// Keep the Recommended section pinned to recommendedModels order. Checked
// and default-model priority only apply within the More section.
slices.SortStableFunc(items, func(a, b ModelItem) int { slices.SortStableFunc(items, func(a, b ModelItem) int {
ac, bc := checked[a.Name], checked[b.Name] ac, bc := checked[a.Name], checked[b.Name]
aNew, bNew := notInstalled[a.Name], notInstalled[b.Name] aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0 aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
if ac != bc {
if ac {
return -1
}
return 1
}
if aRec != bRec { if aRec != bRec {
if aRec { if aRec {
return -1 return -1
@@ -379,14 +374,14 @@ func buildModelList(existing []modelInfo, preChecked []string, current string) (
return 1 return 1
} }
if aRec && bRec { if aRec && bRec {
if aCloud != bCloud {
if aCloud {
return -1
}
return 1
}
return recRank[a.Name] - recRank[b.Name] return recRank[a.Name] - recRank[b.Name]
} }
if ac != bc {
if ac {
return -1
}
return 1
}
// Among checked non-recommended items - put the default first // Among checked non-recommended items - put the default first
if ac && !aRec && current != "" { if ac && !aRec && current != "" {
aCurrent := a.Name == current aCurrent := a.Name == current

View File

@@ -14,8 +14,6 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/mod/semver"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/cmd/internal/fileutil" "github.com/ollama/ollama/cmd/internal/fileutil"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -98,9 +96,7 @@ func (c *Openclaw) Run(model string, args []string) error {
patchDeviceScopes() patchDeviceScopes()
} }
if ensureWebSearchPlugin() { configureOllamaWebSearch()
registerWebSearchPlugin()
}
// When extra args are passed through, run exactly what the user asked for // When extra args are passed through, run exactly what the user asked for
// after setup and skip the built-in gateway+TUI convenience flow. // after setup and skip the built-in gateway+TUI convenience flow.
@@ -738,89 +734,13 @@ func clearSessionModelOverride(primary string) {
_ = os.WriteFile(path, out, 0o600) _ = os.WriteFile(path, out, 0o600)
} }
const ( // configureOllamaWebSearch keeps launch-managed OpenClaw installs on the
webSearchNpmPackage = "@ollama/openclaw-web-search" // bundled Ollama web_search provider. Older launch builds installed an
webSearchMinVersion = "0.2.1" // external openclaw-web-search plugin that added custom ollama_web_search and
) // ollama_web_fetch tools. Current OpenClaw versions ship Ollama web_search as
// the bundled "ollama" plugin instead, so we migrate stale config and ensure
// ensureWebSearchPlugin installs the openclaw-web-search extension into the // fresh installs select the bundled provider.
// user-level extensions directory (~/.openclaw/extensions/) if it isn't already func configureOllamaWebSearch() {
// present, or re-installs if the installed version is older than webSearchMinVersion.
// Returns true if the extension is available.
func ensureWebSearchPlugin() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if webSearchPluginUpToDate(pluginDir) {
return true
}
npmBin, err := exec.LookPath("npm")
if err != nil {
return false
}
if err := os.MkdirAll(pluginDir, 0o755); err != nil {
return false
}
// Download the tarball via `npm pack`, extract it flat into the plugin dir.
pack := exec.Command(npmBin, "pack", webSearchNpmPackage, "--pack-destination", pluginDir)
out, err := pack.Output()
if err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not download web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
tgzName := strings.TrimSpace(string(out))
tgzPath := filepath.Join(pluginDir, tgzName)
defer os.Remove(tgzPath)
tar := exec.Command("tar", "xzf", tgzPath, "--strip-components=1", "-C", pluginDir)
if err := tar.Run(); err != nil {
fmt.Fprintf(os.Stderr, "%s Warning: could not extract web search plugin: %v%s\n", ansiYellow, err, ansiReset)
return false
}
fmt.Fprintf(os.Stderr, "%s ✓ Installed Ollama web search %s\n", ansiGreen, ansiReset)
return true
}
// webSearchPluginUpToDate returns true if the plugin is installed and its
// package.json version is >= webSearchMinVersion.
func webSearchPluginUpToDate(pluginDir string) bool {
data, err := os.ReadFile(filepath.Join(pluginDir, "package.json"))
if err != nil {
return false
}
var pkg struct {
Version string `json:"version"`
}
if json.Unmarshal(data, &pkg) != nil || pkg.Version == "" {
return false
}
return !versionLessThan(pkg.Version, webSearchMinVersion)
}
// versionLessThan compares two semver version strings (major.minor.patch).
// Inputs may omit the "v" prefix; it is added automatically for semver.Compare.
func versionLessThan(a, b string) bool {
if !strings.HasPrefix(a, "v") {
a = "v" + a
}
if !strings.HasPrefix(b, "v") {
b = "v" + b
}
return semver.Compare(a, b) < 0
}
// registerWebSearchPlugin adds plugins.entries.openclaw-web-search to the OpenClaw
// config so the gateway activates it on next start. Best-effort; silently returns
// on any error.
func registerWebSearchPlugin() {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return return
@@ -835,6 +755,8 @@ func registerWebSearchPlugin() {
return return
} }
stalePluginConfigured := false
plugins, _ := config["plugins"].(map[string]any) plugins, _ := config["plugins"].(map[string]any)
if plugins == nil { if plugins == nil {
plugins = make(map[string]any) plugins = make(map[string]any)
@@ -843,68 +765,100 @@ func registerWebSearchPlugin() {
if entries == nil { if entries == nil {
entries = make(map[string]any) entries = make(map[string]any)
} }
entries["openclaw-web-search"] = map[string]any{"enabled": true}
plugins["entries"] = entries
// Pin trust so the gateway doesn't warn about untracked plugins.
allow, _ := plugins["allow"].([]any)
hasAllow := false
for _, v := range allow {
if s, ok := v.(string); ok && s == "openclaw-web-search" {
hasAllow = true
break
}
}
if !hasAllow {
allow = append(allow, "openclaw-web-search")
}
plugins["allow"] = allow
// Record install provenance so the loader can verify the plugin origin.
installs, _ := plugins["installs"].(map[string]any)
if installs == nil {
installs = make(map[string]any)
}
pluginDir := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
installs["openclaw-web-search"] = map[string]any{
"source": "npm",
"spec": webSearchNpmPackage,
"installPath": pluginDir,
}
plugins["installs"] = installs
config["plugins"] = plugins
// Add plugin tools to tools.alsoAllow so they survive the coding profile's
// policy pipeline (which has an explicit allow list of core tools only).
tools, _ := config["tools"].(map[string]any) tools, _ := config["tools"].(map[string]any)
if tools == nil { if tools == nil {
tools = make(map[string]any) tools = make(map[string]any)
} }
alsoAllow, _ := tools["alsoAllow"].([]any)
needed := []string{"ollama_web_search", "ollama_web_fetch"}
have := make(map[string]bool, len(alsoAllow))
for _, v := range alsoAllow {
if s, ok := v.(string); ok {
have[s] = true
}
}
for _, name := range needed {
if !have[name] {
alsoAllow = append(alsoAllow, name)
}
}
tools["alsoAllow"] = alsoAllow
// Disable built-in web search/fetch since our plugin replaces them.
web, _ := tools["web"].(map[string]any) web, _ := tools["web"].(map[string]any)
if web == nil { if web == nil {
web = make(map[string]any) web = make(map[string]any)
} }
web["search"] = map[string]any{"enabled": false} search, _ := web["search"].(map[string]any)
web["fetch"] = map[string]any{"enabled": false} if search == nil {
search = make(map[string]any)
}
fetch, _ := web["fetch"].(map[string]any)
if fetch == nil {
fetch = make(map[string]any)
}
alsoAllow, _ := tools["alsoAllow"].([]any)
var filteredAlsoAllow []any
for _, v := range alsoAllow {
s, ok := v.(string)
if !ok {
filteredAlsoAllow = append(filteredAlsoAllow, v)
continue
}
if s == "ollama_web_search" || s == "ollama_web_fetch" {
stalePluginConfigured = true
continue
}
filteredAlsoAllow = append(filteredAlsoAllow, v)
}
if len(filteredAlsoAllow) > 0 {
tools["alsoAllow"] = filteredAlsoAllow
} else {
delete(tools, "alsoAllow")
}
if _, ok := entries["openclaw-web-search"]; ok {
delete(entries, "openclaw-web-search")
stalePluginConfigured = true
}
ollamaEntry, _ := entries["ollama"].(map[string]any)
if ollamaEntry == nil {
ollamaEntry = make(map[string]any)
}
ollamaEntry["enabled"] = true
entries["ollama"] = ollamaEntry
plugins["entries"] = entries
if allow, ok := plugins["allow"].([]any); ok {
var nextAllow []any
hasOllama := false
for _, v := range allow {
s, ok := v.(string)
if ok && s == "openclaw-web-search" {
stalePluginConfigured = true
continue
}
if ok && s == "ollama" {
hasOllama = true
}
nextAllow = append(nextAllow, v)
}
if !hasOllama {
nextAllow = append(nextAllow, "ollama")
}
plugins["allow"] = nextAllow
}
if installs, ok := plugins["installs"].(map[string]any); ok {
if _, exists := installs["openclaw-web-search"]; exists {
delete(installs, "openclaw-web-search")
stalePluginConfigured = true
}
if len(installs) > 0 {
plugins["installs"] = installs
} else {
delete(plugins, "installs")
}
}
if stalePluginConfigured || search["provider"] == nil {
search["provider"] = "ollama"
}
if stalePluginConfigured {
fetch["enabled"] = true
}
search["enabled"] = true
web["search"] = search
if len(fetch) > 0 {
web["fetch"] = fetch
}
tools["web"] = web tools["web"] = web
config["plugins"] = plugins
config["tools"] = tools config["tools"] = tools
out, err := json.MarshalIndent(config, "", " ") out, err := json.MarshalIndent(config, "", " ")

View File

@@ -2242,95 +2242,7 @@ func TestIntegrationOnboarded(t *testing.T) {
}) })
} }
func TestVersionLessThan(t *testing.T) { func TestConfigureOllamaWebSearch(t *testing.T) {
tests := []struct {
a, b string
want bool
}{
{"0.1.7", "0.2.1", true},
{"0.2.0", "0.2.1", true},
{"0.2.1", "0.2.1", false},
{"0.2.2", "0.2.1", false},
{"1.0.0", "0.2.1", false},
{"0.2.1", "1.0.0", true},
{"v0.1.7", "0.2.1", true},
{"0.2.1", "v0.2.1", false},
}
for _, tt := range tests {
t.Run(tt.a+"_vs_"+tt.b, func(t *testing.T) {
if got := versionLessThan(tt.a, tt.b); got != tt.want {
t.Errorf("versionLessThan(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestWebSearchPluginUpToDate(t *testing.T) {
t.Run("missing directory", func(t *testing.T) {
if webSearchPluginUpToDate(filepath.Join(t.TempDir(), "nonexistent")) {
t.Error("expected false for missing directory")
}
})
t.Run("missing package.json", func(t *testing.T) {
dir := t.TempDir()
if webSearchPluginUpToDate(dir) {
t.Error("expected false for missing package.json")
}
})
t.Run("old version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.1.7"}`), 0o644); err != nil {
t.Fatal(err)
}
if webSearchPluginUpToDate(dir) {
t.Error("expected false for old version 0.1.7")
}
})
t.Run("exact minimum version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"0.2.1"}`), 0o644); err != nil {
t.Fatal(err)
}
if !webSearchPluginUpToDate(dir) {
t.Error("expected true for exact minimum version 0.2.1")
}
})
t.Run("newer version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":"1.0.0"}`), 0o644); err != nil {
t.Fatal(err)
}
if !webSearchPluginUpToDate(dir) {
t.Error("expected true for newer version 1.0.0")
}
})
t.Run("invalid json", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`not json`), 0o644); err != nil {
t.Fatal(err)
}
if webSearchPluginUpToDate(dir) {
t.Error("expected false for invalid json")
}
})
t.Run("empty version", func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "package.json"), []byte(`{"version":""}`), 0o644); err != nil {
t.Fatal(err)
}
if webSearchPluginUpToDate(dir) {
t.Error("expected false for empty version")
}
})
}
func TestRegisterWebSearchPlugin(t *testing.T) {
home := t.TempDir() home := t.TempDir()
setTestHome(t, home) setTestHome(t, home)
@@ -2345,7 +2257,7 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
registerWebSearchPlugin() configureOllamaWebSearch()
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
if err != nil { if err != nil {
@@ -2361,40 +2273,30 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
t.Fatal("plugins section missing") t.Fatal("plugins section missing")
} }
// Check entries
entries, _ := plugins["entries"].(map[string]any) entries, _ := plugins["entries"].(map[string]any)
entry, _ := entries["openclaw-web-search"].(map[string]any) entry, _ := entries["ollama"].(map[string]any)
if enabled, _ := entry["enabled"].(bool); !enabled { if enabled, _ := entry["enabled"].(bool); !enabled {
t.Error("expected entries.openclaw-web-search.enabled = true") t.Error("expected entries.ollama.enabled = true")
}
if _, ok := entries["openclaw-web-search"]; ok {
t.Error("expected stale openclaw-web-search entry to be absent")
} }
// Check allow list if _, ok := plugins["allow"]; ok {
allow, _ := plugins["allow"].([]any) t.Error("did not expect plugins.allow to be created when no allowlist exists")
found := false
for _, v := range allow {
if s, ok := v.(string); ok && s == "openclaw-web-search" {
found = true
}
} }
if !found { if _, ok := plugins["installs"]; ok {
t.Error("expected plugins.allow to contain openclaw-web-search") t.Error("did not expect plugins.installs to be created")
} }
// Check install provenance tools, _ := config["tools"].(map[string]any)
installs, _ := plugins["installs"].(map[string]any) web, _ := tools["web"].(map[string]any)
record, _ := installs["openclaw-web-search"].(map[string]any) search, _ := web["search"].(map[string]any)
if record == nil { if got, _ := search["provider"].(string); got != "ollama" {
t.Fatal("expected plugins.installs.openclaw-web-search") t.Errorf("search provider = %q, want %q", got, "ollama")
} }
if source, _ := record["source"].(string); source != "npm" { if enabled, _ := search["enabled"].(bool); !enabled {
t.Errorf("install source = %q, want %q", source, "npm") t.Error("expected tools.web.search.enabled = true")
}
if spec, _ := record["spec"].(string); spec != webSearchNpmPackage {
t.Errorf("install spec = %q, want %q", spec, webSearchNpmPackage)
}
expectedPath := filepath.Join(home, ".openclaw", "extensions", "openclaw-web-search")
if installPath, _ := record["installPath"].(string); installPath != expectedPath {
t.Errorf("installPath = %q, want %q", installPath, expectedPath)
} }
}) })
@@ -2403,8 +2305,8 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
registerWebSearchPlugin() configureOllamaWebSearch()
registerWebSearchPlugin() configureOllamaWebSearch()
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
if err != nil { if err != nil {
@@ -2416,30 +2318,39 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
} }
plugins, _ := config["plugins"].(map[string]any) plugins, _ := config["plugins"].(map[string]any)
allow, _ := plugins["allow"].([]any) entries, _ := plugins["entries"].(map[string]any)
count := 0 if len(entries) != 1 {
for _, v := range allow { t.Fatalf("expected only bundled ollama entry, got %v", entries)
if s, ok := v.(string); ok && s == "openclaw-web-search" {
count++
}
} }
if count != 1 { if _, ok := entries["ollama"]; !ok {
t.Errorf("expected exactly 1 openclaw-web-search in allow, got %d", count) t.Fatalf("expected entries.ollama to exist, got %v", entries)
} }
}) })
t.Run("preserves existing config", func(t *testing.T) { t.Run("migrates stale plugin config and preserves unrelated settings", func(t *testing.T) {
initial := map[string]any{ initial := map[string]any{
"plugins": map[string]any{ "plugins": map[string]any{
"allow": []any{"some-other-plugin"}, "allow": []any{"some-other-plugin", "openclaw-web-search"},
"entries": map[string]any{ "entries": map[string]any{
"some-other-plugin": map[string]any{"enabled": true}, "some-other-plugin": map[string]any{"enabled": true},
"openclaw-web-search": map[string]any{"enabled": true},
}, },
"installs": map[string]any{ "installs": map[string]any{
"some-other-plugin": map[string]any{ "some-other-plugin": map[string]any{
"source": "npm", "source": "npm",
"installPath": "/some/path", "installPath": "/some/path",
}, },
"openclaw-web-search": map[string]any{
"source": "npm",
"installPath": "/old/path",
},
},
},
"tools": map[string]any{
"alsoAllow": []any{"ollama_web_search", "ollama_web_fetch", "browser"},
"web": map[string]any{
"search": map[string]any{"enabled": false},
"fetch": map[string]any{"enabled": false},
}, },
}, },
"customField": "preserved", "customField": "preserved",
@@ -2449,7 +2360,7 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
registerWebSearchPlugin() configureOllamaWebSearch()
out, err := os.ReadFile(configPath) out, err := os.ReadFile(configPath)
if err != nil { if err != nil {
@@ -2469,28 +2380,61 @@ func TestRegisterWebSearchPlugin(t *testing.T) {
if entries["some-other-plugin"] == nil { if entries["some-other-plugin"] == nil {
t.Error("existing plugin entry was lost") t.Error("existing plugin entry was lost")
} }
if entries["openclaw-web-search"] != nil {
t.Error("stale openclaw-web-search entry should be removed")
}
if ollamaEntry, _ := entries["ollama"].(map[string]any); ollamaEntry == nil {
t.Fatal("expected bundled ollama entry to be enabled")
}
installs, _ := plugins["installs"].(map[string]any) installs, _ := plugins["installs"].(map[string]any)
if installs["some-other-plugin"] == nil { if installs["some-other-plugin"] == nil {
t.Error("existing install record was lost") t.Error("existing install record was lost")
} }
if installs["openclaw-web-search"] != nil {
t.Error("stale openclaw-web-search install record should be removed")
}
allow, _ := plugins["allow"].([]any) allow, _ := plugins["allow"].([]any)
hasOther, hasWebSearch := false, false hasOther, hasStalePlugin, hasOllama := false, false, false
for _, v := range allow { for _, v := range allow {
s, _ := v.(string) s, _ := v.(string)
if s == "some-other-plugin" { if s == "some-other-plugin" {
hasOther = true hasOther = true
} }
if s == "openclaw-web-search" { if s == "openclaw-web-search" {
hasWebSearch = true hasStalePlugin = true
}
if s == "ollama" {
hasOllama = true
} }
} }
if !hasOther { if !hasOther {
t.Error("existing allow entry was lost") t.Error("existing allow entry was lost")
} }
if !hasWebSearch { if hasStalePlugin {
t.Error("openclaw-web-search not added to allow") t.Error("stale openclaw-web-search allow entry should be removed")
}
if !hasOllama {
t.Error("expected plugins.allow to contain bundled ollama plugin")
}
tools, _ := config["tools"].(map[string]any)
alsoAllow, _ := tools["alsoAllow"].([]any)
if len(alsoAllow) != 1 || alsoAllow[0] != "browser" {
t.Errorf("expected stale custom web tools to be removed, got %v", alsoAllow)
}
web, _ := tools["web"].(map[string]any)
search, _ := web["search"].(map[string]any)
fetch, _ := web["fetch"].(map[string]any)
if got, _ := search["provider"].(string); got != "ollama" {
t.Errorf("search provider = %q, want %q", got, "ollama")
}
if enabled, _ := search["enabled"].(bool); !enabled {
t.Error("expected migrated tools.web.search.enabled = true")
}
if enabled, _ := fetch["enabled"].(bool); !enabled {
t.Error("expected migrated tools.web.fetch.enabled = true")
} }
}) })
} }

View File

@@ -2,6 +2,10 @@
title: Structured Outputs title: Structured Outputs
--- ---
<Note>
Ollama's Cloud currently does not support structured outputs.
</Note>
Structured outputs let you enforce a JSON schema on model responses so you can reliably extract structured data, describe images, or keep every reply consistent. Structured outputs let you enforce a JSON schema on model responses so you can reliably extract structured data, describe images, or keep every reply consistent.
## Generating structured JSON ## Generating structured JSON

View File

@@ -15,7 +15,7 @@ Ollama handles everything automatically:
1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm 1. **Install** — If OpenClaw isn't installed, Ollama prompts to install it via npm
2. **Security** — On the first launch, a security notice explains the risks of tool access 2. **Security** — On the first launch, a security notice explains the risks of tool access
3. **Model** — Pick a model from the selector (local or cloud) 3. **Model** — Pick a model from the selector (local or cloud)
4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and installs the web search and fetch plugin 4. **Onboarding** — Ollama configures the provider, installs the gateway daemon, sets your model as the primary, and enables OpenClaw's bundled Ollama web search
5. **Gateway** — Starts in the background and opens the OpenClaw TUI 5. **Gateway** — Starts in the background and opens the OpenClaw TUI
<Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note> <Note>OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens if using local models. See [Context length](/context-length) for more information.</Note>
@@ -24,19 +24,19 @@ Ollama handles everything automatically:
## Web search and fetch ## Web search and fetch
OpenClaw ships with a web search and fetch plugin that gives local or cloud models the ability to search the web and extract readable page content. OpenClaw ships with a bundled Ollama `web_search` provider that lets local or cloud-backed Ollama setups search the web through the configured Ollama host.
```bash ```bash
ollama launch openclaw ollama launch openclaw
``` ```
Web search and fetch is enabled automatically when launching OpenClaw through Ollama. To install the plugin directly: Ollama web search is enabled automatically when launching OpenClaw through Ollama. To configure it manually:
```bash ```bash
openclaw plugins install @ollama/openclaw-web-search openclaw configure --section web
``` ```
<Note>Web search for local models requires `ollama signin`.</Note> <Note>Ollama web search for local models requires `ollama signin`.</Note>
## Configure without launching ## Configure without launching
@@ -93,4 +93,3 @@ Link WhatsApp, Telegram, Slack, Discord, or iMessage to chat with your local mod
```bash ```bash
openclaw gateway stop openclaw gateway stop
``` ```

View File

@@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) {
} }
func TestAPIGenerateLogprobs(t *testing.T) { func TestAPIGenerateLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
@@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) {
} }
func TestAPIChatLogprobs(t *testing.T) { func TestAPIChatLogprobs(t *testing.T) {
if testModel != "" {
// Logprobs requires runner support (e.g. llama.cpp has it, MLX does not).
t.Skip("logprobs not supported by all runners")
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()

View File

@@ -151,22 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error {
} }
} }
// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. type CompletionRequest struct {
type completionRequest struct { Prompt string
Prompt string `json:"prompt"` Options api.Options
Options *completionOpts `json:"options,omitempty"` Logprobs bool
} TopLogprobs int
type completionOpts struct {
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TopK int `json:"top_k,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"`
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
} }
type CompletionResponse struct { type CompletionResponse struct {
@@ -179,6 +168,8 @@ type CompletionResponse struct {
EvalCount int EvalCount int
EvalDuration time.Duration EvalDuration time.Duration
Logprobs []llm.Logprob
Error *api.StatusError Error *api.StatusError
} }
@@ -203,21 +194,13 @@ func (c *Client) Close() error {
// Completion implements llm.LlamaServer. // Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := completionRequest{ creq := CompletionRequest{
Prompt: req.Prompt, Prompt: req.Prompt,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
} }
if req.Options != nil { if req.Options != nil {
creq.Options = &completionOpts{ creq.Options = *req.Options
Temperature: req.Options.Temperature,
TopP: req.Options.TopP,
MinP: req.Options.MinP,
TopK: req.Options.TopK,
RepeatLastN: req.Options.RepeatLastN,
RepeatPenalty: req.Options.RepeatPenalty,
PresencePenalty: req.Options.PresencePenalty,
FrequencyPenalty: req.Options.FrequencyPenalty,
NumPredict: req.Options.NumPredict,
}
} }
body, err := json.Marshal(creq) body, err := json.Marshal(creq)
@@ -243,7 +226,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body) respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(respBody))) return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))}
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
@@ -266,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
PromptEvalDuration: raw.PromptEvalDuration, PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount, EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration, EvalDuration: raw.EvalDuration,
Logprobs: raw.Logprobs,
} }
fn(cresp) fn(cresp)

View File

@@ -10,6 +10,8 @@ import (
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
"sync"
"sync/atomic"
"unsafe" "unsafe"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
@@ -18,20 +20,28 @@ import (
type Array struct { type Array struct {
ctx C.mlx_array ctx C.mlx_array
name string name string
pinned int pinned atomic.Int32
} }
var arrays []*Array var (
arrays []*Array
arraysMu sync.Mutex
)
// constructor utilities // constructor utilities
func New(name string) *Array { func New(name string) *Array {
t := &Array{name: name} t := &Array{name: name}
if tracing { if tracing {
traceScratch = append(traceScratch, t) traceScratch = append(traceScratch, t)
} else { } else {
arraysMu.Lock()
defer arraysMu.Unlock()
arrays = append(arrays, t) arrays = append(arrays, t)
} }
return t return t
} }
@@ -131,7 +141,7 @@ func (t *Array) Clone() *Array {
func Pin(s ...*Array) { func Pin(s ...*Array) {
for _, t := range s { for _, t := range s {
if t != nil { if t != nil {
t.pinned++ t.pinned.Add(1)
} }
} }
} }
@@ -140,8 +150,7 @@ func Pin(s ...*Array) {
func Unpin(s ...*Array) { func Unpin(s ...*Array) {
for _, t := range s { for _, t := range s {
if t != nil { if t != nil {
t.pinned-- if t.pinned.Add(-1) < 0 {
if t.pinned < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name)) panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
} }
} }
@@ -151,9 +160,11 @@ func Unpin(s ...*Array) {
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly // Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph. // free them when there are no other references, including dependencies in the graph.
func Sweep() { func Sweep() {
arraysMu.Lock()
defer arraysMu.Unlock()
n := 0 n := 0
for _, t := range arrays { for _, t := range arrays {
if t.pinned > 0 && t.Valid() { if t.pinned.Load() > 0 && t.Valid() {
arrays[n] = t arrays[n] = t
n++ n++
} else if t.Valid() { } else if t.Valid() {
@@ -180,7 +191,7 @@ func (t *Array) String() string {
func (t *Array) LogValue() slog.Value { func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{ attrs := []slog.Attr{
slog.String("name", t.name), slog.String("name", t.name),
slog.Int("pinned", t.pinned), slog.Int("pinned", int(t.pinned.Load())),
} }
if t.Valid() { if t.Valid() {
attrs = append(attrs, attrs = append(attrs,
@@ -194,19 +205,19 @@ func (t *Array) LogValue() slog.Value {
// shape utilities // shape utilities
func (t Array) Size() int { func (t *Array) Size() int {
return int(C.mlx_array_size(t.ctx)) return int(C.mlx_array_size(t.ctx))
} }
func (t Array) NumBytes() int { func (t *Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx)) return int(C.mlx_array_nbytes(t.ctx))
} }
func (t Array) NumDims() int { func (t *Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx)) return int(C.mlx_array_ndim(t.ctx))
} }
func (t Array) Dims() []int { func (t *Array) Dims() []int {
dims := make([]int, t.NumDims()) dims := make([]int, t.NumDims())
for i := range dims { for i := range dims {
dims[i] = t.Dim(i) dims[i] = t.Dim(i)
@@ -215,29 +226,32 @@ func (t Array) Dims() []int {
return dims return dims
} }
func (t Array) Dim(dim int) int { func (t *Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim))) return int(C.mlx_array_dim(t.ctx, C.int(dim)))
} }
func (t Array) DType() DType { func (t *Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx)) return DType(C.mlx_array_dtype(t.ctx))
} }
// data utilities // data utilities
func (t Array) Int() int { func (t *Array) Int() int {
var item C.int64_t var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx) C.mlx_array_item_int64(&item, t.ctx)
return int(item) return int(item)
} }
func (t Array) Float() float64 { func (t *Array) Float() float64 {
var item C.double var item C.double
C.mlx_array_item_float64(&item, t.ctx) C.mlx_array_item_float64(&item, t.ctx)
return float64(item) return float64(item)
} }
func (t Array) Ints() []int { func (t *Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
ints := make([]int, t.Size()) ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f) ints[i] = int(f)
@@ -245,7 +259,10 @@ func (t Array) Ints() []int {
return ints return ints
} }
func (t Array) Floats() []float32 { func (t *Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
floats := make([]float32, t.Size()) floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f) floats[i] = float32(f)
@@ -253,7 +270,7 @@ func (t Array) Floats() []float32 {
return floats return floats
} }
func (t Array) Save(name string) error { func (t *Array) Save(name string) error {
cName := C.CString(name) cName := C.CString(name)
defer C.free(unsafe.Pointer(cName)) defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx) C.mlx_save(cName, t.ctx)
@@ -262,6 +279,8 @@ func (t Array) Save(name string) error {
// LogArrays logs all live arrays, sorted by size // LogArrays logs all live arrays, sorted by size
func LogArrays() { func LogArrays() {
arraysMu.Lock()
defer arraysMu.Unlock()
sort.Slice(arrays, func(i, j int) bool { sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes() return arrays[i].NumBytes() > arrays[j].NumBytes()
}) })
@@ -270,7 +289,7 @@ func LogArrays() {
for _, t := range arrays { for _, t := range arrays {
nb := t.NumBytes() nb := t.NumBytes()
total += nb total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned, t.Dims())) logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
} }
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory()))) logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
} }

View File

@@ -150,7 +150,7 @@ func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload
traceScratch = nil traceScratch = nil
defer func() { defer func() {
for _, a := range traceScratch { for _, a := range traceScratch {
if a.pinned > 0 { if a.pinned.Load() > 0 {
panic("mlx: traced array was pinned during compilation") panic("mlx: traced array was pinned during compilation")
} }
if a.Valid() { if a.Valid() {

View File

@@ -24,8 +24,8 @@ func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *A
} }
type LayerNorm struct { type LayerNorm struct {
Weight Array `weight:"weight"` Weight *Array `weight:"weight"`
Bias Array `weight:"bias"` Bias *Array `weight:"bias"`
} }
func (r *LayerNorm) Forward(x *Array, eps float32) *Array { func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
@@ -35,10 +35,10 @@ func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
} }
type RMSNorm struct { type RMSNorm struct {
Weight Array `weight:"weight"` Weight *Array `weight:"weight"`
} }
func (r RMSNorm) Forward(x *Array, eps float32) *Array { func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM") out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx) C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out return out

View File

@@ -1,12 +1,12 @@
package mlx package mlx
type Linear struct { type Linear struct {
Weight Array `weight:"weight"` Weight *Array `weight:"weight"`
Bias Array `weight:"bias"` Bias *Array `weight:"bias"`
} }
// Forward computes the linear transformation: x @ Weight.T + Bias // Forward computes the linear transformation: x @ Weight.T + Bias
func (m Linear) Forward(x *Array) *Array { func (m *Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0) w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() { if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0) return m.Bias.Addmm(x, w, 1.0, 1.0)
@@ -15,14 +15,14 @@ func (m Linear) Forward(x *Array) *Array {
return x.Matmul(w) return x.Matmul(w)
} }
func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array { func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1) w := m.Weight.Transpose(0, 2, 1)
// TODO: bias // TODO: bias
return x.GatherMM(w, lhs, rhs, sorted) return x.GatherMM(w, lhs, rhs, sorted)
} }
type Embedding struct { type Embedding struct {
Weight Array `weight:"weight"` Weight *Array `weight:"weight"`
} }
func (e *Embedding) Forward(indices *Array) *Array { func (e *Embedding) Forward(indices *Array) *Array {

View File

@@ -139,6 +139,12 @@ func (t *Array) Less(other *Array) *Array {
return out return out
} }
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array { func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL") out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx) C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)

View File

@@ -6,36 +6,59 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "sort"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/mlx"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
) )
func prefillChunkSize() int { func prefillChunkSize() int {
return 2 << 10 return 2 << 10
} }
func (r *Runner) TextGenerationPipeline(request Request) error { // Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.NumPredict.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil { if r.Model == nil {
return errors.New("model not loaded") return errors.New("model not loaded")
} }
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
request.Tokens = tokens
return nil
}
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory() mlx.ResetPeakMemory()
ctx := request.Ctx var sample, nextSample sampler.Result
var (
sample *mlx.Array
nextSample *mlx.Array
)
defer func() { defer func() {
if request.Sampler != nil { if request.Sampler != nil {
request.Sampler.Free() request.Sampler.Free()
} }
mlx.Unpin(sample) mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample) mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep() mlx.Sweep()
mlx.ClearCache() mlx.ClearCache()
@@ -46,26 +69,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}() }()
inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS()) inputs := request.Tokens
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.MaxTokens <= 0 {
request.Options.MaxTokens = maxGenerate
} else {
request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate)
}
request.Sampler.ResetHistory(inputs) request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs) session := r.cache.begin(r.Model, inputs)
@@ -135,40 +139,38 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
mlx.ClearCache() mlx.ClearCache()
} }
step := func(token *mlx.Array) *mlx.Array { step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token.ExpandDims(0), caches) fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd) logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
sample := request.Sampler.Sample(logits) sample := request.Sampler.Sample(logits)
mlx.Pin(sample.Arrays()...)
mlx.Pin(sample)
mlx.Sweep() mlx.Sweep()
mlx.AsyncEval(sample) mlx.AsyncEval(sample.Arrays()...)
return sample return sample
} }
sample = step(mlx.FromValues(tokens[processed:], total-processed)) sample = step(mlx.FromValues(tokens[processed:], total-processed))
var b bytes.Buffer dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1} final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.MaxTokens { for i := range request.Options.NumPredict {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
request.Sampler.AppendToken(sample) request.Sampler.AppendToken(sample.Token)
nextSample = step(sample) nextSample = step(sample.Token)
if i == 0 { if i == 0 {
mlx.Eval(sample) mlx.Eval(sample.Arrays()...)
final.PromptEvalDuration = time.Since(now) final.PromptEvalDuration = time.Since(now)
now = time.Now() now = time.Now()
} }
output := int32(sample.Int()) output := int32(sample.Token.Int())
session.outputs = append(session.outputs, output) session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) { if r.Tokenizer.IsEOS(output) {
@@ -177,17 +179,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
break break
} }
select { if resp, ok := dec.decode(sample); ok {
case <-ctx.Done(): select {
return ctx.Err() case <-ctx.Done():
case request.Responses <- CompletionResponse{ return ctx.Err()
Content: r.Decode(output, &b), case request.Responses <- resp:
}: }
} }
mlx.Unpin(sample) mlx.Unpin(sample.Arrays()...)
sample = nextSample sample, nextSample = nextSample, sampler.Result{}
nextSample = nil
if i%256 == 0 { if i%256 == 0 {
mlx.ClearCache() mlx.ClearCache()
@@ -203,13 +204,57 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
} }
} }
func (r Runner) Decode(sample int32, b *bytes.Buffer) string { // decoder serializes sampled tokens into response chunks, holding bytes
token := r.Tokenizer.Decode([]int32{sample}) // whose UTF-8 sequence hasn't completed yet and the logprobs that belong
// with those bytes so Content and Logprobs stay aligned when a chunk does
// flush.
type decoder struct {
tokenizer *tokenizer.Tokenizer
buf bytes.Buffer
logprobs []llm.Logprob
}
if _, err := b.WriteString(token); err != nil { func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
slog.Error("Failed to write token to buffer", "error", err) output := int32(res.Token.Int())
return "" d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf)
if content == "" {
return CompletionResponse{}, false
}
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
d.logprobs = nil
return resp, true
}
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
if sample.Logprob == nil {
return nil
}
tok := func(id int32) string { return decode([]int32{id}) }
out := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: tok(int32(sample.Token.Int())),
Logprob: float64(sample.Logprob.Floats()[0]),
},
} }
return flushValidUTF8Prefix(b) if sample.TopTokens != nil {
ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids))
for i, id := range ids {
pairs[i] = llm.TokenLogprob{
Token: tok(int32(id)),
Logprob: float64(vals[i]),
}
}
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob
})
out.TopLogprobs = pairs
}
return []llm.Logprob{out}
} }

View File

@@ -18,34 +18,20 @@ import (
"github.com/ollama/ollama/x/tokenizer" "github.com/ollama/ollama/x/tokenizer"
) )
// Request is a short-lived struct that carries a completion request through
// a channel from the HTTP handler to the runner goroutine. The ctx field
// must travel with the request so that cancellation propagates across the
// channel boundary.
type Request struct { type Request struct {
TextCompletionsRequest CompletionRequest
Responses chan CompletionResponse Responses chan CompletionResponse
Pipeline func(Request) error Pipeline func(context.Context, Request) error
Ctx context.Context
Ctx context.Context //nolint:containedctx
Tokens []int32
Sampler *sample.Sampler Sampler *sample.Sampler
} }
type TextCompletionsRequest struct {
Prompt string `json:"prompt"`
Options struct {
Temperature float32 `json:"temperature"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TopK int `json:"top_k"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
MaxTokens int `json:"max_tokens"`
// Deprecated: use MaxTokens instead
NumPredict int `json:"num_predict"`
} `json:"options"`
}
type Runner struct { type Runner struct {
Model base.Model Model base.Model
Tokenizer *tokenizer.Tokenizer Tokenizer *tokenizer.Tokenizer
@@ -149,7 +135,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case request := <-r.Requests: case request := <-r.Requests:
if err := request.Pipeline(request); err != nil { if err := request.Pipeline(request.Ctx, request); err != nil {
slog.Info("Request terminated", "error", err) slog.Info("Request terminated", "error", err)
var statusErr api.StatusError var statusErr api.StatusError
if !errors.As(err, &statusErr) { if !errors.As(err, &statusErr) {

View File

@@ -0,0 +1,249 @@
//go:build mlx
package sample
import (
"math"
"sort"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// logprobEntry is the (token id, logprob) pair returned by the sampler's
// top-K extraction, used after the test-side descending sort.
type logprobEntry struct {
id int
logprob float64
}
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
// and returns the greedily-sampled token id, its logprob, and the top-K
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
t.Helper()
s := New(Options{Logprobs: true, TopLogprobs: topK})
defer func() {
s.Free()
mlx.Sweep()
}()
tensor := mlx.FromValues(logits, 1, len(logits))
res := s.Sample(tensor)
mlx.Pin(res.Arrays()...)
defer mlx.Unpin(res.Arrays()...)
mlx.Sweep()
mlx.Eval(res.Arrays()...)
selected := res.Token.Int()
selLP := float64(res.Logprob.Floats()[0])
var top []logprobEntry
if topK > 0 && res.TopTokens != nil {
ids := res.TopTokens.Ints()
vals := res.TopLogprobs.Floats()
top = make([]logprobEntry, len(ids))
for i, id := range ids {
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
}
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
}
return selected, selLP, top
}
func TestSampleLogprobsBasic(t *testing.T) {
tests := []struct {
name string
logits []float32
topK int
wantSelectedID int
wantTopLen int
}{
{
name: "single token without top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 0,
wantSelectedID: 0,
wantTopLen: 0,
},
{
name: "single token with top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 3,
wantSelectedID: 0,
wantTopLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
if selected != tt.wantSelectedID {
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
}
if len(top) != tt.wantTopLen {
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
}
})
}
}
func TestSampleLogprobsNumericalStability(t *testing.T) {
logits := []float32{1000.0, 999.0, 998.0}
_, selLP, top := runSampleLogprobs(t, logits, 3)
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
t.Errorf("selected logprob is not finite: %f", selLP)
}
for i, e := range top {
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
}
}
}
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if selLP > 0 {
t.Errorf("selected logprob should be <= 0, got %f", selLP)
}
for i, e := range top {
if e.logprob > 0 {
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
}
}
if tt.name == "uniform" {
want := 1.0 / float64(len(tt.logits))
got := math.Exp(selLP)
if math.Abs(got-want) > 1e-6 {
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending at %d: %f > %f",
i, top[i].logprob, top[i-1].logprob)
}
}
found := false
for _, e := range top {
if e.id == selected {
found = true
if math.Abs(e.logprob-selLP) > 1e-6 {
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
}
break
}
}
if !found {
t.Errorf("selected token %d not present in top-K", selected)
}
})
}
}
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
tests := []struct {
name string
logits []float32
}{
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
{"large differences", []float32{10.0, 0.0, -10.0}},
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
{"very large values", []float32{500.0, 499.0, 498.0}},
{"very small values", []float32{-500.0, -499.0, -498.0}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if len(top) != len(tt.logits) {
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
}
var sum float64
for _, e := range top {
p := math.Exp(e.logprob)
if p < 0 || p > 1 {
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
}
sum += p
}
if math.Abs(sum-1.0) > 1e-5 {
t.Errorf("probabilities sum = %f, want 1.0", sum)
}
})
}
}
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
logits := []float32{3.0, 1.0, 2.0, 0.5}
maxIdx := 0
for i, v := range logits[1:] {
if v > logits[maxIdx] {
maxIdx = i + 1
}
}
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
if selected != maxIdx {
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
}
if top[0].id != maxIdx {
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
}
if math.Abs(top[0].logprob-selLP) > 1e-6 {
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
}
}
func TestSampleLogprobsTopKOrdering(t *testing.T) {
// Logits chosen so argmax order differs from index order.
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
wantOrder := []int{1, 3, 4, 0, 2}
_, _, top := runSampleLogprobs(t, logits, len(logits))
if len(top) != len(wantOrder) {
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
}
for i, e := range top {
if e.id != wantOrder[i] {
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
i, top[i].logprob, i-1, top[i-1].logprob)
}
}
}

View File

@@ -8,7 +8,7 @@ import (
type Transform func(*Sampler, *mlx.Array) *mlx.Array type Transform func(*Sampler, *mlx.Array) *mlx.Array
type Sampler struct { type Options struct {
Temperature float32 Temperature float32
TopP float32 TopP float32
MinP float32 MinP float32
@@ -18,45 +18,66 @@ type Sampler struct {
PresencePenalty float32 PresencePenalty float32
FrequencyPenalty float32 FrequencyPenalty float32
// Logprobs causes Sample to populate Result.Logprob with the selected
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
Logprobs bool
TopLogprobs int
}
type Sampler struct {
Options
history *mlx.Array history *mlx.Array
historyLen int historyLen int
transforms []Transform transforms []Transform
} }
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) *Sampler { // Result bundles the outputs of one decode step. The logprob tensors are
if repeatPenalty <= 0 { // populated only when the sampler is configured to report them.
repeatPenalty = 1 type Result struct {
Token *mlx.Array // sampled token id, shape [B]
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
}
// Arrays returns the tensor fields as a slice so callers can drive the mlx
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
// fields stay nil; the mlx helpers skip them.
func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
}
func New(opts Options) *Sampler {
if opts.RepeatPenalty <= 0 {
opts.RepeatPenalty = 1
} }
s := &Sampler{ s := &Sampler{Options: opts}
Temperature: temp,
TopP: top_p,
MinP: min_p,
TopK: top_k,
RepeatLastN: repeatLastN,
RepeatPenalty: repeatPenalty,
PresencePenalty: presencePenalty,
FrequencyPenalty: frequencyPenalty,
}
var transforms []Transform var transforms []Transform
if s.usesHistory() { if s.usesHistory() {
transforms = append(transforms, penalty) transforms = append(transforms, penalty)
} }
if top_p > 0 && top_p < 1 { hasTopP := opts.TopP > 0 && opts.TopP < 1
transforms = append(transforms, topP) hasTopK := opts.TopK > 0
} switch {
case hasTopP:
if min_p != 0 { // topKTopP always does a full descending sort for the top-P
transforms = append(transforms, minP) // cumulative mask and opportunistically masks top-K during the
} // same pass when it is also configured.
transforms = append(transforms, topKTopP)
if top_k > 0 { case hasTopK:
// Argpartition (partial sort) is cheaper than a full sort.
transforms = append(transforms, topK) transforms = append(transforms, topK)
} }
if temp == 0 { if opts.MinP != 0 {
transforms = append(transforms, minP)
}
if opts.Temperature == 0 {
transforms = append(transforms, greedy) transforms = append(transforms, greedy)
} else { } else {
transforms = append(transforms, temperature) transforms = append(transforms, temperature)
@@ -123,76 +144,121 @@ func (s *Sampler) Free() {
s.setHistory(nil, 0) s.setHistory(nil, 0)
} }
func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array { // Sample runs the configured transform chain on the raw per-token logits
// and returns the sampled token id plus, when configured, the reported
// log-probability tensors for the selected token and the top-K tokens.
func (s *Sampler) Sample(logits *mlx.Array) Result {
scores := logits
for _, transform := range s.transforms { for _, transform := range s.transforms {
logits = transform(s, logits) scores = transform(s, scores)
} }
return logits res := Result{Token: scores}
}
func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array { if s.Logprobs {
return logits.Argmax(-1, false) // Compute log_softmax in fp32 and subtract the max before
} // logsumexp so the final subtraction stays on small values.
// Otherwise it cancels two large numbers and loses precision.
func temperature(s *Sampler, logits *mlx.Array) *mlx.Array { lp := logits.AsType(mlx.DTypeFloat32)
return mlx.DivScalar(logits, s.Temperature).Categorical(-1) lp = lp.Subtract(lp.MaxAxis(-1, true))
} lp = lp.Subtract(lp.Logsumexp(true))
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
func topP(s *Sampler, logits *mlx.Array) *mlx.Array { if k := s.TopLogprobs; k > 0 {
if s.TopP <= 0 || s.TopP >= 1 { if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
return logits k = vocab
}
// Argpartition on the negated values places the K largest
// (unsorted) in positions [0:K].
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
res.TopTokens = idx.AsType(mlx.DTypeInt32)
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
}
} }
return res
}
order := logits.Negative().ArgsortAxis(-1) func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
sortedLogits := logits.TakeAlongAxis(order, -1) return scores.Argmax(-1, false)
sortedProbs := mlx.SoftmaxAxis(sortedLogits, -1, true) }
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
}
// topKTopP applies top-P in a descending sort pass and, when top-K is also
// configured, masks any surviving value below the K-th largest in the same
// pass. Callers dispatch here whenever top-P is enabled — the top-K-only
// case uses a cheaper partial sort via the topK transform.
func topKTopP(s *Sampler, scores *mlx.Array) *mlx.Array {
vocab := scores.Dim(scores.NumDims() - 1)
applyTopK := s.TopK > 0 && s.TopK < vocab
order := scores.Negative().ArgsortAxis(-1)
sorted := scores.TakeAlongAxis(order, -1)
negInf := mlx.FromValue(float32(math.Inf(-1)))
// Top-P: in descending order, keep tokens whose exclusive cumulative
// probability is still below s.TopP.
probs := mlx.SoftmaxAxis(sorted, -1, true)
prevCumProbs := probs.Cumsum(-1, false, true).Subtract(probs)
keep := prevCumProbs.Less(mlx.FromValue(s.TopP)) keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
filtered := mlx.Where(keep, sortedLogits, mlx.FromValue(float32(math.Inf(-1)))) sorted = mlx.Where(keep, sorted, negInf)
return logits.PutAlongAxis(order, filtered, -1)
}
func minP(s *Sampler, logits *mlx.Array) *mlx.Array { out := scores.PutAlongAxis(order, sorted, -1)
if s.MinP <= 0 || s.MinP > 1 {
return logits // Top-K: sorted is already in descending order, so positions [K, V)
// are the ones to drop. Scatter -inf through their original-layout
// indices (order[K:]). Positional (not value-based) so exactly K
// tokens survive — ties at the K-th logit get broken by the sort
// order rather than promoted through the filter.
if applyTopK {
dropOrder := order.Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
out = out.PutAlongAxis(dropOrder, negInf, -1)
} }
maxLogits := logits.TakeAlongAxis(logits.Argmax(-1, true), -1) return out
minLogits := mlx.AddScalar(maxLogits, float32(math.Log(float64(s.MinP)))) }
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.MinP <= 0 || s.MinP > 1 {
return scores
}
maxScore := scores.MaxAxis(-1, true)
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
return mlx.Where( return mlx.Where(
logits.Less(minLogits), scores.Less(threshold),
mlx.FromValue(float32(math.Inf(-1))), mlx.FromValue(float32(math.Inf(-1))),
logits, scores,
) )
} }
func topK(s *Sampler, logits *mlx.Array) *mlx.Array { func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.TopK <= 0 { if s.TopK <= 0 {
return logits return scores
} }
vocab := logits.Dim(logits.NumDims() - 1) vocab := scores.Dim(scores.NumDims() - 1)
if s.TopK >= vocab { if s.TopK >= vocab {
return logits return scores
} }
mask := logits.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
return logits.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
} }
func penalty(s *Sampler, logits *mlx.Array) *mlx.Array { func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
if s.historyLen == 0 { if s.historyLen == 0 {
return logits return scores
} }
tokenIndices := s.history tokenIndices := s.history
if logits.NumDims() > 1 { if scores.NumDims() > 1 {
tokenIndices = tokenIndices.ExpandDims(0) tokenIndices = tokenIndices.ExpandDims(0)
} }
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 { if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
adjusted := logits.TakeAlongAxis(tokenIndices, -1) adjusted := scores.TakeAlongAxis(tokenIndices, -1)
if s.RepeatPenalty != 1 { if s.RepeatPenalty != 1 {
factor := mlx.Where( factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))), adjusted.Less(mlx.FromValue(float32(0))),
@@ -204,12 +270,12 @@ func penalty(s *Sampler, logits *mlx.Array) *mlx.Array {
if s.PresencePenalty != 0 { if s.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty) adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
} }
logits = logits.PutAlongAxis(tokenIndices, adjusted, -1) scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
} }
if s.FrequencyPenalty != 0 { if s.FrequencyPenalty != 0 {
logits = logits.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1) scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
} }
return logits return scores
} }

View File

@@ -10,8 +10,7 @@ import (
) )
func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
// RepeatLastN = 1, PresencePenalty = 6 s := New(Options{RepeatLastN: 1, PresencePenalty: 6})
s := New(0, 0, 0, 0, 1, 1, 6, 0)
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
@@ -21,7 +20,7 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1})) s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1}))
logits := mlx.FromValues([]float32{0, 5, 4}, 3) logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits) got := s.Sample(logits).Token
mlx.Eval(got) mlx.Eval(got)
// logits will be [0, -1, 4] after the penalty // logits will be [0, -1, 4] after the penalty
@@ -33,7 +32,7 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) {
} }
func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) { func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
s := New(0, 0, 0, 0, 1, 2, 0, 0) s := New(Options{RepeatLastN: 1, RepeatPenalty: 2})
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
@@ -42,7 +41,7 @@ func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
s.ResetHistory([]int32{1}) s.ResetHistory([]int32{1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3) logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits) got := s.Sample(logits).Token
mlx.Eval(got) mlx.Eval(got)
// token 1 is repeated and positive, so 5 / 2 falls below token 2. // token 1 is repeated and positive, so 5 / 2 falls below token 2.
@@ -53,7 +52,7 @@ func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) {
} }
func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) { func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
s := New(0, 0, 0, 0, 4, 1, 0, 2) s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2})
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()
@@ -62,7 +61,7 @@ func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
s.ResetHistory([]int32{1, 1}) s.ResetHistory([]int32{1, 1})
logits := mlx.FromValues([]float32{0, 5, 4}, 3) logits := mlx.FromValues([]float32{0, 5, 4}, 3)
got := s.Sample(logits) got := s.Sample(logits).Token
mlx.Eval(got) mlx.Eval(got)
// token 1 appears twice, so 5 - (2 * 2) falls below token 2. // token 1 appears twice, so 5 - (2 * 2) falls below token 2.
@@ -73,7 +72,7 @@ func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) {
} }
func TestMinPMasksTokensBelowThreshold(t *testing.T) { func TestMinPMasksTokensBelowThreshold(t *testing.T) {
s := New(0, 0, 0.5, 0, 0, 1, 0, 0) s := New(Options{MinP: 0.5})
defer func() { defer func() {
s.Free() s.Free()
mlx.Sweep() mlx.Sweep()

View File

@@ -2,7 +2,6 @@ package mlxrunner
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"encoding/json" "encoding/json"
"flag" "flag"
@@ -87,25 +86,30 @@ func Execute(args []string) error {
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)} request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
slog.Error("Failed to decode request", "error", err) slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
} }
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
request.Pipeline = runner.TextGenerationPipeline request.Pipeline = runner.TextGenerationPipeline
request.Sampler = sample.New( request.Sampler = sample.New(sample.Options{
request.Options.Temperature, Temperature: request.Options.Temperature,
request.Options.TopP, TopP: request.Options.TopP,
request.Options.MinP, MinP: request.Options.MinP,
request.Options.TopK, TopK: request.Options.TopK,
request.Options.RepeatLastN, RepeatLastN: request.Options.RepeatLastN,
request.Options.RepeatPenalty, RepeatPenalty: request.Options.RepeatPenalty,
request.Options.PresencePenalty, PresencePenalty: request.Options.PresencePenalty,
request.Options.FrequencyPenalty, FrequencyPenalty: request.Options.FrequencyPenalty,
) Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs,
})
if err := runner.Prepare(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var cancel context.CancelFunc var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context()) request.Ctx, cancel = context.WithCancel(r.Context())

View File

@@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) {
gotScores, gotInds := r.Forward(x, cfg) gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg) wantScores, wantInds := legacyRouterForward(r, x, cfg)
gotInds = gotInds.AsType(mlx.DTypeInt32)
wantInds = wantInds.AsType(mlx.DTypeInt32)
mlx.Eval(gotScores, gotInds, wantScores, wantInds) mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) { if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {

View File

@@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4") dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight) mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input) qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
dOut := NewLinear(dequantizedWeight, nil).Forward(input) dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
mlx.Eval(qOut, dOut) mlx.Eval(qOut, dOut)
got := qOut.Floats() got := qOut.Floats()